In [None]:
import numpy as np
import torch
from torch import nn
import torchvision.transforms as transform
import braindecode 
from braindecode.models import *
from braindecode.models.modules import Expression
from braindecode.models.functions import squeeze_final_output
from braindecode.datasets import BaseDataset, BaseConcatDataset,create_from_X_y
from braindecode.models.util import to_dense_prediction_model, get_output_shape
import pandas as pd
import resampy
from skorch.callbacks import Checkpoint
from skorch.helper import predefined_split
from config import *
from dataset import *
from braindecode.preprocessing import create_fixed_length_windows
from mne import set_log_level
set_log_level(False)

In [None]:
preproc_functions = []
preproc_functions.append( lambda data, fs: (data[:, int(sec_to_cut * fs):-int(sec_to_cut * fs)], fs))
preproc_functions.append(lambda data, fs: (data[:, :int(duration_recording_mins * 60 * fs)], fs))
if max_abs_val is not None:
    preproc_functions.append(lambda data, fs:(np.clip(data, -max_abs_val, max_abs_val), fs))
preproc_functions.append(lambda data, fs: (resampy.resample(data, fs,sampling_freq,axis=1,filter='kaiser_fast'),sampling_freq))
if divisor is not None:
    preproc_functions.append(lambda data, fs: (data / divisor, fs))
dataset = DiagnosisSet(n_recordings=n_recordings,
                           max_recording_mins=max_recording_mins,
                           preproc_functions=preproc_functions,
                           data_folders=data_folders,
                           train_or_eval='train',
                           sensor_types=sensor_types)
if test_on_eval:
    test_dataset = DiagnosisSet(n_recordings=n_recordings,
                           max_recording_mins=max_recording_mins,
                           preproc_functions=preproc_functions,
                           data_folders=data_folders,
                           train_or_eval='eval',
                           sensor_types=sensor_types)
del preproc_functions

In [3]:
X,y=dataset.load()
if test_on_eval:
    test_x,test_y=test_dataset.load()

In [4]:
def create_set(X, y, inds):
    """
    X list and y nparray
    :return: 
    """
    new_X = []
    for i in inds:
        new_X.append(X[i])
    new_y = y[inds]
    return (new_X, new_y)
#Use of TrainValidTestSplitter is not necessary in newer versions of braindecode
class TrainValidSplitter(object):
    def __init__(self, n_folds, i_valid_fold, shuffle):
        self.n_folds = n_folds
        self.i_valid_fold = i_valid_fold
        self.rng = np.random.RandomState(39483948)
        self.shuffle = shuffle

    def split(self, X, y):
        if len(X) < self.n_folds:
            raise ValueError("Less Trials: {:d} than folds: {:d}".format(
                len(X), self.n_folds
            ))
        indices=np.arange(len(y))
        #Compared to paper, the valid set will be unbalanced
        batch_size=len(X)//self.n_folds
        if self.shuffle:
            self.rng.shuffle(indices)
        valid_inds=indices[self.i_valid_fold*batch_size:(self.i_valid_fold+1)*batch_size]
        train_inds = np.setdiff1d(indices,valid_inds)
        train_set = create_set(X, y, train_inds)
        valid_set = create_set(X, y, valid_inds)
        return train_set, valid_set

In [5]:
if test_on_eval==False:
    splitter=TrainValidSplitter(n_folds,i_test_fold,True)
    train_set,valid_set=splitter.split(X,y)
    del X,y
    X,y=train_set
    valid_X,valid_y=valid_set
    del train_set,valid_set

In [6]:
ch_names=['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1','FP2', 'FZ', 'O1', 'O2','P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']
#we take a 20 second stride as 1 sample and 1 second stride takes too long
stride=sampling_freq*50
train_set=create_from_X_y(X,y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                       window_stride_samples=stride)
if test_on_eval==False:
    valid_set=create_from_X_y(valid_X,valid_y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                        window_stride_samples=stride)
    del valid_X,valid_y
elif test_on_eval:
    test_set=create_from_X_y(test_x,test_y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                        window_stride_samples=stride)
    del test_x,test_y
del ch_names,X,y

In [7]:
n_classes = 2
if model_name=="shallow":
    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0
    #The final conv length is auto to ensure that output will give two values for single EEG window
    model = ShallowFBCSPNet(n_chans,
                                    n_classes,
                                    n_filters_time=n_start_chans,
                                    n_filters_spat=n_start_chans,
                                    input_window_samples=input_time_length,
                                    final_conv_length='auto',)
#Works properly, fit the hybrid cnn
if model_name=="hybrid":
    optimizer_lr = init_lr
    optimizer_weight_decay = 0
    #The final conv length is auto to ensure that output will give two values for single EEG window
    model = HybridNet(n_chans, n_classes,input_window_samples=input_time_length,)
    test=torch.ones(size=(2,21,6000))
    out=model.forward(test)
    out_length=out.shape[2]
    model.final_conv=nn.Conv2d(100,n_classes,(out_length,1),bias=True,)
    model=nn.Sequential(model,Expression(torch.squeeze))
    out=model.forward(test)
    print(out.shape)
if cuda:
    model.cuda()
