In [1]:
import torch
from torch.utils.data import Dataset

import torchaudio
import torchaudio.transforms

import torchvision

import sys, os

from pprint import pprint

from tqdm.autonotebook import tqdm

import json

import numpy as np

import matplotlib.pylab as plt
import seaborn as sns

import librosa
import librosa.display

import pandas as pd

from pathlib import Path

import gc

MANUAL_SEED = 69

import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

from datetime import date
from datetime import datetime

import os.path
from os import path
  
import json

import time

import copy

from matplotlib import pyplot as plt
plt.rcParams['figure.dpi'] = 200
plt.rcParams['savefig.dpi'] = 200

from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler

from sklearn.model_selection import KFold

import random

from pprint import pformat

import math
import pathlib

from torchaudio_augmentations import * 

import wandb

In [2]:
!jupyter nbextension enable --py widgetsnbextension

os.environ["WANDB_NOTEBOOK_NAME"] = "GeNNus_CNN_mel_spectrogram"

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)
print(
  torch.cuda.get_device_name(device) if torch.cuda.is_available() else "cpu"
)

cuda
NVIDIA GeForce RTX 2070


In [4]:
def make_dir_if_absent(dir_path):
  
  if not os.path.exists(dir_path):
    os.makedirs(dir_path)
    

In [5]:
class ComposeTransform:
    def __init__(self, transforms, p_boosting_factors, epoch_steps):
        self.transforms = transforms
        self.p_boosting_factors = p_boosting_factors 
        self.epoch_steps = epoch_steps

        self.step_has_been_called_once = False

    def step(self, epoch):

        if self.epoch_steps is not None and epoch in self.epoch_steps:

            if not self.step_has_been_called_once:
                
                self.pbar_step_desc = tqdm(
                    total=0, position=4, bar_format='{desc}'
                )

                self.pbar_step_desc.set_description_str(
                    f"ComposeTransform.p = {-1}"
                )

                self.step_has_been_called_once = True
            
            for t, i in zip(self.transforms, range(len(self.transforms))):
                if "RandomApply" in t.__class__.__name__:

                    try:
                        t.p = self.p_boosting_factors[epoch]

                        # if i == 0:
                        self.pbar_step_desc.set_description_str(
                            f"ComposeTransform.p = {t.p}"
                        )
                
                        # print(self.epoch_steps[epoch], t.p)
                    
                    except Exception as e:

                        # print("Caught the following exception, continuing to whatever comes next!")
                        # print(e)

                        pass

    def get_p_at_epoch(self, epoch):
        try:
            return self.p_boosting_factors[epoch]
        except:
            return -0.01

    def __call__(self, audio_data):
        for t in self.transforms:
            audio_data = t(audio_data)
        return audio_data
    
    def __repr__(self):
        repr_list = []

        for t in self.transforms:
            repr_list.append(t.__repr__())

        repr_list.append(
            {
                "p_boosting_factors": self.p_boosting_factors,
                "epoch_steps": self.epoch_steps
            }
        )

        return str(repr_list)

In [6]:
class IdentityTransform:
  def __init__(self):
    pass

  def __call__(self, audio_data):
    return audio_data

  def __repr__(self):
    
    return str( { "transform_name": "IdentityTransform" } )

In [7]:
class StandardizeTransform:
  def __init__(self, mean, std):
    self.mean = mean
    self.std = std

  def __call__(self, audio_data):

    return (audio_data - self.mean) / self.std

  def __repr__(self):
    
    return str(
      {
        "transform_name": "StandardizeTransform",
        "mean": self.mean,
        "std": self.std 
      }
    )

In [8]:
class FMADataset(Dataset):

  def __init__(
    self, path, data_transforms_train, data_transforms_eval, data_type
  ):
    self.path = path
    self.data_type = data_type,
    self.stage = None
    self.data_transforms_train = data_transforms_train
    self.data_transforms_eval = data_transforms_eval
    self.data_paths = self._load_audio_list()

  def __len__(self):
    return len(self.data_paths)

  def __getitem__(self, idx):

    data = torch.load(self.data_paths[idx])

    if self.stage == "train" and self.data_transforms_train is not None:
      data = self.data_transforms_train(data)

    if (
      self.stage == "val" or self.stage == "test"
    ) and self.data_transforms_eval is not None:
      data = self.data_transforms_eval(data)

    label_one_hot = self._label_from_str_to_one_hot(
      self.data_paths[idx].split("/")[-2]
    )

    return data, label_one_hot

  
  def _label_from_str_to_one_hot(self, label_str: str): 
  
    if label_str == "Pop":
      return torch.tensor([1, 0, 0, 0, 0, 0]).float()
    
    if label_str == "Hip-Hop":
      return torch.tensor([0, 1, 0, 0, 0, 0]).float()
    
    if label_str == "Electronic":
      return torch.tensor([0, 0, 1, 0, 0, 0]).float()
    
    if label_str == "Rock":
      return torch.tensor([0, 0, 0, 1, 0, 0]).float()

    if label_str == "Folk":
      return torch.tensor([0, 0, 0, 0, 1, 0]).float()

    if label_str == "Jazz":
      return torch.tensor([0, 0, 0, 0, 0, 1]).float()
    
  
  def _load_audio_list(self):
    
    audio_path_list = []
    
    for path, subdirs, files in os.walk(self.path):
      for name in files:
          
        file_audio_path = os.path.join(path, name)
        
        audio_path_list.append(file_audio_path)

    return sorted(audio_path_list, reverse=True)
        
        

