In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GroupKFold
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score
import pickle
from collections import defaultdict
from sklearn.utils import shuffle
import os
import cv2
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import pandas as pd




In [None]:
# load data, labels, patients from pkl files
import pickle
import os

def load_data(data_path, labels_path, patients_path):
    with open(data_path, 'rb') as f:
        data = pickle.load(f)
    with open(labels_path, 'rb') as f:
        labels = pickle.load(f)
    with open(patients_path, 'rb') as f:
        patients = pickle.load(f)
    return data, labels, patients

data,labels, patients = load_data('data.pkl', 'labels.pkl', 'patients.pkl')


<h1> DATA SPLITTING

In [None]:

import numpy as np

def separate_data_by_class(labels,data, patients, class_no):
    class_indices = np.where(labels == int(class_no))[0]
    data_class = data[class_indices]
    labels_class = labels[class_indices]
    patients_class = patients[class_indices]
    return class_indices, data_class, labels_class, patients_class


def create_patient_folds(patients, n_splits=5):
    all_unique_patients = np.unique(patients)
    print("No. of unique patients is:", len(all_unique_patients))
    np.random.shuffle(all_unique_patients)
    patient_splits = np.array_split(all_unique_patients, n_splits)

    patient_fold_mapping = {}
    for fold, patient_group in enumerate(patient_splits):
        for patient in patient_group:
            patient_fold_mapping[patient] = fold

    return patient_fold_mapping

def split_data_by_patients_across_classes(eeg_data, labels, patients, patient_fold_mapping, n_splits=5):
    eeg_data_splits = [[] for _ in range(n_splits)]
    labels_splits = [[] for _ in range(n_splits)]
    patients_splits = [[] for _ in range(n_splits)]
    for i in range(len(patients)):
        patient_id = patients[i]
        fold = patient_fold_mapping[patient_id]  # Get the fold number from the mapping
        eeg_data_splits[fold].append(eeg_data[i])
        labels_splits[fold].append(labels[i])
        patients_splits[fold].append(patients[i])
        
    eeg_data_splits = [np.array(split) for split in eeg_data_splits]
    labels_splits = [np.array(split) for split in labels_splits]
    patients_splits = [np.array(split) for split in patients_splits]

    return eeg_data_splits, labels_splits, patients_splits


def shuffle_fold(eeg_fold, labels_fold, patients_fold):
    indices = np.arange(len(eeg_fold))  
    np.random.shuffle(indices)          
    eeg_fold_shuffled = eeg_fold[indices]
    labels_fold_shuffled = labels_fold[indices]
    patients_fold_shuffled = patients_fold[indices]

    return eeg_fold_shuffled, labels_fold_shuffled, patients_fold_shuffled


c0_indices, data_c0, labels_c0, patients_c0 = separate_data_by_class(labels, data, patients, 0)

c1_indices, data_c1, labels_c1, patients_c1 = separate_data_by_class(labels, data, patients, 1)

c2_indices, data_c2, labels_c2, patients_c2 = separate_data_by_class(labels, data, patients, 2)


print(f"Class 0: {data_c0.shape}, {labels_c0.shape}, {len(np.unique(patients_c0))}")
print(f"Class 1: {data_c1.shape}, {labels_c1.shape}, {len(np.unique(patients_c1))}")
print(f"Class 2: {data_c2.shape}, {labels_c2.shape}, {len(np.unique(patients_c2))}")


patients_set_c0 = set(patients_c0)
patients_set_c1 = set(patients_c1)
patients_set_c2 = set(patients_c2)

common_patients_all = patients_set_c0.intersection(patients_set_c1, patients_set_c2)
common_patients_all = list(common_patients_all)
print('No. of common patients between all classes: ', len(common_patients_all))

def extract_common_data(patients, data, labels, common_patient):
    common_indices = [i for i, patient in enumerate(patients) if patient in common_patient]
    common_data = data[common_indices]
    common_labels = labels[common_indices]
    common_patients = patients[common_indices]
    return common_data, common_labels, common_patients

common_data_c0, common_labels_c0, common_patients_c0 = extract_common_data(patients_c0, data_c0, labels_c0, common_patients_all)
common_data_c1, common_labels_c1, common_patients_c1 = extract_common_data(patients_c1, data_c1, labels_c1, common_patients_all)
common_data_c2, common_labels_c2, common_patients_c2 = extract_common_data(patients_c2, data_c2, labels_c2, common_patients_all)

print(np.unique(common_patients_c0).shape)

common_allpatient_fold_mapping = create_patient_folds(common_patients_all, n_splits=5)
eeg_012_c0, labels_012_c0, patients_012_c0 = split_data_by_patients_across_classes(
    common_data_c0, common_labels_c0, common_patients_c0, common_allpatient_fold_mapping, n_splits=5)

eeg_012_c1, labels_012_c1, patients_012_c1 = split_data_by_patients_across_classes(
    common_data_c1, common_labels_c1, common_patients_c1, common_allpatient_fold_mapping, n_splits=5)

