In [1]:
import os
import sys
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm
import seaborn as sns

import torch
from torch import nn
import torch.nn.functional as F

sys.path.append('..')
from config import Config
from utils import seed_everything
from train import load_data
from dataloader import get_dataloaders, get_datasets
from model.model import SpecCNN
from ext.kaggle_kl_div.kaggle_kl_div import score as kaggle_kl_div_score


class CFG(Config):
    model_name = 'decent-galaxy-454'
    base_model = 'efficientnet_b0'   # resnet18/34/50, efficientnet_b0/b1/b2/b3/b4, efficientnet_v2_s, convnext_tiny, swin_t
    batch_size = 16
    epochs = 3
    base_lr = 1e-3
    scheduler_step_size = 2
    optimizer = 'Adan'
    scheduler = 'StepLR'
    loss = 'KLDivLoss'
    lr_gamma = 0.1
    sgd_momentum = 0.9
    random_erasing_p = 0
    freeze_epochs = 0
    spec_random_trial_num = 1
    eeg_random_trial_num = 1
    data_type = 'eeg_tf'  # 'spec', 'eeg_tf', 'spec+eeg_tf
    eeg_tf_data = 'eeg_tf_data_globalnorm'

    # Augmentation
    random_ch_erease_args = dict(p=0.0, eeg_ch_num=4, drop_ch_num=1)
    random_time_masking_args = dict(p=0.0, width_prop=0.1, erase_num=2)
    random_frequency_masking_args = dict(p=0.0, eeg_ch_num=4, bandwidth_prop=0.1, erase_num=1)
    use_mixup = False
    mixup_alpha = 2.0
    coarse_dropout_args = dict(p=0.5, max_holes=8, max_height=128, max_width=128)
    time_crop_p = 0.5
    time_crop_args = dict(max_trim=150)

    if data_type == 'spec':
        in_channels = 1
        spec_trial_selection = 'first'
        eeg_trial_selection = 'all'
    elif data_type == 'eeg_tf':
        in_channels = 1
        spec_trial_selection = 'all'
        eeg_trial_selection = 'first'
    elif data_type == 'spec+eeg_tf':
        spec_trial_selection = 'all'
        eeg_trial_selection = 'first'


full_model_name = f'{CFG.project_name}-{CFG.model_name}'
model_dir = os.path.join(CFG.models_dir, full_model_name)
diag_dir = os.path.join(model_dir, 'diag')
if os.path.exists(model_dir):
    os.makedirs(diag_dir, exist_ok=True)

# Load splits
df = pd.read_csv(os.path.join(model_dir, 'splits.csv'))

# Load models
model_paths = []
for fold in range(CFG.cv_fold):
    path = os.path.join(model_dir, f'{full_model_name}-cv{fold+1}_best.pt')
    assert os.path.exists(path), f'Model {path} does not exist'
    model_paths.append(path)

seed_everything(CFG.seed)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

_, data = load_data(CFG)

print(model_paths)
display(df)

['/media/latlab/MR/projects/kaggle-hms/results/models/hms-decent-galaxy-454/hms-decent-galaxy-454-cv1_best.pt', '/media/latlab/MR/projects/kaggle-hms/results/models/hms-decent-galaxy-454/hms-decent-galaxy-454-cv2_best.pt', '/media/latlab/MR/projects/kaggle-hms/results/models/hms-decent-galaxy-454/hms-decent-galaxy-454-cv3_best.pt', '/media/latlab/MR/projects/kaggle-hms/results/models/hms-decent-galaxy-454/hms-decent-galaxy-454-cv4_best.pt', '/media/latlab/MR/projects/kaggle-hms/results/models/hms-decent-galaxy-454/hms-decent-galaxy-454-cv5_best.pt']