In [9]:
DATASET_SIZE = "s"
DATASET_TYPE = "mel_spectrogram"
DATASET_FOLDER = f"./data/{DATASET_TYPE}"

DATASET_NUM_SAMPLES_PER_SECOND = 8000
DATASET_NUM_CHANNELS = 1
N_FFT = 1024
WIN_LENGTH = None
HOP_LENGTH = 128
N_MELS = 128

DATASET_BASE_NAME = f"fma_{DATASET_SIZE}_resampled_{DATASET_NUM_SAMPLES_PER_SECOND}_rechanneled_{DATASET_NUM_CHANNELS}"
DATASET_NAME = f"{DATASET_BASE_NAME}_n_fft_{N_FFT}_win_length_{WIN_LENGTH}_hop_length_{HOP_LENGTH}_n_mels_{N_MELS}"


dataset_path = f"{DATASET_FOLDER}/{DATASET_NAME}"



SUMMARY_STATISTICS_PATH = f"./data/summary_statistics/{DATASET_BASE_NAME}/{DATASET_BASE_NAME}_summary_statistics.json"

In [10]:
summary_statistics_json = open(SUMMARY_STATISTICS_PATH)

summary_statistics_dict = json.load(summary_statistics_json)

In [11]:
identity_transform = IdentityTransform()

standardize_transform = StandardizeTransform(
  mean=summary_statistics_dict[f"{DATASET_TYPE}_mean"],
  std=summary_statistics_dict[f"{DATASET_TYPE}_std"]
)

In [12]:
RANDOM_CROP_SIZE_REDUCTION = 0.9


In [13]:
fma_data_transforms_train = ComposeTransform(
  transforms=[ 
    standardize_transform
    # ,
    # torchvision.transforms.RandomCrop(
    #   (
    #     int(128 * RANDOM_CROP_SIZE_REDUCTION), 
    #     int(1860 * RANDOM_CROP_SIZE_REDUCTION)
    #   )
    # )
    # ,
    # torchaudio.transforms.FrequencyMasking(
    #   freq_mask_param=FREQ_MASK_PARAM, iid_masks=IID_MASKS
    # )
  ], 
  p_boosting_factors=None, 
  epoch_steps=None
)

fma_data_transforms_eval = ComposeTransform(
  transforms=[standardize_transform],
  p_boosting_factors=None, 
  epoch_steps=None
)

In [14]:
fma_dataset = FMADataset(
  path=dataset_path, 
  data_transforms_train=fma_data_transforms_train,
  data_transforms_eval=fma_data_transforms_eval,
  data_type=DATASET_TYPE
)

In [15]:
TRAIN_VAL_PERCENTAGE = 0.9

full_size = len(fma_dataset)
train_val_size = int(TRAIN_VAL_PERCENTAGE * full_size)
test_size = full_size - train_val_size

In [16]:
generator=torch.Generator().manual_seed(MANUAL_SEED)

fma_dataset_train_val, fma_dataset_test = torch.utils.data.random_split(
  fma_dataset, [train_val_size, test_size], generator
)

In [17]:
TRAIN_PERCENTAGE = 0.8

full_size = train_val_size
train_size = int(TRAIN_PERCENTAGE * full_size)
val_size = full_size - train_size

In [18]:
fma_dataset_train, fma_dataset_val = torch.utils.data.random_split(
  fma_dataset_train_val, [train_size, val_size], generator
)

In [19]:
BATCH_SIZE = 16
NUM_WORKERS = 16

data_logs = {
  "data_type": DATASET_TYPE,
  "dataset_size": DATASET_SIZE,
  "batch_size": BATCH_SIZE,
  "num_samples_per_second": DATASET_NUM_SAMPLES_PER_SECOND,
  "num_channels": DATASET_NUM_CHANNELS,
  "data_transforms_train": fma_data_transforms_train
}

In [20]:
def count_num_trainable_parameters(model):
  return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [21]:
def gen_train_id():
  return datetime.now().strftime("%Y_%m_%d_%H_%M_%S")

In [22]:
def save_dict_to_disk(dict, full_path):

  make_dir_if_absent("/".join(full_path.split("/")[:-1]))

  with open(full_path, 'w') as fp:
    json.dump(dict, fp)

def load_dict_from_disk(full_path):
  
  with open(full_path) as json_file:
    dict_from_disk = json.load(json_file)
  
  return dict_from_disk

In [23]:
def store_ckp(
  model, optimizer, ckp_dir, ckp_name, epoch, loss_train, loss_val, loss_test
):

  model_copy = copy.deepcopy(model)
  
  full_path_pickle = f"{ckp_dir}/{ckp_name}_epoch_{epoch}.pth"
  
  make_dir_if_absent(dir_path="/".join(full_path_pickle.split('/')[:-1]))
  
  torch.save(
    {
      'epoch': epoch,
      'model_state_dict': model_copy.cpu().state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'loss_train': loss_train,
      'loss_val': loss_val,
      'loss_test': loss_test,
    }, 
    full_path_pickle
  )
  
  torch.save(
    model_copy.cpu(), 
    full_path_pickle
  )

In [24]:
def load_ckp(ckp_path, perform_loading_sanity_check):

  loaded_model = torch.load(ckp_path)

  if perform_loading_sanity_check:

    loaded_model.eval()

    sanity_check_out = loaded_model(torch.rand((16, 1, 238000)))

  return loaded_model

