In [None]:
from moabb.datasets import BNCI2014001
from moabb.paradigms import MotorImagery
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from models.EEGNetAttention import EEGNnetAttention, EEGNnetAttentionModel

In [None]:
clfs = {
    'eegnetAttention_TT_32': (EEGNnetAttention(res=[True, True], temporal_conv_size=32), {'lr': 0.001}),
}

In [None]:
subjects = [[1], [2], [3], [4], [5], [6], [7], [8], [9]]
sessions = [['session_T', 'session_E']]
events = ['left_hand', 'right_hand', 'feet', 'tongue']
events_dict = {'left_hand':0, 'right_hand':1, 'feet':2, 'tongue':3}
n_classes = len(events)
tmin, tmax = 0.5, 2.5
fmin, fmax = 4, 40
fLoop = None
baseline=None
channels = None
resample = 128
rate = resample if resample is not None else 250
window_size = 2.0
step_window = 0.2
window_train_start = np.array([0.5])
window_test_start =  np.array([0.5])

window_size_map = int(rate*window_size)
window_train_start_map = (rate*(window_train_start-tmin)).astype(int)
window_test_start_map = (rate*(window_test_start-tmin)).astype(int)


skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
dataset = BNCI2014001()
paradigm = MotorImagery(
    events=events, n_classes=len(events),
    tmin=tmin, tmax=tmax,
    fmin=fmin, fmax=fmax,
    baseline=baseline,
    channels=channels,
    resample=resample
)

In [None]:
def crop_train(data, labels, window_size_map, start_window_map):
    X = []
    y = []
    for s in start_window_map:
        X.append(data[:, :, s:s+window_size_map])
        y.append(labels)

    X = np.array(X)
    X = np.concatenate(X, axis=0)
    y = np.concatenate(y, axis=None)

    return X, y


In [None]:
def load_Xy(dataset, subjects, sessions, fmin, fmax):
    paradigm.fmin = fmin
    paradigm.fmax = fmax
    X, y, metadata = paradigm.get_data(dataset=dataset, subjects=subjects)
    y = [events_dict[label] for label in y]
    X, y = np.array(X), np.array(y)

    m = np.array(metadata['session'])
    s = np.array(sessions)
    isin = np.isin(m, s)
    X = X[isin]
    y = y[isin]

    return X, y

In [None]:
def load_Xy_with_bands(dataset, subjects, sessions, fLoop):
    X_list = []
    y_temp = None
    for fmin_temp, fmax_temp in fLoop:
        X_temp, y_temp = load_Xy(dataset, subjects, sessions, fmin_temp, fmax_temp)
        X_list.append(X_temp)
    return np.stack(X_list, axis=-1), y_temp

In [None]:
for clf_name, clf_data in clfs.items():
    clf, metadata = clf_data
    results = []
    proba = []
    for sub in subjects:
        for sec in sessions:

            if fLoop is None:
                X, y = load_Xy(dataset, sub, sec, fmin, fmax)
            else:
                X, y = load_Xy_with_bands(dataset, sub, sec, fLoop)

            fold_number = 0
            acc_avg = 0
            for train_index, test_index in skf.split(X, y):
                
                fold_number += 1
                print(sub[0] if len(sub)==1 else sub, sec, fold_number, '                                  ')

                X_train, X_test = X[train_index], X[test_index]
                y_train, y_test = y[train_index], y[test_index]
                X_train, y_train = crop_train(X_train, y_train, window_size_map, window_train_start_map)
                X_track, y_track = crop_train(X_test, y_test, window_size_map, window_train_start_map)
                acc = clf.fit(X_train, y_train, lr=metadata['lr'], iterations=1000, batchsize=64, track=(X_track, y_track))
                acc_avg += acc

                for start_wind_map in window_test_start_map:
                    for trial in range(len(y_test)):
                        
                        X_test_windowed = X_test[:, :, start_wind_map:start_wind_map+window_size_map]

                        results.append( [sub[0] if len(sub)==1 else sub, sec, fold_number, trial, np.round(tmin+start_wind_map/rate, decimals=2), y_test[trial]] )
                        proba.append( np.round(*clf.predict_proba([X_test_windowed[trial]]), decimals=4) )
            print(f"resultados: {acc_avg/5}\n")
        
    results = pd.DataFrame(results, columns=['subjects', 'session', 'fold_number', 'trial', 'window_test_start', 'label'])
    proba = pd.DataFrame(proba, columns=['proba_' + event for event in events])
    results = pd.concat([results, proba], axis=1)

    results.to_csv('results_'+clf_name+'.csv', index=False)