In [1]:
import os, sys
import scipy.io as sio
import warnings
warnings.filterwarnings('ignore')
import time
from joblib import Parallel, delayed

import numpy as np
from sklearn.model_selection import ShuffleSplit,LeaveOneOut
from toolbox1 import PreProcessing,acc_calculate,TRCA_train,TRCA_test
from toolbox1 import get_P,TDCA_train,TDCA_test
from toolbox1 import get_augment_fb_noiseAfter_ms
warnings.filterwarnings('ignore')

In [2]:
# main function
def benchmark_TRCAmssame(idx_num, n_train, t_task, n_Aug, n_Neig=None):

    # setting
    freq_tmp = np.arange(8, 16, 1)
    f_list = np.hstack((freq_tmp, freq_tmp + 0.2, freq_tmp + 0.4, freq_tmp + 0.6, freq_tmp + 0.8))
    target_order = np.argsort(f_list)
    f_list = f_list[target_order]
    phase_list = np.array([
        0, 0.5, 1, 1.5, 0, 0.5, 1, 1.5,
         0, 0.5, 1, 1.5, 0, 0.5, 1, 1.5,
         0, 0.5, 1, 1.5, 0, 0.5, 1, 1.5,
         0, 0.5, 1, 1.5, 0, 0.5, 1, 1.5,
         0, 0.5, 1, 1.5, 0, 0.5, 1, 1.5,
    ])
    subject_id = ['S'+'{:02d}'.format(idx_subject+1) for idx_subject in range(35)]

    idx_num = idx_num
    idx_subject = subject_id[idx_num]
    sfreq = 250
    filepath = os.path.join(filepath_ori, str(idx_subject) + '.mat')
    num_filter = 5
    preEEG = PreProcessing(filepath, t_begin=0.5, t_end=0.5 + 0.14 + t_task,  # t_begin=0.5+0.14, t_end=0.5+0.14+0.3
                           fs_down=250, chans=['POZ', 'PZ', 'PO3', 'PO5', 'PO4', 'PO6', 'O1', 'OZ', 'O2'],
                           num_filter=num_filter)

    raw_data = preEEG.load_data()
    w_pass_2d = np.array([[5, 14, 22, 30, 38], [90, 90, 90, 90, 90]])  # 70
    w_stop_2d = np.array([[3, 12, 20, 28, 36], [92, 92, 92, 92, 92]])  # 72
    filtered_data = preEEG.filtered_data_iir111(w_pass_2d, w_stop_2d, raw_data)

    filtered_data['bank1'] = filtered_data['bank1'][:, : ,target_order,:]  # Sorted by frequency in ascending order
    filtered_data['bank2'] = filtered_data['bank2'][:, :, target_order, :]
    filtered_data['bank3'] = filtered_data['bank3'][:, :, target_order, :]
    filtered_data['bank4'] = filtered_data['bank4'][:, :, target_order, :]
    filtered_data['bank5'] = filtered_data['bank5'][:, :, target_order, :]

    """
     Cross-validation parameters
    """
    nBlock = 6
    nEvent = 40
    train_size = n_train  # input
    n_splits = 6
    if train_size == nBlock - 1 or train_size == 1:
        kf = LeaveOneOut()
    else:
        kf = ShuffleSplit(n_splits=n_splits, train_size=train_size, random_state=idx_num+1)

    t = t_task              # input
    task_point = np.arange(int((0.14) * sfreq), int((0.14 + t) * sfreq))

    # train
    acc_s = 0
    for train, test in kf.split(np.arange(nBlock)):
        if train_size == 1:
            train, test = test, train
        # train : get ensembleW of banks
        train_w = dict()
        train_meantemp = dict()
        for idx_filter in range(num_filter):
            idx_filter += 1
            bank_data = filtered_data['bank' + str(idx_filter)]
            train_data11 = bank_data[:, :, :, train]
            train_data = train_data11[:, task_point, :, :]  # n_channels * n_times * n_events * n_trials

            if n_Aug == 0:
                trainData_pt = train_data.copy()
            else:
                # Data augmentation
                ntrail_noise = n_Aug
                data_augment = np.zeros((train_data.shape[0], train_data.shape[1], train_data.shape[2], ntrail_noise))
                for ievent in range(nEvent):
                    # get Nh_strat
                    f = f_list[ievent]
                    for ih in range(5):
                        ih = ih + 1
                        if ih * f >= 8 * idx_filter:
                            Nh_start = ih
                            break
                    # print('idx_filter=',idx_filter,'f=',f,'Nh_start=',Nh_start)
                    data_augment[:, :, ievent, :] = get_augment_fb_noiseAfter_ms(fs=sfreq, f_list=f_list,
                                                                                 phi_list=phase_list, Nh_start=Nh_start,
                                                                                 Nh_end=5,
                                                                                 ntrail_noise=ntrail_noise,
                                                                                 mean_temp_all=np.mean(train_data, -1),
                                                                                 iEvent=ievent, nTemplates=n_Neig)
                trainData_pt = np.concatenate((train_data, data_augment), axis=3)

            # train
            w, mean_temp = TRCA_train(trainData_pt)
            train_w['bank' + str(idx_filter)] = w
            train_meantemp['bank' + str(idx_filter)] = mean_temp

        # test:
        predictAll = np.zeros((test.shape[0], nEvent),int)
        flag = 0
        for isplit in test:
            rrall = np.zeros((nEvent, nEvent))
            for idx_filter in range(num_filter):
                idx_filter += 1
                bank_data = filtered_data['bank' + str(idx_filter)]
                test_data = bank_data[:, :, :, isplit]
                test_data = test_data[:, task_point, :]
                rr = TRCA_test(test_data, train_w['bank' + str(idx_filter)],
                               train_meantemp['bank' + str(idx_filter)], True)
                rrall += np.multiply(np.sign(rr), (rr ** 2)) * (idx_filter ** (-1.25) + 0.25)
            predict = np.argmax(rrall, -1)
            predictAll[flag, :] = predict
            flag += 1
        acc_s = acc_calculate(predictAll) + acc_s
    acc = acc_s / n_splits
    print('sub', idx_num + 1, ', acc = ', acc_s / n_splits)
    return acc

