<a href="https://colab.research.google.com/github/UN-GCPDS/python-gcpds.EEG_Tensorflow_models/blob/main/Examples/BCI2a/mtvae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U git+https://github.com/UN-GCPDS/python-gcpds.EEG_Tensorflow_models.git >/dev/null

In [None]:
from EEG_Tensorflow_models.Utils.LoadData import load_dataset
from EEG_Tensorflow_models.Utils.Callbacks import get_callbacks
from EEG_Tensorflow_models.Utils.TrainingModels import get_loss,get_model,get_optimizer,train_model_cv
from EEG_Tensorflow_models.Models import MTVAE

In [None]:
import numpy as np
import pickle

# Exp 1: Schirrmeister2017

In [None]:
opt_args = {'lr': 0.001,'beta_1': 0.9}
optimizer = get_optimizer('Adam',opt_args)

loss = get_loss(['mse','CategoricalCrossentropy'])

metrics = ['accuracy']

callbacks_names = {'early_stopping_train':'early_stopping','checkpoint_train':'checkpoint',
                'Threshold_valid':'Threshold','checkpoint_valid':'checkpoint',
                'early_stopping_valid':'early_stopping'}


Experiment = 'schirrmeister2017'
model_name = 'MTVAE'

subjects = np.arange(1,10)
Acc = []
History = []
Subject = []

for sbj in subjects:
    print('Subject: {:d} of {:d}'.format(sbj,len(subjects)))

    X_train,y_train,X_valid,y_valid,fs = load_dataset(dataset_name="BNCI2014001", subject_id=sbj)   
    X_train = X_train[:,:,0:-1,:]
    X_valid = X_valid[:,:,0:-1,:]
    
    model_args = {'nb_classes':4,'Chans':X_train.shape[1],'Samples':X_train.shape[2],'dropoutRate':0.5}
    model = get_model(model_name,model_args)
    
    call_args = [{'monitor':'val_Classif_accuracy','patience':100,'min_delta':0.001,'mode':'max','verbose':1,'restore_best_weights':False},
                {'filepath':Experiment+'checkpoint_sbj_'+str(sbj),'save_format':'tf','monitor':'val_Classif_accuracy','verbose':1,'save_weights_only':True,'save_best_only':True},
                {'threshold':None},
                {'filepath':Experiment+'checkpoint_2_sbj_'+str(sbj),'save_format':'tf','monitor':'val_Classif_accuracy','verbose':1,'save_weights_only':True,'save_best_only':True},
                {'monitor':'val_Classif_accuracy','patience':None,'min_delta':0.001,'mode':'max','verbose':1,'restore_best_weights':True}]
                
    callbacks = get_callbacks(callbacks_names,call_args)

    cv = train_model_cv(model,optimizer,loss,metrics,callbacks=callbacks,loss_weights=[2.5,1])

    history = cv.fit_validation(X_train,y = y_train,X_val=X_valid,y_val=y_valid,batch_size=30,epochs=1000,verbose=1,val_mode=Experiment,autoencoder=True)
    acc = cv.get_accuracy()
    print('Subject accuracy: {:f}'.format(acc))
    Acc.append(acc)
    History.append(History)
    Subject.append(sbj)

    results = {}
    results['subject'] = Subject
    results['history'] = History
    results['accuracy'] = Acc

    with open('Results_BCI2a_'+Experiment+'_'+model_name+'.p','wb') as handle:
        pickle.dump(results,handle)
    
    del cv,X_train,y_train,X_valid,y_valid,fs

# Exp 2: Shirrmeister2017 legal


In [None]:
opt_args = {'lr': 0.001,'beta_1': 0.9}
optimizer = get_optimizer('Adam',opt_args)

loss = get_loss(['mse','CategoricalCrossentropy'])

metrics = ['accuracy']

