In [51]:
from sklearn.pipeline import FeatureUnion
from mne.decoding import Vectorizer
from pyriemann.estimation import ERPCovariances, XdawnCovariances, Covariances
from sklearn.feature_selection import SelectKBest
from mne.decoding import CSP
from data_loaders import load_data_labels_based_on_dataset
import pandas as pd
from scipy import signal
from share import datasets_basic_infos
import mne
from pyriemann.tangentspace import TangentSpace
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
import numpy as np

threshold_for_bug = 0.00000001  # could be any value, ex numpy.min

ROOT_VOTING_SYSTEM_PATH: str = "/Users/rosit/Documents/workprojects/bci_complete/voting_system_platform"

dataset_name = 'aguilera_gamified'
subject_id = 1

# Folders and paths
dataset_foldername = dataset_name + "_dataset"
computer_root_path = ROOT_VOTING_SYSTEM_PATH + "/Datasets/"
data_path = computer_root_path + dataset_foldername
print(data_path)
dataset_info: dict = datasets_basic_infos[dataset_name]
epochs, data, labels = load_data_labels_based_on_dataset(dataset_info, subject_id, data_path)
data[data < threshold_for_bug] = threshold_for_bug # To avoid the error "SVD did not convergence"

/Users/rosit/Documents/workprojects/bci_complete/voting_system_platform/Datasets/aguilera_gamified_dataset
EEG channel type selected for re-referencing
Applying a custom ('EEG',) reference.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 1.4e+02 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 140.00 Hz
- Upper transition bandwidth: 35.00 Hz (-6 dB cutoff frequency: 157.50 Hz)
- Filter length: 3301 samples (6.602 s)


  raw.set_montage(mne.channels.read_custom_montage(channel_location))
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 59 - 61 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 59.35
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 59.10 Hz)
- Upper passband edge: 60.65 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 60.90 Hz)
- Filter length: 3301 samples (6.602 s)
Filtering raw data in 1 contiguous segment
Setting up high-pass filter at 1 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Filter length: 1651 samples

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


Fitting ICA to data using 24 channels (please be patient, this may take a while)
Selecting by number: 15 components
Fitting ICA took 3.0s.
Effective window size : 4.096 (s)
Using EOG channels: FP1, FP2
EOG channel index for this subject is: [0 1]
Filtering the data to remove DC offset to help distinguish blinks from saccades
Selecting channel FP1 for blink detection
Setting up band-pass filter from 1 - 10 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed frequency-domain design (firwin2) method
- Hann window
- Lower passband edge: 1.00
- Lower transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 0.75 Hz)
- Upper passband edge: 10.00 Hz
- Upper transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 10.25 Hz)
- Filter length: 5000 samples (10.000 s)
Now detecting blinks and generating corresponding events
Found 268 significant peaks
Number of EOG events detected: 268
Not setting metadata
268 ma

  0%|          | Creating augmented epochs : 0/24 [00:00<?,       ?it/s]

  0%|          | Computing thresholds ... : 0/24 [00:00<?,       ?it/s]

  0%|          | Repairing epochs : 0/121 [00:00<?,       ?it/s]

  0%|          | n_interp : 0/3 [00:00<?,       ?it/s]

  0%|          | Repairing epochs : 0/121 [00:00<?,       ?it/s]

  0%|          | Fold : 0/10 [00:00<?,       ?it/s]

  0%|          | Repairing epochs : 0/121 [00:00<?,       ?it/s]

  0%|          | Fold : 0/10 [00:00<?,       ?it/s]

  0%|          | Repairing epochs : 0/121 [00:00<?,       ?it/s]

  0%|          | Fold : 0/10 [00:00<?,       ?it/s]





Estimated consensus=0.50 and n_interpolate=1


  0%|          | Repairing epochs : 0/121 [00:00<?,       ?it/s]

Dropped 21 epochs: 16, 33, 40, 43, 49, 53, 64, 66, 73, 78, 84, 85, 92, 95, 96, 97, 104, 105, 114, 119, 120


In [54]:
x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)


features: dict = {
    "Vectorizer": Pipeline([("Vectorizer", Vectorizer())]),
    "ERPcova": Pipeline([("ERPcova", ERPCovariances(estimator='oas')), ("ts", TangentSpace())]), # Add TangentSpace, otherwise the dimensions are not 2D.
    "XdawnCova": Pipeline([("XdawnCova", XdawnCovariances(estimator='oas')), ("ts", TangentSpace())]), # Add TangentSpace, otherwise the dimensions are not 2D.
    "CSP": Pipeline([("Vectorizer", CSP(n_components=4, reg=None, log=True, norm_trace=False))]),
    "Cova": Pipeline([("Cova", Covariances()), ("ts", TangentSpace())]), # Add TangentSpace, otherwise the dimensions are not 2D.
            }