def benchmark_TDCAmssame(idx_num, n_train, t_task, n_Aug, n_Neig=None):

    # setting
    freq_tmp = np.arange(8, 16, 1)
    f_list = np.hstack((freq_tmp, freq_tmp + 0.2, freq_tmp + 0.4, freq_tmp + 0.6, freq_tmp + 0.8))
    target_order = np.argsort(f_list)
    f_list = f_list[target_order]
    phase_list = np.array([
        0, 0.5, 1, 1.5, 0, 0.5, 1, 1.5,
         0, 0.5, 1, 1.5, 0, 0.5, 1, 1.5,
         0, 0.5, 1, 1.5, 0, 0.5, 1, 1.5,
         0, 0.5, 1, 1.5, 0, 0.5, 1, 1.5,
         0, 0.5, 1, 1.5, 0, 0.5, 1, 1.5,
    ])
    subject_id = ['S'+'{:02d}'.format(idx_subject+1) for idx_subject in range(35)]

    idx_num = idx_num
    idx_subject = subject_id[idx_num]
    sfreq = 250
    filepath = os.path.join(filepath_ori, str(idx_subject) + '.mat')
    num_filter = 5
    preEEG = PreProcessing(filepath, t_begin=0.5, t_end=0.5 + 0.14 + t_task + 5/sfreq,
                           fs_down=250, chans=['POZ', 'PZ', 'PO3', 'PO5', 'PO4', 'PO6', 'O1', 'OZ', 'O2'],
                           num_filter=num_filter)

    raw_data = preEEG.load_data()
    w_pass_2d = np.array([[5, 14, 22, 30, 38], [90, 90, 90, 90, 90]])
    w_stop_2d = np.array([[3, 12, 20, 28, 36], [92, 92, 92, 92, 92]])
    filtered_data = preEEG.filtered_data_iir111(w_pass_2d, w_stop_2d, raw_data)

    filtered_data['bank1'] = filtered_data['bank1'][:, : ,target_order,:]  # Sorted by frequency in ascending order
    filtered_data['bank2'] = filtered_data['bank2'][:, :, target_order, :]
    filtered_data['bank3'] = filtered_data['bank3'][:, :, target_order, :]
    filtered_data['bank4'] = filtered_data['bank4'][:, :, target_order, :]
    filtered_data['bank5'] = filtered_data['bank5'][:, :, target_order, :]

    """
     Cross-validation parameters
    """
    nBlock = 6
    nEvent = 40
    train_size = n_train   # input
    n_splits = 6
    if train_size == nBlock - 1 or train_size == 1:
        kf = LeaveOneOut()
    else:
        kf = ShuffleSplit(n_splits=n_splits, train_size=train_size, random_state=idx_num+1)

    """
    TDCA parameters
    """
    l = 5  # delay point for TDCA
    sTime = t_task        #   input
    train_point  = np.arange(int((0.14) * sfreq), int((0.14 + sTime) * sfreq)+l)
    test_point = np.arange(int((0.14) * sfreq), int((0.14 + sTime) * sfreq))
    # Obtain the projection matrix P of all classes
    P = get_P(f_list=f_list, Nh=5, sTime=sTime, sfreq=sfreq)

    # train4
    acc_s = 0
    for train, test in kf.split(np.arange(nBlock)):
        if train_size == 1:
            train, test = test, train
        # train : get ensembleW of banks
        train_w = dict()
        train_meantemp = dict()
        for idx_filter in range(num_filter):
            idx_filter += 1
            bank_data = filtered_data['bank' + str(idx_filter)]
            train_data11 = bank_data[:, :, :, train]
            train_data = train_data11[:, train_point, :, :]  # n_channels * n_times * n_events * n_trials

            if n_Aug == 0:
                trainData_pt = train_data.copy()
            else:
                # Data augmentation
                ntrail_noise = n_Aug
                data_augment = np.zeros((train_data.shape[0], train_data.shape[1], train_data.shape[2], ntrail_noise))
                for ievent in range(nEvent):
                    # get Nh_strat
                    f = f_list[ievent]
                    for ih in range(5):
                        ih = ih + 1
                        if ih * f >= 8 * idx_filter:
                            Nh_start = ih
                            break
                    data_augment[:, :, ievent, :] = get_augment_fb_noiseAfter_ms(fs=sfreq, f_list=f_list,
                                                                                 phi_list=phase_list, Nh_start=Nh_start,
                                                                                 Nh_end=5,
                                                                                 ntrail_noise=ntrail_noise,
                                                                                 mean_temp_all=np.mean(train_data, -1),
                                                                                 iEvent=ievent, nTemplates=n_Neig)
                trainData_pt = np.concatenate((train_data, data_augment), axis=3)

            # train
            w, mean_temp_TDCA = TDCA_train(trainData_pt, P=P, l=l, Nk=8)
            train_w['bank' + str(idx_filter)] = w
            train_meantemp['bank' + str(idx_filter)] = mean_temp_TDCA

        # test:
        predictAll = np.zeros((test.shape[0], nEvent),int)
        flag = 0
        for isplit in test:
            rrall = np.zeros((nEvent, nEvent))
            for idx_filter in range(num_filter):
                idx_filter += 1
                bank_data = filtered_data['bank' + str(idx_filter)]
                test_data111 = bank_data[:, :, :, isplit]
                test_data = test_data111[:,test_point,:]
                rr = TDCA_test(test_data, train_w['bank' + str(idx_filter)], train_meantemp['bank' + str(idx_filter)],
                               P=P, l=l)
                rrall += np.multiply(np.sign(rr), (rr ** 2)) * (idx_filter ** (-1.25) + 0.25)
            predict = np.argmax(rrall, -1)
            predictAll[flag, :] = predict
            flag += 1
        acc_s = acc_calculate(predictAll) + acc_s
    acc = acc_s / n_splits
    print('sub', idx_num + 1, ', acc = ', acc_s / n_splits)
    return acc

