In [1]:
from strokesinit import StrokePatientsMIProcessedDataset, StrokePatientsMIDataset
import scipy
from strokesdict import STROKEPATIENTSMI_LOCATION_DICT
from torcheeg.transforms import Select,BandSignal,Compose,Lambda,ToTensor,BandDifferentialEntropy,Resize
from to import ToGrid
from typing import Callable, Dict, Union, List
import numpy as np
import soxr
from downsample import SetSamplingRate
from baseline import BaselineCorrection
from base_transform import EEGTransform
    

dataset = StrokePatientsMIDataset(root_path='./subdataset',
                                #   io_path='.torcheeg/datasets_1735798885866_IhPTv',
                        chunk_size=500,  # 1 second
                        overlap = 0,
                        offline_transform=Compose(
                                [BaselineCorrection(),SetSamplingRate(500,484),
                                BandSignal(sampling_rate=484,band_dict={'frequency_range':[8,40]})]),
                        online_transform=Compose(
                                [ToTensor()]),
                
                        label_transform=Select('label'),
                        num_worker=8
)
print(dataset[0][0].shape) #EEG shape(1,30,500)
print(dataset[0][1])  # label (int)
print(len(dataset))

  from .autonotebook import tqdm as notebook_tqdm
[2025-01-16 16:05:35] INFO (torcheeg/MainThread) 🔍 | Processing EEG data. Processed EEG data has been cached to [92m.torcheeg/datasets_1737014735074_lmTAi[0m.
[2025-01-16 16:05:35] INFO (torcheeg/MainThread) ⏳ | Monitoring the detailed processing of a record for debugging. The processing of other records will only be reported in percentage to keep it clean.
[PROCESS]: 100%|██████████| 2/2 [00:00<00:00, 205.77it/s]

[RECORD ./subdataset/sourcedata/sub-45/sub-45_task-motor-imagery_eeg.mat]: 0it [00:00, ?it/s][A
[RECORD ./subdataset/sourcedata/sub-45/sub-45_task-motor-imagery_eeg.mat]: 1it [00:00,  4.45it/s][A
[RECORD ./subdataset/sourcedata/sub-45/sub-45_task-motor-imagery_eeg.mat]: 2it [00:00,  5.78it/s][A
[RECORD ./subdataset/sourcedata/sub-45/sub-45_task-motor-imagery_eeg.mat]: 12it [00:00, 36.41it/s][A
[RECORD ./subdataset/sourcedata/sub-45/sub-45_task-motor-imagery_eeg.mat]: 22it [00:00, 56.40it/s][A
[RECORD ./subdataset/sourc

torch.Size([1, 30, 484])
0
320


In [2]:
HYPERPARAMETERS = {
    "seed": 42,
    "batch_size": 32,
    "lr": 1e-4,
    "weight_decay": 0,
    "num_epochs": 50,
}

In [3]:
from eca import ECA
from model import SwinTransformer
import torch
import torch.nn as nn

class ECASwinTransformerModel(nn.Module):
    def __init__(self, num_channels):
        super(ECASwinTransformerModel, self).__init__()
        self.eca = ECA(input_size=484,num_attention_heads=1,hidden_size=484,hidden_dropout_prob=0.5)
        self.swin_transformer = SwinTransformer(
                                                in_chans=30,
                                                num_classes=2,
                                                embed_dim=192,
                                                depths=(2, 2, 18, 2),
                                                num_heads=(6, 12, 24, 48)
                                                )

    def forward(self, x):
        x = x.view(x.size(0),30,484)
        x = self.eca(x)
        x = x.view(x.size(0),30,22,22)
        x = self.swin_transformer(x)
        return x

In [4]:
import os
import shutil

def delete_folder_if_exists(target_folder_name):
    # 获取父文件夹中的所有内容
    parent_folder = os.getcwd()
    for folder_name in os.listdir(parent_folder):
        folder_path = os.path.join(parent_folder, folder_name)

        # 检查是否是文件夹并且名称是否匹配
        if os.path.isdir(folder_path) and folder_name == target_folder_name:
            try:
                # 删除目标文件夹
                shutil.rmtree(folder_path)
                print(f"已删除文件夹: {folder_path}")
            except Exception as e:
                print(f"删除文件夹 {folder_path} 时出错: {e}")


In [None]:
from torcheeg.model_selection import KFoldPerSubject
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader
from model import SwinTransformer
from torcheeg.trainers import ClassifierTrainer

k_fold = KFoldPerSubject(
    n_splits=5,
    shuffle=True,
    random_state=42)

training_metrics = []
test_metrics = []

for i, (training_dataset, test_dataset) in enumerate(k_fold.split(dataset)):
    delete_folder_if_exists(target_folder_name='lightning_logs')
    model = ECASwinTransformerModel(num_channels=30)
    trainer = ClassifierTrainer(model=model,
                                num_classes=2,
                                lr=HYPERPARAMETERS['lr'],
                                weight_decay=HYPERPARAMETERS['weight_decay'],
                                metrics=["accuracy"],
                                accelerator="gpu")
    
    training_loader = DataLoader(training_dataset,
                             batch_size=HYPERPARAMETERS['batch_size'],
                             shuffle=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=HYPERPARAMETERS['batch_size'],
                             shuffle=False)
    # 提前停止回调
    early_stopping_callback = EarlyStopping(
        monitor='train_loss',
        patience=8, 
        mode='min',
        verbose=True
    )
    
    trainer.fit(training_loader,
                test_loader,
                max_epochs=HYPERPARAMETERS['num_epochs'],
                callbacks=[early_stopping_callback],
                enable_progress_bar=True,
                enable_model_summary=False,
                limit_val_batches=0.0)
    # training_result = trainer.test(training_loader,
    #                                enable_progress_bar=True,
    #                                enable_model_summary=True)[0]
    test_result = trainer.test(test_loader,
                               enable_progress_bar=True,
                               enable_model_summary=True)[0]
    # training_metrics.append(training_result["test_accuracy"])
    test_metrics.append(test_result["test_accuracy"])
     
print({
    # "training_metric_avg": np.mean(training_metrics),
    # "training_metric_std": np.std(training_metrics),
    "test_metric_avg": np.mean(test_metrics),
    "test_metric_std": np.std(test_metrics)
})

In [6]:
for i, score in enumerate(test_metrics):
    print(i,score)

0 0.78125
1 0.53125
2 0.78125
3 0.71875
4 0.75
5 0.625
6 0.59375
7 0.40625
8 0.4375
9 0.6875


In [8]:
import numpy as np
print({
    # "training_metric_avg": np.mean(training_metrics),
    # "training_metric_std": np.std(training_metrics),
    "len": len(test_metrics),
    "test_metric_avg": np.mean(test_metrics),
    "test_metric_std": np.std(test_metrics)
})

{'len': 10, 'test_metric_avg': 0.63125, 'test_metric_std': 0.13020416659999787}