eeg_012_c2, labels_012_c2, patients_012_c2 = split_data_by_patients_across_classes(
    common_data_c2, common_labels_c2, common_patients_c2, common_allpatient_fold_mapping, n_splits=5)

for i in range(5):
    print(f"Fold {i+1}: EEG shape: {np.concatenate([eeg_012_c0[i], eeg_012_c1[i], eeg_012_c2[i]]).shape}, "
          f"Labels shape: {np.concatenate([labels_012_c0[i], labels_012_c1[i], labels_012_c2[i]]).shape}, "
          f"Patients shape: {np.concatenate([patients_012_c0[i], patients_012_c1[i], patients_012_c2[i]]).shape}")


for i in range(5):
    print(f"Fold {i+1}: EEG shape 0: {eeg_012_c0[i].shape}, "
          f"Labels shape 0: {labels_012_c0[i].shape}, "
          f"Patients shape 0: {patients_012_c0[i].shape}")
print("__________________________________\n")

for i in range(5):
    print(f"Fold {i+1}: EEG shape 1: {eeg_012_c1[i].shape}, "
      f"Labels shape 1: {labels_012_c1[i].shape}, "
      f"Patients shape 1: {patients_012_c1[i].shape}")

print("__________________________________\n")
for i in range(5):
    print(f"Fold {i+1}: EEG shape 2: {eeg_012_c2[i].shape}, "
      f"Labels shape 2: {labels_012_c2[i].shape}, "
      f"Patients shape 2: {patients_012_c2[i].shape}")
    
for i in range(5):
    print(f"No. of patients in class 2 is {len(np.unique(patients_012_c2[i]))}")
for i in range(5):
    print(f"No. of patients in class 1 is {len(np.unique(patients_012_c1[i]))}")
for i in range(5):
    print(f"No. of patients in class 0 is {len(np.unique(patients_012_c0[i]))}")
    
    # Find exclusive patients for each class
exclusive_patients_c0 = list(patients_set_c0 - set(common_patients_all))
exclusive_patients_c1 = list(patients_set_c1 - set(common_patients_all))
exclusive_patients_c2 = list(patients_set_c2 - set(common_patients_all))

print(len(exclusive_patients_c0))
print(len(exclusive_patients_c1))
print(len(exclusive_patients_c2))

exclusive_c2 = set(exclusive_patients_c2)
exclusive_c0 = set(exclusive_patients_c0)
common_c2_c0 = exclusive_c0.intersection(exclusive_c2)
common_patient_c2_c0 = list(common_c2_c0)

print(len(common_patient_c2_c0))



common_data_c02, common_labels_c02, common_patients_c02 = extract_common_data(patients_c0, data_c0, labels_c0, common_patient_c2_c0)
common_data_c22, common_labels_c22, common_patients_c22 = extract_common_data(patients_c2, data_c2, labels_c2, common_patient_c2_c0)

print(np.unique(common_patients_c02).shape)
print(np.unique(common_patients_c22).shape)

c20_mapping = create_patient_folds(common_patient_c2_c0, n_splits=5)

common_eeg_splits_c02, common_labels_splits_c02, common_patients_splits_c02 = split_data_by_patients_across_classes(
    common_data_c02, common_labels_c02, common_patients_c02, c20_mapping, n_splits=5)

common_eeg_splits_c22, common_labels_splits_c22, common_patients_splits_c22 = split_data_by_patients_across_classes(
    common_data_c22, common_labels_c22, common_patients_c22, c20_mapping, n_splits=5)


# Check the shape of each fold for verification
for i in range(5):
    print(f"Fold {i+1}: EEG shape: {np.concatenate([common_eeg_splits_c02[i], common_eeg_splits_c22[i]]).shape}, "
          f"Labels shape: {np.concatenate([common_labels_splits_c02[i], common_labels_splits_c22[i]]).shape}, "
          f"Patients shape: {np.concatenate([common_patients_splits_c02[i], common_patients_splits_c22[i]]).shape}")

for i in range(5):
    print(f"Fold {i+1}: EEG shape 0: {common_eeg_splits_c02[i].shape}, "
          f"Labels shape 0: {common_labels_splits_c02[i].shape}, "
          f"Patients shape 0: {common_patients_splits_c02[i].shape}")
print("__________________________________\n")

for i in range(5):
    print(f"Fold {i+1}: EEG shape 1: {common_eeg_splits_c22[i].shape}, "
      f"Labels shape 1: {common_labels_splits_c22[i].shape}, "
      f"Patients shape 1: {common_patients_splits_c22[i].shape}")

print("__________________________________\n")

for i in range(5):
    print(f"No. of patients in class 1 is {len(np.unique(common_patients_splits_c02[i]))}")
for i in range(5):
    print(f"No. of patients in class 2 is {len(np.unique(common_patients_splits_c22[i]))}")


