In [None]:
import os
import numpy as np
import pandas as pd
from collections import namedtuple
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedGroupKFold
from ext.kaggle_kl_div.kaggle_kl_div import score as kaggle_kl_div_score
import wandb

import torch
from torch import nn
import torch.optim as optim

from dataloader import get_dataloaders, get_datasets
from utils import seed_everything
from trainer import Trainer
from model.model import SpecCNN
from train import load_data, train_model
from utils import Config

class CFG(Config):
    seed = 42
    cv_fold = 5
    base_model = 'efficientnet_b0'   # resnet18/34/50, efficientnet_b0/b1/b2/b3/b4, efficientnet_v2_s, convnext_tiny, swin_t
    batch_size = 32
    epochs = 3
    base_lr = 1e-3
    affine_degrees = 0
    affine_translate = None
    affine_scale = None
    dataloader_num_workers = 8
    scheduler_step_size = 2
    optimizer = 'AdamW'
    scheduler = 'StepLR'
    loss = 'KLDivLoss'
    lr_gamma = 0.1
    sgd_momentum = 0.9
    color_jitter_args = dict(p=0.0, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2)
    random_erasing_p = 0
    freeze_epochs = 0
    spec_trial_selection = 'all' # 'all', 'first', 'mean_offset'
    eeg_trial_selection = 'all' # 'all', 'first', 'mean_offset'
    spec_random_trial_num = 1
    eeg_random_trial_num = 2
    TARGETS = ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']
    coarse_dropout_args = dict(p=0.5, max_holes=8, max_height=128, max_width=128)
    data_type = 'eeg_tf'  # 'spec', 'eeg_tf', 'spec+eeg_tf
    eeg_tf_fname = 'eeg_tf_data_globalnorm'

    if data_type == 'spec':
        in_channels = 1
    elif data_type == 'eeg_tf':
        in_channels = 1
    elif data_type == 'spec+eeg_tf':
        in_channels = 2

tags=['torch', 'cv', 'best_epoch']
notes = ''
plot_samples = True
train_final_model = False
use_wandb = False
one_fold = True

# Wandb
if use_wandb:
    wandb.login(key='1b0401db7513303bdea77fb070097f9d2850cf3b')
    run = wandb.init(project='kaggle-hms', config=CFG.get_dict(), tags=tags, notes=notes)
else:
    WandbRun = namedtuple('WandbRun', 'name')
    run = WandbRun('debug')

# Label encoder/decoder
encode = {'seizure_vote': 0, 'lpd_vote': 1, 'gpd_vote': 2, 'lrda_vote': 3, 'grda_vote': 4, 'other_vote': 5}
decode = {v: k for k, v in encode.items()}

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Paths
root = '/media/latlab/MR/projects/kaggle-hms'
data_dir = os.path.join(root, 'data')
results_dir = os.path.join(root, 'results')
train_eeg_dir = os.path.join(data_dir, 'train_eegs')
train_spectrogram_dir = os.path.join(data_dir, 'train_spectrograms')

# Seed
seed_everything(CFG.seed)

# Load data
df, data = load_data(CFG)

print('Train shape:', df.shape )
display(df.head())

In [None]:
# Show training data
if plot_samples:
    dataloaders = get_dataloaders(CFG, get_datasets(CFG, data, df_train=df, df_validation=df))
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloaders['train']):
            plt.figure(figsize=(np.ceil(len(X)/2), 20))
            for i in range(len(X)):
                plt.subplot(int(np.ceil(len(X)/6)), 6, i+1)
                # plt.figure()
                img_data = X[i].permute(1, 2, 0).cpu().numpy()[...]
                # Normalize images for plotting (since there are negative values in tensors)
                # img_data_norm = np.clip(((img_data - img_data.mean(axis=(0, 1, 2))) / img_data.std(axis=(0, 1, 2)))/4 + 0.5, 0, 1)
                plt.imshow(img_data, vmin=-1.5, vmax=1.5, cmap='RdBu_r')
                t = y[i].cpu().numpy()
                tars = f'[{t[0]:0.2f}'
                for s in t[1:]: tars += f', {s:0.2f}'
                tars += ']'
                plt.title(tars, fontdict={'fontsize': 8})
            if batch >= 0:
                break

