In [1]:
from utils import *
from dataset import *
from constants import *
from sklearn.svm import SVC
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier

from dataset import Participant

#### Select participant and session (session only for plot purposes)

In [2]:
participant = Participant('s12', data_path=DATA_PATH_NOTEBOOK)

In [3]:
print(f'Number of sessions: {len(participant.sessions)}')
print(f'Number of channels: {len(participant.channels)}')
print(f'Number of relevant channels: {len(participant.relevant_channels)}')
print(f'The relevant channels are located in the following locations:')
print([participant.channels_locations[i] for i in [channel.idx for channel in participant.relevant_channels]])

Number of sessions: 2
Number of channels: 186
Number of relevant channels: 29
The relevant channels are located in the following locations:
[np.str_('insula'), np.str_('postcentral'), np.str_('postcentral'), np.str_('supramarginal'), np.str_('supramarginal'), np.str_('postcentral'), np.str_('superiorfrontal'), np.str_('insula'), np.str_('insula'), np.str_('insula'), np.str_('WM_insula'), np.str_('posteriorcingulate'), np.str_('posteriorcingulate'), np.str_('posteriorcingulate'), np.str_('posteriorcingulate'), np.str_('WM_paracentral'), np.str_('precentral'), np.str_('precentral'), np.str_('precentral'), np.str_('posteriorcingulate'), np.str_('posteriorcingulate'), np.str_('posteriorcingulate'), np.str_('paracentral'), np.str_('insula'), np.str_('insula'), np.str_('insula'), np.str_('superiorparietal'), np.str_('superiorparietal'), np.str_('superiorparietal')]


#### Get participant's features

In [4]:
features = participant.get_features_all_sessions_ExObs()

100%|██████████| 29/29 [00:01<00:00, 15.47it/s]
100%|██████████| 29/29 [00:01<00:00, 18.87it/s]
100%|██████████| 29/29 [00:01<00:00, 21.98it/s]
100%|██████████| 29/29 [00:01<00:00, 22.12it/s]
100%|██████████| 29/29 [00:02<00:00, 11.02it/s]
100%|██████████| 29/29 [00:01<00:00, 17.76it/s]
100%|██████████| 29/29 [00:01<00:00, 21.14it/s]
100%|██████████| 29/29 [00:02<00:00, 12.40it/s]
100%|██████████| 29/29 [00:01<00:00, 18.25it/s]
100%|██████████| 29/29 [00:01<00:00, 16.46it/s]
100%|██████████| 29/29 [00:01<00:00, 21.73it/s]]
100%|██████████| 29/29 [00:01<00:00, 16.65it/s]]
100%|██████████| 29/29 [00:01<00:00, 20.53it/s]]
100%|██████████| 29/29 [00:01<00:00, 20.61it/s]]
100%|██████████| 29/29 [00:01<00:00, 15.38it/s]]
100%|██████████| 29/29 [00:01<00:00, 20.69it/s]]
100%|██████████| 29/29 [00:01<00:00, 14.66it/s]]
100%|██████████| 29/29 [00:02<00:00, 13.89it/s]]
100%|██████████| 29/29 [00:01<00:00, 24.67it/s]]
100%|██████████| 29/29 [00:01<00:00, 23.57it/s]]
100%|██████████| 29/29 [00:01<

Compute PSD for each baselines and activities, then mean it

In [5]:
print(f'The dataset contains {features.shape[0]} samples and {features.shape[1]} features.')

The dataset contains 257 samples and 2089 features.


## Train a model (SVM)

#### Start without any dimensionality reduction

In [6]:
X = features.drop('label', axis=1)
y = features['label']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Normalize features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Train SVM
parameters = {'C': [0.1, 1, 10, 100, 1000], 'kernel': ['linear', 'rbf', 'sigmoid']}
svm = SVC()
clf = GridSearchCV(svm, parameters)
clf.fit(X_train, y_train)
print(clf.best_params_)

# Test SVM
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")

{'C': 0.1, 'kernel': 'linear'}
Accuracy: 1.00


#### With PCA

In [7]:
X = features.drop('label', axis=1)
y = features['label']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Normalize features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

pca = PCA(n_components=100)
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)

# Train SVM
parameters = {'C': [0.1, 1, 10, 100, 1000], 'kernel': ['linear', 'rbf', 'sigmoid']}
svm = SVC()
clf = GridSearchCV(svm, parameters)
clf.fit(X_train_pca, y_train)
print(clf.best_params_)

# Test SVM
y_pred = clf.predict(X_test_pca)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")

{'C': 0.1, 'kernel': 'linear'}
Accuracy: 0.98


## Train a model (Random Forest)

In [8]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Normalize features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Train Random Forest
n_estimators = [10, 50, 90, 130]
max_depth = [10, 25, 50]
param_grid = {'n_estimators': n_estimators, 'max_depth': max_depth}

