In [1]:
import numpy as np
import torch
from torch import nn
import torchvision.transforms as transform
import braindecode 
from braindecode.models import ShallowFBCSPNet
from braindecode.datasets import BaseDataset, BaseConcatDataset,create_from_X_y
from braindecode.models.util import to_dense_prediction_model, get_output_shape
import pandas as pd
import resampy
from skorch.callbacks import LRScheduler
from skorch.helper import predefined_split
from config import *
from dataset import *
from braindecode.preprocessing import create_fixed_length_windows
from mne import set_log_level
set_log_level(False)

Tensorflow not install, you could not use those pipelines


In [2]:
preproc_functions = []
preproc_functions.append( lambda data, fs: (data[:, int(sec_to_cut * fs):-int(sec_to_cut * fs)], fs))
preproc_functions.append(lambda data, fs: (data[:, :int(duration_recording_mins * 60 * fs)], fs))
if max_abs_val is not None:
    preproc_functions.append(lambda data, fs:(np.clip(data, -max_abs_val, max_abs_val), fs))
preproc_functions.append(lambda data, fs: (resampy.resample(data, fs,sampling_freq,axis=1,filter='kaiser_fast'),sampling_freq))
if divisor is not None:
    preproc_functions.append(lambda data, fs: (data / divisor, fs))
dataset = DiagnosisSet(n_recordings=n_recordings,
                           max_recording_mins=max_recording_mins,
                           preproc_functions=preproc_functions,
                           data_folders=data_folders,
                           train_or_eval='train',
                           sensor_types=sensor_types)
if test_on_eval:
    test_dataset = DiagnosisSet(n_recordings=n_recordings,
                           max_recording_mins=max_recording_mins,
                           preproc_functions=preproc_functions,
                           data_folders=data_folders,
                           train_or_eval='eval',
                           sensor_types=sensor_types)
del preproc_functions

In [3]:
X,y=dataset.load()
if test_on_eval:
    test_x,test_y=test_dataset.load()

In [4]:
def create_set(X, y, inds):
    """
    X list and y nparray
    :return: 
    """
    new_X = []
    for i in inds:
        new_X.append(X[i])
    new_y = y[inds]
    return (new_X, new_y)
#Use of TrainValidTestSplitter is not necessary in newer versions of braindecode
class TrainValidSplitter(object):
    def __init__(self, n_folds, i_valid_fold, shuffle):
        self.n_folds = n_folds
        self.i_valid_fold = i_valid_fold
        self.rng = np.random.RandomState(39483948)
        self.shuffle = shuffle

    def split(self, X, y):
        if len(X) < self.n_folds:
            raise ValueError("Less Trials: {:d} than folds: {:d}".format(
                len(X), self.n_folds
            ))
        indices=np.arange(len(y))
        #Compared to paper, the valid set will be unbalanced
        batch_size=len(X)//self.n_folds
        if self.shuffle:
            self.rng.shuffle(indices)
        valid_inds=indices[self.i_valid_fold*batch_size:(self.i_valid_fold+1)*batch_size]
        train_inds = np.setdiff1d(indices,valid_inds)
        train_set = create_set(X, y, train_inds)
        valid_set = create_set(X, y, valid_inds)
        return train_set, valid_set

In [5]:
if test_on_eval==False:
    splitter=TrainValidSplitter(n_folds,i_test_fold,True)
    train_set,valid_set=splitter.split(X,y)
    del X,y
    X,y=train_set
    valid_X,valid_y=valid_set
    del train_set,valid_set

In [6]:
ch_names=['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1','FP2', 'FZ', 'O1', 'O2','P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']
#we take a 20 second stride as 1 sample takes too long and 1 second stride has little overlap
stride=sampling_freq*20
train_set=create_from_X_y(X,y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                       window_stride_samples=stride)
if test_on_eval==False:
    valid_set=create_from_X_y(valid_X,valid_y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                        window_stride_samples=stride)
    del valid_X,valid_y
elif test_on_eval:
    test_set=create_from_X_y(test_x,test_y,sfreq=sampling_freq,drop_last_window=True,ch_names=ch_names,window_size_samples=input_time_length,
                        window_stride_samples=stride)
    del test_x,test_y
del ch_names,X,y

In [7]:
n_classes = 2
#The final conv length is auto to ensure that output will give two values for single EEG window
model = ShallowFBCSPNet(n_chans,
                                n_classes,
                                n_filters_time=n_start_chans,
                                n_filters_spat=n_start_chans,
                                input_window_samples=input_time_length,
                                final_conv_length='auto',)
if cuda:
    model.cuda()
if model_name=="shallow":
    optimizer_lr = 0.0000625
    optimizer_weight_decay = 0

In [9]:
device = 'cuda' if cuda else 'cpu'
if test_on_eval==False:
    classifier = braindecode.EEGClassifier(
        model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(valid_set),
        optimizer__lr=optimizer_lr,
        optimizer__weight_decay=optimizer_weight_decay,
        iterator_train__shuffle=True,
        batch_size=batch_size,
        device=device,
        callbacks=["accuracy","f1"],
        warm_start=True,
        )
elif test_on_eval:
    classifier = braindecode.EEGClassifier(
        model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(test_set),
        optimizer__lr=optimizer_lr,
        optimizer__weight_decay=optimizer_weight_decay,
        iterator_train__shuffle=True,
        batch_size=batch_size,
        device=device,
        callbacks=["accuracy","f1"],
        warm_start=True,
        )
classifier.initialize()

<class 'braindecode.classifier.EEGClassifier'>[initialized](
  module_=ShallowFBCSPNet(
    (ensuredims): Ensure4d()
    (dimshuffle): Expression(expression=transpose_time_to_spat) 
    (conv_time): Conv2d(1, 25, kernel_size=(25, 1), stride=(1, 1))
    (conv_spat): Conv2d(25, 25, kernel_size=(1, 21), stride=(1, 1), bias=False)
    (bnorm): BatchNorm2d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_nonlin_exp): Expression(expression=square) 
    (pool): AvgPool2d(kernel_size=(75, 1), stride=(15, 1), padding=0)
    (pool_nonlin_exp): Expression(expression=safe_log) 
    (drop): Dropout(p=0.5, inplace=False)
    (conv_classifier): Conv2d(25, 2, kernel_size=(394, 1), stride=(1, 1))
    (softmax): LogSoftmax(dim=1)
    (squeeze): Expression(expression=squeeze_final_output) 
  ),
)

In [10]:
if test_on_eval:
    path='model/shallowII.pt'
elif test_on_eval==False:
    path='model/shallow.pt'
weights=torch.load(path)
classifier.module_.load_state_dict(weights["model"])
classifier.optimizer_.load_state_dict(weights["optimizer"])

In [None]:
classifier.fit(test_set,y=None,epochs=1)

In [None]:
torch.save({"model":classifier.module_.state_dict(),"optimizer":classifier.optimizer_.state_dict()}, path)

In [11]:
weights=torch.load('model/shallowbest.pt')
classifier.module_.load_state_dict(weights["model"])

<All keys matched successfully>

In [12]:
pred_labels=classifier.predict(test_set)
actual_labels=[label[1] for label in test_set]
actual_labels=np.array(actual_labels)

In [13]:
accuracy=np.mean(pred_labels==actual_labels)
print(f"Accuracy:{accuracy}")
tp=np.sum(pred_labels*actual_labels)
precision=tp/np.sum(pred_labels)
recall=tp/np.sum(actual_labels)
f1=2*precision*recall/(precision+recall)
print(f"F1-Score:{f1}")

Accuracy:0.713443727196183
F1-Score:0.6208689194207203
