In [6]:
import mne
import scipy.io as sp
import numpy as np
import random
import pandas as pd
import multiprocessing as mp
import concurrent.futures
from mne.decoding import CSP
import pymrmr
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.ensemble import RandomForestClassifier as RF
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import logging
from scipy.io import loadmat

In [3]:
num_channels = 64
epoch_length = 1000
sampling_freq = 250
number_of_runs = 10
# number_of_splits = 10
number_of_components = 10
number_of_selected_features = 10
number_of_processes = 15
number_of_bands = 9
# rf = pd.DataFrame()
column_names = ['participant', 'class1', 'class2','running_time','test_acc','train_acc','test_size','train_size','train_block','test_block']
# rf = rf.reindex(columns=column_names)

In [None]:
def calc_csp(x_train, y_train, x_test):
    
    csp = CSP(number_of_components)
    csp_fit = csp.fit(x_train, y_train)
    train_feat = csp_fit.transform(x_train)
    test_feat = csp_fit.transform(x_test)
    return train_feat, test_feat

In [None]:
def class_extractor(number_of_epochs, class_1, class_2, data, labels):
    size = sum(labels[:,0] == class_1) + sum(labels[:,0] == class_2)
    Final_labels = np.zeros((size,1)).astype(int)
    dataset = np.zeros((size,num_channels, epoch_length))
    index = 0
    for i in range(number_of_epochs):
        if labels[i,0] == class_1 or labels[i,0] == class_2:
            dataset[index,:,:] = data[i,:,:]
            Final_labels[index,0] = labels[i,0]
            index = index + 1
        else:
            continue
            
    return dataset, Final_labels

In [None]:
def feature_extractor(dataset, labels, number_of_bands, test_data):

    low_cutoff = 0
    
    for b in range(number_of_bands):
        logging.getLogger('mne').setLevel(logging.WARNING)
        low_cutoff += 4
        data = dataset.copy()
        data_test = test_data.copy()
        filtered_data = mne.filter.filter_data(data, sampling_freq, low_cutoff, low_cutoff + 4, verbose = False, n_jobs = 4)
        filtered_data_test = mne.filter.filter_data(test_data, sampling_freq, low_cutoff, low_cutoff + 4, verbose = False, n_jobs = 4)
        [train_feats, test_feats] = calc_csp(filtered_data, labels[:,0], filtered_data_test)
        if b == 0:
            train_features = train_feats
            test_features = test_feats
        else:
            train_features = np.concatenate((train_features, train_feats), axis = 1)
            test_features = np.concatenate((test_features, test_feats), axis = 1)
    
    return train_features, test_features

In [None]:
def feature_selector(train_features, labels, number_of_selected_features):
    X = pd.DataFrame(train_features)
    y = pd.DataFrame(labels)
    K = number_of_selected_features
    
    df = pd.concat([y,X], axis = 1)
    df.columns = df.columns.astype(str)
        
    selected_features = list(map(int, pymrmr.mRMR(df, 'MID', K)))
    return selected_features

In [7]:
def data_reader(path):

    mat = loadmat(path, chars_as_strings=True, mat_dtype=True, squeeze_me=True, struct_as_record=False, verify_compressed_data_integrity=False, variable_names=None)
    df = pd.DataFrame(mat['Data'])
    return df

In [8]:
PATH = '../../Participants/P1/'
P_NUM = 1
B_NUM = 2
CLASS_1 = "Feet"
CLASS_2 = "Rest"
data_tr = data_reader(PATH+'P'+str(P_NUM)+'B'+str(B_NUM)+'.mat')
data_te = data_reader(PATH+'P'+str(P_NUM)+'B'+str(4)+'.mat')






# for i in range(number_of_epochs):
#     data[i,:,:] = X[:, randomlist[i]*epoch_length:(randomlist[i] + 1)*epoch_length]
#     if (df['condition'][randomlist[i]*epoch_length] == 'Left'):
#         labels[i,0] = 0
#     elif(df['condition'][randomlist[i]*epoch_length] == 'Right'):
#         labels[i,0] = 1
#     elif(df['condition'][randomlist[i]*epoch_length] == 'Feet'):
#         labels[i,0] = 2
#     elif(df['condition'][randomlist[i]*epoch_length] == 'Tongue'):
#         labels[i,0] = 3
#     elif(df['condition'][randomlist[i]*epoch_length] == 'Mis'):
#         labels[i,0] = 4
#     elif(df['condition'][randomlist[i]*epoch_length] == 'Si'):
#         labels[i,0] = 5
#     else:
#         labels[i,0] = 6



# preprocessor(data_tr,data_te)

# X_tr_raw,X_te_raw,y_tr_raw,y_te_raw,number_of_epochs_tr,number_of_epochs_te = preprocessor(X_train,X_test,data1)
# [X_tr, y_tr] = class_extraction(number_of_epochs_tr, class_1, class_2, X_tr_raw, y_tr_raw)
# [X_te, y_te] = class_extraction(number_of_epochs_te, class_1, class_2, X_te_raw, y_te_raw) 
# print(X_te.shape,"X_te.shape")


