In [7]:
import mne
import scipy.io as sp
import numpy as np
import random
import pandas as pd
import multiprocessing as mp
import concurrent.futures
from mne.decoding import CSP
import pymrmr
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.ensemble import RandomForestClassifier as RF
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import logging
from scipy.io import loadmat

In [3]:
num_channels = 64
epoch_length = 1000
sampling_freq = 250
number_of_runs = 10
# number_of_splits = 10
number_of_components = 10
number_of_selected_features = 10
number_of_processes = 15
number_of_bands = 9
# rf = pd.DataFrame()
column_names = ['participant', 'class1', 'class2','running_time','test_acc','train_acc','test_size','train_size','train_block','test_block']
# rf = rf.reindex(columns=column_names)

In [None]:
def calc_csp(x_train, y_train, x_test):
    
    csp = CSP(number_of_components)
    csp_fit = csp.fit(x_train, y_train)
    train_feat = csp_fit.transform(x_train)
    test_feat = csp_fit.transform(x_test)
    return train_feat, test_feat

In [None]:
def class_extractor(number_of_epochs, class_1, class_2, data, labels):
    size = sum(labels[:,0] == class_1) + sum(labels[:,0] == class_2)
    Final_labels = np.zeros((size,1)).astype(int)
    dataset = np.zeros((size,num_channels, epoch_length))
    index = 0
    for i in range(number_of_epochs):
        if labels[i,0] == class_1 or labels[i,0] == class_2:
            dataset[index,:,:] = data[i,:,:]
            Final_labels[index,0] = labels[i,0]
            index = index + 1
        else:
            continue
            
    return dataset, Final_labels

In [None]:
def feature_extractor(dataset, labels, number_of_bands, test_data):

    low_cutoff = 0
    
    for b in range(number_of_bands):
        logging.getLogger('mne').setLevel(logging.WARNING)
        low_cutoff += 4
        data = dataset.copy()
        data_test = test_data.copy()
        filtered_data = mne.filter.filter_data(data, sampling_freq, low_cutoff, low_cutoff + 4, verbose = False, n_jobs = 4)
        filtered_data_test = mne.filter.filter_data(test_data, sampling_freq, low_cutoff, low_cutoff + 4, verbose = False, n_jobs = 4)
        [train_feats, test_feats] = calc_csp(filtered_data, labels[:,0], filtered_data_test)
        if b == 0:
            train_features = train_feats
            test_features = test_feats
        else:
            train_features = np.concatenate((train_features, train_feats), axis = 1)
            test_features = np.concatenate((test_features, test_feats), axis = 1)
    
    return train_features, test_features

In [None]:
def feature_selector(train_features, labels, number_of_selected_features):
    X = pd.DataFrame(train_features)
    y = pd.DataFrame(labels)
    K = number_of_selected_features
    
    df = pd.concat([y,X], axis = 1)
    df.columns = df.columns.astype(str)
        
    selected_features = list(map(int, pymrmr.mRMR(df, 'MID', K)))
    return selected_features

In [43]:
def data_reader(path):

    mat = loadmat(path, chars_as_strings=True, mat_dtype=True, squeeze_me=True, struct_as_record=False, verify_compressed_data_integrity=False, variable_names=None)
    df = pd.DataFrame(mat['Data'])
    return df

In [44]:
PATH = '../../Participants/P1/'
P_NUM = 1
B_NUM = 2
data = data_reader(PATH+'P'+str(P_NUM)+'B'+str(B_NUM)+'.mat')


In [46]:
data.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,55,56,57,58,59,60,61,62,63,64
0,-6007.68457,12510.444336,7844.45166,-4246.978027,-3228.525146,-103.249702,8590.283203,-5608.668945,15386.517578,-11876.93457,...,-2211.25415,-7136.249023,43.948391,7693.819824,2708.186035,-10187.832031,10895.167969,7435.178711,15664.952148,Begin
1,-6012.343262,12505.867188,7834.443359,-4234.376953,-3221.28418,-58.525509,8586.755859,-5612.939453,15399.620117,-11870.067383,...,-2209.654785,-7135.818359,42.244194,7694.793945,2700.129883,-10194.820312,10893.699219,7429.57959,15661.667969,Feet
2,-6004.281738,12508.5,7836.693359,-4235.759766,-3225.619385,-89.742882,8589.950195,-5611.961426,15399.019531,-11871.919922,...,-2206.152344,-7133.556641,45.26141,7695.370605,2701.121582,-10193.835938,10892.120117,7430.616699,15662.744141,Feet
3,-6006.712891,12511.026367,7839.393555,-4254.118164,-3237.662842,-97.241196,8592.822266,-5606.171387,15382.796875,-11880.046875,...,-2205.669434,-7130.776367,46.148624,7693.93457,2701.351074,-10195.214844,10894.549805,7432.617188,15667.950195,Feet
4,-6015.308594,12510.445312,7840.519043,-4254.352539,-3237.874268,-128.30748,8591.375,-5594.76123,15381.959961,-11880.235352,...,-2208.33374,-7131.682617,45.585205,7698.09082,2702.700195,-10194.633789,10893.917969,7433.499512,15667.558594,Feet