frequency_ranges: dict = {
    "complete": [0, 140],
    "delta": [0, 3],
    "theta": [3, 7],
    "alpha": [7, 13],
    "beta 1": [13, 16],
    "beta 2": [16, 20],
    "beta 3": [20, 35],
    "gamma": [35, 140],
}


features_df = pd.DataFrame()
for feature_name, feature_method in features.items():
    for frequency_bandwidth_name, frequency_bandwidth in frequency_ranges.items():
        print(frequency_bandwidth)
        iir_params = dict(order=8, ftype="butter")
        filt = mne.filter.create_filter(
            x_train, dataset_info['sample_rate'], l_freq=frequency_bandwidth[0], h_freq=frequency_bandwidth[1], method="iir", iir_params=iir_params, verbose=True
        )
        filtered = signal.sosfiltfilt(filt["sos"], x_train)
        
        X_features = feature_method.fit_transform(filtered, y_train)
        print("Combined space has", X_features.shape[1], "features")
        column_name = [f'{frequency_bandwidth_name}_{feature_name}_{num}' for num in range(0, X_features.shape[1])]
        temp_features_df = pd.DataFrame(X_features, columns=column_name)
        features_df = pd.concat([features_df, temp_features_df], axis=1)

[0, 140]
Setting up low-pass filter at 1.4e+02 Hz

IIR filter parameters
---------------------
Butterworth lowpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoff at 140.00 Hz: -6.02 dB


ValueError: Found array with dim 3. MinMaxScaler expected <= 2.

# Select characteristics

In [41]:
from sklearn.feature_selection import SelectKBest
from sklearn.feature_selection import f_classif

features_df.shape

(84, 197228)

In [42]:
X_SelectKBest = SelectKBest(f_classif, k=10000)
X_new = X_SelectKBest.fit_transform(features_df, y_train)
X_new.shape

(84, 10000)

In [43]:
columns_list = X_SelectKBest.get_feature_names_out()
columns_list

array(['complete_Vectorizer_967', 'complete_Vectorizer_971',
       'complete_Vectorizer_1008', ..., 'complete_Cova_138',
       'complete_Cova_234', 'complete_Cova_281'], dtype=object)

In [45]:
X_new_df = pd.DataFrame(X_new, columns=columns_list)
#X_new_df.to_csv(ROOT_VOTING_SYSTEM_PATH + f"/Results/features_Vectorizer_and_CSP_{dataset_name}.csv", index=False)
X_new_df

Unnamed: 0,complete_Vectorizer_967,complete_Vectorizer_971,complete_Vectorizer_1008,complete_Vectorizer_1009,complete_Vectorizer_1028,complete_Vectorizer_1093,complete_Vectorizer_1414,complete_Vectorizer_1462,complete_Vectorizer_1463,complete_Vectorizer_1464,...,beta 3_CSP_0,beta 3_CSP_2,beta 3_CSP_3,complete_Cova_12,complete_Cova_14,complete_Cova_18,complete_Cova_35,complete_Cova_138,complete_Cova_234,complete_Cova_281
0,0.388619,0.174608,0.238878,1.252703e-01,0.015921,0.000014,3.063442e+00,1.533769e+00,1.332792e+00,1.085726e+00,...,-1.969238,-0.554704,-0.080776,0.085728,-0.229480,-0.105350,-0.204235,-0.256018,0.935951,0.030557
1,-0.001111,-0.000818,-0.000094,6.198250e-06,-0.000725,-0.010537,1.298304e-02,4.996582e-03,1.146699e-03,-4.204899e-03,...,-2.153030,-0.485715,-1.003320,-0.203215,0.071245,0.019106,0.001073,-0.081053,-0.123643,0.460803
2,0.269030,0.053549,0.325844,4.076216e-01,0.335259,0.251690,7.627464e-01,1.709923e-03,-5.759997e-02,1.640157e-01,...,-0.564486,3.312458,-0.272012,-0.076677,0.310052,0.038013,0.042998,-0.082986,2.906098,-0.153944
3,-0.000034,-0.000015,0.000001,9.012330e-07,-0.000083,-0.000154,9.999999e-09,1.000974e-08,1.000112e-08,9.985120e-09,...,-1.652073,-1.380654,0.404118,0.248350,0.070930,-0.059967,0.259227,-0.138710,0.136986,0.030523
4,-0.007818,-0.005944,-0.000493,1.117791e-03,0.032792,0.017865,2.078882e-06,-1.559640e-02,-5.037876e-02,4.972813e-02,...,-2.203559,0.008424,-0.509605,0.443731,-0.051345,-0.033492,-0.365832,-0.148214,0.578689,-0.148915
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
79,-0.021693,0.000086,0.226316,2.829980e-01,0.233483,0.328561,5.478628e-01,3.826861e-01,4.731290e-01,9.772583e-01,...,-2.104256,-1.165099,0.352150,0.107375,0.012635,0.063652,-0.149026,-0.086537,-0.564139,0.164412
80,0.000774,-0.006757,0.083303,6.410674e-02,-0.000113,0.006393,8.979927e-02,1.121637e-02,-2.587531e-02,-1.160922e-02,...,-1.244687,-0.733689,1.059212,0.288433,0.029791,-0.028178,0.055336,0.071149,-0.196173,0.070820
81,-0.007383,0.003680,-0.004548,3.846708e-03,-0.000468,0.000082,1.652811e+00,1.391509e+00,1.328599e+00,8.247640e-01,...,-0.781909,-0.378260,0.032145,0.493051,0.176135,-0.025063,0.736571,-0.051002,-0.011910,-0.032785
82,-0.007086,0.009441,0.107235,1.160800e-01,-0.019969,0.056075,1.000000e-08,1.000006e-08,1.000037e-08,9.999736e-09,...,-1.580850,-0.301780,0.375330,-0.247301,-0.109771,0.027605,0.013604,-0.182694,0.798616,-0.189218


