In [None]:
import os
import sys
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm

import torch
import torch.nn.functional as F

sys.path.append('..')
from config import Config
from utils import seed_everything
from train import load_data
from dataloader import get_dataloaders, get_datasets
from model.model import SpecCNN
from ext.kaggle_kl_div.kaggle_kl_div import score as kaggle_kl_div_score


class CFG(Config):
    model_name = 'red-chrysanthemum-258'
    base_model = 'efficientnet_b0'   # resnet18/34/50, efficientnet_b0/b1/b2/b3/b4, efficientnet_v2_s, convnext_tiny, swin_t
    batch_size = 32
    data_type = 'eeg_tf'
    eeg_tf_data = 'eeg_tf_data_globalnorm'
    spec_trial_selection = 'all'
    eeg_trial_selection = 'all'
    coarse_dropout_args = {}
    pretrained = False

    if data_type == 'spec':
        in_channels = 1
    elif data_type == 'eeg_tf':
        in_channels = 1
    elif data_type == 'spec+eeg_tf':
        in_channels = 2


full_model_name = f'{CFG.project_name}-{CFG.model_name}'
model_dir = os.path.join(CFG.models_dir, full_model_name)
diag_dir = os.path.join(model_dir, 'diag')
if os.path.exists(model_dir):
    os.makedirs(diag_dir, exist_ok=True)

# Load splits
df = pd.read_csv(os.path.join(model_dir, 'splits.csv'))

# Load models
model_paths = []
for fold in range(CFG.cv_fold):
    path = os.path.join(model_dir, f'{full_model_name}-cv{fold+1}_best.pt')
    assert os.path.exists(path), f'Model {path} does not exist'
    model_paths.append(path)

seed_everything(CFG.seed)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

_, data = load_data(CFG)

print(model_paths)
display(df)

### Get OOF predictions

In [None]:
y_all = []
pred_all = []
for fold in tqdm(range(1, CFG.cv_fold+1)):
    # Get data
    df_fold = df[df['fold']==fold]
    df_train = df_fold[df_fold['split']=='train']
    df_validation = df_fold[df_fold['split']=='validation']
    dataloaders = get_dataloaders(CFG, get_datasets(CFG, data, df_train=df_train, df_validation=df_validation))

    # Load model
    model = SpecCNN(model_name=CFG.base_model, num_classes=len(CFG.TARGETS), in_channels=CFG.in_channels).to(device)
    model.load_state_dict(torch.load(model_paths[fold-1]))
    model.to(device)
    model.eval()

    # Inference
    with torch.no_grad():
        for b, (X, y) in enumerate(dataloaders['validation']):
            pred = model(X.to(device))
            pred = F.softmax(pred, dim=-1).cpu().numpy()
            y_all.append(y.numpy())
            pred_all.append(pred)
y_all = np.concatenate(y_all)
pred_all = np.concatenate(pred_all)

y_label = np.argmax(y_all, axis=1)
pred_label = np.argmax(pred_all, axis=1)

### Sample data

In [None]:
with torch.no_grad():
    for X, y in dataloaders['validation']:
        break
sample_data = X[0].squeeze()
plt.figure(figsize=(15, 10))
plt.subplot(1,2,1)
plt.imshow(sample_data)
plt.subplot(1,2,2)
plt.imshow(sample_data, cmap='RdBu_r', vmin=-3, vmax=3)

### Metric

In [None]:
y_all_df = pd.DataFrame(y_all)
y_all_df['id'] = np.arange(len(y_all_df))

pred_all_df = pd.DataFrame(pred_all)
pred_all_df['id'] = np.arange(len(pred_all_df))

metric = kaggle_kl_div_score(submission=pred_all_df, solution=y_all_df, row_id_column_name='id')
print(f'Kaggle KL Divergence: {metric:.6f}')

### Confusion matrix

In [None]:
from sklearn.metrics import confusion_matrix
from ext.pretty_confusion_matrix import pp_matrix

cm = confusion_matrix(y_label, pred_label)
df_cm = pd.DataFrame(cm, index=CFG.TARGETS, columns=CFG.TARGETS)
pp_matrix(df_cm, pred_val_axis='x', cmap='rocket_r', figsize=(8, 8))
plt.savefig(os.path.join(diag_dir, 'confusion_matrix.png'))

### GradCAM

In [None]:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

fold = 1

# Get data
df_fold = df[df['fold']==fold]
df_train = df_fold[df_fold['split']=='train']
df_validation = df_fold[df_fold['split']=='validation']
dataloaders = get_dataloaders(CFG, get_datasets(CFG, data, df_train=df_train, df_validation=df_validation))

# Load model
model = SpecCNN(model_name=CFG.base_model, num_classes=len(CFG.TARGETS), in_channels=CFG.in_channels).to(device)
model.load_state_dict(torch.load(model_paths[fold-1]))
model.to(device)
model.eval();

target_layers = [model.model.conv_head]

In [None]:
cam = GradCAM(model=model, target_layers=target_layers)

# Inference
with torch.no_grad():
    for b, (X, y) in enumerate(dataloaders['validation']):
        pred = model(X.to(device))
        pred = F.softmax(pred, dim=-1).cpu().numpy()
        break

grayscale_cam = cam(input_tensor=X, targets=[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0]])
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(sample_data, grayscale_cam, use_rgb=True)