In [1]:
# Mount drive and access data
from google.colab import drive
from google.colab import files
drive.mount('/content/drive')
%cd /content/drive/My\ Drive

# Import dependencies and check TF version
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from scipy.io import loadmat
import random
import os
from time import time, sleep
import itertools
from sklearn.utils import shuffle
from tqdm.auto import tqdm
# !pip install tqdm
# check tensorflow version should be higher or equal than 2.0
print(tf.__version__)


# Check GPU (no need at every execution)
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))


CTL = [890, 891, 892, 893, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, 905, 906, 907, 909, 910, 911, 912, 913, 914, 8060, 8070]
PD = [804, 805, 806, 807, 808, 809, 810, 811, 813, 814, 815, 816, 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829]

# Ces données sont disponible dans les fichier REST, IMPORT_ME_REST.xlsx
# Pour chaque MDP on a la session où il était off medication
off_medication_session = {804: 2, 805: 2, 806: 1, 807: 1, 808: 1, 809: 2, 810: 2, 811: 2, 813: 1, 814: 2, 815: 2, 816: 1, 817: 1, 818: 2, 819: 1, 
                          820: 2, 821: 2, 822: 2, 823: 1, 824: 1, 825: 2, 826: 2, 827: 1, 828: 1, 829: 1} 


# ---------------------------------------------------------- Multiple functions ----------------------------------------------------------
# Setting different folders for the multiple pre-processing of datasets
def set_path(choice):
    if choice == "dataset_one":
        data_folder = "./dataset_one/"
        model_folder = "./dataset_one_models"
        screen_folder = "./dataset_one_images"
        
        # Time vector
        t_min = -250 #[ms]
        t_max = 998  #[ms]
        nb_point = 625 
        t = np.linspace(t_min, t_max, nb_point)

    elif choice=="dataset_two":
        data_folder = "./dataset_two/"
        model_folder = "./dataset_two_models"
        screen_folder = "./dataset_two_images"

        # Time vector
        t_min = -250 #[ms]
        t_max = 998  #[ms]
        nb_point = 625 
        t = np.linspace(t_min, t_max, nb_point)

    elif choice=="dataset_three":
        data_folder = "./dataset_three/"
        model_folder = "./dataset_three_models"
        screen_folder = "./dataset_three_images"

        # Time vector
        t_min = -250 #[ms]
        t_max = 998  #[ms]
        nb_point = 625 
        t = np.linspace(t_min, t_max, nb_point)


    elif choice=="dataset_zero":
        data_folder = "./dataset_zero/"
        model_folder = "./dataset_zero_models"
        screen_folder = "./dataset_zero_images"

        # Time vector
        t_min = -250 #[ms]
        t_max = 998  #[ms]
        nb_point = 625 
        t = np.linspace(t_min, t_max, nb_point)

    elif choice=="dataset_five":
        data_folder = "./dataset_five/"
        model_folder = "./dataset_five_models"
        screen_folder = "./dataset_five_images"

        # Time vector
        t_min = -250 #[ms]
        t_max = 998  #[ms]
        nb_point = 625 
        t = np.linspace(t_min, t_max, nb_point)

    return data_folder, model_folder, screen_folder, nb_point

    
# Show an epoch, chan of an EEG
def show(data, epoch, chan, save=False):
    plt.figure(figsize=(10,6))
    plt.plot(t, data[epoch,:,chan], label='Signal EEG')
    plt.axvline(x=0, linewidth=2, color='k')
    plt.axhline(y=0, linewidth=0.4, color='k')
    plt.grid(color='k', linestyle='-', linewidth=.3)
    plt.xlabel('Temps [ms]', fontsize=12)
    plt.ylabel('Amplitude [uV]', fontsize=12)
    plt.gca().invert_yaxis()
    name = 'EEG signal, chan : ' + str(chan+1) + ', epoch : ' + str(epoch+1)
    plt.title(name, fontsize=14)
    plt.legend(fontsize=12)
    if save:
        save_name = screen_folder + '/Normalized data epoch ' + str(epoch) + ' chan ' + str(chan) + '.png'
        plt.savefig(save_name)
        print('Images saved as :', save_name)


