# Set-up

## Installing libraries and libcudnn8

In [None]:
## only with GPU
FILEID = "1-bPsREsUCOiJHzIqi8DQrfSjTAf5VAW_"
!apt-get install --allow-change-held-packages libcudnn8=8.1.1.33-1+cuda11.2 -y
!pip install -U git+https://github.com/UN-GCPDS/python-gcpds.databases
!pip install mne
!pip install scikeras[tensorflow]
!dir

In [None]:
!git clone https://github.com/Marcos-L/MI-EEG-ClassMeth

In [None]:
cd MI-EEG-ClassMeth

In [None]:
from MI_EEG_ClassMeth.FeatExtraction import TimeFrequencyRpr

In [None]:
cd ../

## Import libraries

In [None]:
# gcpds datasets
from gcpds.databases import GIGA_MI_ME
from gcpds.databases.BCI_Competition_IV import Dataset_2a
import gcpds.databases as loaddb

# freq filter 
#from MI_EEG_ClassMeth.FeatExtraction import TimeFrequencyRpr

# general
import numpy as np
from scipy.signal import resample
import pickle
import warnings
import mne
from time import time
warnings.filterwarnings('ignore')

# tensorlfow 
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Dropout, Conv2D, AveragePooling2D, BatchNormalization, Input, Flatten
from tensorflow.keras.constraints import max_norm
from tensorflow.keras.layers import Layer
from tensorflow.keras.regularizers import L1L2

# scikeras
from scikeras.wrappers import KerasClassifier
from sklearn.model_selection import GridSearchCV,StratifiedShuffleSplit
from sklearn.metrics import make_scorer
from sklearn.metrics import accuracy_score, cohen_kappa_score, roc_auc_score

## Define functions

In [None]:
def kappa(y_true, y_pred):
    return cohen_kappa_score(np.argmax(y_true, axis = 1),np.argmax(y_pred, axis = 1))

## GIGA Science dataset

In [None]:
def load_GIGA(db: GIGA_MI_ME,
              sbj: int,
              fs: float, 
              f_bank: np.ndarray, 
              vwt: np.ndarray, 
              new_fs: float) -> np.ndarray:

    eeg_ch_names = ['Fp1','Fpz','Fp2',
              'AF7','AF3','AFz','AF4','AF8',
              'F7','F5','F3','F1','Fz','F2','F4','F6','F8',
              'FT7','FC5','FC3','FC1','FCz','FC2','FC4','FC6','FT8',
              'T7','C5','C3','C1','Cz','C2','C4','C6','T8',
              'TP7','CP5','CP3','CP1','CPz','CP2','CP4','CP6','TP8',
              'P9','P7','P5','P3','P1','Pz','P2','P4','P6','P8','P10',
              'PO7','PO3','POz','PO4','PO8',
              'O1','Oz','O2',
              'Iz']

    index_eeg_chs = db.format_channels_selectors(channels = eeg_ch_names) - 1

    tf_repr = TimeFrequencyRpr(sfreq = fs, f_bank = f_bank, vwt = vwt)

    db.load_subject(sbj)
    X, y = db.get_data(classes = ['left hand mi', 'right hand mi']) #Load MI classes, all channels {EEG}, reject bad trials, uV
    X = X[:,index_eeg_chs,:] #spatial rearrangement
    X = np.squeeze(tf_repr.transform(X))
    #Resampling
    if new_fs == fs:
        print('No resampling, since new sampling rate same.')
    else:
        print("Resampling from {:f} to {:f} Hz.".format(fs, new_fs))
        X = resample(X, int((X.shape[-1]/fs)*new_fs), axis = -1)
    return X, y

## BCI2a dataset

