In [1]:
import os
import sys
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np

import torch
import torch.nn.functional as F

sys.path.append('..')
from config import Config
from utils import seed_everything
from train import load_data
from dataloader import get_dataloaders, get_datasets
from model.model import SpecCNN


class CFG(Config):
    model_name = 'lunar-pig-237'
    base_model = 'efficientnet_b0'   # resnet18/34/50, efficientnet_b0/b1/b2/b3/b4, efficientnet_v2_s, convnext_tiny, swin_t
    batch_size = 32
    data_type = 'eeg_tf'
    eeg_tf_data = 'eeg_tf_data_globalnorm'
    spec_trial_selection = 'all'
    eeg_trial_selection = 'all'
    coarse_dropout_args = {}
    pretrained = False

    if data_type == 'spec':
        in_channels = 1
    elif data_type == 'eeg_tf':
        in_channels = 1
    elif data_type == 'spec+eeg_tf':
        in_channels = 2


full_model_name = f'{CFG.project_name}-{CFG.model_name}'
model_dir = os.path.join(CFG.models_dir, full_model_name)

# Load splits
df = pd.read_csv(os.path.join(model_dir, 'splits.csv'))

# Load models
model_paths = []
for fold in range(CFG.cv_fold):
    path = os.path.join(model_dir, f'{full_model_name}-cv{fold+1}_best.pt')
    assert os.path.exists(path), f'Model {path} does not exist'
    model_paths.append(path)

seed_everything(CFG.seed)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

_, data = load_data(CFG)

print(model_paths)
display(df)

['/media/latlab/MR/projects/kaggle-hms/results/models/hms-lunar-pig-237/hms-lunar-pig-237-cv1_best.pt', '/media/latlab/MR/projects/kaggle-hms/results/models/hms-lunar-pig-237/hms-lunar-pig-237-cv2_best.pt', '/media/latlab/MR/projects/kaggle-hms/results/models/hms-lunar-pig-237/hms-lunar-pig-237-cv3_best.pt', '/media/latlab/MR/projects/kaggle-hms/results/models/hms-lunar-pig-237/hms-lunar-pig-237-cv4_best.pt', '/media/latlab/MR/projects/kaggle-hms/results/models/hms-lunar-pig-237/hms-lunar-pig-237-cv5_best.pt']


Unnamed: 0,eeg_id,eeg_sub_id,eeg_label_offset_seconds,spectrogram_id,spectrogram_sub_id,spectrogram_label_offset_seconds,label_id,patient_id,expert_consensus,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote,split,fold
0,2277392603,0,0.0,924234,0,0.0,1978807404,30539,GPD,0.0,0.0000,0.454545,0.000,0.090909,0.454545,train,1
1,722738444,0,0.0,999431,0,0.0,557980729,56885,LRDA,0.0,0.0625,0.000000,0.875,0.000000,0.062500,train,1
2,387987538,0,0.0,1084844,0,0.0,4099147263,4264,LRDA,0.0,0.0000,0.000000,1.000,0.000000,0.000000,train,1
3,2175806584,0,0.0,1219001,0,0.0,1963161945,23435,Seizure,1.0,0.0000,0.000000,0.000,0.000000,0.000000,train,1
4,1626798710,0,0.0,1219001,2,74.0,3631726128,23435,Seizure,0.6,0.0000,0.400000,0.000,0.000000,0.000000,train,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
85440,1502386617,0,0.0,2141520273,0,0.0,299113150,6770,Seizure,1.0,0.0000,0.000000,0.000,0.000000,0.000000,validation,5
85441,3643661974,0,0.0,2143211626,0,0.0,3196742703,5512,GRDA,0.0,0.0000,0.000000,0.000,0.583333,0.416667,validation,5
85442,1437016107,0,0.0,2145805074,0,0.0,1221972194,20588,Other,0.0,0.0000,0.000000,0.000,0.000000,1.000000,validation,5
85443,3813274317,0,0.0,2146166212,0,0.0,1079072938,3838,LRDA,0.0,0.0000,0.000000,0.500,0.000000,0.500000,validation,5


In [2]:
fold = 1
df_fold = df[df['fold']==fold]
df_train = df_fold[df_fold['split']=='train']
df_validation = df_fold[df_fold['split']=='validation']

dataloaders = get_dataloaders(CFG, get_datasets(CFG, data, df_train=df_train, df_validation=df_validation))

model = SpecCNN(model_name=CFG.base_model, num_classes=len(CFG.TARGETS), in_channels=CFG.in_channels).to(device)
model.load_state_dict(torch.load(model_paths[fold-1]))
model.to(device)
model.eval()

with torch.no_grad():
    for b, (X, y) in enumerate(dataloaders['validation']):
        pred = model(X.to(device))
        pred = F.softmax(pred, dim=-1).cpu().numpy()
        print(pred)

[[8.5354024e-01 5.1177456e-03 1.4692729e-03 4.6250239e-02 1.1513055e-02
  8.2109503e-02]
 [1.8578983e-04 3.6018990e-02 2.2872368e-01 1.6388665e-03 2.6271862e-03
  7.3080552e-01]
 [1.9578801e-02 1.1721151e-02 1.0643171e-01 1.2782994e-02 4.9880186e-01
  3.5068357e-01]
 [7.0350778e-01 8.8278279e-03 6.0642429e-02 2.7463613e-03 1.4762084e-02
  2.0951343e-01]
 [2.3122707e-02 3.5200196e-03 2.3458954e-03 2.4987955e-02 7.5531232e-01
  1.9071114e-01]
 [1.8525608e-04 2.1193945e-04 2.0423161e-03 3.0488751e-03 5.7545257e-01
  4.1905898e-01]
 [4.1178926e-03 2.9184851e-03 2.1202110e-02 4.4209533e-03 2.7530575e-01
  6.9203484e-01]
 [6.4035686e-04 8.9577483e-03 3.1572089e-03 4.1167863e-02 1.1087233e-02
  9.3498951e-01]
 [1.8084295e-03 6.8560168e-03 8.2168601e-02 5.1658708e-03 1.6683181e-01
  7.3716933e-01]
 [1.0321295e-01 6.3114434e-01 1.9364553e-02 1.4783723e-02 3.1830817e-02
  1.9966362e-01]
 [9.9576521e-01 7.1403556e-05 6.4249290e-04 3.8645961e-04 5.1595410e-04
  2.6183915e-03]
 [9.6479470e-01 6.825