rf = RandomForestClassifier() 
clf = GridSearchCV(rf, param_grid)
clf.fit(X_train, y_train)
print(clf.best_params_)

# Test Random Forest
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")


  _data = np.array(data, dtype=dtype, copy=copy,


{'max_depth': 25, 'n_estimators': 50}
Accuracy: 0.98


## Train model with one frequency band

In [9]:
for band, (high, low) in FREQ_BANDS.items() :
    print('**********************************************************')
    print('**********************************************************')
    print(f'Band: {band}')
    
    freq_band = {band: (high, low)}
    features = participant.get_features_all_sessions_ExObs(freq_band=freq_band)
    
    X = features.drop('label', axis=1)
    y = features['label']

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    
    # Normalize features
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)
    
    # SVM
    print('**********************************************************')
    print('SVM')
    
    # Train SVM
    parameters = {'C': [0.1, 1, 10, 100, 1000], 'kernel': ['linear', 'rbf', 'sigmoid']}
    svm = SVC()
    clf = GridSearchCV(svm, parameters)
    clf.fit(X_train, y_train)
    print(clf.best_params_)

    # Test SVM
    y_pred = clf.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    print(f"Accuracy: {accuracy:.2f}")
    
    # PCA SVM
    print('**********************************************************')
    print('PCA SVM')
    
    pca = PCA(n_components=100)
    X_train_pca = pca.fit_transform(X_train)
    X_test_pca = pca.transform(X_test)

    # Train SVM
    parameters = {'C': [0.1, 1, 10, 100, 1000], 'kernel': ['linear', 'rbf', 'sigmoid']}
    svm = SVC()
    clf = GridSearchCV(svm, parameters)
    clf.fit(X_train_pca, y_train)
    print(clf.best_params_)

    # Test SVM
    y_pred = clf.predict(X_test_pca)
    accuracy = accuracy_score(y_test, y_pred)
    print(f"Accuracy: {accuracy:.2f}")
    
    # RANDOM FOREST
    print('**********************************************************')
    print('RANDOM FOREST')
    
    # Train Random Forest
    n_estimators = [10, 50, 90, 130]
    max_depth = [10, 25, 50]
    param_grid = {'n_estimators': n_estimators, 'max_depth': max_depth}

    rf = RandomForestClassifier() 
    clf = GridSearchCV(rf, param_grid)
    clf.fit(X_train, y_train)
    print(clf.best_params_)

    # Test Random Forest
    y_pred = clf.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    print(f"Accuracy: {accuracy:.2f}")
    


**********************************************************
**********************************************************
Band: Delta


100%|██████████| 29/29 [00:00<00:00, 35.26it/s]
100%|██████████| 29/29 [00:00<00:00, 46.72it/s]
100%|██████████| 29/29 [00:00<00:00, 51.01it/s]
100%|██████████| 29/29 [00:00<00:00, 43.87it/s]
100%|██████████| 29/29 [00:01<00:00, 23.67it/s]
100%|██████████| 29/29 [00:00<00:00, 47.46it/s]
100%|██████████| 29/29 [00:00<00:00, 54.22it/s]
100%|██████████| 29/29 [00:01<00:00, 27.33it/s]
100%|██████████| 29/29 [00:00<00:00, 41.55it/s]
100%|██████████| 29/29 [00:00<00:00, 34.73it/s]
100%|██████████| 29/29 [00:00<00:00, 61.77it/s]]
100%|██████████| 29/29 [00:00<00:00, 44.49it/s]]
100%|██████████| 29/29 [00:00<00:00, 49.11it/s]]
100%|██████████| 29/29 [00:00<00:00, 53.16it/s]]
100%|██████████| 29/29 [00:00<00:00, 34.48it/s]]
100%|██████████| 29/29 [00:00<00:00, 50.23it/s]]
100%|██████████| 29/29 [00:00<00:00, 33.22it/s]]
100%|██████████| 29/29 [00:00<00:00, 30.41it/s]]
100%|██████████| 29/29 [00:00<00:00, 53.77it/s]]
100%|██████████| 29/29 [00:00<00:00, 58.79it/s]]
100%|██████████| 29/29 [00:00<

**********************************************************
SVM
{'C': 0.1, 'kernel': 'sigmoid'}
Accuracy: 0.87
**********************************************************
PCA SVM
{'C': 0.1, 'kernel': 'sigmoid'}
Accuracy: 0.87
**********************************************************
RANDOM FOREST
{'max_depth': 10, 'n_estimators': 50}
Accuracy: 0.90
**********************************************************
**********************************************************
Band: Theta


100%|██████████| 29/29 [00:00<00:00, 99.96it/s] 
100%|██████████| 29/29 [00:00<00:00, 120.92it/s]
100%|██████████| 29/29 [00:00<00:00, 135.59it/s]
100%|██████████| 29/29 [00:00<00:00, 134.45it/s]
100%|██████████| 29/29 [00:00<00:00, 65.55it/s]
100%|██████████| 29/29 [00:00<00:00, 111.79it/s]
100%|██████████| 29/29 [00:00<00:00, 124.76it/s]
100%|██████████| 29/29 [00:00<00:00, 73.76it/s]
100%|██████████| 29/29 [00:00<00:00, 109.37it/s]
100%|██████████| 29/29 [00:00<00:00, 98.88it/s]
100%|██████████| 29/29 [00:00<00:00, 136.42it/s]
100%|██████████| 29/29 [00:00<00:00, 113.62it/s]
100%|██████████| 29/29 [00:00<00:00, 123.89it/s]
100%|██████████| 29/29 [00:00<00:00, 132.49it/s]
100%|██████████| 29/29 [00:00<00:00, 91.69it/s]]
100%|██████████| 29/29 [00:00<00:00, 119.26it/s]
100%|██████████| 29/29 [00:00<00:00, 89.80it/s]]
100%|██████████| 29/29 [00:00<00:00, 86.30it/s]]
100%|██████████| 29/29 [00:00<00:00, 142.55it/s]
100%|██████████| 29/29 [00:00<00:00, 133.02it/s]
100%|██████████| 29/29 

