In [None]:
import time
import os
import mne
mne.set_log_level('ERROR')

from warnings import filterwarnings
filterwarnings('ignore')


from IPython.utils import io

import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D


import torch
from torch.nn.functional import relu
from torch.utils.data import WeightedRandomSampler

from braindecode import EEGClassifier
from braindecode.training.losses import CroppedLoss
from braindecode.models import Deep4Net,ShallowFBCSPNet,EEGNetv4, TCN
from braindecode.util import set_random_seeds
from braindecode.models.util import to_dense_prediction_model, get_output_shape

from braindecode.datautil.windowers import create_fixed_length_windows
from braindecode.datautil.serialization import  load_concat_dataset

from braindecode.datasets import BaseConcatDataset
from braindecode.datautil.preprocess import preprocess, Preprocessor, exponential_moving_standardize


from braindecode.training import trial_preds_from_window_preds


from functools import partial 
from skorch.callbacks import LRScheduler, EarlyStopping,Checkpoint, EpochScoring
from skorch.helper import predefined_split

# Hyperparameters

In [None]:
# model-specific
model_name = 'BD-Deep4'
drop_prob=0.5
batch_size=64
lr=0.01
n_epochs=35
weight_decay=0.0005

# for all models
result_folder= '/home/results/TUAB/
train_folder='/home/data/preprocessed_TUAB/final_train/'
eval_folder= '/home/data/preprocessed_TUAB/final_eval/'
task_name = 'train_complete_set'
ids_to_load_train=None 


seed= 20110407



n_classes = 2
# Extract number of chans from dataset
n_chans = 21

input_window_samples =6000 

## Set random seeds for reproducibility

In [None]:
from braindecode.util import set_random_seeds
cuda = True
# Set random seed to be able to reproduce results
set_random_seeds(seed=seed, cuda=cuda)

# Model definition

In [None]:
import torch as th

th.backends.cudnn.benchmark = True

In [None]:
import torch
from braindecode.models import Deep4Net,ShallowFBCSPNet,EEGNetv4, TCN



## BD-Deep4
n_start_chans = 25
final_conv_length = 1
n_chan_factor = 2
stride_before_pool = True
model = Deep4Net(
            n_chans, n_classes,
            n_filters_time=n_start_chans,
            n_filters_spat=n_start_chans,
            input_window_samples=input_window_samples,
            n_filters_2=int(n_start_chans * n_chan_factor),
            n_filters_3=int(n_start_chans * (n_chan_factor ** 2.0)),
            n_filters_4=int(n_start_chans * (n_chan_factor ** 3.0)),
            final_conv_length=final_conv_length,
            stride_before_pool=stride_before_pool,
            drop_prob=drop_prob)

# Send model to GPU
if cuda:
    model.cuda()
from braindecode.models.util import to_dense_prediction_model, get_output_shape
to_dense_prediction_model(model)



In [None]:
## BD-Shallow
n_start_chans = 40
final_conv_length = 25

model = ShallowFBCSPNet(n_chans,n_classes,
                        input_window_samples=input_window_samples,
                        n_filters_time=n_start_chans,
                        n_filters_spat=n_start_chans,
                        final_conv_length= final_conv_length,
                        drop_prob=drop_prob)
# Send model to GPU
if cuda:
    model.cuda()
    
from braindecode.models.util import to_dense_prediction_model, get_output_shape
to_dense_prediction_model(model)

n_preds_per_input = get_output_shape(model, n_chans, input_window_samples)[2]


In [None]:
#BD-TCN

n_chan_factor = 2
stride_before_pool = True
l2_decay = 1.7491630095065614e-08
gradient_clip = 0.25

model = TCN(
    n_in_chans=n_chans, n_outputs=n_classes,
    n_filters=55,
    n_blocks=5,
    kernel_size=16,
    drop_prob=drop_prob,
    add_log_softmax=True)

    # Send model to GPU
if cuda:
    model.cuda()
    
from braindecode.models.util import to_dense_prediction_model, get_output_shape
n_preds_per_input = get_output_shape(model, n_chans, input_window_samples)[2]

In [None]:
#BD-EEGNet

final_conv_length=18
model = EEGNetv4(
    n_chans, n_classes,
    input_window_samples=input_window_samples,
    final_conv_length=final_conv_length,
    drop_prob=drop_prob)
if cuda:
    model.cuda()
    
from braindecode.models.util import to_dense_prediction_model, get_output_shape
to_dense_prediction_model(model)

n_preds_per_input = get_output_shape(model, n_chans, input_window_samples)[2]

## Data Loading

