# Cross-Session on Multiple Datasets
Author: LC.Pan  
Date: 2024-06-24  

In [None]:
# 公共工具库
import os, time
import json
import numpy as np
import itertools
import multiprocessing as mp
from joblib import Parallel, delayed, parallel_backend
from sklearn.model_selection import RepeatedStratifiedKFold
from contextlib import redirect_stdout, redirect_stderr
import torch

# 私有工具库
from loaddata import Dataset_Left_Right_MI
from deep_learning.dl_classifier import DL_Classifier
from pre_processing.preprocessing import Pre_Processing
from transfer_learning.tl_classifier import TL_Classifier
from transfer_learning import TLSplitter, encode_datasets

Loading Dataset

In [None]:
# 设置参数
dataset_name = 'Pan2023'
fs = 250
freqband = [8,30]
datapath = r'E:\工作进展\小论文2023会议\数据处理python\datasets'

# 加载数据
dataset = Dataset_Left_Right_MI(dataset_name,fs,fmin=freqband[0],fmax=freqband[1],tmin=0,tmax=4,path=datapath)

# for sub in dataset.subjects:
#     print(f"Subject {sub}...")
#     # 加载数据
#     data = dataset.get_data()

sub = [1]
data,label,info = dataset.get_data(sub)

In [None]:
session_values = info['session'].unique()
print('the session values are:',session_values)
session_indices = info.groupby('session').apply(lambda x: x.index.tolist())

# 将结果转换为字典，键为不同值，值为对应的索引列表
session_index_dict = dict(zip(session_values, session_indices))

Data, Label=[], []
for session in session_values[:2]:
    Data.append(data[session_index_dict[session]])
    Label.append(label[session_index_dict[session]])

X, y_enc, domain =encode_datasets(Data, Label)
print(X.shape, y_enc.shape, len(domain))
print(domain)

target_domain = domain[-1]

设置基于迁移学习的跨会话交叉验证评估索引

In [None]:
from sklearn.model_selection import StratifiedShuffleSplit
cv = StratifiedShuffleSplit(n_splits=10, random_state=42)
tl_cv = TLSplitter(target_domain=target_domain, cv=cv, no_calibration=False)
train_size = 30

if train_size == 0:
    tl_cv.no_calibration = True
else:
    tl_cv.cv.train_size = train_size

for train, test in tl_cv.split(X, y_enc):
    print(len(train), len(test))

Create Pipelines

In [None]:
from joblib import Memory

# 设置缓存目录
cachedir = '../my_cache_directory'
memory = Memory(cachedir, verbose=0)

preprocess = Pre_Processing(fs_new=160, fs_old=250, 
                       n_channels=None, 
                       start_time=0.5, end_time=3.5,
                       lowcut=None, highcut=None, )

Model = TL_Classifier(dpa_method='EA', 
                      fee_method='CSP', 
                      fes_method='MIC-K', 
                      clf_method='SVM',
                      pre_est=preprocess.process,
                      memory=memory,
                      target_domain=target_domain,
                      )

Evaluating cross-session performance

In [None]:
from sklearn.model_selection import cross_val_score, cross_validate

scores = cross_validate(Model, X, y_enc, cv=tl_cv, n_jobs=10)

In [None]:
train_time = scores['fit_time']
test_time = scores['score_time']
test_score = scores['test_score']
print('train time: %.3f s, test time: %.3f s' % (train_time.mean(), test_time.mean()))
print('test score: %.3f +/- %.3f' % (test_score.mean(), test_score.std()))