## main

In [3]:
# setting
n_subjects = 35
Train_size_list = [1,2,3,4,5]
nAug_list=[3,4,5,5,7]
t_task_list = [0.5]
# nNeig_list = [i*2 for i in range(21)]
filepath_ori = '/mnt/Bench'

### eTRCA

In [4]:
######## eTRCA withoutSAME ########
print('eTRCA withoutSAME is executing...')
# acc = benchmark_TRCAmssame(idx_num=14, n_train = 3 ,t_task=0.5, n_Aug=0)
acc_all = np.zeros((n_subjects,len(Train_size_list),len(t_task_list)))
for i, i_times in enumerate(t_task_list):
    for j ,j_train in enumerate(Train_size_list):
        T1 = time.time()
        acc = Parallel(n_jobs=-1)(delayed(benchmark_TRCAmssame)(idx_num, n_train=j_train, t_task=i_times,n_Aug=0) for idx_num in range(n_subjects))
        acc = np.array(acc)
        acc_all[:,j,i] = acc
        T2 = time.time()
        print('n_times=',i_times,';n_trian=',j_train,';mean_acc=',np.mean(acc),'; running time=',T2 - T1)
# sio.savemat(r'bench_eTRCA_acc.mat', {'acc': acc_all})

eTRCA withoutSAME is executing...
n_times= 0.5 ;n_trian= 1 ;mean_acc= 0.10002380952380951 ; running time= 122.80278158187866
n_times= 0.5 ;n_trian= 2 ;mean_acc= 0.5757440476190476 ; running time= 102.6675329208374
n_times= 0.5 ;n_trian= 3 ;mean_acc= 0.6994047619047619 ; running time= 91.52503418922424
n_times= 0.5 ;n_trian= 4 ;mean_acc= 0.7591666666666667 ; running time= 66.55765581130981
n_times= 0.5 ;n_trian= 5 ;mean_acc= 0.7929761904761906 ; running time= 49.36817502975464