# Find exclusive patients for each class
exclusive_patients_c0_c2 = list(patients_set_c0 - set(common_patients_all) - set(common_c2_c0))
exclusive_patients_c1_c2 = list(patients_set_c1 - set(common_patients_all)- set(common_c2_c0))
exclusive_patients_c2_c2 = list(patients_set_c2 - set(common_patients_all)- set(common_c2_c0))

print(len(exclusive_patients_c0_c2))
print(len(exclusive_patients_c1_c2))
print(len(exclusive_patients_c2_c2))

exclusive_c2 = set(exclusive_patients_c2_c2)
exclusive_c1 = set(exclusive_patients_c1_c2)
common_c2_c1 = exclusive_c1.intersection(exclusive_c2)
common_patient_c2_c1 = list(common_c2_c1)

print(len(common_patient_c2_c1))



common_data_c12, common_labels_c12, common_patients_c12 = extract_common_data(patients_c1, data_c1, labels_c1, common_patient_c2_c1)
common_data_1_c22, common_labels_1_c22, common_patients_1_c22 = extract_common_data(patients_c2, data_c2, labels_c2, common_patient_c2_c1)

print(np.unique(common_patients_c12).shape)
print(np.unique(common_patients_1_c22).shape)

c21_mapping = create_patient_folds(common_patient_c2_c1, n_splits=5)

common_eeg_splits_c12, common_labels_splits_c12, common_patients_splits_c12 = split_data_by_patients_across_classes(
    common_data_c12, common_labels_c12, common_patients_c12, c21_mapping, n_splits=5)

common_eeg_splits_1_c22, common_labels_splits_1_c22, common_patients_splits_1_c22 = split_data_by_patients_across_classes(
    common_data_1_c22, common_labels_1_c22, common_patients_1_c22, c21_mapping, n_splits=5)


# Check the shape of each fold for verification
for i in range(5):
    print(f"Fold {i+1}: EEG shape: {np.concatenate([common_eeg_splits_c12[i], common_eeg_splits_1_c22[i]]).shape}, "
          f"Labels shape: {np.concatenate([common_labels_splits_c12[i], common_labels_splits_1_c22[i]]).shape}, "
          f"Patients shape: {np.concatenate([common_patients_splits_c12[i], common_patients_splits_1_c22[i]]).shape}")


for i in range(5):
    print(f"No. of patients in class 1 is {len(np.unique(common_patients_splits_c12[i]))}")
for i in range(5):
    print(f"No. of patients in class 2 is {len(np.unique(common_patients_splits_1_c22[i]))}")
    

exclusive_patients_c0_c2_c1 = list(patients_set_c0 - set(common_patients_all) - set(common_c2_c0) - set(common_c2_c1))
exclusive_patients_c1_c2_c1 = list(patients_set_c1 - set(common_patients_all)- set(common_c2_c0) - set(common_c2_c1))
exclusive_patients_c2_c2_c1 = list(patients_set_c2 - set(common_patients_all)- set(common_c2_c0)- set(common_c2_c1))

print(len(exclusive_patients_c0_c2_c1))
print(len(exclusive_patients_c1_c2_c1))
print(len(exclusive_patients_c2_c2_c1))

only_c2 = set(exclusive_patients_c2_c2_c1)
only_c2= list(only_c2)
data_only_c2, labels_only_c2,patients_only_c2 = extract_common_data(patients_c2, data_c2, labels_c2, only_c2)
print(np.unique(patients_only_c2).shape)

c2_mapping = create_patient_folds(only_c2, n_splits=5)

eeg_splits_c2, labels_splits_c2, patients_splits_c2 = split_data_by_patients_across_classes(
    data_only_c2, labels_c2, patients_only_c2, c2_mapping, n_splits=5)

for i in range(5):
    print(f"No. of patients in class 2 is {len(np.unique(patients_splits_c2[i]))}")
    
    # Find exclusive patients for each class
exclusive_patients_all_c02 = list(patients_set_c0 - set(common_patients_all) - set(common_c2_c0) - set(common_c2_c1)- set(only_c2))
exclusive_patients_all_c12 = list(patients_set_c1 - set(common_patients_all)- set(common_c2_c0) - set(common_c2_c1)- set(only_c2))
exclusive_patients_all_c22 = list(patients_set_c2 - set(common_patients_all)- set(common_c2_c0) - set(common_c2_c1)- set(only_c2))

print(len(exclusive_patients_all_c02))
print(len(exclusive_patients_all_c12))
print(len(exclusive_patients_all_c22))


exc_c1_c11 = set(exclusive_patients_all_c12)
exc_c1_c01 = set(exclusive_patients_all_c02)
common_patient_c1_c0 = exc_c1_c01.intersection(exc_c1_c11)
common_patient_c1_c0 = list(common_patient_c1_c0)

print(len(common_patient_c1_c0))

common_data_c1_c01, common_labels_c1_c01, common_patients_c1_c01 = extract_common_data(patients_c0, data_c0, labels_c0, common_patient_c1_c0)
common_data_c1_c11, common_labels_c1_c11, common_patients_c1_c11 = extract_common_data(patients_c1, data_c1, labels_c1, common_patient_c1_c0)

