In [1]:
import torch
import torch.nn as nn
from strokes import StrokePatientsMIDataset, StrokePatientsMIProcessedDataset
from strokesdict import STROKEPATIENTSMI_LOCATION_DICT
from torcheeg.transforms import Select,BandSignal,Compose
from to import ToGrid, ToTensor
from downsample import SetSamplingRate
from baseline import BaselineCorrection

dataset = StrokePatientsMIDataset(root_path='./subdataset',
                                  io_path='.torcheeg/dataset',
                        chunk_size=500,  # 1 second
                        overlap = 250,
                        offline_transform=Compose(
                                [BaselineCorrection(),
                                SetSamplingRate(origin_sampling_rate=500,target_sampling_rate=128),
                                BandSignal(sampling_rate=128,band_dict={'frequency_range':[8,40]})
                                ]),
                        online_transform=Compose(
                                [ToGrid(STROKEPATIENTSMI_LOCATION_DICT),ToTensor()]),
                        label_transform=Select('label'),
                        num_worker=8
)
print(dataset[0][0].shape) #EEG shape:torch.Size([1, 128, 9, 9])
print(dataset[0][1])  # label (int)
print(len(dataset))

  from .autonotebook import tqdm as notebook_tqdm
[2025-05-15 16:19:23] INFO (torcheeg/MainThread) üîç | Detected cached processing results, reading cache from .torcheeg/dataset.


torch.Size([1, 128, 9, 9])
0
240


In [2]:
from eegswintransformer import SwinTransformer

HYPERPARAMETERS = {
    "seed": 42,
    "batch_size": 12,
    "lr": 1e-5,
    "weight_decay": 1e-4,
    "num_epochs": 20,
}
from torcheeg.model_selection import KFoldPerSubjectGroupbyTrial
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader
from classifier_loss import ClassifierTrainer

k_fold = KFoldPerSubjectGroupbyTrial(
    n_splits=4,
    shuffle=True,
    split_path='.torcheeg/model_selection',
    random_state=42)

training_metrics = []
test_metrics = []

for i, (training_dataset, test_dataset) in enumerate(k_fold.split(dataset)):
    if i==0:
        model = SwinTransformer(patch_size=(8,3,3),
                                depths=(2, 6, 4),
                                num_heads=(3,6,8),
                                window_size=(3,3,3)
                                )
        trainer = ClassifierTrainer(model=model,
                                    num_classes=2,
                                    lr=HYPERPARAMETERS['lr'],
                                    weight_decay=HYPERPARAMETERS['weight_decay'],
                                    metrics=["accuracy"],
                                    accelerator="gpu")
        training_loader = DataLoader(training_dataset,
                                batch_size=HYPERPARAMETERS['batch_size'],
                                shuffle=True)
        test_loader = DataLoader(test_dataset,
                                batch_size=HYPERPARAMETERS['batch_size'],
                                shuffle=False)
        # ÊèêÂâçÂÅúÊ≠¢ÂõûË∞É
        early_stopping_callback = EarlyStopping(
            monitor='train_loss',
            patience=50,
            mode='min',
            verbose=True
        )
        trainer.fit(training_loader,
                    test_loader,
                    max_epochs=HYPERPARAMETERS['num_epochs'],
                    callbacks=[early_stopping_callback],
                    # enable_progress_bar=True,
                    enable_model_summary=False,
                    limit_val_batches=0.0)
        training_result = trainer.test(training_loader,
                                    enable_progress_bar=True,
                                    enable_model_summary=True)[0]
        test_result = trainer.test(test_loader,
                                enable_progress_bar=True,
                                enable_model_summary=True)[0]
        training_metrics.append(training_result["test_accuracy"])
        test_metrics.append(test_result["test_accuracy"])
        
        
        # ‚úÖ ‰øùÂ≠òÊ®°ÂûãÂèÇÊï∞
        model_path = f"swin_fold_{i+1}.pth"
        torch.save(model.state_dict(), model_path)
        print(f"‚úÖ Ê®°ÂûãÂèÇÊï∞Â∑≤‰øùÂ≠òÂà∞ {model_path}")
        break
     