del test,out

torch.Size([2, 2])


In [8]:
device = 'cuda' if cuda else 'cpu'
monitor = lambda net: any(net.history[-1, ('valid_accuracy_best','valid_f1_best')])
cp=Checkpoint(monitor=monitor,dirname='model',f_params=f'{model_name}best_param.pkl',
               f_optimizer=f'{model_name}best_opt.pkl', f_history=f'{model_name}best_history.json')
if test_on_eval==False:
    classifier = braindecode.EEGClassifier(
        model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(valid_set),
        optimizer__lr=optimizer_lr,
        optimizer__weight_decay=optimizer_weight_decay,
        iterator_train__shuffle=True,
        batch_size=batch_size,
        device=device,
        callbacks=["accuracy","f1",cp],
        warm_start=True,
        )
elif test_on_eval:
    classifier = braindecode.EEGClassifier(
        model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(test_set),
        optimizer__lr=optimizer_lr,
        optimizer__weight_decay=optimizer_weight_decay,
        iterator_train__shuffle=True,
        batch_size=batch_size,
        device=device,
        callbacks=["accuracy","f1",cp],
        warm_start=True,
        )
classifier.initialize()
del model

In [9]:
if test_on_eval:
    path=f'{model_name}II'
elif test_on_eval==False:
    path=f'{model_name}'
try:
    classifier.load_params(
        f_params=f'model/{path}_param.pkl', f_optimizer=f'model/{path}_opt.pkl', f_history=f'model/{path}_history.json')
    print("Paramters Loaded")
except:
    pass

#weights=torch.load(path)
#classifier.module_.load_state_dict(weights["model"])
#classifier.optimizer_.load_state_dict(weights["optimizer"])

Paramters Loaded


In [None]:
#Shows the history of training the neural network
classifier.history_

In [10]:
classifier.fit(train_set,y=None,epochs=1)

  epoch    train_accuracy    train_f1    train_loss    valid_accuracy    valid_f1    valid_loss    cp         dur
-------  ----------------  ----------  ------------  ----------------  ----------  ------------  ----  ----------
     16            0.8757      0.6476        8.8817            0.6456      0.5209        7.1183        10708.8133


<class 'braindecode.classifier.EEGClassifier'>[initialized](
  module_=Sequential(
    (0): HybridNet(
      (reduced_deep_model): Sequential(
        (ensuredims): Ensure4d()
        (dimshuffle): Expression(expression=transpose_time_to_spat) 
        (conv_time): Conv2d(1, 20, kernel_size=(10, 1), stride=(1, 1))
        (conv_spat): Conv2d(20, 30, kernel_size=(1, 21), stride=(1, 1), bias=False)
        (bnorm): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv_nonlin): Expression(expression=elu) 
        (pool): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=0, dilation=(1, 1), ceil_mode=False)
        (pool_nonlin): Expression(expression=identity) 
        (drop_2): Dropout(p=0.5, inplace=False)
        (conv_2): Conv2d(30, 40, kernel_size=(10, 1), stride=(1, 1), dilation=(3, 1), bias=False)
        (bnorm_2): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (nonlin_2): Expression(expression=elu

In [11]:
classifier.save_params(
    f_params=f'model/{path}_param.pkl', f_optimizer=f'model/{path}_opt.pkl', f_history=f'model/{path}_history.json')
#torch.save({"model":classifier.module_.state_dict(),"optimizer":classifier.optimizer_.state_dict()}, path)

In [None]:
if test_on_eval==False:
    pred_labels=classifier.predict(valid_set)
    actual_labels=[label[1] for label in valid_set]
elif test_on_eval:
    pred_labels=classifier.predict(test_set)
    actual_labels=[label[1] for label in test_set]
actual_labels=np.array(actual_labels)
accuracy=np.mean(pred_labels==actual_labels)
print(f"Accuracy:{accuracy}")
tp=np.sum(pred_labels*actual_labels)
precision=tp/np.sum(pred_labels)
recall=tp/np.sum(actual_labels)
f1=2*precision*recall/(precision+recall)
print(f"F1-Score:{f1}")

In [14]:
#Test the model on proper test set according to paper
if test_on_eval:
    try:
        del train_set,test_set
    except:
        pass
    test_x,test_y=test_dataset.load()
    ch_names=['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1','FP2', 'FZ', 'O1', 'O2','P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']
    #Stride between windows is set to sampling frequency as written in paper
    test_set=create_from_X_y(test_x,test_y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                        window_stride_samples=sampling_freq)
    del test_x,test_y

In [15]:
if test_on_eval:    
    pred_labels=classifier.predict(test_set)
    actual_labels=[label[1] for label in test_set]
    actual_labels=np.array(actual_labels)
    accuracy=np.mean(pred_labels==actual_labels)
    print(f"Accuracy:{accuracy}")
    tp=np.sum(pred_labels*actual_labels)
    precision=tp/np.sum(pred_labels)
    recall=tp/np.sum(actual_labels)
    f1=2*precision*recall/(precision+recall)
    print(f"F1-Score:{f1}") 

Accuracy:0.6644824393913094
F1-Score:0.557845174326617