print(np.unique(common_patients_c1_c01).shape)

mapping_c1_c0 = create_patient_folds(common_patient_c1_c0, n_splits=5)

common_eeg_splits_c1_c01, common_labels_splits_c1_c01, common_patients_splits_c1_c01= split_data_by_patients_across_classes(
    common_data_c1_c01, common_labels_c1_c01, common_patients_c1_c01, mapping_c1_c0, n_splits=5)

common_eeg_splits_c1_c11, common_labels_splits_c1_c11, common_patients_splits_c1_c11 = split_data_by_patients_across_classes(
    common_data_c1_c11, common_labels_c1_c11, common_patients_c1_c11, mapping_c1_c0, n_splits=5)

for i in range(5):
    print(f"Fold {i+1}: EEG shape: {np.concatenate([common_eeg_splits_c1_c01[i], common_eeg_splits_c1_c11[i]]).shape}, "
          f"Labels shape: {np.concatenate([common_labels_splits_c1_c01[i], common_labels_splits_c1_c11[i]]).shape}, "
          f"Patients shape: {np.concatenate([common_patients_splits_c1_c01[i], common_patients_splits_c1_c11[i]]).shape}")

for i in range(5):
    print(f"Fold {i+1}: EEG shape 0: {common_eeg_splits_c1_c01[i].shape}, "
          f"Labels shape 0: {common_labels_splits_c1_c01[i].shape}, "
          f"Patients shape 0: {common_patients_splits_c1_c01[i].shape}")
print("__________________________________\n")

for i in range(5):
    print(f"Fold {i+1}: EEG shape 1: {common_eeg_splits_c1_c11[i].shape}, "
      f"Labels shape 1: {common_labels_splits_c1_c11[i].shape}, "
      f"Patients shape 1: {common_patients_splits_c1_c11[i].shape}")

print("__________________________________\n")

for i in range(5):
    print(f"No. of patients in class 1 is {len(np.unique(common_patients_splits_c1_c01[i]))}")
for i in range(5):
    print(f"No. of patients in class 0 is {len(np.unique(common_patients_splits_c1_c11[i]))}")
    
exclusive_patients_c1_c01 = list(patients_set_c0 - set(common_patients_all) - set(common_c2_c0) - set(common_c2_c1)- set(only_c2) - set(common_patient_c1_c0))
exclusive_patients_c1_c11 = list(patients_set_c1 - set(common_patients_all)- set(common_c2_c0) - set(common_c2_c1)- set(only_c2)- set(common_patient_c1_c0))
exclusive_patients_c1_c21 = list(patients_set_c2 - set(common_patients_all)- set(common_c2_c0) - set(common_c2_c1)- set(only_c2)- set(common_patient_c1_c0))

print(len(exclusive_patients_c1_c01))
print(len(exclusive_patients_c1_c11))
print(len(exclusive_patients_c1_c21))



print("CLASS 1")

only_c1 = set(exclusive_patients_c1_c11)
only_c1= list(only_c1)
data_only_c1, labels_only_c1,patients_only_c1 = extract_common_data(patients_c1, data_c1, labels_c1, only_c1)
print(np.unique(patients_only_c1).shape)

c1_mapping = create_patient_folds(only_c1, n_splits=5)

eeg_splits_c1, labels_splits_c1, patients_splits_c1 = split_data_by_patients_across_classes(
    data_only_c1, labels_c1, patients_only_c1, c1_mapping, n_splits=5)

for i in range(5):
    print(f"No. of patients in class 1 is {len(np.unique(patients_splits_c1[i]))}")
    
exclusive_patients_all1_c01 = list(patients_set_c0 - set(common_patients_all) - set(common_c2_c0) - set(common_c2_c1)- set(only_c2) - set(common_patient_c1_c0)- set(only_c1))
exclusive_patients_all1_c11 = list(patients_set_c1 - set(common_patients_all)- set(common_c2_c0) - set(common_c2_c1)- set(only_c2)- set(common_patient_c1_c0)- set(only_c1))
exclusive_patients_all1_c21 = list(patients_set_c2 - set(common_patients_all)- set(common_c2_c0) - set(common_c2_c1)- set(only_c2)- set(common_patient_c1_c0)- set(only_c1))

print(len(exclusive_patients_all1_c01))
print(len(exclusive_patients_all1_c11))
print(len(exclusive_patients_all1_c21))

print("CLASS 0")
only_c0 = set(exclusive_patients_all1_c01)
only_c0= list(only_c0)
data_only_c0, labels_only_c0,patients_only_c0 = extract_common_data(patients_c0, data_c0, labels_c0, only_c0)
print(np.unique(patients_only_c0).shape)

c0_mapping = create_patient_folds(only_c0, n_splits=5)