**********************************************************
SVM
{'C': 1, 'kernel': 'sigmoid'}
Accuracy: 0.96
**********************************************************
PCA SVM
{'C': 1, 'kernel': 'sigmoid'}
Accuracy: 0.96
**********************************************************
RANDOM FOREST
{'max_depth': 50, 'n_estimators': 10}
Accuracy: 0.98
**********************************************************
**********************************************************
Band: Alpha


100%|██████████| 29/29 [00:00<00:00, 85.80it/s]
100%|██████████| 29/29 [00:00<00:00, 79.75it/s]
100%|██████████| 29/29 [00:00<00:00, 88.62it/s]
100%|██████████| 29/29 [00:00<00:00, 91.31it/s] 
100%|██████████| 29/29 [00:00<00:00, 49.62it/s]
100%|██████████| 29/29 [00:00<00:00, 74.08it/s]
100%|██████████| 29/29 [00:00<00:00, 86.99it/s]
100%|██████████| 29/29 [00:00<00:00, 55.62it/s]
100%|██████████| 29/29 [00:00<00:00, 76.26it/s]
100%|██████████| 29/29 [00:00<00:00, 85.35it/s]
100%|██████████| 29/29 [00:00<00:00, 115.93it/s]
100%|██████████| 29/29 [00:00<00:00, 102.66it/s]
100%|██████████| 29/29 [00:00<00:00, 107.70it/s]
100%|██████████| 29/29 [00:00<00:00, 111.38it/s]
100%|██████████| 29/29 [00:00<00:00, 78.93it/s]]
100%|██████████| 29/29 [00:00<00:00, 88.76it/s]]
100%|██████████| 29/29 [00:00<00:00, 64.81it/s]]
100%|██████████| 29/29 [00:00<00:00, 72.20it/s]]
100%|██████████| 29/29 [00:00<00:00, 117.95it/s]
100%|██████████| 29/29 [00:00<00:00, 108.42it/s]
100%|██████████| 29/29 [00:00

**********************************************************
SVM
{'C': 1, 'kernel': 'rbf'}
Accuracy: 0.94
**********************************************************
PCA SVM
{'C': 1, 'kernel': 'rbf'}
Accuracy: 0.94
**********************************************************
RANDOM FOREST
{'max_depth': 10, 'n_estimators': 90}
Accuracy: 0.98
**********************************************************
**********************************************************
Band: Beta


100%|██████████| 29/29 [00:00<00:00, 96.13it/s]
100%|██████████| 29/29 [00:00<00:00, 111.59it/s]
100%|██████████| 29/29 [00:00<00:00, 121.28it/s]
100%|██████████| 29/29 [00:00<00:00, 123.03it/s]
100%|██████████| 29/29 [00:00<00:00, 65.58it/s]
100%|██████████| 29/29 [00:00<00:00, 103.59it/s]
100%|██████████| 29/29 [00:00<00:00, 112.59it/s]
100%|██████████| 29/29 [00:00<00:00, 72.05it/s]
100%|██████████| 29/29 [00:00<00:00, 103.69it/s]
100%|██████████| 29/29 [00:00<00:00, 89.15it/s]
100%|██████████| 29/29 [00:00<00:00, 129.10it/s]
100%|██████████| 29/29 [00:00<00:00, 112.32it/s]
100%|██████████| 29/29 [00:00<00:00, 116.18it/s]
100%|██████████| 29/29 [00:00<00:00, 113.67it/s]
100%|██████████| 29/29 [00:00<00:00, 85.23it/s]]
100%|██████████| 29/29 [00:00<00:00, 115.04it/s]
100%|██████████| 29/29 [00:00<00:00, 85.24it/s]]
100%|██████████| 29/29 [00:00<00:00, 80.45it/s]]
100%|██████████| 29/29 [00:00<00:00, 131.86it/s]
100%|██████████| 29/29 [00:00<00:00, 123.50it/s]
100%|██████████| 29/29 [

**********************************************************
SVM
{'C': 10, 'kernel': 'rbf'}
Accuracy: 0.94
**********************************************************
PCA SVM
{'C': 10, 'kernel': 'rbf'}
Accuracy: 0.94
**********************************************************
RANDOM FOREST
{'max_depth': 10, 'n_estimators': 50}
Accuracy: 0.96
**********************************************************
**********************************************************
Band: Gamma


100%|██████████| 29/29 [00:00<00:00, 82.86it/s]
100%|██████████| 29/29 [00:00<00:00, 100.41it/s]
100%|██████████| 29/29 [00:00<00:00, 104.12it/s]
100%|██████████| 29/29 [00:00<00:00, 103.27it/s]
100%|██████████| 29/29 [00:00<00:00, 62.55it/s]
100%|██████████| 29/29 [00:00<00:00, 108.82it/s]
100%|██████████| 29/29 [00:00<00:00, 126.04it/s]
100%|██████████| 29/29 [00:00<00:00, 69.98it/s]
100%|██████████| 29/29 [00:00<00:00, 107.65it/s]
100%|██████████| 29/29 [00:00<00:00, 90.75it/s]
100%|██████████| 29/29 [00:00<00:00, 130.32it/s]
100%|██████████| 29/29 [00:00<00:00, 105.32it/s]
100%|██████████| 29/29 [00:00<00:00, 121.41it/s]
100%|██████████| 29/29 [00:00<00:00, 125.27it/s]
100%|██████████| 29/29 [00:00<00:00, 85.19it/s]]
100%|██████████| 29/29 [00:00<00:00, 117.55it/s]
100%|██████████| 29/29 [00:00<00:00, 88.04it/s]]
100%|██████████| 29/29 [00:00<00:00, 79.26it/s]]
100%|██████████| 29/29 [00:00<00:00, 137.04it/s]
100%|██████████| 29/29 [00:00<00:00, 133.88it/s]
100%|██████████| 29/29 [

**********************************************************
SVM
{'C': 0.1, 'kernel': 'linear'}
Accuracy: 1.00
**********************************************************
PCA SVM
{'C': 0.1, 'kernel': 'linear'}
Accuracy: 1.00
**********************************************************
RANDOM FOREST
{'max_depth': 10, 'n_estimators': 10}
Accuracy: 1.00
**********************************************************
**********************************************************
Band: HighGamma


100%|██████████| 29/29 [00:00<00:00, 90.88it/s]
100%|██████████| 29/29 [00:00<00:00, 109.47it/s]
100%|██████████| 29/29 [00:00<00:00, 131.26it/s]
100%|██████████| 29/29 [00:00<00:00, 132.34it/s]
100%|██████████| 29/29 [00:00<00:00, 67.38it/s]
100%|██████████| 29/29 [00:00<00:00, 96.55it/s]
100%|██████████| 29/29 [00:00<00:00, 129.29it/s]
100%|██████████| 29/29 [00:00<00:00, 74.69it/s]
100%|██████████| 29/29 [00:00<00:00, 100.49it/s]
100%|██████████| 29/29 [00:00<00:00, 98.72it/s]
100%|██████████| 29/29 [00:00<00:00, 132.68it/s]
100%|██████████| 29/29 [00:00<00:00, 117.12it/s]
100%|██████████| 29/29 [00:00<00:00, 112.11it/s]
100%|██████████| 29/29 [00:00<00:00, 129.28it/s]
100%|██████████| 29/29 [00:00<00:00, 90.55it/s]]
100%|██████████| 29/29 [00:00<00:00, 113.91it/s]
100%|██████████| 29/29 [00:00<00:00, 86.99it/s]]
100%|██████████| 29/29 [00:00<00:00, 86.18it/s]]
100%|██████████| 29/29 [00:00<00:00, 136.34it/s]
100%|██████████| 29/29 [00:00<00:00, 121.69it/s]
100%|██████████| 29/29 [0

**********************************************************
SVM
{'C': 0.1, 'kernel': 'linear'}
Accuracy: 0.98
**********************************************************
PCA SVM
{'C': 10, 'kernel': 'rbf'}
Accuracy: 0.98
**********************************************************
RANDOM FOREST
{'max_depth': 50, 'n_estimators': 130}
Accuracy: 0.98