### eTRCA(w/SAME)

In [5]:
######## eTRCA with SAME ########
print('eTRCA with SAME is executing...')
# acc = benchmark_TRCAmssame(idx_num=14, n_train = 3 ,t_task=0.5, n_Aug=nAug_list[2],n_Neig=0)
acc_all = np.zeros((n_subjects,len(Train_size_list),len(t_task_list)))
for i, i_times in enumerate(t_task_list):
    for j ,j_train in enumerate(Train_size_list):
        T1 = time.time()
        acc = Parallel(n_jobs=-1)(delayed(benchmark_TRCAmssame)(idx_num, n_train=j_train, t_task=i_times,n_Aug=nAug_list[j],n_Neig=0) for idx_num in range(n_subjects))
        acc = np.array(acc)
        acc_all[:,j,i] = acc
        T2 = time.time()
        print('n_times=',i_times,';n_trian=',j_train,';mean_acc=',np.mean(acc),'; running time=',T2 - T1)
# sio.savemat(r'bench_eTRCA_SAME_acc.mat', {'acc': acc_all})

eTRCA with SAME is executing...
n_times= 0.5 ;n_trian= 1 ;mean_acc= 0.6005952380952381 ; running time= 129.91607546806335
n_times= 0.5 ;n_trian= 2 ;mean_acc= 0.7397321428571428 ; running time= 110.95805764198303
n_times= 0.5 ;n_trian= 3 ;mean_acc= 0.7946031746031746 ; running time= 95.50911235809326
n_times= 0.5 ;n_trian= 4 ;mean_acc= 0.8216071428571426 ; running time= 75.88654851913452
n_times= 0.5 ;n_trian= 5 ;mean_acc= 0.8420238095238096 ; running time= 68.34845638275146


### eTRCA(w/msSAME)

In [6]:
######## eTRCA with msSAME ########
print('eTRCA with msSAME is executing...')
# acc = benchmark_TRCAmssame(idx_num=14, n_train = 3 ,t_task=0.5, n_Aug=nAug_list[2],n_Neig=12)
acc_all = np.zeros((n_subjects,len(Train_size_list),len(t_task_list)))
for i, i_times in enumerate(t_task_list):
    for j ,j_train in enumerate(Train_size_list):
        T1 = time.time()
        acc = Parallel(n_jobs=-1)(delayed(benchmark_TRCAmssame)(idx_num, n_train=j_train, t_task=i_times,n_Aug=nAug_list[j],n_Neig=12) for idx_num in range(n_subjects))
        acc = np.array(acc)
        acc_all[:,j,i] = acc
        T2 = time.time()
        print('n_times=',i_times,';n_trian=',j_train,';mean_acc=',np.mean(acc),'; running time=',T2 - T1)
# sio.savemat(r'bench_eTRCA_msSAME_acc.mat', {'acc': acc_all})

eTRCA with msSAME is executing...
n_times= 0.5 ;n_trian= 1 ;mean_acc= 0.7166666666666668 ; running time= 141.05468153953552
n_times= 0.5 ;n_trian= 2 ;mean_acc= 0.7786904761904762 ; running time= 117.51873183250427
n_times= 0.5 ;n_trian= 3 ;mean_acc= 0.8083730158730158 ; running time= 96.20761823654175
n_times= 0.5 ;n_trian= 4 ;mean_acc= 0.8241666666666667 ; running time= 96.23084568977356
n_times= 0.5 ;n_trian= 5 ;mean_acc= 0.8354761904761904 ; running time= 73.35704922676086


### TDCA

In [7]:
######## TDCA withoutSAME ########
print('TDCA withoutSAME is executing...')
# acc = benchmark_TDCAmssame(idx_num=14, n_train = 3 ,t_task=0.5, n_Aug=0)
acc_all = np.zeros((n_subjects,len(Train_size_list),len(t_task_list)))
for i, i_times in enumerate(t_task_list):
    for j ,j_train in enumerate(Train_size_list):
        T1 = time.time()
        acc = Parallel(n_jobs=-1)(delayed(benchmark_TDCAmssame)(idx_num, n_train=j_train, t_task=i_times,n_Aug=0) for idx_num in range(n_subjects))
        acc = np.array(acc)
        acc_all[:,j,i] = acc
        T2 = time.time()
        print('n_times=',i_times,';n_trian=',j_train,';mean_acc=',np.mean(acc),'; running time=',T2 - T1)