eeg_splits_c0, labels_splits_c0, patients_splits_c0 = split_data_by_patients_across_classes(
    data_only_c0, labels_c0, patients_only_c0, c0_mapping, n_splits=5)

for i in range(5):
    print(f"No. of patients in class 0 is {len(np.unique(patients_splits_c0[i]))}")

final_c0 = list(patients_set_c0 - set(common_patients_all) - set(common_c2_c0) - set(common_c2_c1)- set(only_c2) - set(common_patient_c1_c0)- set(only_c1) - set(only_c0))
final_c1 = list(patients_set_c1 - set(common_patients_all)- set(common_c2_c0) - set(common_c2_c1)- set(only_c2)- set(common_patient_c1_c0)- set(only_c1)- set(only_c0))
final_c2 = list(patients_set_c2 - set(common_patients_all)- set(common_c2_c0) - set(common_c2_c1)- set(only_c2)- set(common_patient_c1_c0)- set(only_c1)- set(only_c0))

print(len(final_c0))
print(len(final_c1))
print(len(final_c2))

eeg_fold_1 = np.concatenate([eeg_012_c0[0], eeg_012_c1[0], eeg_012_c2[0], common_eeg_splits_c02[0], common_eeg_splits_c22[0], common_eeg_splits_c12[4], common_eeg_splits_1_c22[4],eeg_splits_c2[0],common_eeg_splits_c1_c01[0], common_eeg_splits_c1_c11[0],eeg_splits_c1[0],eeg_splits_c0[4]])
labels_fold_1 = np.concatenate([labels_012_c0[0], labels_012_c1[0], labels_012_c2[0],common_labels_splits_c02[0], common_labels_splits_c22[0],common_labels_splits_c12[4], common_labels_splits_1_c22[4],labels_splits_c2[0],common_labels_splits_c1_c01[0], common_labels_splits_c1_c11[0],labels_splits_c1[0],labels_splits_c0[4] ])
patients_fold_1 = np.concatenate([patients_012_c0[0], patients_012_c1[0], patients_012_c2[0],common_patients_splits_c02[0], common_patients_splits_c22[0],common_patients_splits_c12[4], common_patients_splits_1_c22[4],patients_splits_c2[0],common_patients_splits_c1_c01[0], common_patients_splits_c1_c11[0],patients_splits_c1[0],patients_splits_c0[4]])

eeg_fold_2 = np.concatenate([eeg_012_c0[1], eeg_012_c1[1], eeg_012_c2[1], common_eeg_splits_c02[1], common_eeg_splits_c22[1], common_eeg_splits_c12[3], common_eeg_splits_1_c22[3],eeg_splits_c2[1],common_eeg_splits_c1_c01[1], common_eeg_splits_c1_c11[1],eeg_splits_c1[1],eeg_splits_c0[3]])
labels_fold_2 = np.concatenate([labels_012_c0[1], labels_012_c1[1], labels_012_c2[1],common_labels_splits_c02[1], common_labels_splits_c22[1],common_labels_splits_c12[3], common_labels_splits_1_c22[3],labels_splits_c2[1],common_labels_splits_c1_c01[1], common_labels_splits_c1_c11[1],labels_splits_c1[1],labels_splits_c0[3] ])
patients_fold_2 = np.concatenate([patients_012_c0[1], patients_012_c1[1], patients_012_c2[1],common_patients_splits_c02[1], common_patients_splits_c22[1],common_patients_splits_c12[3], common_patients_splits_1_c22[3],patients_splits_c2[1],common_patients_splits_c1_c01[1], common_patients_splits_c1_c11[1],patients_splits_c1[1],patients_splits_c0[3]])

eeg_fold_3 = np.concatenate([eeg_012_c0[2], eeg_012_c1[2], eeg_012_c2[2], common_eeg_splits_c02[2], common_eeg_splits_c22[2], common_eeg_splits_c12[2], common_eeg_splits_1_c22[2],eeg_splits_c2[2],common_eeg_splits_c1_c01[2], common_eeg_splits_c1_c11[2],eeg_splits_c1[2],eeg_splits_c0[2]])
labels_fold_3 = np.concatenate([labels_012_c0[2], labels_012_c1[2], labels_012_c2[2],common_labels_splits_c02[2], common_labels_splits_c22[2],common_labels_splits_c12[2], common_labels_splits_1_c22[2],labels_splits_c2[2],common_labels_splits_c1_c01[2], common_labels_splits_c1_c11[2],labels_splits_c1[2],labels_splits_c0[2] ])
patients_fold_3 = np.concatenate([patients_012_c0[2], patients_012_c1[2], patients_012_c2[2],common_patients_splits_c02[2], common_patients_splits_c22[2],common_patients_splits_c12[2], common_patients_splits_1_c22[2],patients_splits_c2[2],common_patients_splits_c1_c01[2], common_patients_splits_c1_c11[2],patients_splits_c1[2],patients_splits_c0[2]])