# Classification of train

In [46]:
from data_utils import get_best_classificator_and_test_accuracy, ClfSwitcher
classifier, acc = get_best_classificator_and_test_accuracy(X_new_df, y_train, Pipeline([('clf', ClfSwitcher())]))



Best Test Score: 
0.9880952380952381


In [47]:
print(acc)

0.9880952380952381


# Repeat for test

In [48]:
features_test_df = pd.DataFrame()

for feature_name, feature_method in features.items():
    #combined_features = FeatureUnion([])
    for frequency_bandwidth_name, frequency_bandwidth in frequency_ranges.items():
        print(frequency_bandwidth)
        iir_params = dict(order=8, ftype="butter")
        filt = mne.filter.create_filter(
            x_test, dataset_info['sample_rate'], l_freq=frequency_bandwidth[0], h_freq=frequency_bandwidth[1], method="iir", iir_params=iir_params, verbose=True
        )
        filtered = signal.sosfiltfilt(filt["sos"], x_test)
        
        X_features = feature_method.transform(filtered)
        print("Combined space has", X_features.shape[1], "features")
        column_name = [f'{frequency_bandwidth_name}_{feature_name}_{num}' for num in range(0, X_features.shape[1])]
        temp_features_df = pd.DataFrame(X_features, columns=column_name)
        features_test_df = pd.concat([features_test_df, temp_features_df], axis=1)

[0, 140]
Setting up low-pass filter at 1.4e+02 Hz

IIR filter parameters
---------------------
Butterworth lowpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoff at 140.00 Hz: -6.02 dB

Combined space has 16824 features
[0, 3]
Setting up low-pass filter at 3 Hz

IIR filter parameters
---------------------
Butterworth lowpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoff at 3.00 Hz: -6.02 dB

Combined space has 16824 features
[3, 7]
Setting up band-pass filter from 3 - 7 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 32 (effective, after forward-backward)
- Cutoffs at 3.00, 7.00 Hz: -6.02, -6.02 dB

Combined space has 16824 features
[7, 13]
Setting up band-pass filter from 7 - 13 Hz

IIR filter parameters
---------------------
Butterworth

  eigvals = operator(eigvals)
  eigvals = operator(eigvals)


Combined space has 300 features
[20, 35]
Setting up band-pass filter from 20 - 35 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 32 (effective, after forward-backward)
- Cutoffs at 20.00, 35.00 Hz: -6.02, -6.02 dB

Combined space has 300 features
[35, 140]
Setting up band-pass filter from 35 - 1.4e+02 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 32 (effective, after forward-backward)
- Cutoffs at 35.00, 140.00 Hz: -6.02, -6.02 dB

Combined space has 300 features


In [49]:
from sklearn.metrics import classification_report
y_pred = classifier.predict(features_test_df[columns_list])
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       0.40      0.33      0.36         6
           1       0.25      0.20      0.22         5
           2       0.08      0.50      0.14         2
           3       0.00      0.00      0.00         8

    accuracy                           0.19        21
   macro avg       0.18      0.26      0.18        21
weighted avg       0.18      0.19      0.17        21


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
