In [1]:
import torch
import logging
import torch.nn as nn
from tqdm import tqdm
from pathlib import Path
from torch import autograd
from sklearn import model_selection
from torch.utils.data import Dataset, Subset, DataLoader

from strokes import StrokePatientsMIDataset
from strokesdict import STROKEPATIENTSMI_LOCATION_DICT
import scipy
from torcheeg.transforms import Select,BandSignal,Compose,ToTensor
from to import ToGrid
from typing import Callable, Dict, Union, List
import numpy as np
import soxr
from downsample import SetSamplingRate
from baseline import BaselineCorrection
from torcheeg.transforms import EEGTransform, Select,BandSignal,Compose,ToTensor

dataset = StrokePatientsMIDataset(root_path='../../mi_swin/subdataset',
                                #   io_path='.torcheeg/datasets_1739252099423_RDowH',
                        chunk_size=500,  # 1 second
                        overlap = 0,
                        offline_transform=Compose(
                                [BaselineCorrection(),
                                SetSamplingRate(origin_sampling_rate=500,target_sampling_rate=128),
                                BandSignal(sampling_rate=128,band_dict={'frequency_range':[8,40]})
                                ]),
                        online_transform=Compose(
                                # [ToTensor()]),
                                [ToGrid(STROKEPATIENTSMI_LOCATION_DICT),ToTensor()]),
                
                        label_transform=Select('label'),
                        num_worker=8
)
print(dataset[0][0].shape) #EEG shape(1,30,128)
print(dataset[0][1])  # label (int)
print(len(dataset))

  from .autonotebook import tqdm as notebook_tqdm
[2025-02-11 15:26:28] INFO (torcheeg/MainThread) 🔍 | Processing EEG data. Processed EEG data has been cached to [92m.torcheeg/datasets_1739258788329_UGp5r[0m.
[2025-02-11 15:26:28] INFO (torcheeg/MainThread) ⏳ | Monitoring the detailed processing of a record for debugging. The processing of other records will only be reported in percentage to keep it clean.
[PROCESS]: 100%|██████████| 1/1 [00:00<00:00, 89.59it/s]

[RECORD ../../mi_swin/subdataset/sourcedata/sub-45/sub-45_task-motor-imagery_eeg.mat]: 0it [00:00, ?it/s][A
[RECORD ../../mi_swin/subdataset/sourcedata/sub-45/sub-45_task-motor-imagery_eeg.mat]: 1it [00:00,  4.48it/s][A
[RECORD ../../mi_swin/subdataset/sourcedata/sub-45/sub-45_task-motor-imagery_eeg.mat]: 2it [00:00,  5.89it/s][A
[RECORD ../../mi_swin/subdataset/sourcedata/sub-45/sub-45_task-motor-imagery_eeg.mat]: 12it [00:00, 36.75it/s][A
[RECORD ../../mi_swin/subdataset/sourcedata/sub-45/sub-45_task-motor-imagery_eeg.

torch.Size([128, 9, 9])
0
160


In [2]:
import os
import shutil

def delete_folder_if_exists(target_folder_name):
    # 获取父文件夹中的所有内容
    parent_folder = os.getcwd()
    for folder_name in os.listdir(parent_folder):
        folder_path = os.path.join(parent_folder, folder_name)

        # 检查是否是文件夹并且名称是否匹配
        if os.path.isdir(folder_path) and folder_name == target_folder_name:
            try:
                # 删除目标文件夹
                shutil.rmtree(folder_path)
                print(f"已删除文件夹: {folder_path}")
            except Exception as e:
                print(f"删除文件夹 {folder_path} 时出错: {e}")


In [3]:
def train_test_split(dataset, test_size=0.2, random_state=520, shuffle=True):
    n_samples = len(dataset)
    indices = np.arange(n_samples)
    train_index, test_index = model_selection.train_test_split(
        indices,
        test_size=test_size,
        random_state=random_state,
        shuffle=shuffle)

    trian_dataset = Subset(dataset, train_index)
    test_dataset = Subset(dataset, test_index)

    return trian_dataset, test_dataset

In [4]:
sub_dataset, test_dataset = train_test_split(dataset=dataset)
print(len(sub_dataset), len(test_dataset))

256 64


In [4]:
RECEIVED_PARAMS = {
    "c_lr": 0.00001,
    "g_lr": 0.00001,
    "d_lr": 0.00001,
    "weight_gp": 1.0,
    "weight_decay": 0.0005,
    "weight_ssl": 0.5
}

In [5]:
class Generator(nn.Module):
    def __init__(self, in_channels=4, out_channels=128):
        super(Generator, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels,
                      128,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=True), nn.LeakyReLU())
        self.layer2 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=5, stride=1, padding=2, bias=True),
            nn.LeakyReLU())
        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=5, stride=1, padding=2, bias=True),
            nn.LeakyReLU())
        self.layer4 = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1, bias=True),
            nn.LeakyReLU())
        self.delayer1 = nn.Sequential(
            nn.ConvTranspose2d(16 + 32,
                               32,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=True), nn.LeakyReLU())
        self.delayer2 = nn.Sequential(
            nn.ConvTranspose2d(32 + 64,
                               64,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=True), nn.LeakyReLU())
        self.delayer3 = nn.Sequential(
            nn.ConvTranspose2d(64 + 128,
                               128,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=True))

    def forward(self, x):
        #         x = channel_to_location(x)
        mask = (x.abs().sum(dim=1, keepdim=True) > 0).float()
        out1 = self.layer1(x)
        out2 = self.layer2(out1)
        out3 = self.layer3(out2)
        out = self.layer4(out3)
        out = self.delayer1(torch.cat([out, out3], dim=1))
        out = self.delayer2(torch.cat([out, out2], dim=1))
        out = self.delayer3(torch.cat([out, out1], dim=1))

        return out * mask


class ResidualConv2d(nn.Module):  # 貌似并未使用该函数
    def __init__(self, in_channels, out_channels, bias=True):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,
                      out_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=bias), nn.SELU(),
            nn.Conv2d(out_channels,
                      out_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=bias))
        self.res = nn.Conv2d(in_channels,
                             out_channels,
                             kernel_size=1,
                             stride=1,
                             padding=0,
                             bias=bias)

    def forward(self, x):
        return self.conv(x) + self.res(x)

# 识别情感需要分析不同空间尺度下的EEG信号，故引入了包含三种不同尺寸滤波器的InceptionConv2d来提取多尺度特征图
class InceptionConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, bias=True):
        super().__init__()
        self.conv5x5 = nn.Conv2d(in_channels,
                                 out_channels,
                                 kernel_size=5,
                                 stride=1,
                                 padding=2,
                                 bias=bias)
        self.conv3x3 = nn.Conv2d(in_channels,
                                 out_channels,
                                 kernel_size=3,
                                 stride=1,
                                 padding=1,
                                 bias=bias)
        self.conv1x1 = nn.Conv2d(in_channels,
                                 out_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0,
                                 bias=bias)

    def forward(self, x):
        return self.conv5x5(x) + self.conv3x3(x) + self.conv1x1(x)


class SeparableConv2d(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 stride=1,
                 padding=1,
                 bias=True):
        super().__init__()
        # 不同时间点的二维数据分别进行卷积
        self.depth = nn.Conv2d(in_channels,
                               in_channels,
                               kernel_size=kernel_size,
                               stride=stride,
                               padding=padding,
                               groups=in_channels,
                               bias=bias)
        # 单个eeg通道跨时间点进行卷积
        self.point = nn.Conv2d(in_channels,
                               out_channels,
                               kernel_size=1,
                               stride=stride,
                               padding=0,
                               bias=bias)

    def forward(self, x):
        x = self.depth(x)
        x = self.point(x)
        return x