In [None]:
def load_BCICIV2a(db: Dataset_2a,
               sbj: int,
               fs: float, 
               f_bank: np.ndarray, 
               vwt: np.ndarray, 
               new_fs: float) -> np.ndarray:

    tf_repr = TimeFrequencyRpr(sfreq = fs, f_bank = f_bank, vwt = vwt)
    # training 
    db.load_subject(sbj, mode = 'training')
    X_tr, y_tr = db.get_data() #Load all classes, all channels {EEG, EOG}, reject bad trials
    X_tr = X_tr[:,:-3,:] # pick EEG channels
    X_tr = X_tr*1e6 #uV
    X_tr = np.squeeze(tf_repr.transform(X_tr))
    # testing
    db.load_subject(sbj, mode = 'evaluation')
    X_ts, y_ts = db.get_data() #Load all classes, all channels {EEG, EOG}, reject bad trials
    X_ts = X_ts[:,:-3,:] # pick EEG channels
    X_ts = X_ts*1e6 #uV
    X_ts = np.squeeze(tf_repr.transform(X_ts))
    # merge both training and evaluation sets  
    X = np.concatenate([X_tr, X_ts], axis = 0)
    y = np.concatenate([y_tr, y_ts], axis = 0)    
    # Resampling
    if new_fs == fs:
        print('No resampling, since new sampling rate same.')
    else:
        print("Resampling from {:f} to {:f} Hz.".format(fs, new_fs))
        X = resample(X, int((X.shape[-1]/fs)*new_fs), axis = -1)
    return X, y

## Define the model (Gaussian functional conectivity network)

In [None]:
class GFC(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, batch_input_shape):
        self.gammad = self.add_weight(name = 'gammad',
                                shape = (),
                                initializer = 'zeros',
                                trainable = True)
        super().build(batch_input_shape)

    def call(self, X): 
        X = tf.transpose(X, perm  = (0, 3, 1, 2)) #(N, F, C, T)
        R = tf.reduce_sum(tf.math.multiply(X, X), axis = -1, keepdims = True) #(N, F, C, 1)
        D  = R - 2*tf.matmul(X, X, transpose_b = True) + tf.transpose(R, perm = (0, 1, 3, 2)) #(N, F, C, C)

        ones = tf.ones_like(D[0,0,...]) #(C, C)
        mask_a = tf.linalg.band_part(ones, 0, -1) #Upper triangular matrix of 0s and 1s (C, C)
        mask_b = tf.linalg.band_part(ones, 0, 0)  #Diagonal matrix of 0s and 1s (C, C)
        mask = tf.cast(mask_a - mask_b, dtype=tf.bool) #Make a bool mask (C, C)
        triu = tf.expand_dims(tf.boolean_mask(D, mask, axis = 2), axis = -1) #(N, F, C*(C-1)/2, 1)
        sigma = tfp.stats.percentile(tf.math.sqrt(triu), 50, axis = 2, keepdims = True) #(N, F, 1, 1)

        A = tf.math.exp(-1/(2*tf.pow(10., self.gammad)*tf.math.square(sigma))*D) #(N, F, C, C)
        A.set_shape(D.shape)
        return A

    def compute_output_shape(self, batch_input_shape):
        N, C, T, F = batch_input_shape.as_list()
        return tf.TensorShape([N, F, C, C])

    def get_config(self):
        base_config = super().get_config()
        return {**base_config}


class get_triu(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, batch_input_shape):
        super().build(batch_input_shape)

    def call(self, X): 
        N, F, C, C = X.shape
        ones = tf.ones_like(X[0,0,...]) #(C, C)
        mask_a = tf.linalg.band_part(ones, 0, -1) #Upper triangular matrix of 0s and 1s (C, C)
        mask_b = tf.linalg.band_part(ones, 0, 0)  #Diagonal matrix of 0s and 1s (C, C)
        mask = tf.cast(mask_a - mask_b, dtype=tf.bool) #Make a bool mask (C, C)
        triu = tf.expand_dims(tf.boolean_mask(X, mask, axis = 2), axis = -1) #(N, F, C*(C-1)/2, 1)

        triu.set_shape([N,F,int(C*(C-1)/2),1])
        return triu

    def compute_output_shape(self, batch_input_shape):
        N, F, C, C = batch_input_shape.as_list()
        return tf.TensorShape([N, F, int(C*(C-1)/2),1])

    def get_config(self):
        base_config = super().get_config()
        return {**base_config}
    
    