In [None]:
skf = StratifiedGroupKFold(n_splits=CFG.cv_fold, random_state=CFG.seed, shuffle=True)
metric_list = []
targets = []
oof_preds = []
for cv, (train_index, valid_index) in enumerate(skf.split(X=np.zeros(len(df['expert_consensus'])), y=df['expert_consensus'], groups=df['patient_id'])):
    print(f"Cross-validation fold {cv+1}/{CFG.cv_fold}")
    df_train = df.iloc[train_index]
    df_validation = df.iloc[valid_index]
    run_name = f'{run.name}-cv{cv+1}'
    state_filename = os.path.join(results_dir, 'models', f'ubc-ocean-{run_name}.pt')
    if use_wandb and cv == 0:
        wandb_log = True
    else:
        wandb_log = False
    trainer = train_model(CFG, data, df_train, df_validation, state_filename, wandb_log=wandb_log)
    metric_list.append(trainer.best_metric)
    if use_wandb:
        wandb.log({f'kl_div_cv{cv+1}': trainer.best_metric})
    if one_fold:
        break

    # Get OOF predictions
    targets.append(trainer.test_y)
    oof_preds.append(trainer.test_pred)

if use_wandb:
    wandb.log({f'mean_kl_div': np.mean(metric_list)})
    wandb.finish()

In [None]:
# Final training on all data
if train_final_model:
    state_filename = os.path.join(results_dir, 'models', f'ubc-ocean-{run.name}.pt')
    trainer = train_model(CFG, data, df, df, state_filename, validate=False, wandb_log=False)
    if use_wandb:
        wandb.finish()

In [None]:
# # Confusion matrix
# import warnings
# from sklearn.metrics import balanced_accuracy_score, confusion_matrix
# loss_fn = nn.CrossEntropyLoss(reduction='none')
# datasets = get_tiles_datasets(CFG, train_image_dir, df_train, df_validation[df_validation['is_tma']==True])
# dataloaders = get_dataloaders(CFG, datasets)
# y_list = []
# pred_list = []
# loss_list = []
# metric = 0
# with torch.no_grad():
#     for X, y in dataloaders['validation']:
#         X, y = X.to(device), y.to(device)
#         outputs = model(X)
#         _, preds = torch.max(outputs, 1)
#         loss = loss_fn(outputs, y)
#         y_list.append(y.cpu().numpy())
#         pred_list.append(preds.cpu().numpy())
#         loss_list.append(loss.cpu().numpy())
#         with warnings.catch_warnings():
#             warnings.simplefilter('ignore', category=UserWarning)
#             metric += balanced_accuracy_score(y.cpu().numpy(), preds.cpu().numpy())
# metric /= len(dataloaders['validation'])
# y_list = np.concatenate(y_list)
# pred_list = np.concatenate(pred_list)
# loss_list = np.concatenate(loss_list)

# from ext.pretty_confusion_matrix import pp_matrix
# cm = confusion_matrix(y_list, pred_list)
# df_cm = pd.DataFrame(cm, index=encode.keys(), columns=encode.keys())
# pp_matrix(df_cm, pred_val_axis='x', cmap='Oranges', figsize=(8, 8))

In [None]:
# # Top k losses
# k = 10
# topk_loss_idx = list(loss_list.argsort()[-k:])
# with torch.no_grad():
#     for b, (X, y) in enumerate(dataloaders['validation']):
#         for bi in range(len(X)):
#             i = b * CFG.batch_size + bi
#             if i not in topk_loss_idx:
#                 continue
#             plt.figure()
#             plt.imshow(X[bi].permute(1, 2, 0))
#             plt.title(f'loss: {loss_list[i]:.4f}, label: {decode[y_list[i]]}, pred: {decode[pred_list[i]]}')