# Execution time decorator
def execution_time(function):
    def my_function(*args, **kwargs):
        tic = time()
        result = function(*args, **kwargs)
        toc = time()
        if toc - tic > 1e-2:
            print("Temps d'éxecution de", function.__name__, ": {0:.3f} (s)".format(toc - tic))
        else:
            print("Temps d'éxecution de", function.__name__, ": {0:.3f} (ms)".format(1000 * (toc - tic)))
        return result
    return my_function


# Channel normalization
def chan_normalize(data, *args):
    # case no mean_std vector provided
    if len(args) == 0:
        mean_std = [] # Keep in memory for testing on new data 
        for chan in tqdm(range(data.shape[2])):
            #mean = np.mean(data[:,:int(-t_min/2),chan])
            mean = np.mean(data[:,:,chan])
            data[:,:,chan] -= mean
            std = np.std(data[:,:,chan])
            data[:,:,chan] /= std

            mean_std.append([mean, std])
        return data, mean_std

    # case where the mean_std vector is provided
    else:
        for chan in tqdm(range(data.shape[2])):
            data[:,:,chan] -= args[0][chan][0]
            data[:,:,chan] /= args[0][chan][1]
        return data


# Normalize data
def normalize(data):
    for i in range(data.shape[0]):
        data[i,:,:] -= np.mean(data[i,:,:], axis=0)
        data[i,:,:] /= np.std(data[i,:,:], axis=0)
    return data


# Display a confusion matrix
def CM_display(confusion_matrix, normalize=True, save=False):
    target_names = ['MP', 'CTL']
    if len(np.shape(confusion_matrix)) == 3:
        CM_std = np.std(confusion_matrix, axis=0)
        confusion_matrix = np.mean(confusion_matrix, axis=0)
        
    accuracy = np.trace(confusion_matrix) / float(np.sum(confusion_matrix))
    misclass = 1 - accuracy

    plt.figure(figsize=(8, 8))

    tick_marks = np.arange(len(target_names))
    plt.xticks(tick_marks, target_names, rotation=45, fontsize=14)
    plt.yticks(tick_marks, target_names, fontsize=14)
    
    if normalize:
        summ = confusion_matrix.sum(axis=1)
        confusion_matrix = 100 * confusion_matrix.astype('float') / summ[:, np.newaxis]
        if 'CM_std' in locals():
            CM_std = 100 * CM_std.astype('float') / summ[:, np.newaxis]
        plt.imshow(confusion_matrix, cmap=plt.get_cmap('Blues'))
        plt.title('Normalized confusion matrix', fontsize=16)
        plt.clim(0, 100)
        cbar = plt.colorbar()
        cbar.ax.tick_params(labelsize=14)
    else:
        plt.imshow(confusion_matrix, cmap=plt.get_cmap('Blues'))
        plt.title('Confusion matrix', fontsize=16)
        cbar = plt.colorbar()
        cbar.ax.tick_params(labelsize=14)

    thresh = confusion_matrix.max()/1.5 if normalize else confusion_matrix.max()/2
    for i, j in itertools.product(range(confusion_matrix.shape[0]), range(confusion_matrix.shape[1])):
        if 'CM_std' in locals():
            if normalize:
                plt.text(j, i, "{:.1f} % ± {:.1f} %".format(confusion_matrix[i, j], CM_std[i, j]),
                        horizontalalignment="center",
                        color="white" if confusion_matrix[i, j] > thresh else "black",
                        fontsize=14)
            else:
                plt.text(j, i, "{:} ± {:}".format(confusion_matrix[i, j], CM_std[i, j]),
                        horizontalalignment="center",
                        color="white" if confusion_matrix[i, j] > thresh else "black",
                        fontsize=14)
        else:
            if normalize:
                plt.text(j, i, "{:.1f} %".format(confusion_matrix[i, j]),
                        horizontalalignment="center",
                        color="white" if confusion_matrix[i, j] > thresh else "black",
                        fontsize=14)
            else:
                plt.text(j, i, "{:}".format(confusion_matrix[i, j]),
                        horizontalalignment="center",
                        color="white" if confusion_matrix[i, j] > thresh else "black",
                        fontsize=14)

    plt.ylabel('True label', fontsize=16)
    plt.xlabel('Predicted label\naccuracy={:.1f}; misclass={:.1f}'.format(100*accuracy, 100*misclass), fontsize=16)

    if save:
        save_name = screen_folder + '/Confusion_matrix.png'
        plt.savefig(save_name)
    plt.show()


