In [None]:
import torch
import torch.nn as nn
from strokes import StrokePatientsMIDataset, StrokePatientsMIProcessedDataset
from strokesdict import STROKEPATIENTSMI_LOCATION_DICT
from torcheeg.transforms import Select,BandSignal,Compose
from to import ToGrid, ToTensor
from downsample import SetSamplingRate
from baseline import BaselineCorrection

dataset = StrokePatientsMIDataset(root_path='./subdataset',
                                  io_path='.torcheeg/datasets_1745478591849_L5nX8',
                        chunk_size=500,  # 1 second
                        overlap = 250,
                        offline_transform=Compose(
                                [BaselineCorrection(),
                                SetSamplingRate(origin_sampling_rate=500,target_sampling_rate=128),
                                BandSignal(sampling_rate=128,band_dict={'frequency_range':[8,40]})
                                ]),
                        online_transform=Compose(
                                [ToGrid(STROKEPATIENTSMI_LOCATION_DICT),ToTensor()]),
                        label_transform=Select('label'),
                        num_worker=8
)
print(dataset[0][0].shape) #EEG shape:torch.Size([1, 128, 9, 9])
print(dataset[0][1])  # label (int)
print(len(dataset))

In [None]:
from eegswintransformer import SwinTransformer

HYPERPARAMETERS = {
    "seed": 42,
    "batch_size": 16,
    "lr": 1e-4,
    "weight_decay": 1e-4,
    "num_epochs": 50,
}
from torcheeg.model_selection import KFoldPerSubjectGroupbyTrial
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader
from classifier import ClassifierTrainer

k_fold = KFoldPerSubjectGroupbyTrial(
    n_splits=4,
    shuffle=True,
    # split_path='.torcheeg/model_selection_1741918660674_jp0WB',
    random_state=42)

training_metrics = []
test_metrics = []

for i, (training_dataset, test_dataset) in enumerate(k_fold.split(dataset)):
    model = SwinTransformer(patch_size=(8,3,3),
                              num_classes=2,
                              depths=(2, 6, 4),
                              num_heads=(3,6,8),
                              window_size=(3,3,3),
                              in_chans=1
                              ) # T, W, H 同时缩小
    trainer = ClassifierTrainer(model=model,
                                num_classes=2,
                                lr=HYPERPARAMETERS['lr'],
                                weight_decay=HYPERPARAMETERS['weight_decay'],
                                metrics=["accuracy"],
                                accelerator="gpu")
    training_loader = DataLoader(training_dataset,
                             batch_size=HYPERPARAMETERS['batch_size'],
                             shuffle=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=HYPERPARAMETERS['batch_size'],
                             shuffle=False)
    # 提前停止回调
    early_stopping_callback = EarlyStopping(
        monitor='train_loss',
        patience=50,
        mode='min',
        verbose=True
    )
    trainer.fit(training_loader,
                test_loader,
                max_epochs=HYPERPARAMETERS['num_epochs'],
                callbacks=[early_stopping_callback],
                # enable_progress_bar=True,
                enable_model_summary=False,
                limit_val_batches=0.0)
    training_result = trainer.test(training_loader,
                                   enable_progress_bar=True,
                                   enable_model_summary=True)[0]
    test_result = trainer.test(test_loader,
                               enable_progress_bar=True,
                               enable_model_summary=True)[0]
    training_metrics.append(training_result["test_accuracy"])
    test_metrics.append(test_result["test_accuracy"])
    
     

In [None]:
import numpy as np

for i in range(0, len(test_metrics), 4):
    print(f"{i}\t"
          f"{np.mean(training_metrics[i:i+4]):.3f}\t"
          f"{np.std(training_metrics[i:i+4]):.3f}\t"
          f"{np.mean(test_metrics[i:i+4]):.3f}\t"
          f"{np.std(test_metrics[i:i+4]):.3f}")


In [None]:
for i, score in enumerate(test_metrics):
    print(f"{score:.3f}", end="\t")
    if (i + 1) % 4 == 0:
        print()  # Print a newline every 4 elements

# Ensure the last line is printed properly if the length isn't a multiple of 4
if len(test_metrics) % 4 != 0:
    print()
