In [None]:
import torch
import torch.nn as nn
from strokes import StrokePatientsMIDataset, StrokePatientsMIProcessedDataset
from strokesdict import STROKEPATIENTSMI_LOCATION_DICT
from torcheeg.transforms import Select,BandSignal,Compose
from to import ToGrid, ToTensor
from downsample import SetSamplingRate
from baseline import BaselineCorrection

dataset = StrokePatientsMIDataset(root_path='./subdataset',
                                  io_path='.torcheeg/dataset',
                        chunk_size=500,  # 1 second
                        overlap = 250,
                        offline_transform=Compose(
                                [BaselineCorrection(),
                                SetSamplingRate(origin_sampling_rate=500,target_sampling_rate=128),
                                BandSignal(sampling_rate=128,band_dict={'frequency_range':[8,40]})
                                ]),
                        online_transform=Compose(
                                [ToGrid(STROKEPATIENTSMI_LOCATION_DICT),ToTensor()]),
                        label_transform=Select('label'),
                        num_worker=8
)
print(dataset[0][0].shape) #EEG shape:torch.Size([1, 128, 9, 9])
print(dataset[0][1])  # label (int)
print(len(dataset))

In [None]:
from eegswintransformer import SwinTransformer

HYPERPARAMETERS = {
    "seed": 42,
    "batch_size": 12,
    "lr": 1e-5,
    "weight_decay": 1e-4,
    "num_epochs": 20,
}
from torcheeg.model_selection import KFoldPerSubjectGroupbyTrial
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader
from classifier_loss import ClassifierTrainer

k_fold = KFoldPerSubjectGroupbyTrial(
    n_splits=4,
    shuffle=True,
    split_path='.torcheeg/model_selection',
    random_state=42)

training_metrics = []
test_metrics = []
test_f1 = []

for i, (training_dataset, test_dataset) in enumerate(k_fold.split(dataset)):
    if i==2:
        model = SwinTransformer(patch_size=(8,3,3),
                                depths=(2, 6, 4),
                                num_heads=(3,6,8),
                                window_size=(3,3,3)
                                )
        trainer = ClassifierTrainer(model=model,
                                    num_classes=2,
                                    lr=HYPERPARAMETERS['lr'],
                                    weight_decay=HYPERPARAMETERS['weight_decay'],
                                    metrics=["accuracy","f1score"],
                                    accelerator="gpu")
        training_loader = DataLoader(training_dataset,
                                batch_size=HYPERPARAMETERS['batch_size'],
                                shuffle=True)
        test_loader = DataLoader(test_dataset,
                                batch_size=HYPERPARAMETERS['batch_size'],
                                shuffle=False)
        # 提前停止回调
        early_stopping_callback = EarlyStopping(
            monitor='train_loss',
            patience=50,
            mode='min',
            verbose=True
        )
        trainer.fit(training_loader,
                    test_loader,
                    max_epochs=HYPERPARAMETERS['num_epochs'],
                    callbacks=[early_stopping_callback],
                    # enable_progress_bar=True,
                    enable_model_summary=False,
                    limit_val_batches=0.0)
        training_result = trainer.test(training_loader,
                                    enable_progress_bar=True,
                                    enable_model_summary=True)[0]
        test_result = trainer.test(test_loader,
                                enable_progress_bar=True,
                                enable_model_summary=True)[0]
        training_metrics.append(training_result["test_accuracy"])
        test_f1.append(test_result["f1score"])
        test_metrics.append(test_result["test_accuracy"])
        
        
        # ✅ 保存模型参数
        # model_path = f"swin_fold_{i+1}.pth"
        # torch.save(model.state_dict(), model_path)
        # print(f"✅ 模型参数已保存到 {model_path}")
        break
     

In [None]:
import torch
print(torch.__version__)

In [None]:
import torch
import numpy as np
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from swin_CAM import SwinTransformer  
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# 假设目标层输出为 (batch, 36, 1536)，即 6×6 patch，1536 维特征
# def reshape_transform(tensor, time=4, height=3, width=3):
#     # 输入形状: (B, 36, 1536)
#     B, seq_len, C = tensor.size()