def GFC_triu_net_avg(nb_classes: int,
          Chans: int,
          Samples: int,
          l1: int = 0, 
          l2: int = 0, 
          dropoutRate: float = 0.5,
          filters: int = 1, 
          maxnorm: float = 2.0,
          maxnorm_last_layer: float = 0.5,
          kernel_time_1: int = 20,
          strid_filter_time_1: int = 1,
          bias_spatial: bool = False) -> Model:


    input_main   = Input((Chans, Samples, 1),name='Input')                    
    
    block        = Conv2D(filters,(1,kernel_time_1),strides=(1,strid_filter_time_1),
                            use_bias=bias_spatial,
                            kernel_constraint = max_norm(maxnorm, axis=(0,1,2))
                            )(input_main)
    
    block        = BatchNormalization(epsilon=1e-05, momentum=0.1)(block)

    block        = Activation('elu')(block)      
    
    block        = GFC()(block)

    block        = get_triu()(block)

    block        = AveragePooling2D(pool_size=(block.shape[1],1),strides=(1,1))(block)
    
    block        = BatchNormalization(epsilon=1e-05, momentum=0.1)(block)

    block        = Activation('elu')(block) 
    
    block        = Flatten()(block)    

    block        = Dropout(dropoutRate)(block) 

    block        = Dense(nb_classes, kernel_regularizer=L1L2(l1=l1,l2=l2),name='logits',
                              kernel_constraint = max_norm(maxnorm_last_layer)
                              )(block)

    softmax      = Activation('softmax',name='output')(block)
    
    return Model(inputs=input_main, outputs=softmax)

# Experiment

## Experiment configuration 

In [None]:
seed = 23
folds=5
epochs_train = 500 #500
data_set='GIGA_MI_ME'## BCICIV2a # GIGA_MI_ME

save_folder = 'GFC_triu_avg_128Hz'
Experiment=f'{data_set}__{folds}_fold'

model_name= f'{save_folder}{epochs_train}_epoch'


In [None]:
PATH=f'./{save_folder}/'
!mkdir '{PATH}'

## Run experiment

In [None]:
tf.random.set_seed(seed)

if data_set=='GIGA_MI_ME':
    subjects = np.arange(52)+1
    subjects = np.delete(subjects,[28,33])
    db = GIGA_MI_ME('/kaggle/input/giga-science-gcpds/GIGA_MI_ME/') # update path
    num_class = 2
    load_args = dict(db = db,
                fs = db.metadata['sampling_rate'],
                f_bank = np.asarray([[4., 40.]]),
                vwt = np.asarray([[2.5, 5]]),
                new_fs = 128.0)    
elif data_set=='BCICIV2a':
    subjects = np.arange(9)+1
    db = Dataset_2a('../input/bciciv2a/BCI_CIV_2a/') # update path
    num_class = 4
    load_args = dict(db = db,
                fs = db.metadata['sampling_rate'],
                f_bank = np.asarray([[4., 40.]]),
                vwt = np.asarray([[2.5, 6]]),
                new_fs = 128.0)
else:
    print('Select a valid dataset')            

