In [5]:
import tensorflow as tf
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '2'
import Graph_EEGresnet
from tensorflow.keras.callbacks import ModelCheckpoint

#### Example Code for Training Our model

* This example is made based on subject-dependent paradigm in Physionet dataset.
* The data should be organized as (num, eletrodes, data, 1)


In [4]:
dataset = np.load('data/Physionet.npy', allow_pickle=True).item()

### The dataset and split code should be replaced for different dataset

def split_data(fold, dataset):
    X = [dataset[i]['X'] for i in dataset.keys() if i not in [87, 88, 91, 99, 103]]
    y = [dataset[i]['y'] for i in dataset.keys() if i not in [87, 88, 91, 99, 103]]
    X = np.concatenate(X)
    y = np.concatenate(y)

    n = 0
    kf = KFold(n_splits=5, random_state=2022, shuffle=True)
    for train_index, test_index in kf.split(X):
        if n == fold:
            X_train, y_train = X[train_index], y[train_index]
            X_test, y_test = X[test_index], y[test_index]
            X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.125, random_state=2022)
            break
        n = n + 1
    
    return X_train, X_val, X_test, y_train, y_val, y_test

In [None]:
MODEL_NAME = 'Ours_test'
for fold in range(5):
    X_train, X_val, X_test, y_train, y_val, y_test = split_data(fold, dataset)
    model = Graph_EEGresnet.irfanet(960, 64, 128, 4, '', 0.075)
    ## Graph_EEGresnet.irfanet(num_of_time_samples, num_of_eletrodes, filter_len, num_of_classes)
    model.compile(optimizer=tf.keras.optimizers.Adam(), loss={'convloss':Graph_EEGresnet.mycrossentropy_wrapper(0.075), 'graph_loss':Graph_EEGresnet.mycrossentropy_wrapper(0.075), 'fused_loss':Graph_EEGresnet.mycrossentropy_wrapper(0.075)}, 
                    loss_weights = {'convloss':1, 'graph_loss':1, 'fused_loss': 3}, metrics={'convloss':'accuracy', 'graph_loss': 'accuracy', 'fused_loss': 'accuracy'})

    if not os.path.exists('model/%s'%MODEL_NAME):
        os.mkdir('model/%s'%MODEL_NAME)
    if not os.path.exists('model/%s/%s'%(MODEL_NAME, str(fold))):
        os.mkdir('model/%s/%s'%(MODEL_NAME, str(fold)))
    filepath='model/%s/%s'%(MODEL_NAME, str(fold))+"/weights.hdf5"

    checkpoint = \
        ModelCheckpoint(filepath, monitor='val_fused_loss_accuracy', verbose=1, save_best_only=True, save_weights_only=False, mode='auto', overwrite=True, period=1)

    history = model.fit(X_train, y_train,
        batch_size=32,
        epochs=200,
        validation_data=(X_val, y_val),
        shuffle=True, verbose=1, callbacks=[checkpoint])