eeg_fold_4 = np.concatenate([eeg_012_c0[3], eeg_012_c1[3], eeg_012_c2[3], common_eeg_splits_c02[3], common_eeg_splits_c22[3], common_eeg_splits_c12[1], common_eeg_splits_1_c22[1],eeg_splits_c2[3],common_eeg_splits_c1_c01[3], common_eeg_splits_c1_c11[3],eeg_splits_c1[3],eeg_splits_c0[1]])
labels_fold_4 = np.concatenate([labels_012_c0[3], labels_012_c1[3], labels_012_c2[3],common_labels_splits_c02[3], common_labels_splits_c22[3],common_labels_splits_c12[1], common_labels_splits_1_c22[1],labels_splits_c2[3],common_labels_splits_c1_c01[3], common_labels_splits_c1_c11[3],labels_splits_c1[3],labels_splits_c0[1] ])
patients_fold_4 = np.concatenate([patients_012_c0[3], patients_012_c1[3], patients_012_c2[3],common_patients_splits_c02[3], common_patients_splits_c22[3],common_patients_splits_c12[1], common_patients_splits_1_c22[1],patients_splits_c2[3],common_patients_splits_c1_c01[3], common_patients_splits_c1_c11[3],patients_splits_c1[3],patients_splits_c0[1]])

eeg_fold_5 = np.concatenate([eeg_012_c0[4], eeg_012_c1[4], eeg_012_c2[4], common_eeg_splits_c02[4], common_eeg_splits_c22[4], common_eeg_splits_c12[0], common_eeg_splits_1_c22[0],eeg_splits_c2[4],common_eeg_splits_c1_c01[4], common_eeg_splits_c1_c11[4],eeg_splits_c1[4],eeg_splits_c0[0]])
labels_fold_5 = np.concatenate([labels_012_c0[4], labels_012_c1[4], labels_012_c2[4],common_labels_splits_c02[4], common_labels_splits_c22[4],common_labels_splits_c12[0], common_labels_splits_1_c22[0],labels_splits_c2[4],common_labels_splits_c1_c01[4], common_labels_splits_c1_c11[4],labels_splits_c1[4],labels_splits_c0[0] ])
patients_fold_5 = np.concatenate([patients_012_c0[4], patients_012_c1[4], patients_012_c2[4],common_patients_splits_c02[4], common_patients_splits_c22[4],common_patients_splits_c12[0], common_patients_splits_1_c22[0],patients_splits_c2[4],common_patients_splits_c1_c01[4], common_patients_splits_c1_c11[4],patients_splits_c1[4],patients_splits_c0[0]])

eeg_folds = [eeg_fold_1, eeg_fold_2, eeg_fold_3, eeg_fold_4, eeg_fold_5]
labels_folds = [labels_fold_1, labels_fold_2, labels_fold_3, labels_fold_4, labels_fold_5]
patients_folds = [patients_fold_1, patients_fold_2, patients_fold_3, patients_fold_4, patients_fold_5]

for i in range(len(eeg_folds)):
    eeg_folds[i] = eeg_folds[i].astype(np.float16)

<h1> DATA BALANCER AND EARLYSTOPPING

In [None]:
def data_balancer(data, labels, factor):
    # Count the number of samples in each class
    num_class_0 = np.sum(labels == 0)
    num_class_1 = np.sum(labels == 1)
    num_class_2 = np.sum(labels == 2)

    # Find the minimum number of samples across all classes
    min_samples = min(num_class_0, num_class_1, num_class_2)

    # Calculate the number of samples to take from each class
    samples_per_class = min_samples // factor

    # Randomly sample 'samples_per_class' from each class
    class_0_indices = np.random.choice(np.where(labels == 0)[0], samples_per_class, replace=False)
    class_1_indices = np.random.choice(np.where(labels == 1)[0], samples_per_class, replace=False)
    class_2_indices = np.random.choice(np.where(labels == 2)[0], samples_per_class, replace=False)

    # Combine balanced indices
    balanced_indices = np.concatenate((class_0_indices, class_1_indices, class_2_indices))

    # Shuffle the balanced indices
    np.random.shuffle(balanced_indices)

    # Create balanced training data and labels
    balanced_data = data[balanced_indices]
    balanced_labels = labels[balanced_indices]

    return balanced_data, balanced_labels

In [None]:
class EarlyStopping:
    def __init__(self, patience=5):
        """
        Initializes the early stopping mechanism based on divergence detection.

        Args:
            patience (int): Number of consecutive epochs with increasing validation loss
                            before stopping.
        """
        self.patience = patience
        self.best_loss = None
        self.counter = 0
        self.early_stop = False
        self.best_model_state = None

    def __call__(self, val_loss, model):
        """
        Checks if the validation loss is diverging and updates the state accordingly.

        Args:
            val_loss (float): Current epoch's validation loss.
            model (torch.nn.Module): The model being trained.
        """
        if self.best_loss is None or val_loss < self.best_loss:
            # Improvement detected
            self.best_loss = val_loss
            self.best_model_state = model.state_dict()
            self.counter = 0
        else:
            # Validation loss increased
            self.counter += 1
            if self.counter >= self.patience:
                print(f"Divergence detected. Stopping training after {self.counter} epochs.")
                self.early_stop = True

    def load_best_model(self, model):
        """
        Restores the model to the state with the lowest validation loss.

        Args:
            model (torch.nn.Module): The model to restore.
        """
        model.load_state_dict(self.best_model_state)




