In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
# to make torch.use_deterministic_algorithms(True) work
os.environ.update({'CUBLAS_WORKSPACE_CONFIG': ':4096:8'})

In [None]:
import random

import numpy as np
import torch


seed = 3407
_ = torch.manual_seed(seed)
rng = np.random.default_rng(seed)
random.seed(seed)
torch.use_deterministic_algorithms(True)

In [None]:
from src.augmentations import DefaultWaveAugmentations, DefaultWave2Spec
from src.configs import DefaultConfig
from src.data_utils import Collator, get_sampler, SpeechCommandsDataset
from src.metrics import count_FA_FR, get_au_fa_fr
from src.models import CRNN
from src.train_utils import count_parameters, train

In [None]:
DATA_DIR = 'data/speech_commands'

In [None]:
dataset = SpeechCommandsDataset(
    path2dir=DATA_DIR, keywords=DefaultConfig.keyword
)

In [None]:
indexes = torch.randperm(len(dataset))
train_indexes = indexes[:int(len(dataset) * 0.8)]
val_indexes = indexes[int(len(dataset) * 0.8):]

train_df = dataset.csv.iloc[train_indexes].reset_index(drop=True)
val_df = dataset.csv.iloc[val_indexes].reset_index(drop=True)
train_set = SpeechCommandsDataset(csv=train_df, transform=DefaultWaveAugmentations(DATA_DIR))
val_set = SpeechCommandsDataset(csv=val_df)

In [None]:
train_sampler = get_sampler(train_set.csv['label'].values)
val_sampler = get_sampler(val_set.csv['label'].values)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=DefaultConfig.batch_size,
                          shuffle=False, collate_fn=Collator(),
                          sampler=train_sampler,
                          num_workers=2, pin_memory=True)

val_loader = torch.utils.data.DataLoader(val_set, batch_size=DefaultConfig.batch_size,
                        shuffle=False, collate_fn=Collator(),
                        sampler=val_sampler,
                        num_workers=2, pin_memory=True)

In [None]:
train_wave2spec = DefaultWave2Spec(is_train=True, config=DefaultConfig)
val_wave2spec = DefaultWave2Spec(is_train=False, config=DefaultConfig)

In [None]:
config = DefaultConfig()
model = CRNN(config).to(config.device)

print(model)

opt = torch.optim.Adam(
    model.parameters(),
    lr=config.learning_rate,
    weight_decay=config.weight_decay
)

In [None]:
count_parameters(model)

In [None]:
best_score, best_model = train(
    config.num_epochs,
    model,
    opt,
    (train_loader, val_loader),
    (train_wave2spec, val_wave2spec),
    config.device,
    make_plots=True,
)

In [None]:
torch.save(best_model, 'baseline.pth')