# Calculate F1 score from a confusion matrix
def calculate_F1_score(confusion_matrix):
    TP = confusion_matrix[0][0] # True positive
    TN = confusion_matrix[1][1] # True negative
    FN = confusion_matrix[1][0] # False negative
    FP = confusion_matrix[0][1] # False positive
    precision = TP / float(TP + FP)
    recall = TP / float(TP + FN)
    F1 = 2 / (1/precision + 1/recall)
    print("The F1_score is {:.1f}\n\n".format(100 * F1))
    return F1


# Display Confidence Interval
def plot_mean_and_CI(mean, lb, ub, legend, color_mean=None, color_shading=None):
    # plot the shaded range of the confidence intervals
    plt.fill_between(range(mean.shape[0]), ub, lb, color=color_shading, alpha=.3)
                     
    # plot the mean on top
    plt.plot(mean, color_mean, label=legend)
    plt.legend(fontsize=14)


# Display accuracy with confidence interval
def accuracy_display(history, save=False, CI_display=True):
    train_mean = np.empty([0], dtype=float)
    train_std = np.empty([0], dtype=float)
    test_mean = np.empty([0], dtype=float)
    test_std = np.empty([0], dtype=float)

    for i in range(history.shape[1]):
        train_mean = np.append(train_mean, np.expand_dims(np.mean(history[:,i,0], dtype=float), 0), axis=0)
        test_mean = np.append(test_mean, np.expand_dims(np.mean(history[:,i,1], dtype=float), 0), axis=0)
        train_std = np.append(train_std, np.expand_dims(np.std(history[:,i,0], dtype=float), 0), axis=0)
        test_std = np.append(test_std, np.expand_dims(np.std(history[:,i,1], dtype=float), 0), axis=0)

    train_mean *= 100
    test_mean *= 100
    train_std *= 100
    test_std *= 100

    plt.figure(figsize=(12,8))
    plt.rc('xtick', labelsize=12)
    plt.rc('ytick', labelsize=12)
    if CI_display:
        plot_mean_and_CI(train_mean, train_mean-train_std, train_mean+train_std, color_mean='C0', color_shading='C0', legend='Train')
        plot_mean_and_CI(test_mean, test_mean-test_std, test_mean+test_std, color_mean='C3', color_shading='C3', legend='Test')
    else:
        plot_mean_and_CI(train_mean, train_mean, train_mean, color_mean='C0', color_shading='C0', legend='Train')
        plot_mean_and_CI(test_mean, test_mean, test_mean, color_mean='C3', color_shading='C3', legend='Test')
    plt.grid(color='k', linestyle='-', linewidth=.3)
    plt.xlabel('Epoch', fontsize=14)
    plt.ylabel('Accuracy [%]', fontsize=14)
    plt.title('Train and test accuracy', fontsize=16)
    plt.gca().spines['top'].set_position('zero')
    plt.gca().spines['right'].set_position('zero')
    plt.autoscale(enable=True, axis='x', tight=True)
    plt.ylim(0, 100)
    plt.yticks(np.arange(0, 110, step=10))
    plt.xticks(np.arange(0, train_mean.shape[0], step=5))
    if save:
        save_name = screen_folder + '/Training_test_accuracy.png'
        plt.savefig(save_name)
        #save_name = screen_folder + '/Training_test_accuracy.svg'
        #plt.savefig(save_name, format='svg', dpi=1200)
    plt.show()


def dft(data, fe=500, lim=None, show=False, save=False):
    N = np.shape(data)[1]
    f = np.linspace(0, fe, N, endpoint=False)
    f = f[0:int(N/2-1)] # unispectrale
    fft = abs(np.fft.fft(data, axis=1))/N
    fft[:,1:int(N/2-1),:] = 2*fft[:,1:int(N/2-1),:]
    fft = fft[:,0:int(N/2-1),:]

    if lim != None:
        f = f[lim[0]:lim[1]]
        fft = fft[:,lim[0]:lim[1],:]

    if show:
        # if show display a random one
        chan = np.random.randint(61)
        epoch = np.random.randint(data.shape[0]+1)
        plt.figure(figsize=(10,6))
        plt.rc('xtick', labelsize=12)
        plt.rc('ytick', labelsize=12)
        plt.autoscale(enable=True, axis='x', tight=True)
        plt.plot(f, fft[epoch,:,chan])
        #plt.xlim(0, 250)
        plt.ylim(bottom=0)
        plt.xlabel('Frequence [Hz]', fontsize=14)
        plt.ylabel('Amplitude', fontsize=14)
        title = 'Transformation de Fourier, epoch : ' + str(epoch) + ', chan : ' + str(chan)
        plt.title(title, fontsize=16)
        plt.grid()
        if save:
            save_name = screen_folder + '/' + title + '.png'
            plt.savefig(save_name)
            files.download(save_name)
        plt.show()
    return fft

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/My Drive
2.2.0
Found GPU at: /device:GPU:0