In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# Set the device to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#set random seed for reproducibility
torch.manual_seed(42)
debug_mode_flag = False

# Debug mode flag
DEBUG_MODE = False


class EEGNet(nn.Module):
    def __init__(self, num_classes = 3,num_channels = 20 , num_timepoints = 5120):
        super(EEGNet, self).__init__()
        self.T = num_timepoints
        
        # Layer 1
        self.conv1 = nn.Conv2d(1, 16, (1, num_channels), padding = 0)
        self.batchnorm1 = nn.BatchNorm2d(16, False)
        
        # Layer 2
        self.padding1 = nn.ZeroPad2d((16, 17, 0, 1))
        self.conv2 = nn.Conv2d(1, 4, (2, 32))
        self.batchnorm2 = nn.BatchNorm2d(4, False)
        self.pooling2 = nn.MaxPool2d(2, 4)
        
        # Layer 3
        self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))
        self.conv3 = nn.Conv2d(4, 4, (8, 4))
        self.batchnorm3 = nn.BatchNorm2d(4, False)
        self.pooling3 = nn.MaxPool2d((2, 4))
        
        # FC Layer
        # NOTE: This dimension will depend on the number of timestamps per sample in your data.
        self.fc1 = nn.Linear(2560, num_classes)
        

    def forward(self, x):
        if DEBUG_MODE: print(f"Input shape: {x.shape}")
        
        # Layer 1
        x = F.elu(self.conv1(x))
        if DEBUG_MODE: print(f"After conv1: {x.shape}")
        x = self.batchnorm1(x)
        if DEBUG_MODE: print(f"After batchnorm1: {x.shape}")
        x = F.dropout(x, 0.25)
        if DEBUG_MODE: print(f"After dropout1: {x.shape}")
        x = x.permute(0, 3, 1, 2)
        if DEBUG_MODE: print(f"After permute: {x.shape}")
        
        # Layer 2
        x = self.padding1(x)
        if DEBUG_MODE: print(f"After padding1: {x.shape}")
        x = F.elu(self.conv2(x))
        if DEBUG_MODE: print(f"After conv2: {x.shape}")
        x = self.batchnorm2(x)
        if DEBUG_MODE: print(f"After batchnorm2: {x.shape}")
        x = F.dropout(x, 0.25)
        if DEBUG_MODE: print(f"After dropout2: {x.shape}")
        x = self.pooling2(x)
        if DEBUG_MODE: print(f"After pooling2: {x.shape}")
        
        # Layer 3
        x = self.padding2(x)
        if DEBUG_MODE: print(f"After padding2: {x.shape}")
        x = F.elu(self.conv3(x))
        if DEBUG_MODE: print(f"After conv3: {x.shape}")
        x = self.batchnorm3(x)
        if DEBUG_MODE: print(f"After batchnorm3: {x.shape}")
        x = F.dropout(x, 0.25)
        if DEBUG_MODE: print(f"After dropout3: {x.shape}")
        x = self.pooling3(x)
        if DEBUG_MODE: print(f"After pooling3: {x.shape}")
        
        # FC Layer
        x = x.reshape(-1, 4*2*x.size(3))
        if DEBUG_MODE: print(f"After flattening: {x.shape}")
        x = F.sigmoid(self.fc1(x))
        
        return x


In [None]:
model = EEGNet()


demo_input = torch.randn(32, 1, 5120,20)

# Run forward pass
output = model(demo_input)

