### Import necessary packages 

In [None]:
import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow.keras.models import Model, Sequential, load_model
from tensorflow.keras import regularizers
from tensorflow.keras.layers import Input, Dense, Activation, Dropout, SpatialDropout1D, SpatialDropout2D, BatchNormalization
from tensorflow.keras.layers import Flatten, InputSpec, Layer, Concatenate, AveragePooling2D, MaxPooling2D, Reshape, Permute
from tensorflow.keras.layers import Conv2D, LSTM , SeparableConv2D, DepthwiseConv2D, ConvLSTM2D, LayerNormalization
from tensorflow.keras.layers import TimeDistributed, Lambda, AveragePooling1D, GRU, Attention, Dot, Add, Conv1D, Multiply
from tensorflow.keras.constraints import max_norm, unit_norm 
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.optimizers import Adam
from tensorflow_addons.layers import WeightNormalization
from tensorflow.keras.utils import plot_model

import random
import time
import numpy as np
import pandas as pd
import math
import mne
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import scipy.io
import scipy
from scipy import stats, fft, signal
import sklearn
from sklearn.model_selection import train_test_split, KFold
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score, confusion_matrix
from sklearn.cluster import KMeans
from sklearn.preprocessing import OneHotEncoder, LabelEncoder, StandardScaler
from timeit import default_timer as timer
from tabulate import tabulate
import seaborn as sns

### Receive inputs

In [None]:
while True:
    Data_name = int(input("Which dataset do you want to use for your analysis?"
                             "\n1. BCI competition IV dataset 2a\n2. OpenBMI motor imagery\n"))
    
    if Data_name in range(1,3):
        break
    else:
        print('Invalid input, please try again.')
        
#=========================================================

while True:
    Net_name = int(input("Which network do you want to use for your analysis?"
                         "\n1. EEG-Inception\n2. EEGNet\n3. EEG-TCNet\n4. EEG_ITNet\n"))
    if Net_name in range(1,5):
        break
    else:
        print('Invalid input, please try again.')
        
#=========================================================

while True:
    analysis_type = int(input("Please select the type of analysis:"
                         "\n1. Within-subject\n2. Cross-subject\n3. Cross-subject with fine-tuning\n"))
    if analysis_type in range(1,4):
        break
    else:
        print('Invalid input, please try again.')


### Initialisation

In [None]:
Fs = 125          # Sampling frequency
Win_start = 0     # Start of the trial (sec)
Win_end = 3       # End of the trial (sec)
n_ff = [2,4,8]    # Number of frequency filters for each inception module of EEG-ITNet
n_sf = [1,1,1]    # Number of spatial filters in each frequency sub-band of EEG-ITNet
batch_size = 32   # Mini-batch size
patience = 100    # Maximum number of epochs in case of no loss improvement
extra_epoch = 50  # Number of extra epochs for training with all labelled data
#=========================
if analysis_type==1:
    epochs = 600  # Number of training epochs
else:
    epochs = 300

### Opening dataset

In [None]:
if Data_name==1:
    BCI_Data = scipy.io.loadmat('BCI4_2a_Data')
    Data_BCI_Train = BCI_Data['Data_BCI_Train']
    Data_BCI_Test = BCI_Data['Data_BCI_Test']
else:
    BCI_Data = scipy.io.loadmat('OpenBMI')
    Data_BCI_Train = BCI_Data['Data_OpenBMI_Train']
    Data_BCI_Test = BCI_Data['Data_OpenBMI_Test']

### Defining networks

In [None]:
def SQ (in_tensor):
    return tf.squeeze(in_tensor, axis=1)

def Select(in_tensor):
    return in_tensor[:,-1,:]