In [25]:
def get_num_correct_preds(outputs, labels):
  
  output_pred_ind = torch.argmax(outputs, dim=1)
  labels_ind = torch.argmax(labels, dim=1)
  
  matching_mask = (output_pred_ind == labels_ind).float()
  
  num_correct_preds = matching_mask.sum()
  
  return num_correct_preds

In [26]:
def train_model(
  model, optimizer, criterion,
  batch_size, train_dl, val_dl, test_dl, 
  num_epochs, 
  lr_scheduler,
  device, 
  print_freq, ckp_freq, 
  ckp_dir, ckp_name,
  should_close_tqdm_prog_bars_when_done,
  limit_num_batches_train,
  limit_num_batches_val,
  limit_num_batches_test,
  weight_decay
):

  train_id = gen_train_id()
  
  training_logs = {
    "train_id": train_id,
    "accuracies": {},
    "losses": {}
  }

  model = model.to(device)

  wandb.watch(model)
  
  pbar_epochs = tqdm(range(num_epochs), colour="#9400d3", position=1)
  pbar_batches_train = tqdm(
    iter(train_dl), colour="#4169e1", leave=False, position=4
  )
  pbar_batches_val = tqdm(
    iter(val_dl), colour="#008080", leave=False, position=5
  )
  pbar_best_epoch_desc = tqdm(
    total=0, position=2, bar_format='{desc}', colour="green"
  )
  pbar_epoch_desc = tqdm(
    total=0, position=3, bar_format='{desc}', colour="#9400d3"
  )
  pbar_weight_decay_desc = tqdm(
    total=0, position=4, bar_format='{desc}', colour="#9400d3"
  )
  
  training_start_time = time.time()

  best_loss_val = np.Inf

  for epoch in range(num_epochs):

    running_loss_train = 0.0
    running_loss_val   = 0.0
    running_loss_test  = -1.0
    
    num_correct_preds_train = 0.0
    num_preds_train = 0.0
    accuracy_train = 0.0
    
    num_correct_preds_val = 0.0
    num_preds_val = 0.0
    accuracy_val = 0.0
    
    num_correct_preds_test = 0.0
    num_preds_test = 0.000000001
    accuracy_test = 0.0

    num_batches_train = 0
    num_batches_val = 0
    num_batches_test = 0

    for g in optimizer.param_groups:
      g['weight_decay'] = weight_decay[epoch]
    
    pbar_weight_decay_desc.set_description(f"Weight decay: {weight_decay[epoch]}")
        
    ## BEGIN training step
    
    model.train()

    train_dl.dataset.dataset.stage="train"
    train_dl.dataset.dataset.data_transforms_train.step(epoch=epoch)
    
    pbar_batches_train.reset()
    pbar_batches_val.reset()
    
    pbar_epochs.set_description(f"epoch {epoch}")
    pbar_batches_train.set_description(f"epoch {epoch}")
    pbar_batches_val.set_description  (f"epoch {epoch}")
    
    for batch_x, batch_y in iter(train_dl):

      if num_batches_train <= limit_num_batches_train:

        inputs, labels = batch_x, batch_y
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()

        outputs = model(inputs)
        outputs = outputs.squeeze(-1)
        
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()

        running_loss_train += loss.item() * batch_x.shape[0]
        
        num_correct_preds_train += get_num_correct_preds(outputs, labels)
        num_preds_train += outputs.shape[0]
      
      num_batches_train += 1
      
      pbar_batches_train.update(1)
      
    
    ## END training step
    
    ## BEGIN validation step
    
    with torch.no_grad():
      
      model.eval()

      train_dl.dataset.dataset.stage="val"
      
      for batch_x, batch_y in iter(val_dl):

        if num_batches_val <= limit_num_batches_val:

          inputs, labels = batch_x, batch_y
          inputs, labels = inputs.to(device), labels.to(device)
          
          outputs = model(inputs)
          outputs = outputs.squeeze(-1)
          
          loss = criterion(outputs, labels)
          
          running_loss_val += loss.item() * batch_x.shape[0]
          
          num_correct_preds_val += get_num_correct_preds(outputs, labels)
          num_preds_val += outputs.shape[0]
        
        num_batches_val += 1

        pbar_batches_val.update(1)
        
    ## END validation step

    if lr_scheduler is not None:
      if "ReduceLROnPlateau" in lr_scheduler.__class__.__name__:
        lr_scheduler.step(running_loss_train) 
      else:
        lr_scheduler.step()
    
    ## BEGIN test step
    
    if (epoch + 1 == num_epochs):
      
      pbar_batches_test = tqdm(
        iter(test_dl), colour="#808000", leave=False,
      )
      pbar_batches_test.set_description  (f"epoch {epoch}")
    
      with torch.no_grad():
        
        model.eval()

        train_dl.dataset.dataset.stage="test"
        
        for batch_x, batch_y in iter(test_dl):

          if num_batches_test <= limit_num_batches_test:

            inputs, labels = batch_x, batch_y
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            outputs = outputs.squeeze(-1)
            
            loss = criterion(outputs, labels)
            
            running_loss_test += loss.item() * batch_x.shape[0]
            
            num_correct_preds_test += get_num_correct_preds(outputs, labels)
            num_preds_test += outputs.shape[0]
          
          num_batches_test += 1
          
          pbar_batches_test.update(1)
        
    ## END test step
    
    accuracy_train = num_correct_preds_train / num_preds_train
    accuracy_val = num_correct_preds_val / num_preds_val
    accuracy_test = num_correct_preds_test / num_preds_test
    
    training_logs["accuracies"][str(epoch)] = {
      "accuracy_train": accuracy_train.cpu().item(),
      "accuracy_val": accuracy_val.cpu().item(),
    }
    training_logs["losses"][str(epoch)] = {
      "loss_train": running_loss_train,
      "loss_val": running_loss_val,
    }

    
    # if ((epoch + 1) % print_freq == 0):
    if running_loss_val < best_loss_val:  
      pbar_best_epoch_desc.set_description_str(
        f"[best] epoch: {(str(epoch + 1)).zfill(3)}, " + 
        f"train loss: {str(round(running_loss_train, 2)).zfill(6)}, train acc: {str(round(accuracy_train.cpu().item(), 2))}, " + 
        f"val loss  : {str(round(running_loss_val, 2)).zfill(6)} , val acc  : {str(round(accuracy_val.cpu().item(), 2))}, " + 
        f"val loss delta change: {round(best_loss_val - running_loss_val, 4)}"
      )

      best_loss_val = running_loss_val

    pbar_epoch_desc.set_description_str(
      f"[curr] epoch: {(str(epoch + 1)).zfill(3)}, " + 
      f"train loss: {str(round(running_loss_train, 2)).zfill(6)}, train acc: {str(round(accuracy_train.cpu().item(), 2))}, " + 
      f"val loss  : {str(round(running_loss_val, 2)).zfill(6)} , val acc  : {str(round(accuracy_val.cpu().item(), 2))}"
      
    )

    pbar_epochs.update(1)

    wandb.log(
      {
        "epoch": epoch, 
        "loss/train": round(running_loss_train, 2),
        "loss/val": round(running_loss_val, 2),
        "acc/train": round(accuracy_train.cpu().item(), 2),
        "acc/val": round(accuracy_val.cpu().item(), 2),
        "transform_p": train_dl.dataset.dataset.data_transforms_train.get_p_at_epoch(epoch)
      }
    )
    
    if ((epoch + 1) == num_epochs):
      tqdm.write(
        f"epoch: {(str(epoch + 1)).zfill(3)}\n" + 
        f"  train loss: {str(round(running_loss_train, 2)).zfill(6)}, train acc: {str(round(accuracy_train.cpu().item(), 2))}\n" + 
        f"  val loss  : {str(round(running_loss_val, 2)).zfill(6)}, val acc  : {str(round(accuracy_val.cpu().item(), 2))}\n" + 
        f"  test loss : {round(running_loss_test, 2)} , test acc: {round(accuracy_test.cpu().item(), 2)}\n"
      )
      
      training_logs["accuracies"][str(epoch)][
        "accuracy_test"
      ] = accuracy_test.cpu().item()
      
      training_logs["losses"][str(epoch)][
        "loss_test"
      ] = running_loss_test
      
    if (ckp_freq != None and (epoch + 1) % ckp_freq == 0):
      
      store_ckp(
        model=model, optimizer=optimizer, 
        ckp_dir=ckp_dir, ckp_name=ckp_name, epoch=epoch, 
        loss_train=running_loss_train, 
        loss_val=running_loss_val, 
        loss_test=running_loss_test
      )
  
  training_end_time = time.time()

  training_logs["training_time_secs"] = training_end_time - training_start_time

  if (should_close_tqdm_prog_bars_when_done):
    pbar_epochs.container.close()
    pbar_batches_train.close()
    pbar_batches_val.close()
    pbar_batches_test.close()

  wandb.finish()
  
  return training_logs