Unnamed: 0,eeg_id,eeg_sub_id,eeg_label_offset_seconds,spectrogram_id,spectrogram_sub_id,spectrogram_label_offset_seconds,label_id,patient_id,expert_consensus,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote,rater_num,split,fold
0,2277392603,0,0.0,924234,0,0.0,1978807404,30539,GPD,0.0,0.0000,0.454545,0.000,0.090909,0.454545,11,train,1
1,722738444,0,0.0,999431,0,0.0,557980729,56885,LRDA,0.0,0.0625,0.000000,0.875,0.000000,0.062500,16,train,1
2,387987538,0,0.0,1084844,0,0.0,4099147263,4264,LRDA,0.0,0.0000,0.000000,1.000,0.000000,0.000000,3,train,1
3,2175806584,0,0.0,1219001,0,0.0,1963161945,23435,Seizure,1.0,0.0000,0.000000,0.000,0.000000,0.000000,3,train,1
4,1626798710,0,0.0,1219001,2,74.0,3631726128,23435,Seizure,0.6,0.0000,0.400000,0.000,0.000000,0.000000,5,train,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
85440,1502386617,0,0.0,2141520273,0,0.0,299113150,6770,Seizure,1.0,0.0000,0.000000,0.000,0.000000,0.000000,3,validation,5
85441,3643661974,0,0.0,2143211626,0,0.0,3196742703,5512,GRDA,0.0,0.0000,0.000000,0.000,0.583333,0.416667,12,validation,5
85442,1437016107,0,0.0,2145805074,0,0.0,1221972194,20588,Other,0.0,0.0000,0.000000,0.000,0.000000,1.000000,2,validation,5
85443,3813274317,0,0.0,2146166212,0,0.0,1079072938,3838,LRDA,0.0,0.0000,0.000000,0.500,0.000000,0.500000,4,validation,5


### Get OOF predictions

In [None]:
y_all = []
pred_all = []
df_test = pd.DataFrame()
for fold in tqdm(range(1, CFG.cv_fold+1)):
    # Get data
    df_fold = df[df['fold']==fold]
    df_train = df_fold[df_fold['split']=='train']
    df_validation = df_fold[df_fold['split']=='validation']
    df_test = pd.concat([df_test, df_validation])
    dataloaders = get_dataloaders(CFG, get_datasets(CFG, data, df_train=df_train, df_validation=df_validation))

    # Load model
    model = SpecCNN(model_name=CFG.base_model, num_classes=len(CFG.TARGETS), in_channels=CFG.in_channels).to(device)
    model.load_state_dict(torch.load(model_paths[fold-1]))
    model.to(device)
    model.eval()

    # Inference
    with torch.no_grad():
        for b, (X, y) in enumerate(dataloaders['validation']):
            pred = model(X.to(device))
            pred = F.softmax(pred, dim=-1).cpu().numpy()
            y_all.append(y.numpy())
            pred_all.append(pred)
y_all = np.concatenate(y_all)
pred_all = np.concatenate(pred_all)

y_label = np.argmax(y_all, axis=1)
pred_label = np.argmax(pred_all, axis=1)

In [None]:
plt.bar(CFG.TARGETS, ((pred_all-y_all)**2).mean(0))
plt.title('MSE');

In [None]:
loss_fn = nn.KLDivLoss()
loss_all = []
for i in range(y_all.shape[0]):
    loss_all.append(loss_fn(torch.tensor(pred_all[i]), torch.tensor(y_all[i])).item())
df_test['loss'] = loss_all
df_test['y'] = y_all.tolist()
df_test['pred'] = pred_all.tolist()

In [None]:
df_test_loss_sorted = df_test.sort_values('loss', ascending=False).reset_index(drop=True)
df_test_loss_sorted['pred'] = df_test_loss_sorted['pred'].apply(lambda x: [round(i, 2) for i in x])
df_test_loss_sorted['y'] = df_test_loss_sorted['y'].apply(lambda x: [round(i, 2) for i in x])

In [None]:
plt.hist(df_test_loss_sorted['loss'], 100);

### Rater numbers

In [None]:
sns.lmplot(df_test, x='rater_num', y='loss')
# # sns.lmplot(df_test, x='rater_num', y='loss', row='expert_consensus')

# for cons in df_test['expert_consensus'].unique():
#     df_test_cons = df_test[df_test['expert_consensus']==cons]
#     plt.figure(figsize=(10, 10))
#     sns.jointplot(df_test_cons, x='rater_num', y='loss', kind='hist')
#     plt.suptitle(f'{cons}')
#     plt.ylim(-0.31, 0.01)
#     plt.xlim(0, 28)

