In [12]:
%pip install -e git+https://github.com/UN-GCPDS/python-gcpds.MI_prediction.git#egg=MI_prediction

Obtaining MI_prediction from git+https://github.com/UN-GCPDS/python-gcpds.MI_prediction.git#egg=MI_prediction
  Updating ./src/mi-prediction clone
  Running command git fetch -q --tags
  Running command git reset --hard -q a95f4e7dfa4666c4c5cdfa244d912f7f12ab0f27
  Preparing metadata (setup.py) ... [?25ldone
Installing collected packages: MI_prediction
  Attempting uninstall: MI_prediction
    Found existing installation: MI-prediction 0.1
    Uninstalling MI-prediction-0.1:
      Successfully uninstalled MI-prediction-0.1
  Running setup.py develop for MI_prediction
Successfully installed MI_prediction-0.1
Note: you may need to restart the kernel to use updated packages.


In [1]:
from MI_prediction.Utils.Datasets import DataLoader, DataLoader_Rest, get_epochs, get_runs, get_labels
from MI_prediction.Datasets import Cho2017_Rest, BNCI2014001_Rest
from MI_prediction.Utils.Preprocess import filterbank_preprocessor, filterbank,FBCSP
from MI_prediction.Validation.Scores import get_scores_cv, get_scores
from braindecode.preprocessing.preprocess import exponential_moving_standardize, preprocess, Preprocessor, scale

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.pipeline import Pipeline
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split,StratifiedKFold,cross_val_score,GridSearchCV

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from mne.preprocessing import compute_current_source_density

import time

# Lawhern2018 sklearn

## BCI2a

### Load data

In [11]:
dl = DataLoader(dataset_name="BNCI2014001")
subjects = np.arange(1,10)

### Preprocessing

In [11]:
fb = filterbank_preprocessor([(8.,15.),(15.,25.)])

preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),  # Keep EEG sensors
        Preprocessor(scale, factor=1e6, apply_on_array=True),  # Convert from V to uV
        ]

Ch_prep = [
        Preprocessor(compute_current_source_density,copy=False,apply_on_array=False)
        ]

windows = [
        (0,0)
        ]


band to filter: (8.0, 15.0) Hz
band to filter: (15.0, 25.0) Hz


In [18]:
fs = 250.

start = -1


win = 1



In [53]:
def create_windows(win = 1, start_offset = 0, end_offset = 0, duration = 4, overlap = 0.0):
    st_offsets = []
    ed_offsets = []

    cont = start_offset + win
    while start_offset + win <= duration + end_offset:
        print(start_offset,cont-duration)
        st_offsets.append(start_offset)
        ed_offsets.append(cont-duration)

        start_offset += win*(1-overlap)
        cont += win*(1-overlap)

    return st_offsets, ed_offsets


In [54]:
create_windows(win = 0.5, start_offset = 0, end_offset = 0, duration = 4, overlap = 0.5)

0 -3.5
0.25 -3.25
0.5 -3.0
0.75 -2.75
1.0 -2.5
1.25 -2.25
1.5 -2.0
1.75 -1.75
2.0 -1.5
2.25 -1.25
2.5 -1.0
2.75 -0.75
3.0 -0.5
3.25 -0.25
3.5 0.0


([0,
  0.25,
  0.5,
  0.75,
  1.0,
  1.25,
  1.5,
  1.75,
  2.0,
  2.25,
  2.5,
  2.75,
  3.0,
  3.25,
  3.5],
 [-3.5,
  -3.25,
  -3.0,
  -2.75,
  -2.5,
  -2.25,
  -2.0,
  -1.75,
  -1.5,
  -1.25,
  -1.0,
  -0.75,
  -0.5,
  -0.25,
  0.0])

In [4]:

acc = {}