In [27]:
def plot_loss_curves(stats):
  epochs = stats["training_logs"]["losses"].keys()
  
  loss_train = [
    j["loss_train"] for j in stats["training_logs"]["losses"].values()
  ]
  
  loss_val = [j["loss_val"] for j in stats["training_logs"]["losses"].values()]

  sns.lineplot(
    x=epochs,
    y=loss_train,
    legend="full",
    label="train loss"
  )

  sns.lineplot(
    x=epochs,
    y=loss_val,
    legend="full",
    label="val loss"
  )

## ResNet

In [28]:
class Block(nn.Module):
    
    def __init__(
        self, in_channels, out_channels, dropout_conv_p, dropout_fc_p, 
        identity_downsample=None, stride=1
    ):
        super(Block, self).__init__()

        self.dropout_conv_p = dropout_conv_p
        self.dropout_fc_p = dropout_fc_p


        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.dropout_conv1 = nn.Dropout2d(p=self.dropout_conv_p)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.dropout_conv2 = nn.Dropout2d(p=self.dropout_conv_p)


        self.relu = nn.ReLU()
        self.identity_downsample = identity_downsample

        self.dropout_fc = nn.Dropout2d(p=self.dropout_fc_p)
        
    def forward(self, x):
        identity = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout_conv1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)
        x += identity
        x = self.relu(x)
        x = self.dropout_conv2(x)

        return x