In [None]:
%%time 
from braindecode.datasets.tuh import TUHAbnormal
data_path = '/data/datasets/TUH/EEG/tuh_eeg_abnormal/v2.0.0/edf/'
dataset = TUHAbnormal(
    path=data_path,
    recording_ids=None,  # loads the n chronologically first recordings
    target_name=target_name,  # age, gender, pathology
    preload=False,
    add_physician_reports=False,
)

In [None]:


from braindecode.datasets import BaseConcatDataset
dataset = BaseConcatDataset(dataset.datasets[:n_recordings_to_load])



In [None]:
%%time
from braindecode.preprocessing import preprocess, Preprocessor, scale as multiply
import numpy as np
from copy import deepcopy


whole_train_set = dataset.split('train')['True']
whole_eval_set = dataset.split('train')['False']

short_ch_names = sorted([
                'A1', 'A2', 'C3', 'C4', 'Cz', 'F3', 'F4', 'F7', 'F8',
                'Fp1', 'Fp2', 'Fz', 'O1', 'O2', 'P3', 'P4', 'Pz', 'T3',
                 'T4', 'T5', 'T6'
            ])
ar_ch_names = sorted([
    'EEG A1-REF', 'EEG A2-REF',
    'EEG FP1-REF', 'EEG FP2-REF', 'EEG F3-REF', 'EEG F4-REF', 'EEG C3-REF',
    'EEG C4-REF', 'EEG P3-REF', 'EEG P4-REF', 'EEG O1-REF', 'EEG O2-REF',
    'EEG F7-REF', 'EEG F8-REF', 'EEG T3-REF', 'EEG T4-REF', 'EEG T5-REF',
    'EEG T6-REF', 'EEG FZ-REF', 'EEG CZ-REF', 'EEG PZ-REF'])
le_ch_names = sorted([
    'EEG A1-LE', 'EEG A2-LE',
    'EEG FP1-LE', 'EEG FP2-LE', 'EEG F3-LE', 'EEG F4-LE', 'EEG C3-LE',
    'EEG C4-LE', 'EEG P3-LE', 'EEG P4-LE', 'EEG O1-LE', 'EEG O2-LE',
    'EEG F7-LE', 'EEG F8-LE', 'EEG T3-LE', 'EEG T4-LE', 'EEG T5-LE',
    'EEG T6-LE', 'EEG FZ-LE', 'EEG CZ-LE', 'EEG PZ-LE'])
assert len(short_ch_names) == len(ar_ch_names) == len(le_ch_names)
ar_ch_mapping = {ch_name: short_ch_name for ch_name, short_ch_name in zip(
    ar_ch_names, short_ch_names)}
le_ch_mapping = {ch_name: short_ch_name for ch_name, short_ch_name in zip(
    le_ch_names, short_ch_names)}
ch_mapping = {'ar': ar_ch_mapping, 'le': le_ch_mapping}



def custom_rename_channels(raw, mapping):
    # rename channels which are dependent on referencing:
    # le: EEG 01-LE, ar: EEG 01-REF
    # mne fails if the mapping contains channels as keys that are not present
    # in the raw
    reference = raw.ch_names[0].split('-')[-1].lower()
    assert reference in ['le', 'ref'], 'unexpected referencing'
    reference = 'le' if reference == 'le' else 'ar'
    raw.rename_channels(mapping[reference])


def custom_crop(raw, tmin=0.0, tmax=None, include_tmax=True):
    # crop recordings to tmin – tmax. can be incomplete if recording
    # has lower duration than tmax
    # by default mne fails if tmax is bigger than duration
    tmax = min((raw.n_times - 1) / raw.info['sfreq'], tmax)
    raw.crop(tmin=tmin, tmax=tmax, include_tmax=include_tmax)


n_max_minutes=21
tmin = 1 * 60
tmax = n_max_minutes * 60
sfreq = 100

preprocessors = [
    Preprocessor(custom_crop, tmin=tmin, tmax=tmax, include_tmax=False,
                 apply_on_array=False),

    Preprocessor(custom_rename_channels, mapping=ch_mapping,
                 apply_on_array=False),
    Preprocessor('pick_channels', ch_names=short_ch_names, ordered=True),
 
    Preprocessor(multiply, factor=1e6, apply_on_array=True),
    Preprocessor(np.clip, a_min=-800, a_max=800, apply_on_array=True),
    
    Preprocessor('set_eeg_reference', ref_channels='average', ch_type='eeg'),

    Preprocessor('resample', sfreq=sfreq),
    Preprocessor('set_meas_date', meas_date=None)
    
]
# Preprocess the data
preprocess(whole_train_set, preprocessors)


