In [None]:
from pathlib import Path
import random
import os
from importlib import reload

import warnings
# Ignore PyTorch's KLDivLoss warning
warnings.simplefilter("ignore", category=UserWarning, lineno=2949)
warnings.simplefilter("ignore", category=UserWarning, lineno=2943)
warnings.simplefilter("ignore", category=UserWarning, lineno=1603)

import numpy as np
import matplotlib.pyplot as plt
import scipy
import pandas as pd
import torch
from tqdm.auto import tqdm

import utils.config

from utils.config import BASE_PATH, SPEC_DIR, DEVICE
from utils.data_handling import metadata_df

class_names = ['Seizure', 'LPD', 'GPD', 'LRDA','GRDA', 'Other']
label2name = dict(enumerate(class_names))
name2label = {v:k for k, v in label2name.items()}

metadata = metadata_df("train")

from utils import SpectrogramDataset

valid_frac = 0.1
num_unique_spectrograms = metadata.spectrogram_id.unique().shape[0]
valid_num = round(valid_frac * num_unique_spectrograms)
print(f"{num_unique_spectrograms} unique spectrograms, using {valid_num} for validation set.")

rng = np.random.default_rng(seed=4)
valid_set = rng.choice(metadata.spectrogram_id.unique(), size=valid_num, replace=False)
metadata_train = metadata[~metadata.spectrogram_id.isin(valid_set)]
metadata_valid = metadata[metadata.spectrogram_id.isin(valid_set)]
print(f"{len(metadata_train)} training items, {len(metadata_valid)} validation items.")

train_small = SpectrogramDataset(metadata_train, n_items=5000, preloaded=True, random_state=4)
valid_small = SpectrogramDataset(metadata_valid, n_items=1000, preloaded=True, random_state=4)
print(f"{len(train_small)} training items, {len(valid_small)} validation items.")

import torchvision
auto_tfms = torchvision.models.EfficientNet_B0_Weights.IMAGENET1K_V1.transforms()

from torch.utils.data import DataLoader

import utils.models

from utils.training import Trainer
from utils.models import Spectrogram_EfficientNet
from utils.config import DEVICE

model = Spectrogram_EfficientNet(frozen=False).to(DEVICE)


train_dset = SpectrogramDataset(metadata_train, preloaded=True, normalize_targets=True)
valid_dset = SpectrogramDataset(metadata_valid, preloaded=True, normalize_targets=True)

batch_size = 16
train_loader = DataLoader(train_dset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dset, batch_size=batch_size)

model = Spectrogram_EfficientNet(frozen=False).to(DEVICE)
optimizer_class = torch.optim.SGD
loss_fn = torch.nn.KLDivLoss(reduction="batchmean")

In [None]:
total_experts = metadata_train[['seizure_vote','lpd_vote','gpd_vote','lrda_vote','grda_vote','other_vote']].sum(axis=1)
metadata_train['total_votes'] = total_experts
#dataset A - few experts;
#dataset B - many experts
metadata_train_A = metadata_train[metadata_train.total_votes < 8]
metadata_train_B = metadata_train[metadata_train.total_votes > 8]
print(metadata_train_A.shape)
print(metadata_train_B.shape)

In [None]:
#Now train on dataset A first

from torch_lr_finder import LRFinder
train_dset = SpectrogramDataset(metadata_train_A, preloaded=True, normalize_targets=True)
valid_dset = SpectrogramDataset(metadata_valid, preloaded=True, normalize_targets=True)

batch_size = 16
train_loader = DataLoader(train_dset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dset, batch_size=batch_size)

model = Spectrogram_EfficientNet(frozen=False).to(DEVICE)
optimizer_class = torch.optim.SGD
loss_fn = torch.nn.KLDivLoss(reduction="batchmean")

optimizer = optimizer_class(model.parameters(), lr=1e-7)
lr_finder = LRFinder(model, optimizer, loss_fn, device="cuda")
lr_finder.range_test(train_loader, end_lr=10, num_iter=300)
lr_finder.plot(ax=plt.gca())
lr_finder.reset();

In [None]:
lr = 1e-1
trainer = Trainer(model, train_loader, valid_loader,
                  optimizer=optimizer_class,
                  criterion=loss_fn,
                  lr=lr,
                  writer="auto",
                  model_name="spectrogram_efficientnet_b0_SGD_batch_size_16_coarse")

trainer.train_eval_loop(10_000, 500, 500, save_period=10_000)
trainer.plot_metrics()