def Network(Chans, Samples, out_type = 'single'):
    
    if Data_name==1:
        out_class = 4
    else:
        out_class = 2
    
    Input_block = Input(shape = (Chans, Samples, 1))
    
    #======================================================================================== 
    # EEG-Inception
    #========================================================================================
    
    if Net_name == 1:
        
        if analysis_type==1:
            drop_rate = 0.3
        else:
            drop_rate = 0.2
        
        block1 = Conv2D(8, (1, 64), padding='same')(Input_block)
        block1 = BatchNormalization()(block1)
        block1 = Activation('elu')(block1)
        block1 = Dropout(drop_rate)(block1)

        block1 = DepthwiseConv2D((Chans, 1), padding='valid', depth_multiplier = 2)(block1)
        block1 = BatchNormalization()(block1)
        block1 = Activation('elu')(block1)
        block1 = Dropout(drop_rate)(block1)

        #================================

        block2 = Conv2D(8, (1, 32), padding='same')(Input_block)
        block2 = BatchNormalization()(block2)
        block2 = Activation('elu')(block2)
        block2 = Dropout(drop_rate)(block2)

        block2 = DepthwiseConv2D((Chans, 1), padding='valid', depth_multiplier = 2)(block2)
        block2 = BatchNormalization()(block2)
        block2 = Activation('elu')(block2)
        block2 = Dropout(drop_rate)(block2)

        #================================

        block3 = Conv2D(8, (1, 16), padding='same')(Input_block)
        block3 = BatchNormalization()(block3)
        block3 = Activation('elu')(block3)
        block3 = Dropout(drop_rate)(block3)

        block3 = DepthwiseConv2D((Chans, 1), padding='valid', depth_multiplier = 2)(block3)
        block3 = BatchNormalization()(block3)
        block3 = Activation('elu')(block3)
        block3 = Dropout(drop_rate)(block3)

        #================================

        block = Concatenate(axis = -1)([block1, block2, block3])
        block = AveragePooling2D((1, 4))(block)

        #================================

        block1_1 = Conv2D(8, (1, 16), padding='same')(block)
        block1_1 = BatchNormalization()(block1_1)
        block1_1 = Activation('elu')(block1_1)
        block1_1 = Dropout(drop_rate)(block1_1)

        #================================

        block2_1 = Conv2D(8, (1, 8), padding='same')(block)
        block2_1 = BatchNormalization()(block2_1)
        block2_1 = Activation('elu')(block2_1)
        block2_1 = Dropout(drop_rate)(block2_1)

        #================================

        block3_1 = Conv2D(8, (1, 4), padding='same')(block)
        block3_1 = BatchNormalization()(block3_1)
        block3_1 = Activation('elu')(block3_1)
        block3_1 = Dropout(drop_rate)(block3_1)

        #================================

        block_new = Concatenate(axis = -1)([block1_1, block2_1, block3_1])
        block_new = AveragePooling2D((1, 2))(block_new)

        block_new = Conv2D(12, (1, 8), padding='same')(block_new)
        block_new = BatchNormalization()(block_new)
        block_new = Activation('elu')(block_new)
        block_new = Dropout(drop_rate)(block_new)

        block_new = AveragePooling2D((1, 2))(block_new)

        block_new = Conv2D(6, (1, 4), padding='same')(block_new)
        block_new = BatchNormalization()(block_new)
        block_new = Activation('elu')(block_new)
        block_new = Dropout(drop_rate)(block_new)

        block_new = AveragePooling2D((1, 2))(block_new)

        embedded = Flatten()(block_new)
        out = Dense(out_class, activation = 'softmax')(embedded)
    
    #========================================================================================
    # EEGNet
    #========================================================================================
    
    elif Net_name == 2:
        
        if analysis_type==1:
            drop_rate = 0.5
        else:
            drop_rate = 0.25

        block = Conv2D(8, (1, 64), use_bias = False, activation = 'linear', padding='same',
                       name = 'Spectral_filter')(Input_block)
        block = BatchNormalization()(block)
        block = DepthwiseConv2D((Chans, 1), use_bias = False, padding='valid', depth_multiplier = 2, activation = 'linear',
                                 depthwise_constraint = tf.keras.constraints.MaxNorm(max_value=1),
                                name = 'Spatial_filter')(block)
        block = BatchNormalization()(block)
        block = Activation('elu')(block)

        #================================

        block = AveragePooling2D((1, 4))(block)
        block = Dropout(drop_rate)(block)

        block = SeparableConv2D(16, (1, 16), use_bias = False, activation = 'linear', padding = 'same')(block)
        block = BatchNormalization()(block)
        block = Activation('elu')(block)
        block = AveragePooling2D((1, 8))(block)
        block = Dropout(drop_rate)(block)
        embedded = Flatten()(block)

        out = Dense(out_class, activation = 'softmax', kernel_constraint = max_norm(0.25))(embedded)
    
    #========================================================================================   
    # EEG-TCNet
    #========================================================================================
    
    elif Net_name == 3:
        
        if analysis_type==1:
            drop_rate = [0.5, 0.3]
        else:
            drop_rate = [0.25, 0.2]
        
        block = Conv2D(8, (1, 64), use_bias = False, activation = 'linear', padding='same',
                       name = 'Spectral_filter')(Input_block)
        block = BatchNormalization()(block)
        block = DepthwiseConv2D((Chans, 1), use_bias = False, padding='valid', depth_multiplier = 2, activation = 'linear',
                                 depthwise_constraint = tf.keras.constraints.MaxNorm(max_value=1),
                                name = 'Spatial_filter')(block)
        block = BatchNormalization()(block)
        block = Activation('elu')(block)

        #================================

        block = AveragePooling2D((1, 4))(block)
        block = Dropout(drop_rate[0])(block)

        block = SeparableConv2D(16, (1, 16), use_bias = False, activation = 'linear', padding = 'same')(block)
        block = BatchNormalization()(block)
        block = Activation('elu')(block)
        block = AveragePooling2D((1, 8))(block)
        block = Dropout(drop_rate[0])(block)

        #================================

        block_in = Lambda(SQ)(block)

        paddings = tf.constant([[0,0], [2,0], [0,0]])
        block = tf.pad(block_in, paddings, "CONSTANT")
        block = Conv1D(16, 3, dilation_rate=1)(block)
        block = BatchNormalization()(block)
        block = Activation('elu')(block)
        block = Dropout(drop_rate[1])(block) 
        block = tf.pad(block, paddings, "CONSTANT")
        block = Conv1D(16, 3, dilation_rate=1)(block)
        block = BatchNormalization()(block)
        block = Activation('elu')(block)
        block = Dropout(drop_rate[1])(block) 
        block_in = Conv1D(16, 1)(block_in)
        block_out = Add()([block_in, block])


        paddings = tf.constant([[0,0], [4,0], [0,0]])
        block = tf.pad(block_out, paddings, "CONSTANT")
        block = Conv1D(16, 3, dilation_rate=2)(block)
        block = BatchNormalization()(block)
        block = Activation('elu')(block)
        block = Dropout(drop_rate[1])(block) 
        block = tf.pad(block, paddings, "CONSTANT")
        block = Conv1D(16, 3, dilation_rate=2)(block)
        block = BatchNormalization()(block)
        block = Activation('elu')(block)
        block = Dropout(drop_rate[1])(block) 
        block_out = Conv1D(16, 1)(block_out)
        block = Add()([block_out, block])

        embedded = Lambda(Select)(block)

        out = Dense(out_class, activation = 'softmax', kernel_constraint = max_norm(0.25))(embedded)

    #========================================================================================   
    # EEG-ITNet
    #========================================================================================  
    
    else:    
        
        if analysis_type==1:
            drop_rate = 0.4
        else:
            drop_rate = 0.2

        block1 = Conv2D(n_ff[0], (1, 16), use_bias = False, activation = 'linear', padding='same',
                       name = 'Spectral_filter_1')(Input_block)
        block1 = BatchNormalization()(block1)
        block1 = DepthwiseConv2D((Chans, 1), use_bias = False, padding='valid', depth_multiplier = n_sf[1], activation = 'linear',
                                 depthwise_constraint = tf.keras.constraints.MaxNorm(max_value=1),
                                name = 'Spatial_filter_1')(block1)
        block1 = BatchNormalization()(block1)
        block1 = Activation('elu')(block1)

        #================================

        block2 = Conv2D(n_ff[1], (1, 32), use_bias = False, activation = 'linear', padding='same',
                       name = 'Spectral_filter_2')(Input_block)
        block2 = BatchNormalization()(block2)
        block2 = DepthwiseConv2D((Chans, 1), use_bias = False, padding='valid', depth_multiplier = n_sf[2], activation = 'linear',
                                 depthwise_constraint = tf.keras.constraints.MaxNorm(max_value=1),
                                name = 'Spatial_filter_2')(block2)
        block2 = BatchNormalization()(block2)
        block2 = Activation('elu')(block2)

        #================================

        block3 = Conv2D(n_ff[2], (1, 64), use_bias = False, activation = 'linear', padding='same',
                       name = 'Spectral_filter_3')(Input_block)
        block3 = BatchNormalization()(block3)
        block3 = DepthwiseConv2D((Chans, 1), use_bias = False, padding='valid', depth_multiplier = n_sf[3], activation = 'linear',
                                 depthwise_constraint = tf.keras.constraints.MaxNorm(max_value=1), 
                                 name = 'Spatial_filter_3')(block3)
        block3 = BatchNormalization()(block3)
        block3 = Activation('elu')(block3)

        #================================

        block = Concatenate(axis = -1)([block1, block2, block3]) 

        #================================

        block = AveragePooling2D((1, 4))(block)
        block_in = Dropout(drop_rate)(block)

        #================================

        paddings = tf.constant([[0,0], [0,0], [3,0], [0,0]])
        block = tf.pad(block_in, paddings, "CONSTANT")
        block = DepthwiseConv2D((1,4), padding="valid", depth_multiplier=1, dilation_rate=(1, 1))(block)
        block = BatchNormalization()(block)
        block = Activation('elu')(block)
        block = Dropout(drop_rate)(block)
        block = tf.pad(block, paddings, "CONSTANT")
        block = DepthwiseConv2D((1,4), padding="valid", depth_multiplier=1, dilation_rate=(1, 1))(block)
        block = BatchNormalization()(block)
        block = Activation('elu')(block)
        block = Dropout(drop_rate)(block)
        block_out = Add()([block_in, block])


        paddings = tf.constant([[0,0], [0,0], [6,0], [0,0]])
        block = tf.pad(block_out, paddings, "CONSTANT")
        block = DepthwiseConv2D((1,4), padding="valid", depth_multiplier=1, dilation_rate=(1, 2))(block)
        block = BatchNormalization()(block)
        block = Activation('elu')(block)
        block = Dropout(drop_rate)(block)
        block = tf.pad(block, paddings, "CONSTANT")
        block = DepthwiseConv2D((1,4), padding="valid", depth_multiplier=1, dilation_rate=(1, 2))(block)
        block = BatchNormalization()(block)
        block = Activation('elu')(block)
        block = Dropout(drop_rate)(block)
        block_out = Add()([block_out, block])


        paddings = tf.constant([[0,0], [0,0], [12,0], [0,0]])
        block = tf.pad(block_out, paddings, "CONSTANT")
        block = DepthwiseConv2D((1,4), padding="valid", depth_multiplier=1, dilation_rate=(1, 4))(block)
        block = BatchNormalization()(block)
        block = Activation('elu')(block)
        block = Dropout(drop_rate)(block)
        block = tf.pad(block, paddings, "CONSTANT")
        block = DepthwiseConv2D((1,4), padding="valid", depth_multiplier=1, dilation_rate=(1, 4))(block)
        block = BatchNormalization()(block)
        block = Activation('elu')(block)
        block = Dropout(drop_rate)(block)
        block_out = Add()([block_out, block]) 


        paddings = tf.constant([[0,0], [0,0], [24,0], [0,0]])
        block = tf.pad(block_out, paddings, "CONSTANT")
        block = DepthwiseConv2D((1,4), padding="valid", depth_multiplier=1, dilation_rate=(1, 8))(block)
        block = BatchNormalization()(block)
        block = Activation('elu')(block)
        block = Dropout(drop_rate)(block)
        block = tf.pad(block, paddings, "CONSTANT")
        block = DepthwiseConv2D((1,4), padding="valid", depth_multiplier=1, dilation_rate=(1, 8))(block)
        block = BatchNormalization()(block)
        block = Activation('elu')(block)
        block = Dropout(drop_rate)(block)
        block_out = Add()([block_out, block]) 

        #================================

        block = block_out

        #================================

        block = Conv2D(28, (1,1))(block)
        block = BatchNormalization()(block)
        block = Activation('elu')(block)
        block = AveragePooling2D((4,1), data_format='Channels_first')(block)
        block = Dropout(drop_rate)(block) 
        embedded = Flatten()(block)

        out = Dense(out_class, activation = 'softmax', kernel_constraint = max_norm(0.25))(embedded)
    
    #========================================================================================
    if out_type == 'multi':

        return Model(inputs = Input_block, outputs = [out, embedded])

    else:

        return Model(inputs = Input_block, outputs = out)