In [11]:
block_order = ['Feet','Mis','Hand','Tongue']
tasks_time = [12,16,8,12,20,8,16,20]
SAMPLING_RATE = 250
df = data_tr.copy()

Begin_indexes = df[df.iloc[:, 64] == 'Begin'].index
End_indexes = df[df.iloc[:, 64] == 'End'].index
# print(indexes)
# print(Begin_indexes)
# print(End_indexes)
# print(df.iloc[1,64])
if(len(Begin_indexes)==len(End_indexes)):

    for i in range(len(Begin_indexes)):
        index = Begin_indexes[i]+1
        val = df.iloc[index,64]
        df.iloc[Begin_indexes[i],64] = "Begin" + "_" + str(val)
        df.iloc[End_indexes[i],64]   =  "End" + "_" + str(val)


In [48]:
idx = df[df.iloc[:, 64] == 'End_Tongue'].index
print(idx[0])

CLASS_1 = block_order[0]
Begin_trigger = "Begin" + "_" + CLASS_1
End_trigger = "End" + "_" + CLASS_1

Begin_idx = df[df.iloc[:, 64] == Begin_trigger].index
End_idx = df[df.iloc[:, 64] == End_trigger].index
print(Begin_idx[0],End_idx[0])

trial_df = df.iloc[Begin_idx[0]:End_idx[0]+1,:]
# trial_df.tail()

idxx = trial_df[trial_df.iloc[:, 64] == 'Rest'].index
idxx2 = trial_df[trial_df.iloc[:, 64] == 'Feet'].index
# print(idxx,len(idxx))
# print(idxx2,len(idxx2))

trial_df2 = trial_df.copy()

class_x = 'Feet'
class_y = 'Rest'
new_df = pd.DataFrame()
# sample_point = tasks_time[0]*SAMPLING_RATE
# if(trial_df2.iloc[sample_point+1,64] == class_x ):
#     temp_df = trial_df2.iloc[:sample_point,:]
#     next_task_idx = trial_df2[trial_df2.iloc[:, 64] == class_y].index
#     trial_df2.drop(trial_df2.index[0:next_task_idx[0]], inplace=True)
#     trial_df2.reset_index(drop=True, inplace=True)
#     new_df = pd.concat([new_df, temp_df], axis=0)

# sample_point = tasks_time[1]*SAMPLING_RATE
# if(trial_df2.iloc[sample_point+1,64] == class_y ):
#     temp_df2 = trial_df2.iloc[:sample_point,:]
#     next_task_idx = trial_df2[trial_df2.iloc[:, 64] == class_x].index
#     trial_df2.drop(trial_df2.index[0:next_task_idx[0]], inplace=True)
#     trial_df2.reset_index(drop=True, inplace=True)    

# new_df = pd.concat([temp_df, temp_df2], axis=0)
# new_df.reset_index(drop=True, inplace=True)
# new_df.tail()

for i in range(len(tasks_time)):
    sample_point = tasks_time[i]*SAMPLING_RATE
    if(trial_df2.iloc[sample_point+1,64] == class_x ):
         if(i==len(tasks_time)-1):
            temp_df = trial_df2.iloc[:sample_point,:]
            new_df = pd.concat([new_df, temp_df], axis=0)
            new_df.reset_index(drop=True, inplace=True)
         else:    
            temp_df = trial_df2.iloc[:sample_point,:]
            next_task_idx = trial_df2[trial_df2.iloc[:, 64] == class_y].index
            trial_df2.drop(trial_df2.index[0:next_task_idx[0]], inplace=True)
            trial_df2.reset_index(drop=True, inplace=True)
            new_df = pd.concat([new_df, temp_df], axis=0)
            new_df.reset_index(drop=True, inplace=True)
            class_x,class_y = class_y,class_x
new_df.head(23002)