In [0]:
# For a faster changing in train and test set all the data are stored in a dict
# Reading data from drive
# At the end a dictionary is provided
def read_data(data_folder, data_type_choice, limits, is_on_medication=False, stimuli_type=[200,201,202]):
    data_dict = {}
    print('Reading data')
    # for each file in the given folder
    for file_name in tqdm(os.listdir(data_folder)):
        number = int(file_name.split('_')[0])
        session = int(file_name.split('_')[1])
    
        if number > 850 or ((off_medication_session[number]==session) != is_on_medication): # != off medication == on medication
            data_temp = loadmat(data_folder + '/' + file_name) 
            EEG = np.swapaxes(data_temp['EEG']['data'][0][0], 0, 2) # format [epoch][point][chan]
            stimuli = data_temp['EEG']['epoch'][0][0][0]

            # selecting stimulus type
            indices = [index for index, element in enumerate(stimuli) if element in stimuli_type]
            EEG = np.take(EEG, indices, axis=0)
            target = [[1, 0]] if int(number) < 850 else [[0, 1]]  # One hot encoding
            targets = np.tile(target, (len(indices), 1))
        
            if data_type_choice == 'EEG':
                if do_normalisation: EEG = normalize(EEG) 
                data_dict[number] = {'feature' : EEG, 'target' : targets}

            elif data_type_choice == 'DFT':
                tf_data = dft(EEG)
                if do_normalisation: tf_data = normalize(tf_data)
                data_dict[number] = {'feature' : tf_data, 'target' : targets}

            elif data_type_choice == 'DFT_lim':
                tf_data = dft(EEG, lim=limits)
                if do_normalisation: tf_data = normalize(tf_data)
                data_dict[number] = {'feature' : tf_data, 'target' : targets}
               
            elif data_type_choice == 'DFT_mean':
                tf_data = np.expand_dims(np.mean(dft(EEG), axis=0), axis=0)
                if do_normalisation: tf_data = normalize(tf_data)
                data_dict[number] = {'feature' : tf_data, 'target' : target}
            
            elif data_type_choice == 'DFT_lim_mean':
                tf_data = np.expand_dims(np.mean(dft(EEG, lim=limits), axis=0), axis=0)
                if do_normalisation: tf_data = normalize(tf_data)
                data_dict[number] = {'feature' : tf_data, 'target' : target}

    return data_dict


# Function that splits train and test set and corrupt data if needed
def train_test_split(test_type, do_corruption, *args):
    train_features = []
    train_targets = []

    # Shuffle data for a better training
    shuffled_keys = shuffle(list(data.keys()))

    # Individual train type (just 1 person in the test set)
    if test_type == 'individual':
        for key in shuffled_keys:
            if key != args[0]: # args[0] = person ID
                train_features = [*train_features, *data[key]['feature']]
                train_targets = [*train_targets, *data[key]['target']]
            else:
                test_features = data[key]['feature']
                test_targets = data[key]['target']

    # Group train type (multiple persons are in test set and the others are in the train set)
    elif test_type == 'group':
        test_features = []
        test_targets = []

        # take random samples for test set
        PD_test = random.sample(PD, args[0])
        CTL_test = random.sample(CTL, args[0])

        if do_corruption:
            to_corrupt_number = int((len(data.keys())-2*args[0])*corruption_percent/100 /2) 
            PD_corrupted = random.sample([item for item in PD if item not in PD_test], to_corrupt_number)
            CTL_corrupted = random.sample([item for item in CTL if item not in CTL_test], to_corrupt_number)
            corrupted = [*PD_corrupted, *CTL_corrupted]
            print('{:}% of the data are corrupted, corrupted = {:}'.format(corruption_percent, corrupted))

        for key in shuffled_keys:
            if key in [*PD_test, *CTL_test]:
                test_features = [*test_features, *data[key]['feature']]
                test_targets = [*test_targets, *data[key]['target']]
            else:
                if do_corruption:
                    if key in corrupted:
                        train_features = [*train_features, *data[key]['feature']]
                        train_targets = [*train_targets, *list([1,1] - np.array(data[key]['target']))]
                    else:
                        train_features = [*train_features, *data[key]['feature']]
                        train_targets = [*train_targets, *data[key]['target']]

        # for item in [*PD_test, *CTL_test]:
        #     shuffled_keys.remove(item)

    train_features = np.array(train_features)
    train_targets  = np.array(train_targets)
    test_features = np.array(test_features)
    test_targets = np.array(test_targets)
    
    # Display sets shape
    print('\nfeatures shape :\n\tTrain =', np.shape(train_features), '---- Test =', np.shape(test_features))
    print('targets shape :\n\tTrain =', np.shape(train_targets), '---- Test =', np.shape(test_targets))
    if test_type == 'group':
        print('Selection :', PD_test, CTL_test, '\n')
    else:
        print('Patient test :', [args[0]])

    return train_features, train_targets, test_features, test_targets