### Main 

In [None]:
info = {'n_epochs_kfold':[], 'best_model':[], 'fold_accuracy_train':[], 'fold_accuracy_val':[],
        'test_accuracy_before':[], 'test_accuracy_after':[]}

for S_num in range(1,Data_BCI_Train.shape[0]+1):
    
    #================================
    print('Analysing subject',S_num,'of', Data_BCI_Train.shape[0], '...')
    if analysis_type==1:
        Data_train = Data_BCI_Train[S_num-1,0]
        Label_train = Data_BCI_Train[S_num-1,1]
        Data_test = Data_BCI_Test[S_num-1,0]
        Label_test = Data_BCI_Test[S_num-1,1]
    else:
        Data_test = Data_BCI_Test[S_num-1,0]
        Label_test = Data_BCI_Test[S_num-1,1]
        Data_train = np.empty(shape = [Data_BCI_Train[0,0].shape[0], Data_BCI_Train[0,0].shape[1], 0])
        Label_train = np.empty(shape = [0, 1])

        for i in range(Data_BCI_Train.shape[0]):
            if i!=S_num-1:
                Data_train = np.concatenate((Data_train, Data_BCI_Train[i,0]), axis = -1)
                Label_train = np.concatenate((Label_train, Data_BCI_Train[i,1]), axis = 0)  
                
    #================================
    # Downsampling 
    if Data_name==1:
        Data_train = signal.decimate(Data_train, 2, axis = 1)
        Data_test = signal.decimate(Data_test, 2, axis = 1)
        
    #================================
    # Remove EOG channels and select time window
    if Data_name==1:
        Data_train = np.delete(Data_train, [22, 23, 24], axis=0)
        Data_test = np.delete(Data_test, [22, 23, 24], axis=0)
        
    Data_train = Data_train[:,int((Win_start*Fs)):int((Win_end*Fs)),:]
    Data_test = Data_test[:,int((Win_start*Fs)):int((Win_end*Fs)),:]

    #================================
    # Normalisation
    X_train = np.zeros(Data_train.shape)
    for ch in range(Data_train.shape[0]):
        temp = Data_train[ch,:,:]
        X_train[ch,:,:] = (temp-np.mean(temp))/np.std(temp)

    X_test = np.zeros(Data_test.shape)      
    for ch in range(Data_test.shape[0]):
        temp = Data_test[ch,:,:]
        X_test[ch,:,:] = (temp-np.mean(temp))/np.std(temp)    
        
    #================================
    # Preparing inputs for the deep learning model
    X_train = np.transpose(X_train, (2, 0, 1))
    X_test = np.transpose(X_test, (2, 0, 1))

    Y_train = Label_train
    Y_test = Label_test

    X_train = X_train[:,:,:,np.newaxis]
    X_test = X_test[:,:,:,np.newaxis]

    enc = OneHotEncoder()
    enc.fit(Y_train)
    Y_train = enc.transform(Y_train).toarray()
    Y_test = enc.transform(Y_test).toarray()

    #===============================
    _, Chans, Samples, _ = X_train.shape
    
    #===============================
    # Training folds
    All_model = []
    All_AccuracyTrain = []
    All_AccuracyVal = []
    All_AccuracyTest = []
    All_loss = []
    All_epochs = []
    
    #===============================
    
        
    kfold = KFold(n_splits=10, shuffle=True)
    fold = 1
    while fold<=10:
        model = Network(Chans, Samples)
        model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])

        #===============================
        if analysis_type==1:

            train, val = list(kfold.split(X_train, Y_train))[fold-1]

            X = X_train[train,]
            Y = Y_train[train,]
            X_val = X_train[val,]
            Y_val = Y_train[val,]

        else:
            subs = random.sample(range(0, Data_BCI_Train.shape[0]-1), int(Data_BCI_Train.shape[0]/3))

            val = []

            for i in subs:
                val.extend(np.arange(i*Data_BCI_Train[0,0].shape[-1],(i+1)*Data_BCI_Train[0,0].shape[-1]))

            X = np.delete(X_train, val, axis = 0)
            Y = np.delete(Y_train, val, axis = 0)
            X_val = X_train[val,]
            Y_val = Y_train[val,]

        #===============================
        es = EarlyStopping(monitor='val_loss', mode='min', verbose=0, patience=patience)
        mc = ModelCheckpoint('./Results/best_model.h5', monitor='val_loss', mode='min', save_best_only=True)
        fittedModel = model.fit(X, Y, batch_size = batch_size, epochs = epochs, 
                                verbose = 0, validation_data=(X_val, Y_val),
                                callbacks=[es, mc])

        All_loss.append(np.amin(fittedModel.history['val_loss']))

        if es.stopped_epoch==0:
            All_epochs.append(epochs)
        else:
            All_epochs.append(es.stopped_epoch)

        model = load_model('./Results/best_model.h5')
        All_model.append(model)

        fold +=1

        probs = model.predict(X)
        preds = probs.argmax(axis = -1)  

        All_AccuracyTrain.append(round(100*np.mean(preds == Y.argmax(axis=-1)),2))

        probs = model.predict(X_val)
        preds = probs.argmax(axis = -1)  

        All_AccuracyVal.append(round(100*np.mean(preds == Y_val.argmax(axis=-1)),2))

        probs = model.predict(X_test)
        preds = probs.argmax(axis = -1)  

        All_AccuracyTest.append(round(100*np.mean(preds == Y_test.argmax(axis=-1)),2))

    #===============================     
    All_model[np.argmin(All_loss)].save_weights('./Results/best_model.h5')
    model = Network(Chans, Samples)
    model.load_weights('./Results/best_model.h5')
    probs = model.predict(X_test)
    preds = probs.argmax(axis = -1)  
    info['test_accuracy_before'].append(round(100*np.mean(preds == Y_test.argmax(axis=-1)),2))

    #===============================     
    opt = tf.keras.optimizers.Adam(learning_rate=0.0001)
    model.compile(optimizer = opt, loss = 'categorical_crossentropy', metrics = ['accuracy'])
    mc = ModelCheckpoint('./Results/best_model_final.h5', monitor='val_accuracy', mode='max', save_best_only=True)
    fittedModel = model.fit(X_train, Y_train, batch_size = batch_size, epochs = extra_epoch,
                            verbose = 0, validation_data=(X_test, Y_test), callbacks=[mc])
    model = Network(Chans, Samples, out_type = 'multi')
    model.load_weights('./Results/best_model_final.h5')  

    probs, _ = model.predict(X_test)
    preds = probs.argmax(axis = -1)  
    if analysis_type==3:

        X_extra = Data_BCI_Train[S_num-1,0]
        Y_extra = Data_BCI_Train[S_num-1,1]

        #===============================
        # Remove EOG channels and select time window  
        if Data_name==1:
            X_extra = np.delete(X_extra, [22, 23, 24], axis=0)
        X_extra = X_extra[:,int((Win_start*Fs)):int((Win_end*Fs)),:]

        #===============================
        # Downsampling 
        if Data_name==1:
            X_extra = signal.decimate(X_extra, 2, axis = 1)

        #===============================
        # Normalisation
        X_extra_Train = np.zeros(X_extra.shape)
        for ch in range(X_extra.shape[0]):
            temp = X_extra[ch,:,:]
            X_extra_Train[ch,:,:] = (temp-np.mean(temp))/np.std(temp)

        #===============================
        X_extra_Train = np.transpose(X_extra_Train, (2, 0, 1))
        X_extra = X_extra_Train[:,:,:,np.newaxis]

        enc = OneHotEncoder()
        enc.fit(Y_extra)
        Y_extra = enc.transform(Y_extra).toarray()

        #===============================

        opt = tf.keras.optimizers.Adam(learning_rate=0.0001)
        model.compile(optimizer = opt, loss = 'categorical_crossentropy', metrics = ['accuracy'])
        mc = ModelCheckpoint('./Results/best_model_final.h5', monitor='val_accuracy', mode='max', save_best_only=True)
        fittedModel = model.fit(X_extra, Y_extra, batch_size = batch_size, epochs = extra_epoch,
                                verbose = 0, validation_data=(X_test, Y_test), callbacks=[mc])
        model = Network(Chans, out_type = 'multi')
        model.load_weights('./Results/best_model_final.h5')  

        probs, _ = model.predict(X_test)
        preds = probs.argmax(axis = -1)  

    #===============================     
    info['n_epochs_kfold'].append(np.mean(All_epochs))
    info['best_model'].append(model)
    info['fold_accuracy_train'].append(np.mean(All_AccuracyTrain))
    info['fold_accuracy_val'].append(np.mean(All_AccuracyVal))
    info['test_accuracy_after'].append(round(100*np.mean(preds == Y_test.argmax(axis=-1)),2))


In [None]:
print(info['test_accuracy_after'])