class Discriminator(nn.Module):
    def __init__(self, num_classes, in_channels=4):
        super(Discriminator, self).__init__()
        self.layer1 = nn.Conv2d(in_channels,
                                256,
                                kernel_size=3,
                                stride=1,
                                padding=1,
                                bias=True)
        self.layer2 = nn.Conv2d(256,
                                128,
                                kernel_size=5,
                                stride=1,
                                padding=2,
                                bias=True)
        self.layer3 = nn.Conv2d(128,
                                64,
                                kernel_size=5,
                                stride=1,
                                padding=2,
                                bias=True)
        self.layer4 = SeparableConv2d(64,
                                      32,
                                      kernel_size=5,
                                      stride=1,
                                      padding=2,
                                      bias=True)
        self.layer5 = InceptionConv2d(32, 16)

        self.drop = nn.Sequential(nn.SELU())
        self.fc1 = nn.Sequential(nn.Linear(9 * 9 * 16, 1024, bias=True),
                                 nn.SELU()) # stroke MI dataset 的网格为 7*5
        self.fc2 = nn.Linear(1024, num_classes, bias=True)

    def forward(self, x):
        out = self.layer1(x)
        out = self.drop(out)
        out = self.layer2(out)
        out = self.drop(out)
        out = self.layer3(out)
        out = self.drop(out)
        out = self.layer4(out)
        out = self.drop(out)
        out = self.layer5(out)
        out = self.drop(out)
        out = out.view(out.size(0), -1) # (batch_size, num_features)
        out = self.fc1(out)
        out = self.fc2(out)
        return out


In [5]:
from model import SwinTransformerGenerator
import torch
# Instantiate the generator model
g_model = SwinTransformerGenerator(in_chans=128,
                                     patch_size=2,
                                     window_size=3,
                                     embed_dim=96,
                                     depths=(2, 2, 4, 2),
                                     num_heads=(2, 2, 4, 6)
                                     )

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [6]:
g_model = Generator(in_channels=128, out_channels=128)
d_model = Discriminator(in_channels=128, num_classes=2)
# 模拟输入
input_tensor = torch.randn(2, 128, 9, 9)  # [batch, 1, 128, 7, 5]
output = g_model(input_tensor)
ans = d_model(output)
print("Output shape:", input_tensor.shape)
print("ans shape:", ans.shape)


Output shape: torch.Size([2, 128, 9, 9])
ans shape: torch.Size([2, 2])


In [6]:
def random_mask(data, min_r=0.0, max_r=0.5):
    # batch_size*channel_num*time_step
    data = data.clone()
    mask = torch.rand(*data.shape[:2], # 随机生成mask值，(batch_size, 128, 1, 1)
                      *([1] * (len(data.shape) - 2)),
                      device=data.device)
    # ratio = np.random.beta(1.0, 1.0, size=(data.shape[0], 1, 1, 1))
    # ratio = torch.tensor(ratio, device=mask.device).clamp(max=0.5)
    ratio = torch.rand(size=(data.shape[0], 1, 1, 1),
                       device=mask.device) * (max_r - min_r) + min_r # 随机生成1个阈值 (batch_size, 1, 1, 1)
    mask = mask < ratio # mask值低于阈值的，被置零
    mask = mask.expand_as(data) # (batch_size, 128, 1, 1) -> (batch_size, 128, 7, 5)
    data[mask] = 0.0
    return data, ratio


def gradient_penalty(model, real, fake):
    device = real.device
    real = real.data
    fake = fake.data
    alpha = torch.rand(real.size(0), *([1] * (len(real.shape) - 1))).to(device)
    inputs = alpha * real + ((1 - alpha) * fake)
    inputs.requires_grad_()
    outputs = model(inputs)

    gradient = autograd.grad(outputs=outputs,
                             inputs=inputs,
                             grad_outputs=torch.ones_like(outputs).to(device),
                             create_graph=True,
                             retain_graph=True,
                             only_inputs=True)[0]

    gradient = gradient.flatten(1)
    return ((gradient.norm(2, dim=1) - 1)**2).mean()


class Trainer():
    def __init__(self, g_model, d_model, trainer_kwargs={'max_epochs': 10}):
        super().__init__()
        self.g_model = g_model.cuda()
        self.d_model = d_model.cuda()

        self._loss_fn_ce = nn.CrossEntropyLoss()
        self._loss_fn_mse = nn.MSELoss()
        self._optimizer_g_model = torch.optim.Adam(
            g_model.parameters(),
            lr=RECEIVED_PARAMS['g_lr'],
            weight_decay=RECEIVED_PARAMS['weight_decay'])
        self._optimizer_d_model = torch.optim.Adam(
            d_model.parameters(),
            lr=RECEIVED_PARAMS['d_lr'],
            weight_decay=RECEIVED_PARAMS['weight_decay'])

        self._trainer_kwargs = trainer_kwargs

        eeg_dataset = dataset
        train_dataset, val_dataset = train_test_split(eeg_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=64,
                                      shuffle=True,
                                      drop_last=False)

        self._train_dataloader = train_dataloader

    def _accuracy(self, input, target):  # pylint: disable=redefined-builtin
        _, predict = torch.max(input.data, 1)
        correct = predict.eq(target.data).cpu().sum().item()
        return correct / input.size(0)

    def training_step_g_model(self, batch, batch_idx, augment_fn=random_mask):
        self._optimizer_g_model.zero_grad()

        for p in self.d_model.parameters():
            p.requires_grad = False

        x, y = batch
        x, y = x.cuda(), y.cuda()

        aug_x, ratio = random_mask(x)
        pred_x = self.g_model(aug_x)
        loss = -self.d_model(pred_x).mean()

        loss.backward()
        self._optimizer_g_model.step()

        return loss

    def training_step_d_model(self, batch, batch_idx, augment_fn=random_mask):
        self._optimizer_d_model.zero_grad()

        for p in self.d_model.parameters():
            p.requires_grad = True

        x, y = batch
        x, y = x.cuda(), y.cuda()

        aug_x, ratio = random_mask(x)
        pred_x = self.g_model(aug_x).detach()

        loss = self.d_model(pred_x).mean() - self.d_model(x).mean()
        loss += RECEIVED_PARAMS['weight_gp'] * gradient_penalty(
            self.d_model, x, pred_x)

        if batch_idx % 5 == 0:
            loss.backward()
            self._optimizer_d_model.step()

        return loss

    def _train(self, epoch_idx=-1):
        """
        单独显示每个 epoch 的训练进度条，并动态更新 G 和 D 的损失。
        """
        pbar = tqdm(total=len(self._train_dataloader), desc=f"[TRAIN] Epoch {epoch_idx}")
        for i, batch in enumerate(self._train_dataloader):
            # 获取 D 模型的损失
            loss_d_model = self.training_step_d_model(batch, i)
            # 获取 G 模型的损失
            loss_g_model = self.training_step_g_model(batch, i)

            # 更新进度条
            pbar.update(1)
            pbar.set_postfix(
                ordered_dict={
                    'loss_g_model': f'{loss_g_model.item():.3f}',
                    'loss_d_model': f'{loss_d_model.item():.3f}'
                }
            )
        pbar.close()
        
    def fit(self) -> None:
        for i in range(self._trainer_kwargs['max_epochs']):
            self._train(i + 1)

    def save(self, param_path):
        torch.save(
            {
                'g_model': self.g_model.state_dict(),
                'd_model': self.d_model.state_dict()
            }, param_path)



In [8]:
class Classifier(nn.Module):
    def __init__(self, num_classes, in_channels=4):
        super(Classifier, self).__init__()
        self.layer1 = nn.Conv2d(in_channels,
                                256,
                                kernel_size=3,
                                stride=1,
                                padding=1,
                                bias=True)
        self.layer2 = nn.Conv2d(256,
                                128,
                                kernel_size=5,
                                stride=1,
                                padding=2,
                                bias=True)
        self.layer3 = nn.Conv2d(128,
                                64,
                                kernel_size=5,
                                stride=1,
                                padding=2,
                                bias=True)
        self.layer4 = SeparableConv2d(64,
                                      32,
                                      kernel_size=5,
                                      stride=1,
                                      padding=2,
                                      bias=True)
        self.layer5 = InceptionConv2d(32, 16)
        self.drop = nn.Sequential(nn.Dropout(), nn.SELU())
        self.fc1 = nn.Sequential(nn.Linear(9 * 9 * 16, 1024, bias=True),
                                 nn.SELU())
        self.fc2 = nn.Linear(1024, num_classes, bias=True)

    def forward(self, x):
        out = self.layer1(x)
        out = self.drop(out)
        out = self.layer2(out)
        out = self.drop(out)
        out = self.layer3(out)
        out = self.drop(out)
        out = self.layer4(out)
        out = self.drop(out)
        out = self.layer5(out)
        out = self.drop(out)
        out = out.view(out.size(0), -1)
        feat = self.fc1(out)
        out = self.fc2(feat)
        return out, feat