callbacks_names = {'early_stopping_train':'early_stopping','checkpoint_train':'checkpoint',
                'Threshold_valid':'Threshold','checkpoint_valid':'checkpoint',
                'early_stopping_valid':'early_stopping'}


Experiment = 'schirrmeister2017_legal'
model_name = 'MTVAE'

subjects = np.arange(1,10)
Acc = []
History = []
Subject = []
for sbj in subjects:
    print('Subject: {:d} of {:d}'.format(sbj,len(subjects)))

    X_train,y_train,X_valid,y_valid,fs = load_dataset(dataset_name="BNCI2014001", subject_id=sbj)
    X_train = X_train[:,:,0:-1,:]
    X_valid = X_valid[:,:,0:-1,:]

    model_args = {'nb_classes':4,'Chans':X_train.shape[1],'Samples':X_train.shape[2],'dropoutRate':0.5}
    model = get_model(model_name,model_args)


    call_args = [{'monitor':'val_Classif_accuracy','patience':100,'min_delta':0.001,'mode':'max','verbose':1,'restore_best_weights':False},
                {'filepath':Experiment+'checkpoint_sbj_'+str(sbj),'save_format':'tf','monitor':'val_Classif_accuracy','verbose':1,'save_weights_only':True,'save_best_only':True},
                {'threshold':None},
                {'filepath':Experiment+'checkpoint_2_sbj_'+str(sbj),'save_format':'tf','monitor':'val_Classif_accuracy','verbose':1,'save_weights_only':True,'save_best_only':True},
                {'monitor':'val_Classif_accuracy','patience':None,'min_delta':0.001,'mode':'max','verbose':1,'restore_best_weights':True}]
                
    callbacks = get_callbacks(callbacks_names,call_args)

    cv = train_model_cv(model,optimizer,loss,metrics,callbacks=callbacks,loss_weights=[2.5,1])

    history = cv.fit_validation(X_train,y = y_train,X_val=X_valid,y_val=y_valid,batch_size=30,epochs=1000,verbose=1,val_mode=Experiment,autoencoder=True)
    acc = cv.get_accuracy()
    print('Subject accuracy: {:f}'.format(acc))
    Acc.append(acc)
    History.append(History)
    Subject.append(sbj)

    results = {}
    results['subject'] = Subject
    results['history'] = History
    results['accuracy'] = Acc

    with open('Results_BCI2a_'+Experiment+'_'+model_name+'.p','wb') as handle:
        pickle.dump(results,handle)
    
    del cv,X_train,y_train,X_valid,y_valid,fs

# Exp 3: Shirrmeister 2021

In [None]:
opt_args = {'lr': 0.001,'beta_1': 0.9}
optimizer = get_optimizer('Adam',opt_args)

loss = get_loss(['mse','CategoricalCrossentropy'])

metrics = ['accuracy']

callbacks_names = {'checkpoint_valid':'checkpoint',
                   'early_stopping_valid':'early_stopping'}


Experiment = 'schirrmeister2021'
model_name = 'MTVAE'

subjects = np.arange(1,10)
Acc = []
History = []
Subject = []
for sbj in subjects:
    print('Subject: {:d} of {:d}'.format(sbj,len(subjects)))

    X_train,y_train,X_valid,y_valid,fs = load_dataset(dataset_name="BNCI2014001", subject_id=sbj)
    X_train = X_train[:,:,0:-1,:]
    X_valid = X_valid[:,:,0:-1,:]
    
    model_args = {'nb_classes':4,'Chans':X_train.shape[1],'Samples':X_train.shape[2],'dropoutRate':0.5}
    model = get_model(model_name,model_args)
    
    call_args = [
            {'filepath':Experiment+'checkpoint_'+str(sbj),
            'save_format':'tf',
            'monitor':'val_Classif_accuracy',
            'verbose':1,
            'save_weights_only':True,
            'save_best_only':True},
            {'monitor':'val_Classif_accuracy',
            'patience':100,
            'min_delta':0.001,
            'mode':'max',
            'verbose':1,
            'restore_best_weights':True}]
                
    callbacks = get_callbacks(callbacks_names,call_args)

    cv = train_model_cv(model,optimizer,loss,metrics,callbacks=callbacks,loss_weights=[2.5,1])

    history = cv.fit_validation(X_train,y_train,X_val=X_valid,y_val=y_valid,batch_size=30,epochs=1000,verbose=1,val_mode=Experiment,autoencoder=True)
    acc = cv.get_accuracy()
    print('Subject accuracy: {:f}'.format(acc))
    Acc.append(acc)
    History.append(History)
    Subject.append(sbj)

    results = {}
    results['subject'] = Subject
    results['history'] = History
    results['accuracy'] = Acc

    with open('Results_BCI2a_'+Experiment+'_'+model_name+'.p','wb') as handle:
        pickle.dump(results,handle)
    
    del cv,X_train,y_train,X_valid,y_valid,fs

