# EEG classification

## Imports

In [4]:
import sys
sys.path.append("../itershap")

import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

from itershap import IterSHAP

## Get EEG data

In [8]:
eeg_columns = ['Coeffiecient of Variation','Mean of Vertex to Vertex Slope','Variance of Vertex to Vertex Slope',
         'Hjorth_Activity','Hjorth_Mobility','Hjorth_Complexity',
         'Kurtosis','2nd Difference Mean','2nd Difference Max',
         'Skewness','1st Difference Mean','1st Difference Max',
         'FFT Delta MaxPower','FFT Theta MaxPower','FFT Alpha MaxPower','FFT Beta MaxPower','Delta/Theta','Delta/Alpha','Theta/Alpha','(Delta+Theta)/Alpha',
         '1Wavelet Approximate Mean', '1Wavelet Approximate Std Deviation', '1Wavelet Approximate Energy', '1Wavelet Detailed Mean', '1Wavelet Detailed Std Deviation', '1Wavelet Detailed Energy', '1Wavelet Approximate Entropy', '1Wavelet Detailed Entropy', '2Wavelet Approximate Mean', '2Wavelet Approximate Std Deviation', '2Wavelet Approximate Energy', '2Wavelet Detailed Mean', '2Wavelet Detailed Std Deviation', '2Wavelet Detailed Energy', '2Wavelet Approximate Entropy', '2Wavelet Detailed Entropy', '3Wavelet Approximate Mean', '3Wavelet Approximate Std Deviation', '3Wavelet Approximate Energy', '3Wavelet Detailed Mean', '3Wavelet Detailed Std Deviation', '3Wavelet Detailed Energy', '3Wavelet Approximate Entropy', '3Wavelet Detailed Entropy', '4Wavelet Approximate Mean', '4Wavelet Approximate Std Deviation', '4Wavelet Approximate Energy', '4Wavelet Detailed Mean', '4Wavelet Detailed Std Deviation', '4Wavelet Detailed Energy', '4Wavelet Approximate Entropy', '4Wavelet Detailed Entropy', '5Wavelet Approximate Mean', '5Wavelet Approximate Std Deviation', '5Wavelet Approximate Energy', '5Wavelet Detailed Mean', '5Wavelet Detailed Std Deviation', '5Wavelet Detailed Energy', '5Wavelet Approximate Entropy', '5Wavelet Detailed Entropy', '6Wavelet Approximate Mean', '6Wavelet Approximate Std Deviation', '6Wavelet Approximate Energy', '6Wavelet Detailed Mean', '6Wavelet Detailed Std Deviation', '6Wavelet Detailed Energy', '6Wavelet Approximate Entropy', '6Wavelet Detailed Entropy', '7Wavelet Approximate Mean', '7Wavelet Approximate Std Deviation', '7Wavelet Approximate Energy', '7Wavelet Detailed Mean', '7Wavelet Detailed Std Deviation', '7Wavelet Detailed Energy', '7Wavelet Approximate Entropy', '7Wavelet Detailed Entropy', '8Wavelet Approximate Mean', '8Wavelet Approximate Std Deviation', '8Wavelet Approximate Energy', '8Wavelet Detailed Mean', '8Wavelet Detailed Std Deviation', '8Wavelet Detailed Energy', '8Wavelet Approximate Entropy', '8Wavelet Detailed Entropy', '9Wavelet Approximate Mean', '9Wavelet Approximate Std Deviation', '9Wavelet Approximate Energy', '9Wavelet Detailed Mean', '9Wavelet Detailed Std Deviation', '9Wavelet Detailed Energy', '9Wavelet Approximate Entropy', '9Wavelet Detailed Entropy', '10Wavelet Approximate Mean', '10Wavelet Approximate Std Deviation', '10Wavelet Approximate Energy', '10Wavelet Detailed Mean', '10Wavelet Detailed Std Deviation', '10Wavelet Detailed Energy', '10Wavelet Approximate Entropy', '10Wavelet Detailed Entropy', '11Wavelet Approximate Mean', '11Wavelet Approximate Std Deviation', '11Wavelet Approximate Energy', '11Wavelet Detailed Mean', '11Wavelet Detailed Std Deviation', '11Wavelet Detailed Energy', '11Wavelet Approximate Entropy', '11Wavelet Detailed Entropy', '12Wavelet Approximate Mean', '12Wavelet Approximate Std Deviation', '12Wavelet Approximate Energy', '12Wavelet Detailed Mean', '12Wavelet Detailed Std Deviation', '12Wavelet Detailed Energy', '12Wavelet Approximate Entropy', '12Wavelet Detailed Entropy', '13Wavelet Approximate Mean', '13Wavelet Approximate Std Deviation', '13Wavelet Approximate Energy', '13Wavelet Detailed Mean', '13Wavelet Detailed Std Deviation', '13Wavelet Detailed Energy', '13Wavelet Approximate Entropy', '13Wavelet Detailed Entropy', '14Wavelet Approximate Mean', '14Wavelet Approximate Std Deviation', '14Wavelet Approximate Energy', '14Wavelet Detailed Mean', '14Wavelet Detailed Std Deviation', '14Wavelet Detailed Energy', '14Wavelet Approximate Entropy', '14Wavelet Detailed Entropy',
         'AR1','AR2','AR3','AR4','AR5','AR6','AR7','AR8','AR9','AR10','AR11','AR12','AR13','AR14','AR15','AR16','AR17','AR18',
         'AR19','AR20','AR21','AR22','AR23','AR24','AR25','AR26','AR27','AR28','AR29','AR30','AR31','AR32','AR33','AR34','AR35','AR36','AR37','AR38','AR39','AR40','AR41','AR42']


# Data source: https://github.com/ivishalanand/Cognitive-Mental-Workload-Estimation-Using-ML
def get_eeg_data(PERC_DP_USED):
    filepath = "../data/eeg/Normalizedfeatures.csv"
    f = open(filepath)
    attributes=f.readline()
    X = []
    y = []
    for line in f:
        line = line.rstrip().split(',')
        l = [float(i) for i in line]
        X.append(l[:-1])
        y.append(l[-1])

    X = np.asarray(X)
    y = np.asarray([round(k-1) for k in y])
    X = pd.DataFrame(X, columns=eeg_columns)
    if PERC_DP_USED < 1.0:
      # If PERC_DP_USED == 1.0, then all DP will be used for the model
      X, X_not_used, y, y_not_used = train_test_split(X, y, train_size=PERC_DP_USED, random_state=20)

    return X, y

# get_eeg_data(1.00)

## Run IterSHAP

In [None]:
PERC_DATA_USED = 0.01

# Load data from the data folder
X, y = get_eeg_data(PERC_DATA_USED)

# Create a data copy to test model performance without feature selection 
X_without_fs = pd.DataFrame(X)

# Check the current shape of the dataset
print(X.shape)

# Create and fit IterSHAP using a RandomForestClassifier (default)
itershap_fs = IterSHAP()
itershap_fs.fit(X, y)

# Transform the input data to only include selected features and print its shape
X = itershap_fs.transform()
print(X.shape)

## Run model with and without feature selection

#### Without feature selection

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X_without_fs, y, test_size=0.25)

clf = RandomForestClassifier()
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

accuracy = accuracy_score(y_pred, y_test)
print(accuracy)

#### With feature selection

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)

clf = RandomForestClassifier()
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

accuracy = accuracy_score(y_pred, y_test)
print(accuracy)