In [1]:
import numpy as np
from pyriemann.spatialfilters import CSP
from pyriemann.estimation import Covariances as Cov
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import StratifiedKFold, cross_validate
from scipy.signal import butter, filtfilt
from loaddata import Dataset_Left_Right_MI

In [2]:
# butterworth bandpass filter
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):   
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    data_filtered = filtfilt(b, a, data) # zero-phase filter
    return data_filtered

In [6]:
#load data
fs = 160
path='datasets' #Where the data set is stored
dataset = Dataset_Left_Right_MI('Lee2019_MI',fs,fmin=1,fmax=79,tmin=0,tmax=4,path=path)
subjects = dataset.subject_list

# create pipelines
# lda = GridSearchCV(LDA(shrinkage='auto'), {"solver": ['eigen','lsqr']}, cv=3, n_jobs=3)
lda = LDA(shrinkage='auto', solver='lsqr')
csp = make_pipeline(Cov(estimator='cov'), CSP(nfilter=2), lda)

# cross-validation
kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)

Acc = []
for subject in subjects:
    data, y = dataset.get_data([subject])
    data = butter_bandpass_filter(data, 5, 32, fs) # filter data
    
    acc = cross_validate(csp, data, y, cv=kf, scoring='accuracy', n_jobs=10)
    Acc.append(acc["test_score"].mean())
    print(f'S{subject} CSP accuracy: {acc["test_score"].mean():.3f} +/- {acc["test_score"].std():.3f}')

# display results
print(f'Mean accuracy: {np.mean(Acc):.3f} +/- {np.std(Acc):.3f}')

S1 CSP accuracy: 0.475 +/- 0.087
S2 CSP accuracy: 0.590 +/- 0.153
S3 CSP accuracy: 0.870 +/- 0.135
S4 CSP accuracy: 0.410 +/- 0.073
S5 CSP accuracy: 0.530 +/- 0.119
S6 CSP accuracy: 0.770 +/- 0.093
S7 CSP accuracy: 0.595 +/- 0.091
S8 CSP accuracy: 0.415 +/- 0.103
S9 CSP accuracy: 0.580 +/- 0.142
S10 CSP accuracy: 0.505 +/- 0.108
S11 CSP accuracy: 0.425 +/- 0.112
S12 CSP accuracy: 0.610 +/- 0.153
S13 CSP accuracy: 0.450 +/- 0.116
S14 CSP accuracy: 0.495 +/- 0.144
S15 CSP accuracy: 0.465 +/- 0.123
S16 CSP accuracy: 0.560 +/- 0.083
S17 CSP accuracy: 0.400 +/- 0.112
S18 CSP accuracy: 0.715 +/- 0.125
S19 CSP accuracy: 0.635 +/- 0.114
S20 CSP accuracy: 0.530 +/- 0.075
S21 CSP accuracy: 0.975 +/- 0.040
S22 CSP accuracy: 0.610 +/- 0.094
S23 CSP accuracy: 0.565 +/- 0.090
S24 CSP accuracy: 0.635 +/- 0.090
S25 CSP accuracy: 0.725 +/- 0.105
S26 CSP accuracy: 0.545 +/- 0.129
S27 CSP accuracy: 0.480 +/- 0.081
S28 CSP accuracy: 0.715 +/- 0.125
S29 CSP accuracy: 0.845 +/- 0.035
S30 CSP accuracy: 0.535