for s in subjects:
    print ("Subject: {}".format(s))
    tic = time.time()

    dl.load_data(subject_ids=[s])
    ds_f = filterbank(dl, preprocess=preprocessors, filters=fb, channels_prep=[])
    trials = [tr.get_trials(start_offset=[0], end_offset=[0]) for tr in ds_f]

    acc[str(s)] = {}

    for nw in range(len(windows)):
        epochs = [get_epochs(epoch['win_'+str(nw)].split('session')['session_T']) for epoch in trials]
        
        X = np.concatenate([np.expand_dims(trial[0],axis=3) for trial in epochs], axis=-1,dtype=np.float64)
        y = epochs[0][1]

        X, y = get_labels(X,y,[0,1])

        skf = StratifiedKFold(n_splits=4)

        fbcsp = FBCSP(norm_trace=False)
        lda = LinearDiscriminantAnalysis()

        param_grid = {
                'fbcsp__n_components': [4,6,8]
                }

        clf = Pipeline([('fbcsp', fbcsp), ('lda', lda)])
        cvs = GridSearchCV(clf,param_grid=param_grid,cv=skf,scoring='accuracy',n_jobs=-1, verbose=1)
        cvs.fit(X,y)

        mu, sig = get_scores_cv(cvs)
        mdl = cvs.best_estimator_
        mdl.fit(X,y)

        epochs_ts = [get_epochs(epoch['win_'+str(nw)].split('session')['session_E']) for epoch in trials]
        
        X_test = np.concatenate([np.expand_dims(trial[0],axis=3) for trial in epochs_ts], axis=-1,dtype=np.float64)
        y_test = epochs_ts[0][1]

        X_test, y_test = get_labels(X_test,y_test,[0,1])

        ypred = mdl.predict(X_test)
        acc_ts = accuracy_score(y_test,ypred)

        acc[str(s)]['win_'+str(nw)]={
                'acc_train': mu,
                'std_train': sig,
                'acc_test': acc_ts,
                'model': mdl
                }
        
        print("Accuracy train: {}  Accuracy test: {}  elapsed: {}".format(mu,acc_ts, time.time()-tic))


Subject: 1
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
48 events found
Event IDs: [1 2 3 4]
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 15.00 Hz
- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)
- Filter length: 413 samp



Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 15.00 Hz
- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)
- Filter length: 413 samples (1.652 sec)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 



Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 15.00 Hz
- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)
- Filter length: 413 samples (1.652 sec)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 



Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 15.00 Hz
- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)
- Filter length: 413 samples (1.652 sec)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 



Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 15.00 Hz
- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)
- Filter length: 413 samples (1.652 sec)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 



Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 15.00 Hz
- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)
- Filter length: 413 samples (1.652 sec)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 



Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 15.00 Hz
- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)
- Filter length: 413 samples (1.652 sec)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 



Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 15.00 Hz
- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)
- Filter length: 413 samples (1.652 sec)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 



Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 15.00 Hz
- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)
- Filter length: 413 samples (1.652 sec)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 



Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 15.00 Hz
- Upper transition bandwidth: 3.75 Hz (-6 dB cutoff frequency: 16.88 Hz)
- Filter length: 413 samples (1.652 sec)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 15 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 

In [5]:
get_scores(acc,'train')

Unnamed: 0,subject,window,accuracy,std
0,1,win_0,0.923611,0.041084
1,2,win_0,0.902778,0.041667
2,3,win_0,0.916667,0.085617
3,4,win_0,0.680556,0.041667
4,5,win_0,0.645833,0.053341
5,6,win_0,0.743056,0.053341
6,7,win_0,0.972222,0.027778
7,8,win_0,0.770833,0.098943
8,9,win_0,0.763889,0.082168


In [6]:
get_scores(acc,'test')

Unnamed: 0,subject,window,accuracy,std
0,1,win_0,0.965278,
1,2,win_0,0.819444,
2,3,win_0,0.888889,
3,4,win_0,0.791667,
4,5,win_0,0.604167,
5,6,win_0,0.784722,
6,7,win_0,0.798611,
7,8,win_0,0.791667,
8,9,win_0,0.854167,


In [17]:
acc

{'1': 0.9236111111111112,
 '2': 0.923611111111111,
 '3': 0.9791666666666666,
 '4': 0.6319444444444444,
 '5': 0.6458333333333334,
 '6': 0.7291666666666667,
 '7': 0.9652777777777779,
 '8': 0.7777777777777779,
 '9': 0.7638888888888888}

In [51]:
acc #(8,12) (12,38)

{'1': 0.923611111111111,
 '2': 0.9166666666666666,
 '3': 0.923611111111111,
 '4': 0.736111111111111,
 '5': 0.6388888888888888,
 '6': 0.7291666666666666,
 '7': 0.9722222222222223,
 '8': 0.7847222222222221,
 '9': 0.763888888888889}

In [42]:
acc # Con laplaciano

{'1': 0.9027777777777778,
 '2': 0.9097222222222222,
 '3': 0.9305555555555556,
 '4': 0.6805555555555556,
 '5': 0.5416666666666666,
 '6': 0.6388888888888888,
 '7': 0.9513888888888888,
 '8': 0.8680555555555556,
 '9': 0.7083333333333333}

In [39]:
acc # sin laplaciano

{'1': 0.9097222222222222,
 '2': 0.8958333333333334,
 '3': 0.9444444444444444,
 '4': 0.7013888888888888,
 '5': 0.5277777777777778,
 '6': 0.6319444444444444,
 '7': 0.9444444444444444,
 '8': 0.8680555555555556,
 '9': 0.7569444444444444}