In [None]:
# 调用NeuroDecKit工具库
from neurodeckit import (
    Dataset_Left_Right_MI, 
    DL_Classifier, 
    TL_Classifier, 
    CrossSessionEvaluator)

# 定义数据集：Pan2023，受试者：S1
dataset = Dataset_Left_Right_MI('Pan2023',fmin=8,fmax=30,tmin=0,tmax=4,path='E:/工作进展/公开数据集/datasets')
data, label, info = dataset.get_data([1])

# 定义实验场景：跨会话评估
experiment = CrossSessionEvaluator(info, metrics='auc')
domain_tags = experiment.get_domain_tags()
target_domain = domain_tags[-1]

# 定义算法流程：基于迁移学习框架的CSP和MVSTNet模型
csp = TL_Classifier(target_domain=target_domain, dpa_method='RA', fee_method='CSP')
net = DL_Classifier(model_name='MSVTNet', device='cuda')
msvtnet = TL_Classifier(target_domain=target_domain, dpa_method='RA', ete_method=net)
pipelines = {'CSP': csp, 'MSVTNet': msvtnet}

# 执行计算
results = experiment.evaluate(data, label, pipelines, target_domain)

# 输出结果
print(results)

  epoch    train_acc    train_loss    valid_acc    valid_loss    cp      lr     dur
-------  -----------  ------------  -----------  ------------  ----  ------  ------
      1       [36m0.4896[0m        [32m0.7340[0m       [35m0.5417[0m        [31m0.7455[0m     +  0.0010  0.1589
      2       [36m0.5625[0m        [32m0.7256[0m       0.5417        [31m0.6804[0m     +  0.0010  0.1451
      3       0.5104        0.7512       [35m0.6667[0m        [31m0.6430[0m     +  0.0010  0.1773
      4       0.5625        [32m0.7144[0m       0.6667        [31m0.6245[0m     +  0.0010  0.2007
      5       0.5625        [32m0.7037[0m       [35m0.7083[0m        [31m0.6037[0m     +  0.0010  0.1998
      6       [36m0.5729[0m        0.7119       0.7083        [31m0.5761[0m     +  0.0010  0.1531
      7       [36m0.6667[0m        [32m0.6982[0m       [35m0.7500[0m        [31m0.5697[0m     +  0.0010  0.1711
      8       0.5625        0.7261       0.7500        [31m0.55

In [5]:
print(results['CSP']['auc']['mean']) 
print(results['MSVTNet']['auc']['mean'])

0.9247222222222222
0.7536111111111111


In [7]:
experiment.get_domain_tags()[-1]

'S1_Sess1'