In [15]:
import os
import shutil
import traceback
from pathlib import Path

import wandb
import optuna
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch import nn, optim
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from IPython.display import clear_output
from torchaudio.datasets import SPEECHCOMMANDS

from melbanks import LogMelFilterBanks

In [2]:
DATASET_DIR = Path('data/SpeechCommands')

In [3]:
if not DATASET_DIR.exists():
     SPEECHCOMMANDS('data', download=True)
     !rm data/speech_commands_v0.02.tar.gz
     for label_dir in Path(DATASET_DIR, 'speech_commands_v0.02').glob('*'):
          if label_dir.is_dir() and label_dir.name not in ['yes', 'no']:
               shutil.rmtree(label_dir)
     !sed -i '/no\/\|yes\//!d' data/SpeechCommands/speech_commands_v0.02/testing_list.txt
     !sed -i '/no\/\|yes\//!d' data/SpeechCommands/speech_commands_v0.02/validation_list.txt

In [30]:
train_dataset = SPEECHCOMMANDS('data', subset='training')
val_dataset = SPEECHCOMMANDS('data', subset='validation')
test_dataset = SPEECHCOMMANDS('data', subset='testing')

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [25]:
set(map(lambda x: mel(x[0]).shape, val_dataset))

{torch.Size([1, 80, 43]),
 torch.Size([1, 80, 47]),
 torch.Size([1, 80, 52]),
 torch.Size([1, 80, 56]),
 torch.Size([1, 80, 60]),
 torch.Size([1, 80, 61]),
 torch.Size([1, 80, 65]),
 torch.Size([1, 80, 66]),
 torch.Size([1, 80, 69]),
 torch.Size([1, 80, 70]),
 torch.Size([1, 80, 73]),
 torch.Size([1, 80, 75]),
 torch.Size([1, 80, 76]),
 torch.Size([1, 80, 77]),
 torch.Size([1, 80, 79]),
 torch.Size([1, 80, 82]),
 torch.Size([1, 80, 84]),
 torch.Size([1, 80, 86]),
 torch.Size([1, 80, 89]),
 torch.Size([1, 80, 90]),
 torch.Size([1, 80, 93]),
 torch.Size([1, 80, 94]),
 torch.Size([1, 80, 98]),
 torch.Size([1, 80, 99]),
 torch.Size([1, 80, 101])}

In [16]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [None]:

val_loader = DataLoader(val_dataset, collate_fn=)
test_loader = DataLoader(test_dataset)