class ResNet_18(nn.Module):
    
    def __init__(
        self, image_channels, num_classes, num_filter_divider, 
        dropout_conv_p, dropout_fc_p
    ):
        
        super(ResNet_18, self).__init__()
        
        self.dropout_conv_p = dropout_conv_p
        self.dropout_fc_p = dropout_fc_p
        
        self.in_channels = int(64 / num_filter_divider)
        self.conv1 = nn.Conv2d(image_channels, int(64 / num_filter_divider), kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(int(64 / num_filter_divider))
        self.dropout_conv1 = nn.Dropout2d(p=self.dropout_conv_p)
        
        self.relu = nn.ReLU()
        
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        #resnet layers
        self.layer1 = self.__make_layer(int(64 / num_filter_divider), int(64 / num_filter_divider), stride=1)
        self.layer2 = self.__make_layer(int(64 / num_filter_divider), int(128 / num_filter_divider), stride=2)
        self.layer3 = self.__make_layer(int(128 / num_filter_divider), int(256 / num_filter_divider), stride=2)
        self.layer4 = self.__make_layer(int(256 / num_filter_divider), int(512 / num_filter_divider), stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.dropout_fc = nn.Dropout2d(p=self.dropout_fc_p)

        self.fc = nn.Linear(int(512 / num_filter_divider), num_classes)
        
    def __make_layer(
        self, in_channels, out_channels, stride
    ):
        
        identity_downsample = None
        if stride != 1:
            identity_downsample = self.identity_downsample(in_channels, out_channels)
            
        return nn.Sequential(
            Block(
                in_channels, out_channels, 
                identity_downsample=identity_downsample, stride=stride,
                dropout_conv_p=self.dropout_conv_p, 
                dropout_fc_p=self.dropout_fc_p
            ), 
            Block(
                out_channels, out_channels, dropout_conv_p=self.dropout_conv_p, 
                dropout_fc_p=self.dropout_fc_p
            )
        )
        
    def forward(self, x):
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout_conv1(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x 
    
    def identity_downsample(self, in_channels, out_channels):
        
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1), 
            nn.BatchNorm2d(out_channels)
        )

    def get_model_setup(self):
        return "ResNet18"

In [29]:
def ResNet18(num_classes, channels, num_filter_divider, dropout_conv_p, dropout_fc_p):
    return ResNet_18(
        num_classes=num_classes, image_channels=channels, 
        num_filter_divider=num_filter_divider,
        dropout_conv_p=dropout_conv_p, dropout_fc_p=dropout_fc_p
    )

## K-fold cross validation for hyperparameter search

In [30]:
NUM_FILTERS_DIVIDER = 8

K_FOLD_CV_DROPOUT_P_CONV = 0.05
K_FOLD_CV_DROPOUT_P_LINEAR = 0.05

In [31]:
def resnet18_factory():
  return ResNet18(
    num_classes=6, channels=1, num_filter_divider=NUM_FILTERS_DIVIDER,
    dropout_conv_p=K_FOLD_CV_DROPOUT_P_CONV,
    dropout_fc_p=K_FOLD_CV_DROPOUT_P_LINEAR
  )

In [32]:
def optimizer_factory(optimizer_name, model, lr, momentum, weight_decay):

  if optimizer_name == "SGD":
    optimizer = optim.SGD(
      model.parameters(), 
      lr=lr, 
      momentum=momentum,
      nesterov=True,
      weight_decay=weight_decay
    )

    optimizer_config = {
    "lr": lr, 
    "momentum": momentum, 
    "weight_decay": weight_decay,
    "nesterov": True
  }  

  elif optimizer_name == "Adam":

    optimizer = optim.Adam(
      model.parameters(),
      lr=lr,
      weight_decay=weight_decay
    )

    optimizer_config = {
    "lr": lr, 
    "weight_decay": weight_decay
  }  
    

  return optimizer, optimizer_config


In [33]:
def step_lr_factory(optimizer, step_size, gamma, last_epoch, verbose):
  return torch.optim.lr_scheduler.StepLR(
    optimizer, step_size, gamma, last_epoch, verbose
  )

In [34]:
def reduce_lr_on_plateau_factory(
  optimizer, mode='min', factor=0.1, patience=3, threshold=0.0001, 
  threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=False
):
  return torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer=optimizer, mode=mode, factor=factor, patience=patience, 
    threshold=threshold, threshold_mode=threshold_mode, cooldown=cooldown, 
    min_lr=min_lr, eps=eps, verbose=verbose
  )

In [35]:
K_FOLD_CV_NUM_FOLDS = 5
K_FOLD_CV_LIMIT_NUM_FOLDS = 1

K_FOLD_CV_BATCH_SIZE = 32
K_FOLD_CV_LIMIT_NUM_BATCHES_PERCENTAGE_TRAIN = 0.3
K_FOLD_CV_LIMIT_NUM_BATCHES_PERCENTAGE_VAL = K_FOLD_CV_LIMIT_NUM_BATCHES_PERCENTAGE_TRAIN * 0.2
K_FOLD_CV_LIMIT_NUM_BATCHES_PERCENTAGE_TEST = K_FOLD_CV_LIMIT_NUM_BATCHES_PERCENTAGE_TRAIN * 0.1
# K_FOLD_CV_LIMIT_NUM_BATCHES_PERCENTAGE_TRAIN = 1
# K_FOLD_CV_LIMIT_NUM_BATCHES_PERCENTAGE_VAL = 1
# K_FOLD_CV_LIMIT_NUM_BATCHES_PERCENTAGE_TEST = 1

K_FOLD_CV_NUM_EPOCHS = 360
# K_FOLD_CV_PRINT_FREQ = int(K_FOLD_CV_NUM_EPOCHS * 0.1)
K_FOLD_CV_PRINT_FREQ = 1

K_FOLD_CV_CKP_FREQ = int(K_FOLD_CV_NUM_EPOCHS * 0.1)

K_FOLD_CV_LOGS_FOLDER = f"./k_fold_cv/cnn/{DATASET_TYPE}"

K_FOLD_CV_CKP_FOLDER = K_FOLD_CV_LOGS_FOLDER

K_FOLD_CV_SHOULD_CLOSE_TQDM_PROG_BARS_WHEN_DONE=True