115078
0 28760


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,55,56,57,58,59,60,61,62,63,64
0,-6007.68457,12510.444336,7844.45166,-4246.978027,-3228.525146,-103.249702,8590.283203,-5608.668945,15386.517578,-11876.93457,...,-2211.25415,-7136.249023,43.948391,7693.819824,2708.186035,-10187.832031,10895.167969,7435.178711,15664.952148,Begin_Feet
1,-6012.343262,12505.867188,7834.443359,-4234.376953,-3221.28418,-58.525509,8586.755859,-5612.939453,15399.620117,-11870.067383,...,-2209.654785,-7135.818359,42.244194,7694.793945,2700.129883,-10194.820312,10893.699219,7429.57959,15661.667969,Feet
2,-6004.281738,12508.5,7836.693359,-4235.759766,-3225.619385,-89.742882,8589.950195,-5611.961426,15399.019531,-11871.919922,...,-2206.152344,-7133.556641,45.26141,7695.370605,2701.121582,-10193.835938,10892.120117,7430.616699,15662.744141,Feet
3,-6006.712891,12511.026367,7839.393555,-4254.118164,-3237.662842,-97.241196,8592.822266,-5606.171387,15382.796875,-11880.046875,...,-2205.669434,-7130.776367,46.148624,7693.93457,2701.351074,-10195.214844,10894.549805,7432.617188,15667.950195,Feet
4,-6015.308594,12510.445312,7840.519043,-4254.352539,-3237.874268,-128.30748,8591.375,-5594.76123,15381.959961,-11880.235352,...,-2208.33374,-7131.682617,45.585205,7698.09082,2702.700195,-10194.633789,10893.917969,7433.499512,15667.558594,Feet
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
22997,-7017.963867,12475.336914,7695.091797,-4297.781738,-3122.813721,-0.738643,8541.771484,-5528.883789,15218.045898,-11725.635742,...,-1862.664429,-7456.255859,55.565693,7705.930176,2801.706787,-10274.302734,11030.19043,7357.914062,15704.285156,Feet
22998,-7020.932129,12474.482422,7693.085938,-4290.768555,-3120.47998,1.002384,8539.151367,-5529.188477,15219.15625,-11724.03125,...,-1862.194946,-7457.035645,54.928406,7708.891113,2803.579102,-10274.628906,11026.12793,7357.099609,15701.990234,Feet
22999,-7029.184082,12473.793945,7692.151855,-4293.062988,-3118.178223,0.633743,8539.469727,-5524.329102,15220.978516,-11723.253906,...,-1868.17688,-7456.041016,55.182117,7709.443359,2802.268555,-10278.207031,11025.816406,7354.458496,15701.234375,Feet
23000,-7037.529297,12468.173828,7688.389648,-4294.895508,-3122.076172,-0.09983,8535.668945,-5514.596191,15223.802734,-11716.105469,...,-1865.614136,-7457.402832,51.82243,7706.48877,2788.309326,-10281.602539,11023.239258,7347.61084,15699.083984,Rest


In [70]:
df.head(86277)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,55,56,57,58,59,60,61,62,63,64
0,-6007.68457,12510.444336,7844.45166,-4246.978027,-3228.525146,-103.249702,8590.283203,-5608.668945,15386.517578,-11876.93457,...,-2211.25415,-7136.249023,43.948391,7693.819824,2708.186035,-10187.832031,10895.167969,7435.178711,15664.952148,Begin_Feet
1,-6012.343262,12505.867188,7834.443359,-4234.376953,-3221.28418,-58.525509,8586.755859,-5612.939453,15399.620117,-11870.067383,...,-2209.654785,-7135.818359,42.244194,7694.793945,2700.129883,-10194.820312,10893.699219,7429.57959,15661.667969,Feet
2,-6004.281738,12508.5,7836.693359,-4235.759766,-3225.619385,-89.742882,8589.950195,-5611.961426,15399.019531,-11871.919922,...,-2206.152344,-7133.556641,45.26141,7695.370605,2701.121582,-10193.835938,10892.120117,7430.616699,15662.744141,Feet
3,-6006.712891,12511.026367,7839.393555,-4254.118164,-3237.662842,-97.241196,8592.822266,-5606.171387,15382.796875,-11880.046875,...,-2205.669434,-7130.776367,46.148624,7693.93457,2701.351074,-10195.214844,10894.549805,7432.617188,15667.950195,Feet
4,-6015.308594,12510.445312,7840.519043,-4254.352539,-3237.874268,-128.30748,8591.375,-5594.76123,15381.959961,-11880.235352,...,-2208.33374,-7131.682617,45.585205,7698.09082,2702.700195,-10194.633789,10893.917969,7433.499512,15667.558594,Feet
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
86272,-7947.979492,12310.938477,7300.790527,-4155.20166,-3016.442627,143.286392,8194.839844,-5476.609375,14971.160156,-11585.59082,...,-1267.154785,-8254.25,-468.513611,7737.567383,2609.070557,-10310.354492,11188.512695,7139.75,15612.225586,Rest
86273,-7956.728516,12311.724609,7300.694336,-4157.084473,-3019.311035,146.885925,8196.986328,-5477.582031,14968.926758,-11585.797852,...,-1267.393799,-8251.016602,-463.748169,7735.005371,2611.488525,-10310.40625,11187.453125,7139.99707,15614.25,Rest
86274,-7955.460449,12314.120117,7301.058594,-4155.061523,-3020.093262,138.646683,8201.264648,-5476.589355,14971.043945,-11588.185547,...,-1264.507202,-8247.374023,-463.993561,7735.862793,2611.591064,-10316.472656,11188.148438,7143.306152,15617.170898,Rest
86275,-7945.963867,12318.460938,7304.883789,-4156.449707,-3019.900879,141.031204,8205.243164,-5473.765137,14969.541992,-11589.478516,...,-1268.526611,-8244.618164,-459.132599,7734.895508,2614.547363,-10317.243164,11189.810547,7147.561035,15620.617188,End_Hand