# OR Preprocess and save dataset
preprocess(
            concat_ds=whole_train_set,
            preprocessors=preprocessors,
            n_jobs=4, 
            save_dir='/home/data/preprocessed_TUAB/final_train/', 
        )


preprocess(
            concat_ds=whole_eval_set,
            preprocessors=preprocessors,
            n_jobs=4, 
            save_dir='/home/data/preprocessed_TUAB/final_eval/', 
        )


## OR load preprocessed dataset

In [None]:
from braindecode.datautil.serialization import  load_concat_dataset
whole_train_set = load_concat_dataset(train_folder, preload=False, ids_to_load=ids_to_load_train)

whole_eval_set = load_concat_dataset(eval_folder, preload=False, ids_to_load=ids_to_load_train)

## load smaller subset

In [None]:
with open("./indices/TUAB-Random/indices_seed0_TUAB-Random_trainsize_100.pkl", 'rb') as f:
         ids_to_load_train = pickle.load(f) 


task_name = 'TUAB_subset_' + str(subset_size)

whole_train_set = load_concat_dataset(train_folder, preload=False, ids_to_load=ids_to_load_train)

#  Data Compute Window Creation

In [None]:
import pandas as pd

from braindecode.models.util import to_dense_prediction_model, get_output_shape

n_preds_per_input = get_output_shape(model, n_chans, input_window_samples)[2]

In [None]:
from braindecode.datautil.windowers import create_fixed_length_windows



window_train_set = create_fixed_length_windows(whole_train_set, 
                                                        start_offset_samples=0,
                                                        stop_offset_samples=None,
                                                        preload=True,
                                                        window_size_samples=input_window_samples,
                                                        window_stride_samples=n_preds_per_input,
                                                        drop_last_window=True,)


window_eval_set = create_fixed_length_windows(whole_eval_set,
                                            start_offset_samples=0,
                                            stop_offset_samples=None,preload=False,
                                            window_size_samples=input_window_samples,
                                            window_stride_samples=n_preds_per_input,
                                            drop_last_window=False,)


In [None]:
## Classifier definition and run training

clf = EEGClassifier(model,cropped=True,
                    criterion=CroppedLoss,
                    criterion__loss_function=torch.nn.functional.nll_loss,
                    optimizer=torch.optim.AdamW,
                    train_split=predefined_split(window_eval_set),
                    optimizer__lr=lr,
                    optimizer__weight_decay=weight_decay,
                    iterator_train__shuffle=True,
                    batch_size=batch_size,
                    callbacks=["accuracy",("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),],  #"accuracy",
                    device='cuda')


clf.fit(window_train_set, y=None, epochs=n_epochs)

In [None]:
## save trained classifer

In [None]:
df = pd.DataFrame(clf.history[:, results_columns], columns=results_columns,index=clf.history[:, 'epoch'])
df.to_pickle(result_path + '_df_history.pkl')
#save history
torch.save(clf.history, result_path + '_clf_history.py')

path = result_path + "model_{}.pt".format(seed)
torch.save(clf.module, path)
path = result_path + "state_dict_{}.pt".format(seed)
torch.save(clf.module.state_dict(), path)

clf.save_params(f_params=result_path +'model.pkl', f_optimizer= result_path +'opt.pkl', f_history=result_path +'history.json')

In [None]:
## evaluate performance test set and save prediction results

from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from braindecode.training import trial_preds_from_window_preds


pred_win = clf.predict_with_window_inds_and_ys(window_eval_set)



preds_per_trial= trial_preds_from_window_preds(pred_win['preds'], pred_win['i_window_in_trials'], pred_win['i_window_stops'])
mean_preds_per_trial = [np.mean(preds, axis=1) for preds in
                                preds_per_trial]
mean_preds_per_trial = np.array(mean_preds_per_trial)
y = window_eval_set.description['pathological']
column0, column1 = "non-pathological", "pathological"
a_dict = {column0: mean_preds_per_trial[:, 0],
          column1: mean_preds_per_trial[:, 1],
          "true_pathological": y}

assert len(y) == len(mean_preds_per_trial)

# store predictions
pd.DataFrame.from_dict(a_dict).to_csv(result_path + "predictions_eval_" + str(model_number) +
                                          ".csv")

deep_preds =  mean_preds_per_trial[:, 0] <=  mean_preds_per_trial[:, 1]
class_label = window_eval_set.description['pathological']
class_preds =deep_preds.astype(int)


    
from sklearn.metrics import confusion_matrix    
confusion_mat = confusion_matrix(class_label, class_preds)

print(classification_report(class_label, class_preds))