In [None]:
dataloaders = get_dataloaders(CFG, get_datasets(CFG, data, df_train=df_train, df_validation=df_test_loss_sorted))
with torch.no_grad():
        for b, (X, y) in enumerate(dataloaders['validation']):
            plt.figure(figsize=(15, 15))
            for i in range(len(X)):
                plt.subplot(int(np.ceil(len(X)/4)), 4, i+1)
                # plt.figure(figsize=(10, 10))
                img_data = X[i].permute(1, 2, 0).cpu().numpy()[...]
                # Normalize images for plotting (since there are negative values in tensors)
                # img_data_norm = np.clip(((img_data - img_data.mean(axis=(0, 1, 2))) / img_data.std(axis=(0, 1, 2)))/4 + 0.5, 0, 1)
                plt.imshow(img_data, vmin=-3, vmax=3, cmap='RdBu_r')
                t = y[i].cpu().numpy()
                tars = f'[{t[0]:0.2f}'
                for s in t[1:]: tars += f', {s:0.2f}'
                tars += ']'
                plt.title(tars, fontdict={'fontsize': 8})
            if b >= 0:
                break

### Sample data

In [None]:
with torch.no_grad():
    for X, y in dataloaders['validation']:
        break
X -= X.min()
X /= X.max()
sample_data = X[0].squeeze()
plt.figure(figsize=(15, 10))
plt.subplot(1,2,1)
plt.imshow(sample_data)
plt.subplot(1,2,2)
plt.imshow(sample_data, cmap='RdBu_r', vmin=0, vmax=1)

### Metric

In [None]:
y_all_df = pd.DataFrame(y_all)
y_all_df['id'] = np.arange(len(y_all_df))

pred_all_df = pd.DataFrame(pred_all)
pred_all_df['id'] = np.arange(len(pred_all_df))

metric = kaggle_kl_div_score(submission=pred_all_df, solution=y_all_df, row_id_column_name='id')
print(f'Kaggle KL Divergence: {metric:.6f}')

### Confusion matrix

In [None]:
from sklearn.metrics import confusion_matrix
from ext.pretty_confusion_matrix import pp_matrix

cm = confusion_matrix(y_label, pred_label)
df_cm = pd.DataFrame(cm, index=CFG.TARGETS, columns=CFG.TARGETS)
pp_matrix(df_cm, pred_val_axis='x', cmap='rocket_r', figsize=(8, 8))
plt.savefig(os.path.join(diag_dir, 'confusion_matrix.png'))

### GradCAM

In [None]:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

fold = 1

# Get data
df_fold = df[df['fold']==fold]
df_train = df_fold[df_fold['split']=='train']
df_validation = df_fold[df_fold['split']=='validation']
dataloaders = get_dataloaders(CFG, get_datasets(CFG, data, df_train=df_train, df_validation=df_validation))

# Load model
model = SpecCNN(model_name=CFG.base_model, num_classes=len(CFG.TARGETS), in_channels=CFG.in_channels).to(device)
model.load_state_dict(torch.load(model_paths[fold-1]))
model.to(device)
model.eval();

target_layers = [model.model.conv_head]

cam = GradCAM(model=model, target_layers=target_layers)

# Inference
all_X = []
with torch.no_grad():
    for b, (X, y) in enumerate(dataloaders['validation']):
        all_X.append(X)
all_X = torch.cat(all_X)


for i in range(len(CFG.TARGETS)):
    grayscale_cam = cam(input_tensor=X, targets=[ClassifierOutputTarget(i)]*len(X))

    ch = sample_data.numpy().astype(np.float32)
    sample_image = np.stack((ch, ch, ch), axis=-1)
    visualization = show_cam_on_image(sample_image, grayscale_cam.mean(0), use_rgb=True)
    plt.figure()
    plt.subplot(1, 2, 1)
    plt.imshow(grayscale_cam)
    plt.subplot(1, 2, 2)
    plt.imshow(visualization)
    plt.title(CFG.TARGETS[i])