In [36]:
LR = 0.001
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-6
OPTIMIZER_NAME = "Adam"

# k_fold_cv_weight_decay = np.ones((K_FOLD_CV_NUM_EPOCHS)) * WEIGHT_DECAY
k_fold_cv_weight_decay = np.hstack(
  (
    np.linspace(1e-4, 1e-4, 120),
    np.linspace(1e-2, 1e-2, 240)
  )
)


In [37]:
# LR_SCHEDULER_TYPE = "step"
# LR_SCHEDULER_TYPE = "reduce_on_plateau"
LR_SCHEDULER_TYPE = None

In [38]:
LR_SCHEDULER_STEP_SIZE = 30
LR_SCHEDULER_GAMMA = 0.02
LR_SCHEDULER_LAST_EPOCH = -1
LR_SCHEDULER_VERBOSE = False

In [39]:
REDUCE_LR_ON_PLATEAU_FACTOR = 0.06

In [40]:
cv_models = [
  resnet18_factory() for _ in range(0, K_FOLD_CV_NUM_FOLDS)
]

cv_criterions = [nn.CrossEntropyLoss() for _ in range(0, K_FOLD_CV_NUM_FOLDS)]

cv_opts = [
  optimizer_factory(
    optimizer_name=OPTIMIZER_NAME,
    model=cv_models[i],
    lr=LR,
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY
  ) for i in range(0, K_FOLD_CV_NUM_FOLDS)
]

cv_optimizers = [opt for opt, _ in cv_opts]
cv_optimizers_configs = [opt_conf for _, opt_conf in cv_opts]

cv_lr_schedulers = []
if LR_SCHEDULER_TYPE is not None:
  for i in range(0, len(cv_optimizers)):
    
    if LR_SCHEDULER_TYPE == "step":
      lr_scheduler = step_lr_factory(
        optimizer=cv_optimizers[i],
        step_size=LR_SCHEDULER_STEP_SIZE,
        gamma=LR_SCHEDULER_GAMMA,
        last_epoch=LR_SCHEDULER_LAST_EPOCH,
        verbose=LR_SCHEDULER_VERBOSE
      )
    
    elif LR_SCHEDULER_TYPE == "reduce_on_plateau":
      lr_scheduler = reduce_lr_on_plateau_factory(
        optimizer=cv_optimizers[i],
        factor=REDUCE_LR_ON_PLATEAU_FACTOR
      )
    
    cv_lr_schedulers.append(lr_scheduler)

cv_lr_schedulers_config = {
  "lr_scheduler_type": LR_SCHEDULER_TYPE,
  "lr_scheduler_step_size": LR_SCHEDULER_STEP_SIZE,
  "lr_scheduler_gamma": LR_SCHEDULER_GAMMA,
  "lr_scheduler_last_epoch": LR_SCHEDULER_LAST_EPOCH,
}

cv_train_dls = []
cv_val_dls = []
cv_test_dls = []

In [41]:
k_fold = KFold(n_splits=K_FOLD_CV_NUM_FOLDS, shuffle=True)

cv_train_idxs = []

for fold, (train_idxs, val_idxs) in enumerate(k_fold.split(fma_dataset_train_val)):

  train_subsampler = torch.utils.data.SubsetRandomSampler(train_idxs)
  val_subsampler = torch.utils.data.SubsetRandomSampler(val_idxs)

  cv_train_dls.append(
    torch.utils.data.DataLoader(
      fma_dataset_train_val, batch_size=K_FOLD_CV_BATCH_SIZE, sampler=train_subsampler
    )
  )
  cv_train_idxs.append(train_idxs)

  cv_val_dls.append(
    torch.utils.data.DataLoader(
      fma_dataset_train_val, batch_size=K_FOLD_CV_BATCH_SIZE, sampler=val_subsampler
    )
  )

  cv_test_dls.append(
    torch.utils.data.DataLoader(
      fma_dataset_test, batch_size=K_FOLD_CV_BATCH_SIZE
    )
  )

k_fold_cv_limit_num_batches_train = int(
  len(list(cv_train_dls[0])) * K_FOLD_CV_LIMIT_NUM_BATCHES_PERCENTAGE_TRAIN
) + 1

k_fold_cv_limit_num_batches_val = int(
  len(list(cv_val_dls[0])) * K_FOLD_CV_LIMIT_NUM_BATCHES_PERCENTAGE_VAL
) + 1

k_fold_cv_limit_num_batches_test = int(
  len(list(cv_test_dls[0])) * K_FOLD_CV_LIMIT_NUM_BATCHES_PERCENTAGE_TEST
) + 1

cv_data_logs = {
  "data_type": DATASET_TYPE,
  "dataset_size": DATASET_SIZE,
  "batch_size": K_FOLD_CV_BATCH_SIZE,
  "num_samples_per_second": DATASET_NUM_SAMPLES_PER_SECOND,
  "num_channels": DATASET_NUM_CHANNELS,
  "train_transforms": fma_data_transforms_train.__repr__()
}