[2025-05-15 16:19:25] INFO (torcheeg/MainThread) üìä | Detected existing split of train and test set, use existing split from .torcheeg/model_selection.
[2025-05-15 16:19:25] INFO (torcheeg/MainThread) üí° | If the dataset is re-generated, you need to re-generate the split of the dataset instead of using the previous split.
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Epoch 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:02<00:00,  6.66it/s, loss=0.863, train_loss=0.893, train_accuracy=0.500]

Metric train_loss improved. New best score: 0.893


Epoch 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:02<00:00,  6.65it/s, loss=0.863, train_loss=0.893, train_accuracy=0.500]

  rank_zero_warn(
  rank_zero_warn(
[2025-05-15 16:19:32] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.863 train_accuracy: 0.519 



Epoch 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.07it/s, loss=0.635, train_loss=0.490, train_accuracy=1.000]

Metric train_loss improved by 0.403 >= min_delta = 0.0. New best score: 0.490


Epoch 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.06it/s, loss=0.635, train_loss=0.490, train_accuracy=1.000]

[2025-05-15 16:19:33] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.536 train_accuracy: 0.869 



Epoch 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.11it/s, loss=0.443, train_loss=0.335, train_accuracy=1.000]

Metric train_loss improved by 0.155 >= min_delta = 0.0. New best score: 0.335


Epoch 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.10it/s, loss=0.443, train_loss=0.335, train_accuracy=1.000]

[2025-05-15 16:19:35] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.381 train_accuracy: 0.969 



Epoch 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.03it/s, loss=0.339, train_loss=0.266, train_accuracy=1.000]

Metric train_loss improved by 0.069 >= min_delta = 0.0. New best score: 0.266


Epoch 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.02it/s, loss=0.339, train_loss=0.266, train_accuracy=1.000]

[2025-05-15 16:19:36] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.316 train_accuracy: 0.994 



Epoch 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.33it/s, loss=0.284, train_loss=0.229, train_accuracy=1.000]

Metric train_loss improved by 0.037 >= min_delta = 0.0. New best score: 0.229


Epoch 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.32it/s, loss=0.284, train_loss=0.229, train_accuracy=1.000]

[2025-05-15 16:19:37] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.279 train_accuracy: 1.000 



Epoch 5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.38it/s, loss=0.263, train_loss=0.188, train_accuracy=1.000]

Metric train_loss improved by 0.041 >= min_delta = 0.0. New best score: 0.188


Epoch 5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.37it/s, loss=0.263, train_loss=0.188, train_accuracy=1.000]

[2025-05-15 16:19:39] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.256 train_accuracy: 1.000 



Epoch 6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.38it/s, loss=0.247, train_loss=0.232, train_accuracy=1.000]

[2025-05-15 16:19:40] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.247 train_accuracy: 1.000 



Epoch 7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.25it/s, loss=0.245, train_loss=0.261, train_accuracy=1.000]

[2025-05-15 16:19:41] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.243 train_accuracy: 1.000 



Epoch 8: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.46it/s, loss=0.241, train_loss=0.258, train_accuracy=1.000]

[2025-05-15 16:19:43] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.239 train_accuracy: 1.000 



Epoch 9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.18it/s, loss=0.235, train_loss=0.193, train_accuracy=1.000]

[2025-05-15 16:19:44] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.233 train_accuracy: 1.000 



Epoch 10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.12it/s, loss=0.232, train_loss=0.223, train_accuracy=1.000]

[2025-05-15 16:19:46] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.232 train_accuracy: 1.000 



Epoch 11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.14it/s, loss=0.236, train_loss=0.250, train_accuracy=1.000]

[2025-05-15 16:19:47] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.232 train_accuracy: 1.000 



Epoch 12: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.23it/s, loss=0.229, train_loss=0.190, train_accuracy=1.000]

[2025-05-15 16:19:48] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.228 train_accuracy: 1.000 



Epoch 13: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.13it/s, loss=0.231, train_loss=0.306, train_accuracy=1.000]

[2025-05-15 16:19:50] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.233 train_accuracy: 1.000 