t=time()
for sbj in [5,6]:#subjects[0:1]:
    load_args['sbj'] = sbj
    if data_set=='GIGA_MI_ME':
        X_train, y_train = load_GIGA(**load_args)
    elif data_set=='BCICIV2a':
        X_train, y_train = load_BCICIV2a(**load_args)
    X_train = X_train[..., np.newaxis]
    Y_train = tf.keras.utils.to_categorical(y_train, num_classes=num_class)
    # build model
    clf = KerasClassifier(
        GFC_triu_net_avg,
        random_state=seed,
        
        #model hyperparameters
        nb_classes=num_class, 
        Chans = X_train.shape[1], 
        Samples = X_train.shape[2],
        dropoutRate=0.5,
        l1 = 0, l2 = 0,
        filters=2, maxnorm=2.0,maxnorm_last_layer=0.5,
        kernel_time_1=25,strid_filter_time_1= 1,
        bias_spatial = False,

        #model config
        verbose=0,
        batch_size=500, #full batch        
        loss=tf.keras.losses.CategoricalCrossentropy(),
        optimizer="adam",
        optimizer_learning__rate=0.1,
        metrics = ['accuracy'],
        epochs = epochs_train
    )
    # search params
    param_grid =  {
                'filters':[2,4], #2,4
                'kernel_time_1':[25,50], #25,50
                }
    
    #Gridsearch
    scoring = {"AUC": 'roc_auc', "Accuracy": make_scorer(accuracy_score),'Kappa':make_scorer(kappa)}
    
    SSS = StratifiedShuffleSplit(n_splits = folds, test_size = 0.2, random_state = seed)
    splits = list(SSS.split(X_train, Y_train))
    
    cv = GridSearchCV(clf,param_grid,cv=SSS,
                         verbose=1,n_jobs=1,
                         scoring=scoring,
                         refit=False,
                            )
    # frind best params with gridsearch
    cv.fit(X_train,Y_train)
    cv_results = cv.cv_results_
    best_index = np.argmax(cv_results['mean_test_Accuracy'])
    best_params = cv_results['params'][best_index] 
    
    tf.random.set_seed(np.random.randint(1000))
    
    best_model = GFC_triu_net_avg(
        #model hyperparameters
        nb_classes=num_class, 
        Chans = X_train.shape[1], 
        Samples = X_train.shape[2],
        dropoutRate=0.5,
        l1 = 0, l2 = 0,
        maxnorm=2.0,maxnorm_last_layer=0.5,
        strid_filter_time_1= 1,
        bias_spatial = False, **best_params)
    
    best_model.compile(      
        loss=tf.keras.losses.CategoricalCrossentropy(),
        optimizer= tf.keras.optimizers.Adam(learning_rate=0.1),
        metrics = ['categorical_accuracy'],
        )
    
    best_model.fit(X_train[splits[best_index][0]], Y_train[splits[best_index][0]],
                   validation_data = (X_train[splits[best_index][1]], Y_train[splits[best_index][1]]),
                   callbacks = [tf.keras.callbacks.TerminateOnNaN()],
                   verbose=1,
                   batch_size=500, #full batch 
                   epochs = 28
                  )
    # best score
    print('Subject',sbj,'Accuracy',np.max(cv_results['mean_test_Accuracy']),'elapsed time',time()-t)
    print('---------')
    print(f'Best Score: {max(cv.cv_results_["mean_test_Accuracy"])}; Mean Test: {np.mean(cv.cv_results_["mean_test_Accuracy"])}')
    print('---------')
    for train, val in splits:
        acc = tf.keras.metrics.CategoricalAccuracy()
        acc.update_state(Y_train[val], best_model.predict(X_train[val]))
        print(f'Acc: {acc.result()}')
    
    cv.cv_results_['best_index_']=best_index
    
    #########
    best_model.save_weights(f'{PATH}Model_{Experiment}_{model_name}_sujeto_{sbj}_'+data_set+'_4_40_weights.h5')
    with open(PATH+'Results_'+Experiment+'_'+model_name+'_sujeto_'+str(sbj)+'_'+data_set+'_4_40.p','wb') as f:
        pickle.dump(cv.cv_results_,f)     

In [None]:
!rm -r ./MI_EEG_ClassMeth
!rm -r ./__MACOSX
!rm ./MI_EEG_ClassMeth.zip