print("Output shape:", output.shape)
print("Output:", output)
from torchinfo import summary
# Print the model summary
summary(model, input_size=(32, 1, 5120, 20), col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds"], row_settings=["var_names"])

torch.Size([8, 3])


In [None]:
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# pca = sklearn.decomposition.PCA(3)
epochs = 100


fold_indices = np.arange(5)
fold_indices = np.random.permutation(fold_indices)
val_fold_indices = np.roll(fold_indices, 1)


num_classes = 3


In [None]:
import optuna
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import balanced_accuracy_score

# import warnings
# warnings.filterwarnings("ignore", category=UserWarning, message=".*step is already reported.*")


def objective(trial):
    # Hyperparameter suggestions
    

    learning_rate = trial.suggest_float("learning_rate", 1e-8, 1e-2, log=True)
    optimizer_name = "Adam"
    weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)
    batch_size = trial.suggest_categorical("batch_size", [10, 16, 32])
    label_smoothing = 0.3
    factor = 1

    #print all parameters
    print(f"Trial {trial.number}: learning_rate={learning_rate}, optimizer_name={optimizer_name}, weight_decay={weight_decay}, batch_size={batch_size}, label_smoothing={label_smoothing}")
    fold_accuracies = []

    for test_fold_idx in range(5):
        test_fold = fold_indices[test_fold_idx]
        remaining_folds = [fold_indices[i] for i in range(5) if i != test_fold_idx]
        val_fold_idx = test_fold_idx % 4
        val_fold = remaining_folds[val_fold_idx]
        train_folds = [fold for fold in remaining_folds if fold != val_fold]

        train_data = np.concatenate([eeg_folds[j] for j in train_folds])
        train_data = train_data
        train_labels = np.concatenate([labels_folds[j] for j in train_folds])

        val_data = eeg_folds[val_fold]
        val_data = val_data
        val_labels = labels_folds[val_fold]

        test_data = eeg_folds[test_fold]
        test_data = test_data
        test_labels = labels_folds[test_fold]

        balanced_train_data, balanced_train_labels = data_balancer(train_data, train_labels, factor=factor)

        train_dataset = TensorDataset(torch.tensor(balanced_train_data, dtype=torch.float32),
                                      torch.tensor(balanced_train_labels, dtype=torch.long))
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True)

        val_dataset = TensorDataset(torch.tensor(val_data, dtype=torch.float32),
                                    torch.tensor(val_labels, dtype=torch.long))
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)

        test_dataset = TensorDataset(torch.tensor(test_data, dtype=torch.float32),
                                     torch.tensor(test_labels, dtype=torch.long))
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = EEGNet(num_classes=num_classes, num_channels=20, num_timepoints=5120).to(device)
        criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
        optimizer_cls = {"Adam": optim.Adam, "AdamW": optim.AdamW, "SGD": optim.SGD}
        optimizer = optimizer_cls[optimizer_name](model.parameters(), lr=learning_rate, weight_decay=weight_decay)

        early_stopping = EarlyStopping(patience=10)

        epochs = 30
        for epoch in range(epochs):
            model.train()
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for val_inputs, val_labels in val_loader:
                    val_inputs, val_labels = val_inputs.to(device, non_blocking=True), val_labels.to(device, non_blocking=True)
                    val_outputs = model(val_inputs)
                    loss = criterion(val_outputs, val_labels)
                    val_loss += loss.item()

            val_loss /= len(val_loader)

            early_stopping(val_loss, model)
            if early_stopping.early_stop:
                print(f"Early stopping at epoch {epoch}")
                break

        early_stopping.load_best_model(model)

        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        fold_acc = balanced_accuracy_score(all_labels, all_preds)
        fold_accuracies.append(fold_acc)
        print(f"Trial {trial.number}, Fold {test_fold_idx+1}: Test Accuracy = {fold_acc:.4f}")

        del model
        torch.cuda.empty_cache()

    mean_accuracy = np.mean(fold_accuracies)
    print(f"Trial {trial.number}: Mean Accuracy = {mean_accuracy:.4f}, Fold Accuracies = {fold_accuracies}")

    trial.set_user_attr("fold_accuracies", fold_accuracies)
    trial.report(mean_accuracy, step=0)  # Single report after all folds

    if trial.should_prune():
        raise optuna.TrialPruned()

    return mean_accuracy

# Start Optuna Study
study = optuna.create_study(
    direction="maximize",
    sampler=optuna.samplers.TPESampler(),
    pruner=optuna.pruners.MedianPruner(n_warmup_steps=5)
)

study.optimize(objective, n_trials=300, n_jobs=3)

# Best result
print("Best hyperparameters:", study.best_params)


[I 2025-05-20 21:30:01,632] A new study created in memory with name: no-name-b7caaa91-ba18-46ec-b10b-637faa3b9b9f


hidden_size: 1024, num_layers: 4, learning_rate: 0.00018671377065131162, optimizer_name: Adam, weight_decay: 1.4865027075569687e-06, batch_size: 16, label_smoothing: 0.3, factor: 1hidden_size: 128, num_layers: 1, learning_rate: 4.520190118492533e-06, optimizer_name: Adam, weight_decay: 0.0003089835017969355, batch_size: 32, label_smoothing: 0.3, factor: 1
hidden_size: 1024, num_layers: 3, learning_rate: 3.5092887725468008e-06, optimizer_name: Adam, weight_decay: 4.3742946042079865e-06, batch_size: 16, label_smoothing: 0.3, factor: 1



[W 2025-05-20 21:30:22,210] Trial 2 failed with parameters: {'hidden_size': 128, 'num_layers': 1, 'dropout': 0.4, 'learning_rate': 4.520190118492533e-06, 'weight_decay': 0.0003089835017969355, 'batch_size': 32} because of the following error: RuntimeError('DataLoader worker (pid(s) 1374) exited unexpectedly').
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1251, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/queues.py", line 114, in get
    raise Empty
_queue.Empty

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "<ipython-input-9-1476bc4554bc>", line 82, in objective
    for