# sio.savemat(r'bench_TDCA_acc.mat', {'acc': acc_all})

TDCA withoutSAME is executing...
n_times= 0.5 ;n_trian= 1 ;mean_acc= 0.16814285714285715 ; running time= 1787.8970699310303
n_times= 0.5 ;n_trian= 2 ;mean_acc= 0.7411607142857143 ; running time= 1497.8953742980957
n_times= 0.5 ;n_trian= 3 ;mean_acc= 0.8076190476190476 ; running time= 1117.550154209137
n_times= 0.5 ;n_trian= 4 ;mean_acc= 0.8401190476190475 ; running time= 771.7164487838745
n_times= 0.5 ;n_trian= 5 ;mean_acc= 0.8570238095238095 ; running time= 444.0344326496124


### TDCA(w/SAME)

In [8]:
######## TDCA with SAME ########
print('TDCA with SAME is executing...')
# acc = benchmark_TDCAmssame(idx_num=14, n_train = 3 ,t_task=0.5, n_Aug=nAug_list[2],n_Neig=0)
acc_all = np.zeros((n_subjects,len(Train_size_list),len(t_task_list)))
for i, i_times in enumerate(t_task_list):
    for j ,j_train in enumerate(Train_size_list):
        T1 = time.time()
        acc = Parallel(n_jobs=-1)(delayed(benchmark_TDCAmssame)(idx_num, n_train=j_train, t_task=i_times,n_Aug=nAug_list[j],n_Neig=0) for idx_num in range(n_subjects))
        acc = np.array(acc)
        acc_all[:,j,i] = acc
        T2 = time.time()
        print('n_times=',i_times,';n_trian=',j_train,';mean_acc=',np.mean(acc),'; running time=',T2 - T1)
# sio.savemat(r'bench_TDCA_SAME_acc.mat', {'acc': acc_all})

TDCA with SAME is executing...
n_times= 0.5 ;n_trian= 1 ;mean_acc= 0.6491428571428572 ; running time= 1807.162671327591
n_times= 0.5 ;n_trian= 2 ;mean_acc= 0.7789583333333334 ; running time= 1477.968709230423
n_times= 0.5 ;n_trian= 3 ;mean_acc= 0.8252380952380953 ; running time= 1136.2217450141907
n_times= 0.5 ;n_trian= 4 ;mean_acc= 0.8511309523809524 ; running time= 779.7826581001282
n_times= 0.5 ;n_trian= 5 ;mean_acc= 0.8642857142857142 ; running time= 492.77216386795044


### TDCA(w/msSAME)

In [9]:
######## TDCA with msSAME ########
print('TDCA with msSAME is executing...')
# acc = benchmark_TDCAmssame(idx_num=14, n_train = 3 ,t_task=0.5, n_Aug=nAug_list[2],n_Neig=14)
acc_all = np.zeros((n_subjects,len(Train_size_list),len(t_task_list)))
for i, i_times in enumerate(t_task_list):
    for j ,j_train in enumerate(Train_size_list):
        T1 = time.time()
        acc = Parallel(n_jobs=-1)(delayed(benchmark_TDCAmssame)(idx_num, n_train=j_train, t_task=i_times,n_Aug=nAug_list[j],n_Neig=14) for idx_num in range(n_subjects))
        acc = np.array(acc)
        acc_all[:,j,i] = acc
        T2 = time.time()
        print('n_times=',i_times,';n_trian=',j_train,';mean_acc=',np.mean(acc),'; running time=',T2 - T1)
# sio.savemat(r'bench_TDCA_msSAME_acc.mat', {'acc': acc_all})

TDCA with msSAME is executing...
n_times= 0.5 ;n_trian= 1 ;mean_acc= 0.7691666666666666 ; running time= 1693.8621637821198
n_times= 0.5 ;n_trian= 2 ;mean_acc= 0.8282440476190476 ; running time= 1398.0316352844238
n_times= 0.5 ;n_trian= 3 ;mean_acc= 0.8539682539682539 ; running time= 1093.1004683971405
n_times= 0.5 ;n_trian= 4 ;mean_acc= 0.8688095238095238 ; running time= 782.0254456996918
n_times= 0.5 ;n_trian= 5 ;mean_acc= 0.8752380952380953 ; running time= 496.7567539215088