In [0]:
# Prediction model (CNN 1-D)
class Mymodel():
    def __init__(self, learning_rate=0.001, input_shape=[750, 60], output_shape=2, batch_size=32, show_summary=False, threshold=0.5):
        # learning and model parameters
        self.input_shape = input_shape
        self.output_shape = output_shape
        self.loss_object = tf.keras.losses.CategoricalCrossentropy() #self.loss_object = tf.keras.losses.BinaryCrossentropy()
        self.opt = tf.keras.optimizers.Adam(learning_rate) #self.opt = tf.keras.optimizers.SGD(self.learning_rate)

        # metrics and other parameters
        self.train_accuracy = tf.keras.metrics.BinaryAccuracy(name='Train accuracy', threshold=threshold) #self.AUC = tf.keras.metrics.AUC()
        self.test_accuracy = tf.keras.metrics.BinaryAccuracy(name='Test accuracy', threshold=threshold)
        self.show_summary = show_summary
        self.history = np.asarray([[0,0]], dtype=float)
        self.confusion_matrix = []
        self.train_loss = tf.keras.metrics.Mean(name='train_loss')

        self.model = self.create_model()


    def create_model(self):
        # Base model
        model = tf.keras.Sequential()
        model.add(tf.keras.layers.Conv1D(filters=256, kernel_size=4, strides=1, input_shape=self.input_shape, activation='relu'))
        model.add(tf.keras.layers.Conv1D(filters=128, kernel_size=8, strides=2, activation='relu'))
        model.add(tf.keras.layers.Conv1D(filters=64, kernel_size=16, strides=4, activation='relu'))
        model.add(tf.keras.layers.MaxPooling1D(2))
        model.add(tf.keras.layers.Flatten())
        model.add(tf.keras.layers.Dense(32, activation='relu'))
        model.add(tf.keras.layers.Dense(16, activation='relu'))
        model.add(tf.keras.layers.Dense(self.output_shape, activation='softmax'))
        # Other models
        """
        model = tf.keras.Sequential()
        model.add(tf.keras.layers.Conv1D(filters=256, kernel_size=4, strides=1, input_shape=self.input_shape, activation='relu'))
        model.add(tf.keras.layers.Conv1D(filters=128, kernel_size=8, strides=2, activation='relu'))
        model.add(tf.keras.layers.Conv1D(filters=64, kernel_size=16, strides=4, activation='relu'))
        model.add(tf.keras.layers.Dropout(0.5))
        model.add(tf.keras.layers.MaxPooling1D(2))
        model.add(tf.keras.layers.Flatten())
        model.add(tf.keras.layers.Dense(64, activation='relu'))
        model.add(tf.keras.layers.Dropout(0.5))
        model.add(tf.keras.layers.Dense(32, activation='relu'))
        model.add(tf.keras.layers.Dense(self.output_shape, activation='softmax'))
        """
        """
        model = tf.keras.Sequential()
        model.add(tf.keras.layers.Conv1D(filters=256, kernel_size=4, strides=1, input_shape=self.input_shape, activation='relu'))
        model.add(tf.keras.layers.MaxPooling1D(2))
        model.add(tf.keras.layers.Conv1D(filters=128, kernel_size=8, strides=2, activation='relu'))
        model.add(tf.keras.layers.MaxPooling1D(4))
        #model.add(tf.keras.layers.Conv1D(filters=64, kernel_size=16, strides=4, activation='relu'))
        #model.add(tf.keras.layers.Dropout(0.5))
        model.add(tf.keras.layers.Flatten())
        model.add(tf.keras.layers.Dense(128, activation='relu'))
        #model.add(tf.keras.layers.Dropout(0.5))
        model.add(tf.keras.layers.Dense(64, activation='relu'))
        model.add(tf.keras.layers.Dense(self.output_shape, activation='softmax'))
        """
        model.compile(optimizer=self.opt, loss=self.loss_object, metrics=self.train_accuracy)
        if self.show_summary:
            model.summary()
        return model


    def batch_generator(self, iterable, batch_size):
        l = len(iterable)
        for dn in range(0, l, batch_size):
            yield iterable[dn:min(dn + batch_size, l)]


    def learn(self, features_train, targets_train, features_test, targets_test, batch_size, epoch, do_regulation=False):        
        tic = time()
        for epoch in range(epoch):
            for batch in self.batch_generator(range(features_train.shape[0]), batch_size):
                if do_regulation:
                    train_prediction = self.model(features_train[batch])
                    target_updated = controller.update_target(train_prediction, targets_train[batch], epoch)
                    self.train(features_train[batch], target_updated)
                    #print("Loss: %s" % self.train_loss.result().numpy())
                    controller.get_loss(self.train_loss)
                else:
                    self.train(features_train[batch], targets_train[batch]) 
                    #print("Loss: %s" % self.train_loss.result().numpy())
                    self.train_loss.reset_states()

            test_prediction = self.test(features_test, targets_test)
            self.update_accuracy()
            
            toc = time()
            if epoch%10==9:
                print('Epoch :{:3d}, Train accuracy : {:2.1f}%  /  Test accuracy : {:2.1f}%  /  Elapsed time : {:.1f} (s)'
                .format(epoch+1, 100*self.history[-1,0], 100*self.history[-1,1], toc-tic))
            tic = time()

        self.calculate_confusion_matrix(test_prediction, targets_test)


    # TF 2.X graph mode decorator
    @tf.function
    def train(self, features, targets):
        with tf.GradientTape() as tape:
            predictions = self.model(features)
            loss = self.loss_object(targets, predictions)
        gradients = tape.gradient(loss, self.model.trainable_variables)
        self.opt.apply_gradients(zip(gradients, self.model.trainable_variables))
        self.train_loss(loss)
        self.train_accuracy.update_state(targets, predictions)
        

    # Saving accuracies in a memory and resetting them
    def update_accuracy(self):
        self.history = np.append(self.history, [[self.train_accuracy.result().numpy(), self.test_accuracy.result().numpy()]], axis=0)
        self.train_accuracy.reset_states()
        self.test_accuracy.reset_states()


    def test(self, features_test, targets_test):
        batch_size = min(300, features_test.shape[0])
        predictions = np.empty([0, 2], dtype=np.uint8)

        for batch in self.batch_generator(range(features_test.shape[0]), batch_size):
            prediction = self.model(features_test[batch])
            predictions = np.append(predictions, prediction, axis=0)

        self.test_accuracy.update_state(targets_test, predictions)
        return predictions


    def calculate_confusion_matrix(self, predictions, targets_test):
        # 0 for PD, 1 for CTL  
        new_target = [0 if tuple(target)==(1,0) else 1 for target in targets_test]
        new_prediction = [0 if prediction[0]>0.5 else 1 for prediction in predictions]
        self.confusion_matrix = tf.math.confusion_matrix(new_target, new_prediction)

    
    def load_model(self, load_name='model'):
        try:
            self.model.load_weights(model_folder + '/' + load_name + '_weights.h5')
            print('...Model loaded...')
        except:
            print('...No model to load...')


    def save_model(self, save_name='model'):
        self.model.save(model_folder + '/my_model.h5')
        self.model.save_weights(model_folder + '/' + save_name + '_weights.h5')
        print('...Model saved...')


