In [1]:
import numpy as np
from pyriemann.estimation import Covariances as Cov
from pyriemann.tangentspace import TangentSpace as TS
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import StratifiedKFold, cross_validate
from scipy.signal import butter, filtfilt
from loaddata import Dataset_Left_Right_MI

In [2]:
# butterworth bandpass filter
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):   
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    data_filtered = filtfilt(b, a, data) # zero-phase filter
    return data_filtered

In [5]:
#load data
fs = 160
path=r'E:\工作进展\小论文2023会议\数据处理python\datasets' #Where the data set is stored
dataset = Dataset_Left_Right_MI('Lee2019_MI',fs,fmin=1,fmax=79,tmin=0,tmax=4,path=path)
subjects = dataset.subject_list

# create pipelines
lda = LDA(shrinkage='auto', solver='lsqr')
ts = make_pipeline(Cov(estimator='cov'), TS(), lda)

# cross-validation
kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)

Acc = []
for subject in subjects:
    data, y = dataset.get_data([subject])
    data = butter_bandpass_filter(data, 5, 32, fs) # filter data
    
    acc = cross_validate(ts, data, y, cv=kf, scoring='accuracy', n_jobs=10)
    Acc.append(acc["test_score"].mean())
    print(f'S{subject} TS accuracy: {acc["test_score"].mean():.3f} +/- {acc["test_score"].std():.3f}')

# display results
print(f'Mean accuracy: {np.mean(Acc):.3f} +/- {np.std(Acc):.3f}')


S1 TS accuracy: 0.695 +/- 0.240
S2 TS accuracy: 0.915 +/- 0.105
S3 TS accuracy: 0.645 +/- 0.186
S4 TS accuracy: 0.605 +/- 0.169
S5 TS accuracy: 0.290 +/- 0.218
S6 TS accuracy: 0.550 +/- 0.152
S7 TS accuracy: 0.980 +/- 0.060
S8 TS accuracy: 0.475 +/- 0.136
S9 TS accuracy: 0.495 +/- 0.199
S10 TS accuracy: 0.600 +/- 0.177
S11 TS accuracy: 0.630 +/- 0.183
S12 TS accuracy: 0.735 +/- 0.215
S13 TS accuracy: 0.715 +/- 0.198
S14 TS accuracy: 0.495 +/- 0.221
S15 TS accuracy: 0.705 +/- 0.205
S16 TS accuracy: 0.450 +/- 0.227
S17 TS accuracy: 0.440 +/- 0.251
S18 TS accuracy: 0.380 +/- 0.218
S19 TS accuracy: 0.475 +/- 0.237
S20 TS accuracy: 0.720 +/- 0.187
S21 TS accuracy: 0.465 +/- 0.216
S22 TS accuracy: 0.885 +/- 0.116
S23 TS accuracy: 0.545 +/- 0.172
S24 TS accuracy: 0.480 +/- 0.182
S25 TS accuracy: 0.580 +/- 0.138
S26 TS accuracy: 0.670 +/- 0.232
S27 TS accuracy: 0.560 +/- 0.173
S28 TS accuracy: 0.360 +/- 0.148
S29 TS accuracy: 0.960 +/- 0.080
S30 TS accuracy: 0.450 +/- 0.244
S31 TS accuracy: 0.