#     result = tensor.reshape(B, time, height, width, C)
#     # result = torch.mean(result, dim=1)
#     result = result.permute(0, 4, 1, 2, 3)  # (B, C, H, W)
#     return result
def reshape_transform(tensor, time=16, height=3, width=3):
    # 输入形状: (B, 96, 16, 3, 3）
    print(tensor.shape) 
    # B, seq_len, C = tensor.size()

    # result = tensor.reshape(B, time, height, width, C)
    # result = torch.mean(result, dim=1)
    # result = result.permute(0, 4, 1, 2, 3)  # (B, C, H, W)
    return tensor

# 加载模型
model = SwinTransformer(patch_size=(8,3,3),
                        depths=(2, 6, 4),
                        num_heads=(3,6,8),
                        window_size=(3,3,3)
                        )

model_path = "swin_fold_3.pth"  # 指定模型文件路径
model.load_state_dict(torch.load(model_path),strict=False)
model.eval()
# target_layers = [model.layers[-1].blocks[-1].norm2]  # 替换为你的实际目标层
target_layers = [model.patch_embed.proj]  # 替换为你的实际目标层

# 是否使用GPU
use_cuda = torch.cuda.is_available()
if use_cuda:
    model = model.cuda()

# 示例 EEG 输入 (batch, 1, 128, 9, 9)
# eeg_input = torch.randn(2, 1, 128, 9, 9)  # 假设 batch size 为 2

# test_loader = DataLoader(dataset,
#                         batch_size=32,
#                         shuffle=False)

# 只处理第一个 batch
for eeg_input, label in test_loader:
    if use_cuda:
        eeg_input = eeg_input.cuda()
        label = label.cuda()
        
    # print(label)
    result = model(eeg_input)
    preds = torch.argmax(result, dim=1)
    # print(preds)
    correct = (preds == label).sum().item()
    total = label.size(0)
    acc = correct / total
    print(f"Accuracy: {acc:.4f}")
    # 初始化 GradCAM
    cam = GradCAM(model=model,
                  target_layers=target_layers,
                  reshape_transform=reshape_transform)

    # 可选：设定类别目标（None 为默认分类）
    targets = None
    # targets = [ClassifierOutputTarget(1)] * eeg_input.shape[0]  # 指定所有样本为类别1

    # 生成 CAM
    grayscale_cam = cam(input_tensor=eeg_input, targets=targets)  # shape: (batch, H, W)
    print(grayscale_cam.shape)
    # 转为 tensor 方便处理
    cam_tensor = torch.tensor(grayscale_cam)  # shape: (batch_size, H, W)
    print(cam_tensor.shape)
    # 生成预测正确和错误的掩码
    correct_mask = (preds == label)
    wrong_mask = ~correct_mask  # preds != label

    # 获取对应的 Grad-CAM 特征图
    correct_cam = cam_tensor[correct_mask]  # (N_correct, H, W)
    wrong_cam = cam_tensor[wrong_mask]      # (N_wrong, H, W)

    # 平均并可视化预测正确的 Grad-CAM
    if correct_cam.size(0) > 0:
        avg_correct_cam = correct_cam.mean(dim=0)  # (H, W)
        plt.imshow(avg_correct_cam.cpu().numpy(), cmap='jet')
        plt.title("Average Grad-CAM (Correct Predictions)")
        plt.axis("off")
        plt.colorbar()
        plt.show()
    else:
        print("No correct predictions.")

    # 平均并可视化预测错误的 Grad-CAM
    if wrong_cam.size(0) > 0:
        avg_wrong_cam = wrong_cam.mean(dim=0)  # (H, W)
        plt.imshow(avg_wrong_cam.cpu().numpy(), cmap='jet')
        plt.title("Average Grad-CAM (Wrong Predictions)")
        plt.axis("off")
        plt.colorbar()
        plt.show()
    else:
        print("No wrong predictions.")
    
    break

In [None]:
for i in range(cam_tensor.shape[0]):
    sample = cam_tensor[i]

    plt.imshow(sample.cpu().numpy(), cmap='jet')
    plt.title(f"Grad-CAM {i}, {correct_mask[i]}")
    plt.axis("off")
    plt.colorbar()
    plt.show()