Epoch 14: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.34it/s, loss=0.234, train_loss=0.249, train_accuracy=1.000]

[2025-05-15 16:19:51] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.230 train_accuracy: 1.000 



Epoch 15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.12it/s, loss=0.227, train_loss=0.220, train_accuracy=1.000]

[2025-05-15 16:19:52] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.228 train_accuracy: 1.000 



Epoch 16: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.21it/s, loss=0.235, train_loss=0.285, train_accuracy=1.000]

[2025-05-15 16:19:54] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.231 train_accuracy: 1.000 



Epoch 17: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.02it/s, loss=0.234, train_loss=0.207, train_accuracy=1.000]

[2025-05-15 16:19:55] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.228 train_accuracy: 1.000 



Epoch 18: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.27it/s, loss=0.228, train_loss=0.265, train_accuracy=1.000]

[2025-05-15 16:19:57] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.231 train_accuracy: 1.000 



Epoch 19: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.05it/s, loss=0.23, train_loss=0.230, train_accuracy=1.000] 

[2025-05-15 16:19:58] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.230 train_accuracy: 1.000 

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:01<00:00, 10.02it/s, loss=0.23, train_loss=0.230, train_accuracy=1.000]


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(
  rank_zero_warn(


Testing DataLoader 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:00<00:00, 73.54it/s]

  rank_zero_warn(
  rank_zero_warn(
[2025-05-15 16:19:59] INFO (torcheeg/MainThread) 
[Test] test_loss: 0.002 test_accuracy: 1.000 



Testing DataLoader 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:00<00:00, 71.69it/s]
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
       Test metric             DataLoader 0
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
      test_accuracy                 1.0
        test_loss          0.001956828171387315
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:00<00:00, 71.01it/s]

[2025-05-15 16:20:00] INFO (torcheeg/MainThread) 
[Test] test_loss: 0.688 test_accuracy: 0.712 



Testing DataLoader 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7/7 [00:00<00:00, 67.65it/s]
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
       Test metric             DataLoader 0
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
      test_accuracy         0.7124999761581421
        test_loss           0.6878145337104797
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚

In [None]:
import torch
import numpy as np
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from swin_CAM import SwinTransformer  
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# ÂÅáËÆæÁõÆÊ†áÂ±ÇËæìÂá∫‰∏∫ (batch, 36, 1536)ÔºåÂç≥ 6√ó6 patchÔºå1536 Áª¥ÁâπÂæÅ
# def reshape_transform(tensor, time=4, height=3, width=3):
#     # ËæìÂÖ•ÂΩ¢Áä∂: (B, 36, 1536)
#     B, seq_len, C = tensor.size()

#     result = tensor.reshape(B, time, height, width, C)
#     # result = torch.mean(result, dim=1)
#     result = result.permute(0, 4, 1, 2, 3)  # (B, C, H, W)
#     return result
def reshape_transform(tensor, time=16, height=3, width=3):
    # ËæìÂÖ•ÂΩ¢Áä∂: (B, 96, 16, 3, 3Ôºâ
    print(tensor.shape) 
    # B, seq_len, C = tensor.size()

    # result = tensor.reshape(B, time, height, width, C)
    # result = torch.mean(result, dim=1)
    # result = result.permute(0, 4, 1, 2, 3)  # (B, C, H, W)
    return tensor

# Âä†ËΩΩÊ®°Âûã
model = SwinTransformer(patch_size=(8,3,3),
                        depths=(2, 6, 4),
                        num_heads=(3,6,8),
                        window_size=(3,3,3)
                        )

model_path = "swin_fold_3.pth"  # ÊåáÂÆöÊ®°ÂûãÊñá‰ª∂Ë∑ØÂæÑ
model.load_state_dict(torch.load(model_path),strict=False)
model.eval()
# target_layers = [model.layers[-1].blocks[-1].norm2]  # ÊõøÊç¢‰∏∫‰Ω†ÁöÑÂÆûÈôÖÁõÆÊ†áÂ±Ç
target_layers = [model.patch_embed.proj]  # ÊõøÊç¢‰∏∫‰Ω†ÁöÑÂÆûÈôÖÁõÆÊ†áÂ±Ç

# ÊòØÂê¶‰ΩøÁî®GPU
use_cuda = torch.cuda.is_available()
if use_cuda:
    model = model.cuda()

# Á§∫‰æã EEG ËæìÂÖ• (batch, 1, 128, 9, 9)
# eeg_input = torch.randn(2, 1, 128, 9, 9)  # ÂÅáËÆæ batch size ‰∏∫ 2

# test_loader = DataLoader(dataset,
#                         batch_size=32,
#                         shuffle=False)

# Âè™Â§ÑÁêÜÁ¨¨‰∏Ä‰∏™ batch
for eeg_input, label in test_loader:
    if use_cuda:
        eeg_input = eeg_input.cuda()
        label = label.cuda()
        
    # print(label)
    result = model(eeg_input)
    preds = torch.argmax(result, dim=1)
    # print(preds)
    correct = (preds == label).sum().item()
    total = label.size(0)
    acc = correct / total
    print(f"Accuracy: {acc:.4f}")
    # ÂàùÂßãÂåñ GradCAM
    cam = GradCAM(model=model,
                  target_layers=target_layers,
                  reshape_transform=reshape_transform)

    # ÂèØÈÄâÔºöËÆæÂÆöÁ±ªÂà´ÁõÆÊ†áÔºàNone ‰∏∫ÈªòËÆ§ÂàÜÁ±ªÔºâ
    targets = None
    # targets = [ClassifierOutputTarget(1)] * eeg_input.shape[0]  # ÊåáÂÆöÊâÄÊúâÊ†∑Êú¨‰∏∫Á±ªÂà´1

    # ÁîüÊàê CAM
    grayscale_cam = cam(input_tensor=eeg_input, targets=targets)  # shape: (batch, H, W)
    print(grayscale_cam.shape)
    # ËΩ¨‰∏∫ tensor Êñπ‰æøÂ§ÑÁêÜ
    cam_tensor = torch.tensor(grayscale_cam)  # shape: (batch_size, H, W)
    print(cam_tensor.shape)
    # ÁîüÊàêÈ¢ÑÊµãÊ≠£Á°ÆÂíåÈîôËØØÁöÑÊé©Á†Å
    correct_mask = (preds == label)
    wrong_mask = ~correct_mask  # preds != label

    # Ëé∑ÂèñÂØπÂ∫îÁöÑ Grad-CAM ÁâπÂæÅÂõæ
    correct_cam = cam_tensor[correct_mask]  # (N_correct, H, W)
    wrong_cam = cam_tensor[wrong_mask]      # (N_wrong, H, W)

    # Âπ≥ÂùáÂπ∂ÂèØËßÜÂåñÈ¢ÑÊµãÊ≠£Á°ÆÁöÑ Grad-CAM
    if correct_cam.size(0) > 0:
        avg_correct_cam = correct_cam.mean(dim=0)  # (H, W)
        plt.imshow(avg_correct_cam.cpu().numpy(), cmap='jet')
        plt.title("Average Grad-CAM (Correct Predictions)")
        plt.axis("off")
        plt.colorbar()
        plt.show()
    else:
        print("No correct predictions.")

    # Âπ≥ÂùáÂπ∂ÂèØËßÜÂåñÈ¢ÑÊµãÈîôËØØÁöÑ Grad-CAM
    if wrong_cam.size(0) > 0:
        avg_wrong_cam = wrong_cam.mean(dim=0)  # (H, W)
        plt.imshow(avg_wrong_cam.cpu().numpy(), cmap='jet')
        plt.title("Average Grad-CAM (Wrong Predictions)")
        plt.axis("off")
        plt.colorbar()
        plt.show()
    else:
        print("No wrong predictions.")
    
    break

In [None]:
for i in range(cam_tensor.shape[0]):
    sample = cam_tensor[i]

    plt.imshow(sample.cpu().numpy(), cmap='jet')
    plt.title(f"Grad-CAM {i}, {correct_mask[i]}")
    plt.axis("off")
    plt.colorbar()
    plt.show()