In [8]:
class M5(nn.Module):
    def __init__(self, n_input=1, n_output=2, stride=16, n_channel=32):
        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.pool1 = nn.MaxPool1d(4)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.pool2 = nn.MaxPool1d(4)
        self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(2 * n_channel)
        self.pool3 = nn.MaxPool1d(4)
        self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.pool4 = nn.MaxPool1d(4)
        self.fc1 = nn.Linear(2 * n_channel, n_output)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        x = self.pool2(x)
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        x = self.pool3(x)
        x = self.conv4(x)
        x = F.relu(self.bn4(x))
        x = self.pool4(x)
        x = F.avg_pool1d(x, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        return F.log_softmax(x, dim=2)

In [10]:
mel = LogMelFilterBanks()
mel

LogMelFilterBanks()

In [14]:
model = M5()
model.to(device)
print(model)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


n = count_parameters(model)
print("Number of parameters: %s" % n)

M5(
  (conv1): Conv1d(1, 32, kernel_size=(80,), stride=(16,))
  (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(32, 32, kernel_size=(3,), stride=(1,))
  (bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv1d(32, 64, kernel_size=(3,), stride=(1,))
  (bn3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool3): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv1d(64, 64, kernel_size=(3,), stride=(1,))
  (bn4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool4): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=64, out_features=2, bias=True)
)
Number

In [None]:
class ModelTrainer:
    def __init__(self, train_loader, test_loader, classes,
                 metric_name,
                 project_name='itmo-dsp',
                 checkpoint_folder='./models'):
        self.project_name = project_name
        self.checkpoint_folder = checkpoint_folder
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.true_val_labels = np.hstack([x[1].numpy() for x in test_loader])
        self.samples = list(map(lambda img: (img - img.min()) / (img.max() - img.min()),
                                next(iter(test_loader))[0][:10].permute((0, 2, 3, 1)).numpy()))
        self.best_metric = 0.0
        self.classes = classes
        self.test_dataset_size = len(self.test_loader.dataset)
        self.train_dataset_size = len(self.train_loader.dataset)
        self.metric_name = metric_name

    def validate(self, model, device):
        model.eval()
        accuracy = 0
        
        with torch.no_grad():
            for inputs, labels in self.test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                accuracy += (preds == labels).sum().detach().cpu() / self.test_dataset_size

        metrics = {
            'val_acc': accuracy,
        }
        return metrics
        
    def save_model_locally(self, model):
        torch.save(model.state_dict(), self.best_model_path)

    def track_progress(self, metrics, patience, trial, epoch, epochs_without_improvement):
        if metrics[self.metric_name] > self.best_metric:
            epochs_without_improvement = 0
            self.best_metric = metrics[self.metric_name]
            self.best_model_path = Path(self.checkpoint_folder) / f'{self.metric_name}_{self.metric_name:.4f}.pth'
        else:
            if epochs_without_improvement > patience:
                print('Early stopping')
                raise optuna.TrialPruned()
            epochs_without_improvement += 1
        if trial:
            trial.report(metrics[self.metric_name], epoch)
            if trial.should_prune():
                print('[OPTUNA] Run pruned')
                raise optuna.TrialPruned()
        return epochs_without_improvement
    
    def log_metrics(self, metrics, epoch, progress_bar, n_epochs):
        progress_bar.set_postfix(metrics)
        print(f'Epoch {epoch + 1}/{n_epochs}:')
        print(f'Best {self.metric_name}: {self.best_metric:.4f}')
        
    def train_epoch(self, model, optimizer, loss_fn, device):
        model.train()
        loss = 0
        
        for inputs, labels in self.train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            
            loss += loss.item().detach().cpu() / self.train_dataset_size

        metrics = {
            'train_loss': loss,
        }
        
        return model, metrics

    def train(self, model, loss_fn, optim_class, optim_args, device, 
              n_epochs=100, patience=10, trial=None):
        run = wandb.init(project=self.project_name, name=f'{model.__class__.__name__}_{optim_class.__name__}lr{np.round(optim_args["lr"], 5)}')
        optimizer = optim_class(model.parameters(), **optim_args)
        model = model.to(device)
        self.best_model_path = None
        progress_bar = tqdm(range(n_epochs), desc='Training', leave=True)
        epochs_without_improvement = 0
        try:
            for epoch in progress_bar:
                model, train_metrics = self.train_epoch(model, optimizer, loss_fn, device)
                val_metrics = self.validate(model, device)
                metrics = {**train_metrics, **val_metrics, 'epoch': epoch}
                epochs_without_improvement = self.track_progress(metrics, patience, trial, epoch, epochs_without_improvement)
                if epochs_without_improvement == 0:
                    self.save_model_locally(model, optimizer, metrics)
                self.log_metrics(run, metrics, epoch, progress_bar, n_epochs)
                clear_output(wait=True)
                plt.show()
        except (Exception, KeyboardInterrupt) as e:
            print(traceback.format_exc())
        finally:
            progress_bar.close()
            wandb.finish()
        return model



In [None]:
def objective(trainer, trial, model_class, loss_fn, device, n_epochs=100):
    model = model_class()
    optimizer_class = optim.Adam
    lr = trial.suggest_float('lr', low=1e-5, high=1e-1, log=True)
    trainer.load_data(train_loader=DataLoader(train_dataset, batch_size=batch_sz, shuffle=True),
                      anomaly_loader=DataLoader(anomaly_dataset, batch_size=len(anomaly_dataset)))
    run_config = {
        'use_batchnorm': use_batch_norm,
        'optimizer': optimizer_name,
        'batch_sz': batch_sz}
    if activation_func_name == 'LeakyReLU':
        run_config['negative_slope'] = negative_slope
    trainer.train(
        model=model,
        loss_fn=loss_fn,
        optim_class=optimizer_class,
        optim_args={'lr': lr},
        device=device,
        n_epochs=n_epochs,
        trial=trial,
        run_config=run_config
    )
    return trainer.best_metric

In [27]:
trainer = ModelTrainer(train_loader, test_loader, classes)
loss_fn = torch.nn.CrossEntropyLoss()

In [None]:

trained_model = trainer.optimize_lr(
    model_class=M5(),
    optim_class=optim.Adam,
    loss_fn=loss_fn,
    n_epochs=100,
    n_trials=20
)