In [42]:
def perform_k_fold_cv(
  cv_id,

  cv_num_folds,
  cv_models, cv_optimizers, cv_criterions,
  batch_size, 
  cv_train_dls, cv_train_idxs, cv_val_dls, cv_test_dls, 
  cv_num_epochs, 
  cv_lr_schedulers,
  cv_device, 
  cv_print_freq, cv_ckp_freq, 
  cv_ckp_dir,
  cv_should_close_tqdm_prog_bars_when_done,
  cv_limit_num_folds,
  cv_limit_num_batches_train,
  cv_limit_num_batches_val,
  cv_limit_num_batches_test,
  cv_weight_decay

):

  cv_training_logs = {}

  pbar_folds = tqdm(range(cv_num_folds), colour="#b22222")

  for fold in pbar_folds:
    pbar_folds.set_description(f"fold {fold}")

    cv_ckp_fold_dir = f"{cv_ckp_dir}/fold_{fold}"

    if fold < cv_limit_num_folds:

      wandb_config ={
        "k_folds_cv_num_folds": K_FOLD_CV_NUM_FOLDS,
        "k_folds_cv_fold": fold,
        "data_logs": cv_data_logs,
        "optimizer_config": cv_optimizers_configs[0], # all the same, one is enough
        "model_setup": cv_models[fold].get_model_setup(), # all the same, one is enough
        "lr_scheduler_configr": cv_lr_schedulers_config,
        "num_filters_divider": NUM_FILTERS_DIVIDER,
        "dropout_p_conv": K_FOLD_CV_DROPOUT_P_CONV,
        "dropout_p_linear": K_FOLD_CV_DROPOUT_P_LINEAR,
        "k_fold_cv_weight_decay": cv_weight_decay
      }

      wandb.init(
        project="GeNNus_CNN_mel_spectrogram", entity="filetto-di-salmone",
        config=wandb_config
      )

      training_log = train_model(
        model=cv_models[fold], 
        optimizer=cv_optimizers[fold], criterion=cv_criterions[fold],
        batch_size=batch_size,
        train_dl=cv_train_dls[fold], val_dl=cv_val_dls[fold], test_dl=cv_test_dls[fold],
        num_epochs=cv_num_epochs, 
        lr_scheduler=cv_lr_schedulers[fold] if len(cv_lr_schedulers) > 0 else None,
        device=cv_device,
        print_freq=cv_print_freq, ckp_freq=cv_ckp_freq, 
        ckp_dir=cv_ckp_fold_dir, ckp_name=f"{cv_id}_fold_{fold}",
        should_close_tqdm_prog_bars_when_done=cv_should_close_tqdm_prog_bars_when_done,
        limit_num_batches_train=cv_limit_num_batches_train,
        limit_num_batches_val=cv_limit_num_batches_val,
        limit_num_batches_test=cv_limit_num_batches_test,
        weight_decay=cv_weight_decay,
      )

    cv_training_logs[str(fold)] = training_log

    pbar_folds.update(1)

  return cv_training_logs
  
  

In [43]:
def print_k_fold_cv_curves(
  k_fold_cv_stats, curve_name_singular, curve_name_plural, plot_title, 
  limit_to_n_folds=None
):
  
  num_folds = len(list(k_fold_cv_stats["training_logs"].keys()))
  num_folds_og = num_folds
  
  if limit_to_n_folds is not None and limit_to_n_folds <= num_folds:
    num_folds = limit_to_n_folds
  

  for fold in range(0, num_folds):

    fold_str = str(fold)

    losses_dict = k_fold_cv_stats["training_logs"][fold_str][curve_name_plural]

    epochs = losses_dict.keys()
    
    curve_train = [ j[f"{curve_name_singular}_train"] for j in losses_dict.values() ]

    curve_val = [ j[f"{curve_name_singular}_val"] for j in losses_dict.values() ]

    fig, ax = plt.subplots()

    sns.lineplot(
      x=epochs,
      y=curve_train,
      legend="full",
      label=f"train {curve_name_singular}", 
      ax=ax
    )

    sns.lineplot(
      x=epochs,
      y=curve_val,
      legend="full",
      label=f"val {curve_name_singular}", 
      ax=ax
    )

    fig.suptitle(plot_title)
    ax.set_title(f"{num_folds_og}-fold CV, fold {fold + 1}")
    
    epochs_as_list = list(epochs)
    PERCENTAGE_X_AXIS_TICKS = 0.2
    epoch_axis_display = epochs_as_list[
      ::int(len(epochs_as_list) * PERCENTAGE_X_AXIS_TICKS)
    ]
    ax.set_xticklabels(epoch_axis_display)
    ax.set_xticks(epoch_axis_display)

In [44]:
def print_avg_fold_cv_curves(
  k_fold_cv_stats, curve_name_singular, curve_name_plural, plot_title
):
  
  num_folds = len(list(k_fold_cv_stats["training_logs"].keys()))

  curves_train = []
  curves_val = []

  for fold in range(0, num_folds):

    fold_str = str(fold)

    losses_dict = k_fold_cv_stats["training_logs"][fold_str][curve_name_plural]

    epochs = losses_dict.keys()
    
    curves_train.append(
      [ j[f"{curve_name_singular}_train"] for j in losses_dict.values() ]
    )

    curves_val.append(
      [ j[f"{curve_name_singular}_val"] for j in losses_dict.values() ]
    )
  
  curves_train_np = np.asarray(curves_train)
  curves_val_np = np.asarray(curves_val)

  curve_train = np.average(curves_train_np, axis=0)
  curve_val = np.average(curves_val_np, axis=0)

  fig, ax = plt.subplots()

  sns.lineplot(
    x=epochs,
    y=curve_train,
    legend="full",
    label=f"train {curve_name_singular}", 
    ax=ax
  )

  sns.lineplot(
    x=epochs,
    y=curve_val,
    legend="full",
    label=f"val {curve_name_singular}", 
    ax=ax
  )

  fig.suptitle(plot_title)
  
  epochs_as_list = list(epochs)
  PERCENTAGE_X_AXIS_TICKS = 0.2
  epoch_axis_display = epochs_as_list[
    ::int(len(epochs_as_list) * PERCENTAGE_X_AXIS_TICKS)
  ]
  ax.set_xticklabels(epoch_axis_display)
  ax.set_xticks(epoch_axis_display)

  fig, ax =plt.subplots(2, 1)
  sns.lineplot(
    x=epochs,y=curve_train,legend="full",
    label=f"train {curve_name_singular}", ax=ax[0]
  )
  sns.lineplot(
    x=epochs,y=curve_val,legend="full",
    label=f"val {curve_name_singular}", ax=ax[1], color="orange"
  )

  ax[0].set_xticklabels(epoch_axis_display)
  ax[1].set_xticklabels(epoch_axis_display)
  ax[0].set_xticks(epoch_axis_display)
  ax[1].set_xticks(epoch_axis_display)
  fig.show()