class control():
    def __init__(self, loss_memory_size=10):
        self.loss_memory_size = loss_memory_size
        self.loss_memory = [1.]
        self.betta_max = 0.7
        self.alpha = 50 
        self.betta = 0  #[*np.linspace(0,1, 15), *85*[1]]

    def update_target(self, prediction, target, epoch):
        count = 0
        self.update_betta()
        # for i in range(prediction.shape[0]):
        #     print('loss = {:.4}, betta = {:.2} prediction = {:}, target = {:}, is_equal = {:}'
        #     .format(float(self.loss_memory[-1]), float(self.betta), np.round(prediction[i],1), target[i], list(target[i])==list(np.round(prediction[i],0))))
        #     if list(target[i])==list(np.round(prediction[i],0)):
        #         count += 1
        # print('COUNT : ', count)
        target_updated = np.array(np.round(self.betta*prediction + (1-self.betta)*target, 2), dtype=int)
        return target_updated
    
    def get_loss(self, loss):
        self.loss_memory.append(loss.result().numpy())

    def update_betta(self):
        self.betta = min(1*np.exp(-self.alpha * self.loss_memory[-1]), self.betta_max) #saturer

In [0]:
# Data reading
dataset_choice = 'dataset_three'        # dataset name
testing_case = 'group'                  # 'group' / 'individual'
data_type_choice = 'DFT_mean'           # choose : 'EEG' 'DFT' 'DFT_lim' 'DFT_mean' 'DFT_lim_mean'
stimuli_type = [202]                    # 201 : Standard / 200 : Target / 202 : Novel (possibility to choose a combination of the three)
corruption_percent = 30                 # Percentage of corrupted data in train set : 5 10 20 30 40 50 70 90 100
limits = [0, 80]                        # In dft case take only few features (not all dft weights) [0, 80]
fs = 500                                # Sampling frequency
do_normalisation = True                 # Normalize data bool
do_corruption = True                    # Corrupt data bool (swap label)
do_regulation = False                   # Let the controller do regulation (testing it right now)

