In [1]:
from torcheeg import transforms

transform = transforms.Compose([
    transforms.Concatenate([
        transforms.BandDifferentialEntropy(sampling_rate=200,
                                           band_dict={
                                               "delta": (1, 4),
                                               "theta": (4, 8),
                                               "alpha": (8, 13),
                                               "beta": (13, 30),
                                               "gamma": (30, 44)
                                           }),
        transforms.BandPowerSpectralDensity(sampling_rate=200,
                                            band_dict={
                                                "delta": (1, 4),
                                                "theta": (4, 8),
                                                "alpha": (8, 13),
                                                "beta": (13, 30),
                                                "gamma": (30, 44)
                                            }),
        transforms.BandMeanAbsoluteDeviation(sampling_rate=200,
                                             band_dict={
                                                 "delta": (1, 4),
                                                 "theta": (4, 8),
                                                 "alpha": (8, 13),
                                                 "beta": (13, 30),
                                                 "gamma": (30, 44)
                                             }),
        transforms.BandDetrendedFluctuationAnalysis(sampling_rate=200,
                                                    band_dict={
                                                        "delta": (1, 4),
                                                        "theta": (4, 8),
                                                        "alpha": (8, 13),
                                                        "beta": (13, 30),
                                                        "gamma": (30, 44)
                                                    }),
        transforms.BandHiguchiFractalDimension(sampling_rate=200,
                                               band_dict={
                                                   "delta": (1, 4),
                                                   "theta": (4, 8),
                                                   "alpha": (8, 13),
                                                   "beta": (13, 30),
                                                   "gamma": (30, 44)
                                               }),
        transforms.BandHjorth(mode='mobility',
                              band_dict={
                                  "delta": (1, 4),
                                  "theta": (4, 8),
                                  "alpha": (8, 13),
                                  "beta": (13, 30),
                                  "gamma": (30, 44)
                              }),
        transforms.BandHjorth(mode='complexity',
                              band_dict={
                                  "delta": (1, 4),
                                  "theta": (4, 8),
                                  "alpha": (8, 13),
                                  "beta": (13, 30),
                                  "gamma": (30, 44)
                              }),
        transforms.BandBinPower(sampling_rate=200,
                                band_dict={
                                    "delta": (1, 4),
                                    "theta": (4, 8),
                                    "alpha": (8, 13),
                                    "beta": (13, 30),
                                    "gamma": (30, 44)
                                })
    ]),
    transforms.MinMaxNormalize(),
    transforms.ToTensor()
])

In [2]:
import scipy.io as sio
import torch


def test(data_path, model_path):
    data = sio.loadmat(data_path)
    sample = data['EEG_ECClean']
    result = []
    for i in range(12):
        eeg_data = sample["data"][0][0][:, i * 2000:(i + 1) * 2000]

        x = transform(eeg=eeg_data)['eeg']
        x = torch.reshape(x, (1, 16, 40))

        model = torch.load(model_path, map_location=torch.device('cpu'))  # gpu训练时需要映射到cpu

        model.eval()
        with torch.no_grad():
            output = model(x)
            result.append(output.argmax(1)[0].tolist())
    return result

# def test(data_path, model_path):
#     data = sio.loadmat(data_path)
#     sample = data['EEG_ECClean']
#     result = []
#     eeg_data = sample["data"][0][0][:, 0:24000]
# 
#     x = transform(eeg=eeg_data)['eeg']
#     x = torch.reshape(x, (1, 16, 40))
# 
#     model = torch.load(model_path, map_location=torch.device('cpu'))  # gpu训练时需要映射到cpu
# 
#     model.eval()
#     with torch.no_grad():
#         output = model(x)
#         result.append(output.argmax(1)[0].tolist())
#     return result

# 批量测试

In [3]:
import os

# 设置文件夹路径和类别名称
folders = [r'C:\Users\bugs_\PycharmProjects\eegProject\data\Test_EEG\HC_backup\BadSub',
           r'C:\Users\bugs_\PycharmProjects\eegProject\data\Test_EEG\MDD_backup',
           r'C:\Users\bugs_\PycharmProjects\eegProject\data\Test_EEG\BD_backup']
class_names = ['HC', 'MDD', 'BD']

# 创建存储数据集的字典
data_file = {'filename': []}

# 遍历每个类别的文件夹
for folder, class_name in zip(folders, class_names):
    # 获取文件夹中的Clean.mat文件 列表
    file_list = os.listdir(folder)
    file_list = [os.path.join(folder, file) for file in file_list if file.endswith('Clean.mat')]

    # 将数据加入到对应的数据集中
    data_file['filename'].extend([(file, class_name) for file in file_list])

In [7]:
model_path = r'C:\Users\bugs_\PycharmProjects\eegProject\torcheegProj\DGCNN\models\dataAug12_feature40_MinMaxNorm_DGCNN16_2_shuffle111_batch64_epoch1000_lr5e-4\DGCNN_16_2_999.pth'
label_dic = {'HC': 0, 'MDD': 1, "BD": 2}

total = 0
calculate = 0

for sub, label in data_file['filename']:
    if label == 'MDD':
        total += 1
        # print(sub)
        output = test(sub, model_path)
        maxTimes = max(output, key=output.count)
        print(maxTimes)
        if maxTimes == label_dic['MDD']:
            calculate += 1

2
2
0
1
2
2
2
2
1
2
1
1
2
1
2
2
1
1
1
1
2
2
1
1
1
1
0
0
2
2
2
2
1
1
1
2
2
1
2
1
2
1
1
1
1
1
1
2
1
2
1
2
1
2
2
0
2
1
1
2
1
1
2
0
0
0
1
0
2
1
0
0
1
0
0
0
1
0
2
0
0
1
1
2
1
2


In [8]:
total

86

In [9]:
correct = calculate / total
correct

0.4418604651162791

# 单个测试

In [5]:
result_list = test(r'C:\Users\bugs_\PycharmProjects\eegProject\data\Test_EEG\HC\X20211029yinqichuan_EEG_ECClean.mat',
                   r'C:\Users\bugs_\PycharmProjects\eegProject\torcheegProj\DGCNN\models\origin\DGCNN_16_2_137.pth')

In [6]:
from collections import Counter

result = Counter(result_list)
print(result)
maxTimes = max(result_list, key=result_list.count)
print('maxTimes', maxTimes)

Counter({0: 1})
maxTimes 0