In [45]:
K_FOLD_CV_RUN = True

K_FOLD_CV_PRINT_LOSS_CURVES_ACROSS_FOLDS = False
K_FOLD_CV_PRINT_ACC_CURVES_ACROSS_FOLDS  = False

K_FOLD_CV_PRINT_LOSS_CURVES_AVG_FOLDS = False
K_FOLD_CV_PRINT_ACC_CURVES_AVG_FOLDS  = False

K_FOLD_CV_PLOTS_LIMIT_TO_N_FOLDS = 1

In [46]:
if K_FOLD_CV_RUN:

  k_fold_cv_id = gen_train_id()

  print(
    f"Tot num trainable params: {count_num_trainable_parameters(cv_models[0])}"
  )

  print(
    f"k_fold_cv_limit_num_batches_train: {k_fold_cv_limit_num_batches_train}\n"
    f"k_fold_cv_limit_num_batches_val: {k_fold_cv_limit_num_batches_val}\n"
    f"k_fold_cv_limit_num_batches_test: {k_fold_cv_limit_num_batches_test}"
  )

  k_fold_cv_training_logs = perform_k_fold_cv(
    cv_id=k_fold_cv_id,
    cv_num_folds=K_FOLD_CV_NUM_FOLDS,
    cv_models=cv_models, cv_optimizers=cv_optimizers, cv_criterions=cv_criterions,
    batch_size=K_FOLD_CV_BATCH_SIZE, 
    cv_train_dls=cv_train_dls, cv_train_idxs=cv_train_idxs, 
    cv_val_dls=cv_val_dls, cv_test_dls=cv_test_dls,
    cv_num_epochs=K_FOLD_CV_NUM_EPOCHS, 
    cv_lr_schedulers=cv_lr_schedulers,
    cv_device=device, 
    cv_print_freq=K_FOLD_CV_PRINT_FREQ, cv_ckp_freq=K_FOLD_CV_CKP_FREQ, 
    cv_ckp_dir=f"{K_FOLD_CV_CKP_FOLDER}/{k_fold_cv_id}",
    cv_should_close_tqdm_prog_bars_when_done=K_FOLD_CV_SHOULD_CLOSE_TQDM_PROG_BARS_WHEN_DONE,
    cv_limit_num_folds=K_FOLD_CV_LIMIT_NUM_FOLDS,
    cv_limit_num_batches_train=k_fold_cv_limit_num_batches_train,
    cv_limit_num_batches_val=k_fold_cv_limit_num_batches_val,
    cv_limit_num_batches_test=k_fold_cv_limit_num_batches_test,
    cv_weight_decay=k_fold_cv_weight_decay
  )

  k_fold_cv_stats = {
    "k_fold_cv_id": k_fold_cv_id,
    "stats_type": "k_fold_cross_validation",
    "k_folds_cv_num_folds": K_FOLD_CV_NUM_FOLDS,
    "data_logs": cv_data_logs,
    "optimizer_config": cv_optimizers_configs[0], # all the same, one is enough
    "model_setup": cv_models[0].get_model_setup(), # all the same, one is enough
    "training_logs": k_fold_cv_training_logs,
    "lr_scheduler_configr": cv_lr_schedulers_config
  }

Tot num trainable params: 198422
k_fold_cv_limit_num_batches_train: 19
k_fold_cv_limit_num_batches_val: 1
k_fold_cv_limit_num_batches_test: 1


  0%|          | 0/5 [00:00<?, ?it/s]

[34m[1mwandb[0m: Currently logged in as: [33mdansolombrino[0m ([33mfiletto-di-salmone[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/360 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]







  0%|          | 0/9 [00:00<?, ?it/s]

epoch: 360
  train loss: 252.87, train acc: 0.87
  val loss  : 114.46, val acc  : 0.58
  test loss : 121.06 , test acc: 0.48



VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
acc/train,▁▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇█████
acc/val,▁▃▅▄▅▅▅▅▇▅▆▅▆▇▅▆▇▇▄▄▅▇▇▆▆▇▆▅▅▇▇▅▅▇▅▅▆▆█▇
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss/train,██▇▇▆▆▆▆▆▅▆▅▅▅▄▄▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁
loss/val,▅▄▃▃▃▂▃▃▁▃▂▂▃▁▃▃▂▁▃▄▄▃▂▅▃▄▄▄▆▃▅▄▆▄▇▆█▅▃▅
transform_p,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
acc/train,0.87
acc/val,0.58
epoch,359.0
loss/train,252.87
loss/val,114.46
transform_p,-0.01