# Execution parameters
repeat_training = 30                    # Repeat multiple times the training and test (we only have 50 patients)
epoch = 50                             
batch_size = 4                         
test_size = 5                           # Number of persons from PD and CTL to take in the train set (5 means a test set of 10)
learning_rate = 0.001


# Set paths and read data
data_folder, model_folder, screen_folder, nb_point = set_path(dataset_choice)
data = read_data(data_folder, data_type_choice, stimuli_type=stimuli_type, limits=limits)
if data_type_choice.split('_')[0]=='DFT':
    print('Index limits = {:} --> freq = ({:.1f}, {:.1f})'.format(limits, limits[0]*fs/nb_point, limits[1]*fs/nb_point))

In [0]:
# Metrics containers
accuracy = []
F1 = []
confusion_matrix = []

if testing_case == 'group':
    for run in range(repeat_training):
        print('Run number :', run+1)
        # Training
        features_train, targets_train, features_test, targets_test = train_test_split(testing_case, do_corruption, test_size)
        controller = control() 
        model = Mymodel(learning_rate, input_shape=features_train.shape[1:])
        model.learn(features_train, targets_train, features_test, targets_test, batch_size, epoch, do_regulation)
        # Display result 
        F1.append(calculate_F1_score(np.array(model.confusion_matrix)))
        confusion_matrix = [*confusion_matrix, *[model.confusion_matrix]]
        accuracy = [*accuracy, *[model.history]]

    print('\n\n', 50*'-' , 'RESULTS', 50*'-')
    print('The F1_score is : {:.1f}% ± {:.1f}%'.format(100*np.mean(F1), 100*np.std(F1)))
    CM_display(np.array(confusion_matrix), save=True, normalize=True)
    accuracy_display(np.array(accuracy), save=True)

if testing_case == 'individual':
    for patient in [*PD, *CTL]:
        patient_accuracy = []
        features_train, targets_train, features_test, targets_test = train_test_split(testing_case, do_corruption, patient)

        for run in range(repeat_training):
            print('Run number :', run+1)
            # Training 
            model = Mymodel(learning_rate, threshold=threshold, input_shape=features_train.shape[1:])
            model.learn(features_train, targets_train, features_test, targets_test, batch_size, epoch, do_regulation)
            accuracy = [*accuracy, *[model.history]]

    accuracy_display(np.array(accuracy), save=True, CI_display=False)

In [0]:
# Download the saved images
files.download(screen_folder + '/Confusion_matrix.png')
files.download(screen_folder + '/Training_test_accuracy.png')