c_model = Classifier(num_classes=2, in_channels=128)


In [7]:
from model import SwinTransformer_D
c_model = SwinTransformer_D(in_chans=128,
                            num_classes=2,
                            embed_dim=96,
                            depths=(2, 2, 4, 2),
                            num_heads=(2, 2, 4, 6),
                            visual_mode=True
                            )

In [30]:
from classifier import  ClassifierTrainer

train_dataset, val_dataset = train_test_split(sub_dataset)
trainer = ClassifierTrainer(model=c_model,
                            num_classes=2,
                            lr=RECEIVED_PARAMS['c_lr'],
                            weight_decay=1e-5,
                            metrics=["accuracy"],
                            accelerator="gpu")
train_dataloader = DataLoader(train_dataset,
                              batch_size=16,
                              shuffle=True,
                              drop_last=False)
val_dataloader = DataLoader(val_dataset,
                            batch_size=16,
                            shuffle=False,
                            drop_last=False)

trainer.fit(train_dataloader,
            val_dataloader,
            max_epochs=30,
            enable_model_summary=False,
            limit_val_batches=0.0)
trainer.save('./parameters/' + 'cross_validation_backbone.pth')

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Epoch 0: 100%|██████████| 13/13 [00:00<00:00, 24.67it/s, loss=0.79, train_loss=0.740, train_accuracy=0.333] 

  rank_zero_warn(
  rank_zero_warn(
[2025-02-11 14:33:21] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.790 train_accuracy: 0.500 



Epoch 1: 100%|██████████| 13/13 [00:00<00:00, 24.68it/s, loss=0.673, train_loss=0.619, train_accuracy=0.667]

[2025-02-11 14:33:22] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.646 train_accuracy: 0.598 



Epoch 2: 100%|██████████| 13/13 [00:00<00:00, 25.02it/s, loss=0.607, train_loss=0.599, train_accuracy=0.667]

[2025-02-11 14:33:23] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.585 train_accuracy: 0.735 



Epoch 3: 100%|██████████| 13/13 [00:00<00:00, 24.94it/s, loss=0.536, train_loss=0.559, train_accuracy=0.750]

[2025-02-11 14:33:24] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.521 train_accuracy: 0.779 



Epoch 4: 100%|██████████| 13/13 [00:00<00:00, 25.00it/s, loss=0.47, train_loss=0.525, train_accuracy=0.583] 

[2025-02-11 14:33:25] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.454 train_accuracy: 0.838 



Epoch 5: 100%|██████████| 13/13 [00:00<00:00, 25.36it/s, loss=0.417, train_loss=0.252, train_accuracy=0.917]

[2025-02-11 14:33:26] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.387 train_accuracy: 0.814 



Epoch 6: 100%|██████████| 13/13 [00:00<00:00, 25.51it/s, loss=0.327, train_loss=0.548, train_accuracy=0.667]

[2025-02-11 14:33:27] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.302 train_accuracy: 0.882 



Epoch 7: 100%|██████████| 13/13 [00:00<00:00, 25.63it/s, loss=0.238, train_loss=0.450, train_accuracy=0.833]

[2025-02-11 14:33:28] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.213 train_accuracy: 0.926 



Epoch 8: 100%|██████████| 13/13 [00:00<00:00, 25.69it/s, loss=0.179, train_loss=0.250, train_accuracy=0.833] 

[2025-02-11 14:33:29] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.165 train_accuracy: 0.941 



Epoch 9: 100%|██████████| 13/13 [00:00<00:00, 25.71it/s, loss=0.132, train_loss=0.171, train_accuracy=0.917] 

[2025-02-11 14:33:30] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.107 train_accuracy: 0.975 



Epoch 10: 100%|██████████| 13/13 [00:00<00:00, 25.73it/s, loss=0.0849, train_loss=0.114, train_accuracy=0.917] 

[2025-02-11 14:33:31] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.068 train_accuracy: 0.995 



Epoch 11: 100%|██████████| 13/13 [00:00<00:00, 24.16it/s, loss=0.0518, train_loss=0.0488, train_accuracy=1.000]

[2025-02-11 14:33:32] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.035 train_accuracy: 1.000 



Epoch 12: 100%|██████████| 13/13 [00:00<00:00, 24.53it/s, loss=0.0279, train_loss=0.0094, train_accuracy=1.000]

[2025-02-11 14:33:33] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.025 train_accuracy: 1.000 



Epoch 13: 100%|██████████| 13/13 [00:00<00:00, 24.66it/s, loss=0.0212, train_loss=0.0164, train_accuracy=1.000]

[2025-02-11 14:33:34] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.021 train_accuracy: 0.995 



Epoch 14: 100%|██████████| 13/13 [00:00<00:00, 24.63it/s, loss=0.0205, train_loss=0.0284, train_accuracy=1.000] 

[2025-02-11 14:33:35] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.017 train_accuracy: 1.000 



Epoch 15: 100%|██████████| 13/13 [00:00<00:00, 24.97it/s, loss=0.0121, train_loss=0.00458, train_accuracy=1.000]

[2025-02-11 14:33:36] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.010 train_accuracy: 1.000 



Epoch 16: 100%|██████████| 13/13 [00:00<00:00, 25.15it/s, loss=0.00865, train_loss=0.00407, train_accuracy=1.000]

[2025-02-11 14:33:37] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.007 train_accuracy: 1.000 



Epoch 17: 100%|██████████| 13/13 [00:00<00:00, 25.49it/s, loss=0.00718, train_loss=0.022, train_accuracy=1.000]  

[2025-02-11 14:33:38] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.007 train_accuracy: 1.000 



Epoch 18: 100%|██████████| 13/13 [00:00<00:00, 23.68it/s, loss=0.00802, train_loss=0.00277, train_accuracy=1.000]

[2025-02-11 14:33:39] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.008 train_accuracy: 1.000 



Epoch 19: 100%|██████████| 13/13 [00:00<00:00, 24.70it/s, loss=0.0065, train_loss=0.0025, train_accuracy=1.000]  

[2025-02-11 14:33:39] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.006 train_accuracy: 1.000 



Epoch 20: 100%|██████████| 13/13 [00:00<00:00, 24.63it/s, loss=0.00497, train_loss=0.00675, train_accuracy=1.000] 

[2025-02-11 14:33:40] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.005 train_accuracy: 1.000 



Epoch 21: 100%|██████████| 13/13 [00:00<00:00, 25.34it/s, loss=0.00494, train_loss=0.00439, train_accuracy=1.000]

[2025-02-11 14:33:41] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.005 train_accuracy: 1.000 



Epoch 22: 100%|██████████| 13/13 [00:00<00:00, 25.13it/s, loss=0.00409, train_loss=0.00302, train_accuracy=1.000]

[2025-02-11 14:33:42] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.004 train_accuracy: 1.000 



Epoch 23: 100%|██████████| 13/13 [00:00<00:00, 25.44it/s, loss=0.00437, train_loss=0.00228, train_accuracy=1.000] 

[2025-02-11 14:33:43] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.004 train_accuracy: 1.000 



Epoch 24: 100%|██████████| 13/13 [00:00<00:00, 24.99it/s, loss=0.0046, train_loss=0.00303, train_accuracy=1.000] 

[2025-02-11 14:33:44] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.005 train_accuracy: 1.000 



Epoch 25: 100%|██████████| 13/13 [00:00<00:00, 25.16it/s, loss=0.00483, train_loss=0.000775, train_accuracy=1.000]

[2025-02-11 14:33:45] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.004 train_accuracy: 1.000 



Epoch 26: 100%|██████████| 13/13 [00:00<00:00, 24.72it/s, loss=0.00274, train_loss=0.00158, train_accuracy=1.000] 

[2025-02-11 14:33:46] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.003 train_accuracy: 1.000 



Epoch 27: 100%|██████████| 13/13 [00:00<00:00, 25.42it/s, loss=0.00278, train_loss=0.00537, train_accuracy=1.000] 

[2025-02-11 14:33:47] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.003 train_accuracy: 1.000 



Epoch 28: 100%|██████████| 13/13 [00:00<00:00, 24.74it/s, loss=0.00206, train_loss=0.000387, train_accuracy=1.000]

[2025-02-11 14:33:48] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.001 train_accuracy: 1.000 



Epoch 29: 100%|██████████| 13/13 [00:00<00:00, 24.91it/s, loss=0.0012, train_loss=0.00254, train_accuracy=1.000]  

[2025-02-11 14:33:49] INFO (torcheeg/MainThread) 
[Train] train_loss: 0.001 train_accuracy: 1.000 

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 13/13 [00:00<00:00, 13.42it/s, loss=0.0012, train_loss=0.00254, train_accuracy=1.000]


In [31]:
val_dataloader = DataLoader(val_dataset,
                            batch_size=16,
                            shuffle=False,
                            drop_last=False)

test_result = trainer.test(val_dataloader,
                            enable_progress_bar=True,
                            enable_model_summary=True)[0]
# training_metrics.append(training_result["test_accuracy"])
print(test_result["test_accuracy"])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 4/4 [00:00<00:00, 94.00it/s] 

[2025-02-11 14:33:53] INFO (torcheeg/MainThread) 
[Test] test_loss: 2.402 test_accuracy: 0.577 



Testing DataLoader 0: 100%|██████████| 4/4 [00:00<00:00, 83.63it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.5769230723381042
        test_loss           2.4023923873901367
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
0.5769230723381042


In [8]:
class CTrainer():
    def __init__(self, c_model, trainer_kwargs={'max_epochs': 10}):
        super().__init__()
        self.c_model = c_model.cuda()

        self._loss_fn_ce = nn.CrossEntropyLoss()
        self._optimizer_c_model = torch.optim.Adam(c_model.parameters(),
                                                   lr=RECEIVED_PARAMS['c_lr'],
                                                   weight_decay=0.0005)
        self._trainer_kwargs = trainer_kwargs

        eeg_dataset = dataset
        train_dataset, val_dataset = train_test_split(eeg_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=16,
                                      shuffle=True,
                                      drop_last=False)
        val_dataloader = DataLoader(val_dataset,
                                    batch_size=16,
                                    shuffle=False,
                                    drop_last=False)

        self._train_dataloader = train_dataloader
        self._val_dataloader = val_dataloader

    def _accuracy(self, input, target):  # pylint: disable=redefined-builtin
        _, predict = torch.max(input.data, 1)
        correct = predict.eq(target.data).cpu().sum().item()
        return correct / input.size(0)

    def training_step_c_model(self, batch, batch_idx):
        for p in self.c_model.parameters():
            p.requires_grad = True

        self._optimizer_c_model.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        y_hat, x_feat = self.c_model(x)
        loss = self._loss_fn_ce(y_hat, y)

        loss.backward()
        self._optimizer_c_model.step()

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = self.validation_step_before_model(batch, batch_idx)
        y_hat, x_feat = self.c_model(x)
        return (y_hat.detach().cpu(), y.detach().cpu())

    def validation_step_before_model(self, batch, batch_idx):
        x, y = batch
        x, y = x.cuda(), y.cuda()
        return x, y

    def validation_epoch_end(self, outputs):
        # We might need dict metrics in future?
        y_hat, y = zip(*outputs)
        y_hat = torch.cat(y_hat, dim=0)
        y = torch.cat(y, dim=0)
        avg_acc = self._accuracy(y_hat, y)
        print("acc:", avg_acc)
        # logger.info('[VAL] Average ACC at epoch end is {}'.format(avg_acc))
        # return {'val_acc': avg_acc}

    def _validate(self, epoch_idx=-1):
        validation_outputs = []
        for i, batch in enumerate(self._val_dataloader):
            validation_outputs.append(self.validation_step(batch, i))
        return self.validation_epoch_end(validation_outputs)

    def _train(self, epoch_idx=-1):
        """
        单独显示每个 epoch 的训练进度条。
        """
        pbar = tqdm(total=len(self._train_dataloader), desc=f"[TRAIN] Epoch {epoch_idx}")
        for i, batch in enumerate(self._train_dataloader):
            loss_c_model = self.training_step_c_model(batch, i)
            pbar.update(1)
            # 更新进度条的后缀信息
            pbar.set_postfix(ordered_dict={'loss_c_model': f'{loss_c_model.item():.3f}'})
        pbar.close()

    def fit(self) -> None:
        """
        按照每个 epoch 单独创建训练和验证进度条。
        """
        for epoch_idx in range(self._trainer_kwargs['max_epochs']):
            self._train(epoch_idx + 1)
            self._validate(epoch_idx + 1)

        # logger.info('[VAL] Final ACC at experiment end is {}'.format(
        #     self._validate()['val_acc']))

    def save(self, param_path):
        torch.save({
            'c_model': self.c_model.state_dict(),
        }, param_path)


trainer = CTrainer(c_model, trainer_kwargs={'max_epochs': 80})
trainer.fit()
trainer.save('./parameters/' + 'cross_validation_backbone.pth')

[TRAIN] Epoch 1: 100%|██████████| 8/8 [00:00<00:00,  8.35it/s, loss_c_model=0.684]


acc: 0.625


[TRAIN] Epoch 2: 100%|██████████| 8/8 [00:00<00:00, 27.80it/s, loss_c_model=0.523]


acc: 0.53125


[TRAIN] Epoch 3: 100%|██████████| 8/8 [00:00<00:00, 27.93it/s, loss_c_model=0.612]


acc: 0.84375


[TRAIN] Epoch 4: 100%|██████████| 8/8 [00:00<00:00, 27.85it/s, loss_c_model=0.365]


acc: 0.84375


[TRAIN] Epoch 5: 100%|██████████| 8/8 [00:00<00:00, 28.14it/s, loss_c_model=0.331]


acc: 0.84375


[TRAIN] Epoch 6: 100%|██████████| 8/8 [00:00<00:00, 28.62it/s, loss_c_model=0.327]


acc: 0.84375


[TRAIN] Epoch 7: 100%|██████████| 8/8 [00:00<00:00, 28.81it/s, loss_c_model=0.244]


acc: 0.8125


[TRAIN] Epoch 8: 100%|██████████| 8/8 [00:00<00:00, 28.16it/s, loss_c_model=0.089]


acc: 0.84375


[TRAIN] Epoch 9: 100%|██████████| 8/8 [00:00<00:00, 28.83it/s, loss_c_model=0.186]


acc: 0.78125


[TRAIN] Epoch 10: 100%|██████████| 8/8 [00:00<00:00, 28.22it/s, loss_c_model=0.064]


acc: 0.8125


[TRAIN] Epoch 11: 100%|██████████| 8/8 [00:00<00:00, 28.61it/s, loss_c_model=0.038]


acc: 0.78125


[TRAIN] Epoch 12: 100%|██████████| 8/8 [00:00<00:00, 28.01it/s, loss_c_model=0.015]


acc: 0.8125


[TRAIN] Epoch 13: 100%|██████████| 8/8 [00:00<00:00, 28.27it/s, loss_c_model=0.004]


acc: 0.8125


[TRAIN] Epoch 14: 100%|██████████| 8/8 [00:00<00:00, 27.33it/s, loss_c_model=0.008]


acc: 0.8125


[TRAIN] Epoch 15: 100%|██████████| 8/8 [00:00<00:00, 28.44it/s, loss_c_model=0.005]


acc: 0.8125


[TRAIN] Epoch 16: 100%|██████████| 8/8 [00:00<00:00, 28.77it/s, loss_c_model=0.007]


acc: 0.8125


[TRAIN] Epoch 17: 100%|██████████| 8/8 [00:00<00:00, 28.37it/s, loss_c_model=0.009]


acc: 0.8125


[TRAIN] Epoch 18: 100%|██████████| 8/8 [00:00<00:00, 28.36it/s, loss_c_model=0.004]


acc: 0.8125


[TRAIN] Epoch 19: 100%|██████████| 8/8 [00:00<00:00, 27.39it/s, loss_c_model=0.007]


acc: 0.8125


[TRAIN] Epoch 20: 100%|██████████| 8/8 [00:00<00:00, 28.37it/s, loss_c_model=0.002]


acc: 0.8125


[TRAIN] Epoch 21: 100%|██████████| 8/8 [00:00<00:00, 28.51it/s, loss_c_model=0.002]


acc: 0.8125


[TRAIN] Epoch 22: 100%|██████████| 8/8 [00:00<00:00, 28.51it/s, loss_c_model=0.010]


acc: 0.8125


[TRAIN] Epoch 23: 100%|██████████| 8/8 [00:00<00:00, 28.77it/s, loss_c_model=0.003]


acc: 0.8125


[TRAIN] Epoch 24: 100%|██████████| 8/8 [00:00<00:00, 28.56it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 25: 100%|██████████| 8/8 [00:00<00:00, 28.82it/s, loss_c_model=0.003]


acc: 0.8125


[TRAIN] Epoch 26: 100%|██████████| 8/8 [00:00<00:00, 28.79it/s, loss_c_model=0.002]


acc: 0.8125


[TRAIN] Epoch 27: 100%|██████████| 8/8 [00:00<00:00, 28.70it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 28: 100%|██████████| 8/8 [00:00<00:00, 28.77it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 29: 100%|██████████| 8/8 [00:00<00:00, 28.93it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 30: 100%|██████████| 8/8 [00:00<00:00, 28.72it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 31: 100%|██████████| 8/8 [00:00<00:00, 28.96it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 32: 100%|██████████| 8/8 [00:00<00:00, 29.09it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 33: 100%|██████████| 8/8 [00:00<00:00, 28.89it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 34: 100%|██████████| 8/8 [00:00<00:00, 28.63it/s, loss_c_model=0.002]


acc: 0.8125


[TRAIN] Epoch 35: 100%|██████████| 8/8 [00:00<00:00, 28.97it/s, loss_c_model=0.003]


acc: 0.8125


[TRAIN] Epoch 36: 100%|██████████| 8/8 [00:00<00:00, 28.98it/s, loss_c_model=0.003]


acc: 0.8125


[TRAIN] Epoch 37: 100%|██████████| 8/8 [00:00<00:00, 28.55it/s, loss_c_model=0.003]


acc: 0.8125


[TRAIN] Epoch 38: 100%|██████████| 8/8 [00:00<00:00, 28.98it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 39: 100%|██████████| 8/8 [00:00<00:00, 29.00it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 40: 100%|██████████| 8/8 [00:00<00:00, 27.79it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 41: 100%|██████████| 8/8 [00:00<00:00, 28.98it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 42: 100%|██████████| 8/8 [00:00<00:00, 29.02it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 43: 100%|██████████| 8/8 [00:00<00:00, 29.01it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 44: 100%|██████████| 8/8 [00:00<00:00, 28.88it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 45: 100%|██████████| 8/8 [00:00<00:00, 29.00it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 46: 100%|██████████| 8/8 [00:00<00:00, 28.72it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 47: 100%|██████████| 8/8 [00:00<00:00, 28.93it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 48: 100%|██████████| 8/8 [00:00<00:00, 29.03it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 49: 100%|██████████| 8/8 [00:00<00:00, 29.05it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 50: 100%|██████████| 8/8 [00:00<00:00, 29.07it/s, loss_c_model=0.002]


acc: 0.8125


[TRAIN] Epoch 51: 100%|██████████| 8/8 [00:00<00:00, 28.78it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 52: 100%|██████████| 8/8 [00:00<00:00, 28.90it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 53: 100%|██████████| 8/8 [00:00<00:00, 28.10it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 54: 100%|██████████| 8/8 [00:00<00:00, 28.96it/s, loss_c_model=0.003]


acc: 0.8125


[TRAIN] Epoch 55: 100%|██████████| 8/8 [00:00<00:00, 29.08it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 56: 100%|██████████| 8/8 [00:00<00:00, 28.95it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 57: 100%|██████████| 8/8 [00:00<00:00, 29.08it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 58: 100%|██████████| 8/8 [00:00<00:00, 28.35it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 59: 100%|██████████| 8/8 [00:00<00:00, 28.94it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 60: 100%|██████████| 8/8 [00:00<00:00, 28.73it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 61: 100%|██████████| 8/8 [00:00<00:00, 28.06it/s, loss_c_model=0.002]


acc: 0.8125


[TRAIN] Epoch 62: 100%|██████████| 8/8 [00:00<00:00, 29.01it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 63: 100%|██████████| 8/8 [00:00<00:00, 29.23it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 64: 100%|██████████| 8/8 [00:00<00:00, 28.79it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 65: 100%|██████████| 8/8 [00:00<00:00, 28.47it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 66: 100%|██████████| 8/8 [00:00<00:00, 28.79it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 67: 100%|██████████| 8/8 [00:00<00:00, 28.82it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 68: 100%|██████████| 8/8 [00:00<00:00, 27.76it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 69: 100%|██████████| 8/8 [00:00<00:00, 27.64it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 70: 100%|██████████| 8/8 [00:00<00:00, 28.34it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 71: 100%|██████████| 8/8 [00:00<00:00, 27.87it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 72: 100%|██████████| 8/8 [00:00<00:00, 28.01it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 73: 100%|██████████| 8/8 [00:00<00:00, 28.62it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 74: 100%|██████████| 8/8 [00:00<00:00, 28.77it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 75: 100%|██████████| 8/8 [00:00<00:00, 27.78it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 76: 100%|██████████| 8/8 [00:00<00:00, 28.77it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 77: 100%|██████████| 8/8 [00:00<00:00, 28.93it/s, loss_c_model=0.001]


acc: 0.8125


[TRAIN] Epoch 78: 100%|██████████| 8/8 [00:00<00:00, 28.53it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 79: 100%|██████████| 8/8 [00:00<00:00, 28.52it/s, loss_c_model=0.000]


acc: 0.8125


[TRAIN] Epoch 80: 100%|██████████| 8/8 [00:00<00:00, 28.95it/s, loss_c_model=0.000]


acc: 0.8125


In [9]:
from model import SwinTransformer, SwinTransformer_D

c_model = SwinTransformer_D(in_chans=128,
                            num_classes=2,
                            embed_dim=96,
                            depths=(2, 2, 4, 2),
                            num_heads=(2, 2, 4, 6),
                            visual_mode=True
                            )
d_model = SwinTransformer(in_chans=128,
                          num_classes=2,
                          embed_dim=96,
                          depths=(2, 2, 4, 2),
                          num_heads=(2, 2, 4, 6),
                          )

In [11]:
# 模拟输入
import torch
input_tensor = torch.randn(2, 128, 9, 9)  # [batch, 1, 128, 7, 5]
output = d_model(input_tensor)

print("Input shape:", input_tensor.shape)
print(output)

Input shape: torch.Size([2, 128, 9, 9])
tensor([[ 0.0711, -0.0519],
        [-0.0337, -0.1579]], grad_fn=<AddmmBackward0>)


In [15]:
import os
import torch.nn.functional as F

class GCTrainer():
    def __init__(self, c_model, g_model, trainer_kwargs={'max_epochs': 10}):
        super().__init__()
        self.c_model = c_model.cuda()
        self.g_model = g_model.cuda()

        self._loss_fn_ce = nn.CrossEntropyLoss()
        self._loss_fn_mse = nn.MSELoss()
        self._optimizer_c_model = torch.optim.Adam(c_model.parameters(),
                                                   lr=RECEIVED_PARAMS['c_lr'],
                                                   weight_decay=0.0005)

        self._trainer_kwargs = trainer_kwargs

        eeg_dataset = dataset
        train_dataset, val_dataset = train_test_split(eeg_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=64,
                                      shuffle=True,
                                      drop_last=False)
        val_dataloader = DataLoader(val_dataset,
                                    batch_size=64,
                                    shuffle=False,
                                    drop_last=False)

        self._train_dataloader = train_dataloader
        self._val_dataloader = val_dataloader

    def _accuracy(self, input, target):  # pylint: disable=redefined-builtin
        _, predict = torch.max(input.data, 1)
        correct = predict.eq(target.data).cpu().sum().item()
        return correct / input.size(0)

    def training_step_c_model(self, batch, batch_idx):
        for p in self.c_model.parameters():
            p.requires_grad = True

        self._optimizer_c_model.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        y_hat, x_feat = self.c_model(x)
        loss = self._loss_fn_ce(y_hat, y)

        aug_x, ratio = random_mask(x)
        aug_x = self.g_model(aug_x).detach()
        aug_y_hat, aug_x_feat = self.c_model(aug_x)

        loss += RECEIVED_PARAMS['weight_ssl'] * (
            (1 - ratio).squeeze() * F.mse_loss(
                x_feat, aug_x_feat, reduction='none').mean(dim=-1)).mean()

        loss.backward()
        self._optimizer_c_model.step()

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = self.validation_step_before_model(batch, batch_idx)
        y_hat, x_feat = self.c_model(x)
        return (y_hat.detach().cpu(), y.detach().cpu())

    def validation_step_before_model(self, batch, batch_idx):
        x, y = batch
        x, y = x.cuda(), y.cuda()
        return x, y

    def validation_epoch_end(self, outputs):
        # We might need dict metrics in future?
        y_hat, y = zip(*outputs)
        y_hat = torch.cat(y_hat, dim=0)
        y = torch.cat(y, dim=0)
        avg_acc = self._accuracy(y_hat, y)
        # logger.info('[VAL] Average ACC at epoch end is {}'.format(avg_acc))
        return {'val_acc': avg_acc}

    def _validate(self, epoch_idx=-1):
        validation_outputs = []
        for i, batch in enumerate(self._val_dataloader):
            validation_outputs.append(self.validation_step(batch, i))
        return self.validation_epoch_end(validation_outputs)

    def _train(self, epoch_idx=-1):
        """
        单独显示每个 epoch 的训练进度条。
        """
        pbar = tqdm(total=len(self._train_dataloader), desc=f"[TRAIN] Epoch {epoch_idx}")
        for i, batch in enumerate(self._train_dataloader):
            loss_c_model = self.training_step_c_model(batch, i)
            pbar.update(1)
            # 更新进度条的后缀信息
            pbar.set_postfix(ordered_dict={'loss_c_model': f'{loss_c_model.item():.3f}'})
        pbar.close()

    def fit(self) -> None:
        """
        按照每个 epoch 单独创建训练和验证进度条。
        """
        for epoch_idx in range(self._trainer_kwargs['max_epochs']):
            self._train(epoch_idx + 1)
                # 验证过程并获取验证结果
            val_metrics = self._validate(epoch_idx + 1)
            val_acc = val_metrics['val_acc']
            
            # 打印验证准确率
            print(f"[EPOCH {epoch_idx + 1}] Validation Accuracy: {val_acc:.3f}")

    def save(self, param_path):
        torch.save({
            'c_model': trainer.c_model.state_dict(),
        }, param_path)

    def load(self):
        gan_model_state_dict = torch.load(
            './parameters/cross_validation_proposed_pretrain.pth')
        self.g_model.load_state_dict(gan_model_state_dict['g_model'])

        if os.path.exists('./parameters/cross_validation_backbone' + '.pth'):
            c_model_state_dict = torch.load(
                './parameters/cross_validation_backbone'  +
                '.pth')
            self.c_model.load_state_dict(c_model_state_dict['c_model'])


trainer = GCTrainer(c_model,
                  g_model,
                  trainer_kwargs={'max_epochs': 100})
trainer.load()
trainer.fit()
trainer.save('./parameters/' +  'cross_validation_finetune.pth')

[TRAIN] Epoch 1: 100%|██████████| 2/2 [00:00<00:00, 12.38it/s, loss_c_model=0.647]


[EPOCH 1] Validation Accuracy: 0.812


[TRAIN] Epoch 2: 100%|██████████| 2/2 [00:00<00:00, 12.76it/s, loss_c_model=0.508]


[EPOCH 2] Validation Accuracy: 0.875


[TRAIN] Epoch 3: 100%|██████████| 2/2 [00:00<00:00, 12.59it/s, loss_c_model=0.396]


[EPOCH 3] Validation Accuracy: 0.875


[TRAIN] Epoch 4: 100%|██████████| 2/2 [00:00<00:00, 12.79it/s, loss_c_model=0.324]


[EPOCH 4] Validation Accuracy: 0.875


[TRAIN] Epoch 5: 100%|██████████| 2/2 [00:00<00:00, 13.01it/s, loss_c_model=0.254]


[EPOCH 5] Validation Accuracy: 0.844


[TRAIN] Epoch 6: 100%|██████████| 2/2 [00:00<00:00, 13.04it/s, loss_c_model=0.190]


[EPOCH 6] Validation Accuracy: 0.812


[TRAIN] Epoch 7: 100%|██████████| 2/2 [00:00<00:00, 13.06it/s, loss_c_model=0.159]


[EPOCH 7] Validation Accuracy: 0.781


[TRAIN] Epoch 8: 100%|██████████| 2/2 [00:00<00:00, 13.09it/s, loss_c_model=0.140]


[EPOCH 8] Validation Accuracy: 0.812


[TRAIN] Epoch 9: 100%|██████████| 2/2 [00:00<00:00, 12.86it/s, loss_c_model=0.132]


[EPOCH 9] Validation Accuracy: 0.812


[TRAIN] Epoch 10: 100%|██████████| 2/2 [00:00<00:00, 12.81it/s, loss_c_model=0.141]


[EPOCH 10] Validation Accuracy: 0.812


[TRAIN] Epoch 11: 100%|██████████| 2/2 [00:00<00:00, 12.35it/s, loss_c_model=0.134]


[EPOCH 11] Validation Accuracy: 0.812


[TRAIN] Epoch 12: 100%|██████████| 2/2 [00:00<00:00, 12.82it/s, loss_c_model=0.112]


[EPOCH 12] Validation Accuracy: 0.781


[TRAIN] Epoch 13: 100%|██████████| 2/2 [00:00<00:00, 12.97it/s, loss_c_model=0.111]


[EPOCH 13] Validation Accuracy: 0.781


[TRAIN] Epoch 14: 100%|██████████| 2/2 [00:00<00:00, 13.06it/s, loss_c_model=0.102]


[EPOCH 14] Validation Accuracy: 0.812


[TRAIN] Epoch 15: 100%|██████████| 2/2 [00:00<00:00, 13.06it/s, loss_c_model=0.105]


[EPOCH 15] Validation Accuracy: 0.812


[TRAIN] Epoch 16: 100%|██████████| 2/2 [00:00<00:00, 12.56it/s, loss_c_model=0.087]


[EPOCH 16] Validation Accuracy: 0.812


[TRAIN] Epoch 17: 100%|██████████| 2/2 [00:00<00:00, 12.86it/s, loss_c_model=0.088]


[EPOCH 17] Validation Accuracy: 0.812


[TRAIN] Epoch 18: 100%|██████████| 2/2 [00:00<00:00, 13.01it/s, loss_c_model=0.079]


[EPOCH 18] Validation Accuracy: 0.781


[TRAIN] Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 13.04it/s, loss_c_model=0.073]


[EPOCH 19] Validation Accuracy: 0.812


[TRAIN] Epoch 20: 100%|██████████| 2/2 [00:00<00:00, 13.00it/s, loss_c_model=0.078]


[EPOCH 20] Validation Accuracy: 0.812


[TRAIN] Epoch 21: 100%|██████████| 2/2 [00:00<00:00, 13.00it/s, loss_c_model=0.070]


[EPOCH 21] Validation Accuracy: 0.812


[TRAIN] Epoch 22: 100%|██████████| 2/2 [00:00<00:00, 12.88it/s, loss_c_model=0.084]


[EPOCH 22] Validation Accuracy: 0.812


[TRAIN] Epoch 23: 100%|██████████| 2/2 [00:00<00:00, 12.50it/s, loss_c_model=0.066]


[EPOCH 23] Validation Accuracy: 0.812


[TRAIN] Epoch 24: 100%|██████████| 2/2 [00:00<00:00, 12.60it/s, loss_c_model=0.065]


[EPOCH 24] Validation Accuracy: 0.812


[TRAIN] Epoch 25: 100%|██████████| 2/2 [00:00<00:00, 12.17it/s, loss_c_model=0.069]


[EPOCH 25] Validation Accuracy: 0.812


[TRAIN] Epoch 26: 100%|██████████| 2/2 [00:00<00:00, 12.48it/s, loss_c_model=0.063]


[EPOCH 26] Validation Accuracy: 0.812


[TRAIN] Epoch 27: 100%|██████████| 2/2 [00:00<00:00, 13.00it/s, loss_c_model=0.062]


[EPOCH 27] Validation Accuracy: 0.812


[TRAIN] Epoch 28: 100%|██████████| 2/2 [00:00<00:00, 12.25it/s, loss_c_model=0.067]


[EPOCH 28] Validation Accuracy: 0.781


[TRAIN] Epoch 29: 100%|██████████| 2/2 [00:00<00:00, 12.75it/s, loss_c_model=0.060]


[EPOCH 29] Validation Accuracy: 0.812


[TRAIN] Epoch 30: 100%|██████████| 2/2 [00:00<00:00, 13.03it/s, loss_c_model=0.063]


[EPOCH 30] Validation Accuracy: 0.812


[TRAIN] Epoch 31: 100%|██████████| 2/2 [00:00<00:00, 13.01it/s, loss_c_model=0.061]


[EPOCH 31] Validation Accuracy: 0.812


[TRAIN] Epoch 32: 100%|██████████| 2/2 [00:00<00:00, 13.04it/s, loss_c_model=0.056]


[EPOCH 32] Validation Accuracy: 0.812


[TRAIN] Epoch 33: 100%|██████████| 2/2 [00:00<00:00, 12.98it/s, loss_c_model=0.059]


[EPOCH 33] Validation Accuracy: 0.812


[TRAIN] Epoch 34: 100%|██████████| 2/2 [00:00<00:00, 12.29it/s, loss_c_model=0.055]


[EPOCH 34] Validation Accuracy: 0.812


[TRAIN] Epoch 35: 100%|██████████| 2/2 [00:00<00:00, 12.41it/s, loss_c_model=0.057]


[EPOCH 35] Validation Accuracy: 0.812


[TRAIN] Epoch 36: 100%|██████████| 2/2 [00:00<00:00, 12.47it/s, loss_c_model=0.055]


[EPOCH 36] Validation Accuracy: 0.781


[TRAIN] Epoch 37: 100%|██████████| 2/2 [00:00<00:00, 12.96it/s, loss_c_model=0.055]


[EPOCH 37] Validation Accuracy: 0.812


[TRAIN] Epoch 38: 100%|██████████| 2/2 [00:00<00:00, 12.94it/s, loss_c_model=0.056]


[EPOCH 38] Validation Accuracy: 0.812


[TRAIN] Epoch 39: 100%|██████████| 2/2 [00:00<00:00, 12.65it/s, loss_c_model=0.050]


[EPOCH 39] Validation Accuracy: 0.812


[TRAIN] Epoch 40: 100%|██████████| 2/2 [00:00<00:00, 12.57it/s, loss_c_model=0.051]


[EPOCH 40] Validation Accuracy: 0.812


[TRAIN] Epoch 41: 100%|██████████| 2/2 [00:00<00:00, 12.86it/s, loss_c_model=0.049]


[EPOCH 41] Validation Accuracy: 0.812


[TRAIN] Epoch 42: 100%|██████████| 2/2 [00:00<00:00, 12.89it/s, loss_c_model=0.049]


[EPOCH 42] Validation Accuracy: 0.812


[TRAIN] Epoch 43: 100%|██████████| 2/2 [00:00<00:00, 12.89it/s, loss_c_model=0.048]


[EPOCH 43] Validation Accuracy: 0.812


[TRAIN] Epoch 44: 100%|██████████| 2/2 [00:00<00:00, 12.96it/s, loss_c_model=0.046]


[EPOCH 44] Validation Accuracy: 0.812


[TRAIN] Epoch 45: 100%|██████████| 2/2 [00:00<00:00, 12.35it/s, loss_c_model=0.048]


[EPOCH 45] Validation Accuracy: 0.812


[TRAIN] Epoch 46: 100%|██████████| 2/2 [00:00<00:00, 12.81it/s, loss_c_model=0.047]


[EPOCH 46] Validation Accuracy: 0.812


[TRAIN] Epoch 47: 100%|██████████| 2/2 [00:00<00:00, 12.87it/s, loss_c_model=0.049]


[EPOCH 47] Validation Accuracy: 0.781


[TRAIN] Epoch 48: 100%|██████████| 2/2 [00:00<00:00, 12.90it/s, loss_c_model=0.043]


[EPOCH 48] Validation Accuracy: 0.781


[TRAIN] Epoch 49: 100%|██████████| 2/2 [00:00<00:00, 12.57it/s, loss_c_model=0.048]


[EPOCH 49] Validation Accuracy: 0.844


[TRAIN] Epoch 50: 100%|██████████| 2/2 [00:00<00:00, 12.27it/s, loss_c_model=0.047]


[EPOCH 50] Validation Accuracy: 0.812


[TRAIN] Epoch 51: 100%|██████████| 2/2 [00:00<00:00, 12.33it/s, loss_c_model=0.044]


[EPOCH 51] Validation Accuracy: 0.812


[TRAIN] Epoch 52: 100%|██████████| 2/2 [00:00<00:00, 11.32it/s, loss_c_model=0.042]


[EPOCH 52] Validation Accuracy: 0.781


[TRAIN] Epoch 53: 100%|██████████| 2/2 [00:00<00:00, 12.35it/s, loss_c_model=0.041]


[EPOCH 53] Validation Accuracy: 0.750


[TRAIN] Epoch 54: 100%|██████████| 2/2 [00:00<00:00, 12.88it/s, loss_c_model=0.042]


[EPOCH 54] Validation Accuracy: 0.781


[TRAIN] Epoch 55: 100%|██████████| 2/2 [00:00<00:00, 12.93it/s, loss_c_model=0.043]


[EPOCH 55] Validation Accuracy: 0.812


[TRAIN] Epoch 56: 100%|██████████| 2/2 [00:00<00:00, 12.74it/s, loss_c_model=0.042]


[EPOCH 56] Validation Accuracy: 0.812


[TRAIN] Epoch 57: 100%|██████████| 2/2 [00:00<00:00, 12.81it/s, loss_c_model=0.041]


[EPOCH 57] Validation Accuracy: 0.812


[TRAIN] Epoch 58: 100%|██████████| 2/2 [00:00<00:00, 12.91it/s, loss_c_model=0.043]


[EPOCH 58] Validation Accuracy: 0.781


[TRAIN] Epoch 59: 100%|██████████| 2/2 [00:00<00:00, 12.89it/s, loss_c_model=0.043]


[EPOCH 59] Validation Accuracy: 0.781


[TRAIN] Epoch 60: 100%|██████████| 2/2 [00:00<00:00, 13.03it/s, loss_c_model=0.040]


[EPOCH 60] Validation Accuracy: 0.750


[TRAIN] Epoch 61: 100%|██████████| 2/2 [00:00<00:00, 13.06it/s, loss_c_model=0.042]


[EPOCH 61] Validation Accuracy: 0.781


[TRAIN] Epoch 62: 100%|██████████| 2/2 [00:00<00:00, 13.04it/s, loss_c_model=0.041]


[EPOCH 62] Validation Accuracy: 0.812


[TRAIN] Epoch 63: 100%|██████████| 2/2 [00:00<00:00, 12.91it/s, loss_c_model=0.043]


[EPOCH 63] Validation Accuracy: 0.781


[TRAIN] Epoch 64: 100%|██████████| 2/2 [00:00<00:00, 12.64it/s, loss_c_model=0.038]


[EPOCH 64] Validation Accuracy: 0.812


[TRAIN] Epoch 65: 100%|██████████| 2/2 [00:00<00:00, 12.99it/s, loss_c_model=0.038]


[EPOCH 65] Validation Accuracy: 0.781


[TRAIN] Epoch 66: 100%|██████████| 2/2 [00:00<00:00, 12.97it/s, loss_c_model=0.039]


[EPOCH 66] Validation Accuracy: 0.750


[TRAIN] Epoch 67: 100%|██████████| 2/2 [00:00<00:00, 12.55it/s, loss_c_model=0.041]


[EPOCH 67] Validation Accuracy: 0.781


[TRAIN] Epoch 68: 100%|██████████| 2/2 [00:00<00:00, 12.84it/s, loss_c_model=0.041]


[EPOCH 68] Validation Accuracy: 0.750


[TRAIN] Epoch 69: 100%|██████████| 2/2 [00:00<00:00, 13.01it/s, loss_c_model=0.038]


[EPOCH 69] Validation Accuracy: 0.781


[TRAIN] Epoch 70: 100%|██████████| 2/2 [00:00<00:00, 12.93it/s, loss_c_model=0.039]


[EPOCH 70] Validation Accuracy: 0.812


[TRAIN] Epoch 71: 100%|██████████| 2/2 [00:00<00:00, 13.00it/s, loss_c_model=0.038]


[EPOCH 71] Validation Accuracy: 0.750


[TRAIN] Epoch 72: 100%|██████████| 2/2 [00:00<00:00, 13.04it/s, loss_c_model=0.038]


[EPOCH 72] Validation Accuracy: 0.750


[TRAIN] Epoch 73: 100%|██████████| 2/2 [00:00<00:00, 12.87it/s, loss_c_model=0.038]


[EPOCH 73] Validation Accuracy: 0.781


[TRAIN] Epoch 74: 100%|██████████| 2/2 [00:00<00:00, 12.95it/s, loss_c_model=0.036]


[EPOCH 74] Validation Accuracy: 0.781


[TRAIN] Epoch 75: 100%|██████████| 2/2 [00:00<00:00, 13.02it/s, loss_c_model=0.035]


[EPOCH 75] Validation Accuracy: 0.750


[TRAIN] Epoch 76: 100%|██████████| 2/2 [00:00<00:00, 12.61it/s, loss_c_model=0.036]


[EPOCH 76] Validation Accuracy: 0.750


[TRAIN] Epoch 77: 100%|██████████| 2/2 [00:00<00:00, 12.96it/s, loss_c_model=0.035]


[EPOCH 77] Validation Accuracy: 0.719


[TRAIN] Epoch 78: 100%|██████████| 2/2 [00:00<00:00, 12.89it/s, loss_c_model=0.036]


[EPOCH 78] Validation Accuracy: 0.719


[TRAIN] Epoch 79: 100%|██████████| 2/2 [00:00<00:00, 12.68it/s, loss_c_model=0.035]


[EPOCH 79] Validation Accuracy: 0.750


[TRAIN] Epoch 80: 100%|██████████| 2/2 [00:00<00:00, 12.97it/s, loss_c_model=0.037]


[EPOCH 80] Validation Accuracy: 0.750


[TRAIN] Epoch 81: 100%|██████████| 2/2 [00:00<00:00, 12.91it/s, loss_c_model=0.036]


[EPOCH 81] Validation Accuracy: 0.781


[TRAIN] Epoch 82: 100%|██████████| 2/2 [00:00<00:00, 12.93it/s, loss_c_model=0.035]


[EPOCH 82] Validation Accuracy: 0.750


[TRAIN] Epoch 83: 100%|██████████| 2/2 [00:00<00:00, 12.77it/s, loss_c_model=0.035]


[EPOCH 83] Validation Accuracy: 0.719


[TRAIN] Epoch 84: 100%|██████████| 2/2 [00:00<00:00, 12.79it/s, loss_c_model=0.034]


[EPOCH 84] Validation Accuracy: 0.719


[TRAIN] Epoch 85: 100%|██████████| 2/2 [00:00<00:00, 12.97it/s, loss_c_model=0.035]


[EPOCH 85] Validation Accuracy: 0.719


[TRAIN] Epoch 86: 100%|██████████| 2/2 [00:00<00:00, 13.00it/s, loss_c_model=0.032]


[EPOCH 86] Validation Accuracy: 0.719


[TRAIN] Epoch 87: 100%|██████████| 2/2 [00:00<00:00, 13.01it/s, loss_c_model=0.035]


[EPOCH 87] Validation Accuracy: 0.750


[TRAIN] Epoch 88: 100%|██████████| 2/2 [00:00<00:00, 12.66it/s, loss_c_model=0.035]


[EPOCH 88] Validation Accuracy: 0.750


[TRAIN] Epoch 89: 100%|██████████| 2/2 [00:00<00:00, 13.04it/s, loss_c_model=0.033]


[EPOCH 89] Validation Accuracy: 0.719


[TRAIN] Epoch 90: 100%|██████████| 2/2 [00:00<00:00, 12.81it/s, loss_c_model=0.032]


[EPOCH 90] Validation Accuracy: 0.750


[TRAIN] Epoch 91: 100%|██████████| 2/2 [00:00<00:00, 12.93it/s, loss_c_model=0.038]


[EPOCH 91] Validation Accuracy: 0.719


[TRAIN] Epoch 92: 100%|██████████| 2/2 [00:00<00:00, 13.04it/s, loss_c_model=0.037]


[EPOCH 92] Validation Accuracy: 0.750


[TRAIN] Epoch 93: 100%|██████████| 2/2 [00:00<00:00, 13.06it/s, loss_c_model=0.033]


[EPOCH 93] Validation Accuracy: 0.750


[TRAIN] Epoch 94: 100%|██████████| 2/2 [00:00<00:00, 13.03it/s, loss_c_model=0.034]


[EPOCH 94] Validation Accuracy: 0.812


[TRAIN] Epoch 95: 100%|██████████| 2/2 [00:00<00:00, 13.05it/s, loss_c_model=0.035]


[EPOCH 95] Validation Accuracy: 0.719


[TRAIN] Epoch 96: 100%|██████████| 2/2 [00:00<00:00, 12.92it/s, loss_c_model=0.033]


[EPOCH 96] Validation Accuracy: 0.719


[TRAIN] Epoch 97: 100%|██████████| 2/2 [00:00<00:00, 13.05it/s, loss_c_model=0.032]


[EPOCH 97] Validation Accuracy: 0.719


[TRAIN] Epoch 98: 100%|██████████| 2/2 [00:00<00:00, 13.02it/s, loss_c_model=0.034]


[EPOCH 98] Validation Accuracy: 0.750


[TRAIN] Epoch 99: 100%|██████████| 2/2 [00:00<00:00, 13.07it/s, loss_c_model=0.033]


[EPOCH 99] Validation Accuracy: 0.719


[TRAIN] Epoch 100: 100%|██████████| 2/2 [00:00<00:00, 12.47it/s, loss_c_model=0.033]


[EPOCH 100] Validation Accuracy: 0.719


In [10]:
print(d_model)

SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(128, 96, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0): BasicLayer(
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=96, out_features=96, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=96, out_features=384, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0,

In [14]:
g_model = Generator(in_channels=128, out_channels=128)

In [14]:
trainer = Trainer(g_model,
                  d_model,
                  trainer_kwargs={'max_epochs': 200})
trainer.fit()
trainer.save('./parameters/' + 'cross_validation_proposed_pretrain.pth')

[TRAIN] Epoch 1: 100%|██████████| 2/2 [00:00<00:00,  6.62it/s, loss_g_model=-3.194, loss_d_model=0.451]
[TRAIN] Epoch 2: 100%|██████████| 2/2 [00:00<00:00,  6.83it/s, loss_g_model=-3.115, loss_d_model=0.369]
[TRAIN] Epoch 3: 100%|██████████| 2/2 [00:00<00:00,  6.59it/s, loss_g_model=-3.049, loss_d_model=0.309]
[TRAIN] Epoch 4: 100%|██████████| 2/2 [00:00<00:00,  6.74it/s, loss_g_model=-2.993, loss_d_model=0.308]
[TRAIN] Epoch 5: 100%|██████████| 2/2 [00:00<00:00,  6.61it/s, loss_g_model=-2.951, loss_d_model=0.324]
[TRAIN] Epoch 6: 100%|██████████| 2/2 [00:00<00:00,  6.71it/s, loss_g_model=-2.909, loss_d_model=0.319]
[TRAIN] Epoch 7: 100%|██████████| 2/2 [00:00<00:00,  6.80it/s, loss_g_model=-2.909, loss_d_model=0.379]
[TRAIN] Epoch 8: 100%|██████████| 2/2 [00:00<00:00,  6.71it/s, loss_g_model=-2.919, loss_d_model=0.399]
[TRAIN] Epoch 9: 100%|██████████| 2/2 [00:00<00:00,  6.86it/s, loss_g_model=-2.886, loss_d_model=0.462]
[TRAIN] Epoch 10: 100%|██████████| 2/2 [00:00<00:00,  6.80it/s, 