# Exp 4: 4-fold CV

In [None]:
opt_args = {'lr': 0.001,'beta_1': 0.9}
optimizer = get_optimizer('Adam',opt_args)

loss = get_loss(['mse','CategoricalCrossentropy'])

metrics = ['accuracy']

callbacks_names = {'checkpoint_train1':'checkpoint','checkpoint_train2':'checkpoint','checkpoint_train3':'checkpoint','checkpoint_train4':'checkpoint'}


Experiment = 'lawhern2018'
model_name = 'MTVAE'


subjects = np.arange(1,10)
Acc = []
History = []
Subject = []
for sbj in subjects:
    print('Subject: {:d} of {:d}'.format(sbj,len(subjects)))

    X_train,y_train,X_valid,y_valid,fs = load_dataset(dataset_name="BNCI2014001", subject_id=sbj)
    X_train = X_train[:,:,0:-1,:]
    X_valid = X_valid[:,:,0:-1,:]
   
    model_args = {'nb_classes':4,'Chans':X_train.shape[1],'Samples':X_train.shape[2],'dropoutRate':0.5}
    model = get_model(model_name,model_args)

    call_args = [
            
            {'filepath':Experiment+'checkpoint1_'+str(sbj),
            'save_format':'tf',
            'monitor':'val_Classif_accuracy',
            'verbose':1,
            'save_weights_only':True,
            'save_best_only':True},
             {'filepath':Experiment+'checkpoint2_'+str(sbj),
            'save_format':'tf',
            'monitor':'val_Classif_accuracy',
            'verbose':1,
            'save_weights_only':True,
            'save_best_only':True},
             {'filepath':Experiment+'checkpoint3_'+str(sbj),
            'save_format':'tf',
            'monitor':'val_Classif_accuracy',
            'verbose':1,
            'save_weights_only':True,
            'save_best_only':True},
             {'filepath':Experiment+'checkpoint4_'+str(sbj),
            'save_format':'tf',
            'monitor':'val_Classif_accuracy',
            'verbose':1,
            'save_weights_only':True,
            'save_best_only':True}]
                
    callbacks = get_callbacks(callbacks_names,call_args)

    cv = train_model_cv(model,optimizer,loss,metrics,callbacks=callbacks,loss_weights=[2.5,1])

    history = cv.fit_validation(X_train,y_train,X_val=X_valid,y_val=y_valid,batch_size=100,epochs=1000,verbose=1,val_mode=Experiment,autoencoder=True)
    acc = cv.get_accuracy()
    print('Subject accuracy: {:f}'.format(acc))
    Acc.append(acc)
    History.append(History)
    Subject.append(sbj)

    results = {}
    results['subject'] = Subject
    results['history'] = History
    results['accuracy'] = Acc

    with open('Results_BCI2a_'+Experiment+'_'+model_name+'.p','wb') as handle:
        pickle.dump(results,handle)
    
    del cv,X_train,y_train,X_valid,y_valid,fs