In [1]:
import torch
import logging
import torch.nn as nn
from tqdm import tqdm
from pathlib import Path
from torch import autograd
from sklearn import model_selection
from torch.utils.data import Dataset, Subset, DataLoader

from strokes import StrokePatientsMIDataset
from strokesdict import STROKEPATIENTSMI_LOCATION_DICT
import scipy
from torcheeg.transforms import Select,BandSignal,Compose,ToTensor
from to import ToGrid
from typing import Callable, Dict, Union, List
import numpy as np
import soxr
from downsample import SetSamplingRate
from baseline import BaselineCorrection
from torcheeg.transforms import EEGTransform, Select,BandSignal,Compose,ToTensor

dataset = StrokePatientsMIDataset(root_path='../../mi_swin/subdataset',
                                  io_path='.torcheeg/datasets_1741312808642_RKI0H',
                        chunk_size=500,  # 1 second
                        overlap = 0,
                        offline_transform=Compose(
                                [BaselineCorrection(),
                                SetSamplingRate(origin_sampling_rate=500,target_sampling_rate=128),
                                BandSignal(sampling_rate=128,band_dict={'frequency_range':[8,40]})
                                ]),
                        online_transform=Compose(
                                # [ToTensor()]),
                                [ToGrid(STROKEPATIENTSMI_LOCATION_DICT),ToTensor()]),
                
                        label_transform=Select('label'),
                        num_worker=8
)
print(dataset[0][0].shape) #EEG shape(1,30,128)
print(dataset[0][1])  # label (int)
print(len(dataset))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
import shutil

def delete_folder_if_exists(target_folder_name):
    # 获取父文件夹中的所有内容
    parent_folder = os.getcwd()
    for folder_name in os.listdir(parent_folder):
        folder_path = os.path.join(parent_folder, folder_name)

        # 检查是否是文件夹并且名称是否匹配
        if os.path.isdir(folder_path) and folder_name == target_folder_name:
            try:
                # 删除目标文件夹
                shutil.rmtree(folder_path)
                print(f"已删除文件夹: {folder_path}")
            except Exception as e:
                print(f"删除文件夹 {folder_path} 时出错: {e}")


In [3]:
def train_test_split(dataset, test_size=0.2, random_state=520, shuffle=True):
    n_samples = len(dataset)
    indices = np.arange(n_samples)
    train_index, test_index = model_selection.train_test_split(
        indices,
        test_size=test_size,
        random_state=random_state,
        shuffle=shuffle)

    trian_dataset = Subset(dataset, train_index)
    test_dataset = Subset(dataset, test_index)

    return trian_dataset, test_dataset

In [None]:
sub_dataset, test_dataset = train_test_split(dataset=dataset)
print(len(sub_dataset), len(test_dataset))

In [4]:
RECEIVED_PARAMS = {
    "c_lr": 0.00001,
    "g_lr": 0.00001,
    "d_lr": 0.00001,
    "weight_gp": 1.0,
    "weight_decay": 0.0005,
    "weight_ssl": 0.5
}

In [2]:
class Generator(nn.Module):
    def __init__(self, in_channels=128, out_channels=128):
        super(Generator, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels,
                      128,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=True), nn.LeakyReLU())
        self.layer2 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=5, stride=1, padding=2, bias=True),
            nn.LeakyReLU())
        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=5, stride=1, padding=2, bias=True),
            nn.LeakyReLU())
        self.layer4 = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1, bias=True),
            nn.LeakyReLU())
        self.delayer1 = nn.Sequential(
            nn.ConvTranspose2d(16 + 32,
                               32,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=True), nn.LeakyReLU())
        self.delayer2 = nn.Sequential(
            nn.ConvTranspose2d(32 + 64,
                               64,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=True), nn.LeakyReLU())
        self.delayer3 = nn.Sequential(
            nn.ConvTranspose2d(64 + 128,
                               128,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=True))

    def forward(self, x):
        #         x = channel_to_location(x)
        mask = (x.abs().sum(dim=1, keepdim=True) > 0).float()
        out1 = self.layer1(x)
        out2 = self.layer2(out1)
        out3 = self.layer3(out2)
        out = self.layer4(out3)
        out = self.delayer1(torch.cat([out, out3], dim=1))
        out = self.delayer2(torch.cat([out, out2], dim=1))
        out = self.delayer3(torch.cat([out, out1], dim=1))

        return out * mask


class ResidualConv2d(nn.Module):  # 貌似并未使用该函数
    def __init__(self, in_channels, out_channels, bias=True):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,
                      out_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=bias), nn.SELU(),
            nn.Conv2d(out_channels,
                      out_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=bias))
        self.res = nn.Conv2d(in_channels,
                             out_channels,
                             kernel_size=1,
                             stride=1,
                             padding=0,
                             bias=bias)

    def forward(self, x):
        return self.conv(x) + self.res(x)

# 识别情感需要分析不同空间尺度下的EEG信号，故引入了包含三种不同尺寸滤波器的InceptionConv2d来提取多尺度特征图
class InceptionConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, bias=True):
        super().__init__()
        self.conv5x5 = nn.Conv2d(in_channels,
                                 out_channels,
                                 kernel_size=5,
                                 stride=1,
                                 padding=2,
                                 bias=bias)
        self.conv3x3 = nn.Conv2d(in_channels,
                                 out_channels,
                                 kernel_size=3,
                                 stride=1,
                                 padding=1,
                                 bias=bias)
        self.conv1x1 = nn.Conv2d(in_channels,
                                 out_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0,
                                 bias=bias)

    def forward(self, x):
        return self.conv5x5(x) + self.conv3x3(x) + self.conv1x1(x)


class SeparableConv2d(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 stride=1,
                 padding=1,
                 bias=True):
        super().__init__()
        # 不同时间点的二维数据分别进行卷积
        self.depth = nn.Conv2d(in_channels,
                               in_channels,
                               kernel_size=kernel_size,
                               stride=stride,
                               padding=padding,
                               groups=in_channels,
                               bias=bias)
        # 单个eeg通道跨时间点进行卷积
        self.point = nn.Conv2d(in_channels,
                               out_channels,
                               kernel_size=1,
                               stride=stride,
                               padding=0,
                               bias=bias)

    def forward(self, x):
        x = self.depth(x)
        x = self.point(x)
        return x


class Discriminator(nn.Module):
    def __init__(self, num_classes, in_channels=4):
        super(Discriminator, self).__init__()
        self.layer1 = nn.Conv2d(in_channels,
                                256,
                                kernel_size=3,
                                stride=1,
                                padding=1,
                                bias=True)
        self.layer2 = nn.Conv2d(256,
                                128,
                                kernel_size=5,
                                stride=1,
                                padding=2,
                                bias=True)
        self.layer3 = nn.Conv2d(128,
                                64,
                                kernel_size=5,
                                stride=1,
                                padding=2,
                                bias=True)
        self.layer4 = SeparableConv2d(64,
                                      32,
                                      kernel_size=5,
                                      stride=1,
                                      padding=2,
                                      bias=True)
        self.layer5 = InceptionConv2d(32, 16)

        self.drop = nn.Sequential(nn.SELU())
        self.fc1 = nn.Sequential(nn.Linear(9 * 9 * 16, 1024, bias=True),
                                 nn.SELU()) # stroke MI dataset 的网格为 7*5
        self.fc2 = nn.Linear(1024, num_classes, bias=True)

    def forward(self, x):
        out = self.layer1(x)
        out = self.drop(out)
        out = self.layer2(out)
        out = self.drop(out)
        out = self.layer3(out)
        out = self.drop(out)
        out = self.layer4(out)
        out = self.drop(out)
        out = self.layer5(out)
        out = self.drop(out)
        out = out.view(out.size(0), -1) # (batch_size, num_features)
        out = self.fc1(out)
        out = self.fc2(out)
        return out


In [None]:
from model import SwinTransformerGenerator
import torch
# Instantiate the generator model
g_model = SwinTransformerGenerator(in_chans=128,
                                     patch_size=2,
                                     window_size=3,
                                     embed_dim=96,
                                     depths=(2, 2, 4, 2),
                                     num_heads=(2, 2, 4, 6)
                                     )

In [3]:
g_model = Generator(in_channels=128, out_channels=128)
# d_model = Discriminator(in_channels=128, num_classes=2)
# 模拟输入
input_tensor = torch.randn(2,1, 128, 9, 9)  # [batch, 1, 128, 7, 5]
output = g_model(input_tensor)
# ans = d_model(output)
print("Output shape:", output.shape)
# print("ans shape:", ans.shape)


torch.Size([2, 128, 9, 9])
Output shape: torch.Size([2, 1, 128, 9, 9])


In [7]:
def random_mask(data, min_r=0.0, max_r=0.5):
    # batch_size*channel_num*time_step
    data = data.clone()
    mask = torch.rand(*data.shape[:2], # 随机生成mask值，(batch_size, 128, 1, 1)
                      *([1] * (len(data.shape) - 2)),
                      device=data.device)
    # ratio = np.random.beta(1.0, 1.0, size=(data.shape[0], 1, 1, 1))
    # ratio = torch.tensor(ratio, device=mask.device).clamp(max=0.5)
    ratio = torch.rand(size=(data.shape[0], 1, 1, 1),
                       device=mask.device) * (max_r - min_r) + min_r # 随机生成1个阈值 (batch_size, 1, 1, 1)
    mask = mask < ratio # mask值低于阈值的，被置零
    mask = mask.expand_as(data) # (batch_size, 128, 1, 1) -> (batch_size, 128, 7, 5)
    data[mask] = 0.0
    return data, ratio


def gradient_penalty(model, real, fake):
    device = real.device
    real = real.data
    fake = fake.data
    alpha = torch.rand(real.size(0), *([1] * (len(real.shape) - 1))).to(device)
    inputs = alpha * real + ((1 - alpha) * fake)
    inputs.requires_grad_()
    outputs = model(inputs)

    gradient = autograd.grad(outputs=outputs,
                             inputs=inputs,
                             grad_outputs=torch.ones_like(outputs).to(device),
                             create_graph=True,
                             retain_graph=True,
                             only_inputs=True)[0]

    gradient = gradient.flatten(1)
    return ((gradient.norm(2, dim=1) - 1)**2).mean()


class Trainer():
    def __init__(self, g_model, d_model, trainer_kwargs={'max_epochs': 10}):
        super().__init__()
        self.g_model = g_model.cuda()
        self.d_model = d_model.cuda()

        self._loss_fn_ce = nn.CrossEntropyLoss()
        self._loss_fn_mse = nn.MSELoss()
        self._optimizer_g_model = torch.optim.Adam(
            g_model.parameters(),
            lr=RECEIVED_PARAMS['g_lr'],
            weight_decay=RECEIVED_PARAMS['weight_decay'])
        self._optimizer_d_model = torch.optim.Adam(
            d_model.parameters(),
            lr=RECEIVED_PARAMS['d_lr'],
            weight_decay=RECEIVED_PARAMS['weight_decay'])

        self._trainer_kwargs = trainer_kwargs

        eeg_dataset = dataset
        train_dataset, val_dataset = train_test_split(eeg_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=16,
                                      shuffle=True,
                                      drop_last=False)

        self._train_dataloader = train_dataloader

    def _accuracy(self, input, target):  # pylint: disable=redefined-builtin
        _, predict = torch.max(input.data, 1)
        correct = predict.eq(target.data).cpu().sum().item()
        return correct / input.size(0)

    def training_step_g_model(self, batch, batch_idx, augment_fn=random_mask):
        self._optimizer_g_model.zero_grad()

        for p in self.d_model.parameters():
            p.requires_grad = False

        x, y = batch
        x, y = x.cuda(), y.cuda()

        aug_x, ratio = random_mask(x)
        pred_x = self.g_model(aug_x)
        loss = -self.d_model(pred_x).mean()

        loss.backward()
        self._optimizer_g_model.step()

        return loss

    def training_step_d_model(self, batch, batch_idx, augment_fn=random_mask):
        self._optimizer_d_model.zero_grad()

        for p in self.d_model.parameters():
            p.requires_grad = True

        x, y = batch
        x, y = x.cuda(), y.cuda()

        aug_x, ratio = random_mask(x)
        pred_x = self.g_model(aug_x).detach()

        loss = self.d_model(pred_x).mean() - self.d_model(x).mean()
        loss += RECEIVED_PARAMS['weight_gp'] * gradient_penalty(
            self.d_model, x, pred_x)

        if batch_idx % 5 == 0:
            loss.backward()
            self._optimizer_d_model.step()

        return loss

    def _train(self, epoch_idx=-1):
        """
        单独显示每个 epoch 的训练进度条，并动态更新 G 和 D 的损失。
        """
        pbar = tqdm(total=len(self._train_dataloader), desc=f"[TRAIN] Epoch {epoch_idx}")
        for i, batch in enumerate(self._train_dataloader):
            # 获取 D 模型的损失
            loss_d_model = self.training_step_d_model(batch, i)
            # 获取 G 模型的损失
            loss_g_model = self.training_step_g_model(batch, i)

            # 更新进度条
            pbar.update(1)
            pbar.set_postfix(
                ordered_dict={
                    'loss_g_model': f'{loss_g_model.item():.3f}',
                    'loss_d_model': f'{loss_d_model.item():.3f}'
                }
            )
        pbar.close()
        
    def fit(self) -> None:
        for i in range(self._trainer_kwargs['max_epochs']):
            self._train(i + 1)

    def save(self, param_path):
        torch.save(
            {
                'g_model': self.g_model.state_dict(),
                'd_model': self.d_model.state_dict()
            }, param_path)



In [8]:
class Classifier(nn.Module):
    def __init__(self, num_classes, in_channels=4):
        super(Classifier, self).__init__()
        self.layer1 = nn.Conv2d(in_channels,
                                256,
                                kernel_size=3,
                                stride=1,
                                padding=1,
                                bias=True)
        self.layer2 = nn.Conv2d(256,
                                128,
                                kernel_size=5,
                                stride=1,
                                padding=2,
                                bias=True)
        self.layer3 = nn.Conv2d(128,
                                64,
                                kernel_size=5,
                                stride=1,
                                padding=2,
                                bias=True)
        self.layer4 = SeparableConv2d(64,
                                      32,
                                      kernel_size=5,
                                      stride=1,
                                      padding=2,
                                      bias=True)
        self.layer5 = InceptionConv2d(32, 16)
        self.drop = nn.Sequential(nn.Dropout(), nn.SELU())
        self.fc1 = nn.Sequential(nn.Linear(9 * 9 * 16, 1024, bias=True),
                                 nn.SELU())
        self.fc2 = nn.Linear(1024, num_classes, bias=True)

    def forward(self, x):
        out = self.layer1(x)
        out = self.drop(out)
        out = self.layer2(out)
        out = self.drop(out)
        out = self.layer3(out)
        out = self.drop(out)
        out = self.layer4(out)
        out = self.drop(out)
        out = self.layer5(out)
        out = self.drop(out)
        out = out.view(out.size(0), -1)
        feat = self.fc1(out)
        out = self.fc2(feat)
        return out, feat


c_model = Classifier(num_classes=2, in_channels=128)


In [7]:
from model import SwinTransformer_D
c_model = SwinTransformer_D(in_chans=128,
                            num_classes=2,
                            embed_dim=96,
                            depths=(2, 2, 4, 2),
                            num_heads=(2, 2, 4, 6),
                            visual_mode=True
                            )

In [None]:
from classifier import  ClassifierTrainer

train_dataset, val_dataset = train_test_split(dataset)
trainer = ClassifierTrainer(model=c_model,
                            num_classes=2,
                            lr=RECEIVED_PARAMS['c_lr'],
                            weight_decay=1e-5,
                            metrics=["accuracy"],
                            accelerator="gpu")
train_dataloader = DataLoader(train_dataset,
                              batch_size=16,
                              shuffle=True,
                              drop_last=False)
val_dataloader = DataLoader(val_dataset,
                            batch_size=16,
                            shuffle=False,
                            drop_last=False)

trainer.fit(train_dataloader,
            val_dataloader,
            max_epochs=300,
            enable_model_summary=False,
            limit_val_batches=0.0)
trainer.save('./parameters/' + 'cross_validation_backbone.pth')

In [None]:
val_dataloader = DataLoader(val_dataset,
                            batch_size=16,
                            shuffle=False,
                            drop_last=False)

test_result = trainer.test(val_dataloader,
                            enable_progress_bar=True,
                            enable_model_summary=True)[0]
# training_metrics.append(training_result["test_accuracy"])
print(test_result["test_accuracy"])

In [None]:
class CTrainer():
    def __init__(self, c_model, trainer_kwargs={'max_epochs': 10}):
        super().__init__()
        self.c_model = c_model.cuda()

        self._loss_fn_ce = nn.CrossEntropyLoss()
        self._optimizer_c_model = torch.optim.Adam(c_model.parameters(),
                                                   lr=RECEIVED_PARAMS['c_lr'],
                                                   weight_decay=0.0005)
        self._trainer_kwargs = trainer_kwargs

        eeg_dataset = dataset
        train_dataset, val_dataset = train_test_split(eeg_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=16,
                                      shuffle=True,
                                      drop_last=False)
        val_dataloader = DataLoader(val_dataset,
                                    batch_size=16,
                                    shuffle=False,
                                    drop_last=False)

        self._train_dataloader = train_dataloader
        self._val_dataloader = val_dataloader

    def _accuracy(self, input, target):  # pylint: disable=redefined-builtin
        _, predict = torch.max(input.data, 1)
        correct = predict.eq(target.data).cpu().sum().item()
        return correct / input.size(0)

    def training_step_c_model(self, batch, batch_idx):
        for p in self.c_model.parameters():
            p.requires_grad = True

        self._optimizer_c_model.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        y_hat, x_feat = self.c_model(x)
        loss = self._loss_fn_ce(y_hat, y)

        loss.backward()
        self._optimizer_c_model.step()

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = self.validation_step_before_model(batch, batch_idx)
        y_hat, x_feat = self.c_model(x)
        return (y_hat.detach().cpu(), y.detach().cpu())

    def validation_step_before_model(self, batch, batch_idx):
        x, y = batch
        x, y = x.cuda(), y.cuda()
        return x, y

    def validation_epoch_end(self, outputs):
        # We might need dict metrics in future?
        y_hat, y = zip(*outputs)
        y_hat = torch.cat(y_hat, dim=0)
        y = torch.cat(y, dim=0)
        avg_acc = self._accuracy(y_hat, y)
        print("acc:", avg_acc)
        # logger.info('[VAL] Average ACC at epoch end is {}'.format(avg_acc))
        # return {'val_acc': avg_acc}

    def _validate(self, epoch_idx=-1):
        validation_outputs = []
        for i, batch in enumerate(self._val_dataloader):
            validation_outputs.append(self.validation_step(batch, i))
        return self.validation_epoch_end(validation_outputs)

    def _train(self, epoch_idx=-1):
        """
        单独显示每个 epoch 的训练进度条。
        """
        pbar = tqdm(total=len(self._train_dataloader), desc=f"[TRAIN] Epoch {epoch_idx}")
        for i, batch in enumerate(self._train_dataloader):
            loss_c_model = self.training_step_c_model(batch, i)
            pbar.update(1)
            # 更新进度条的后缀信息
            pbar.set_postfix(ordered_dict={'loss_c_model': f'{loss_c_model.item():.3f}'})
        pbar.close()

    def fit(self) -> None:
        """
        按照每个 epoch 单独创建训练和验证进度条。
        """
        for epoch_idx in range(self._trainer_kwargs['max_epochs']):
            self._train(epoch_idx + 1)
            self._validate(epoch_idx + 1)

        # logger.info('[VAL] Final ACC at experiment end is {}'.format(
        #     self._validate()['val_acc']))

    def save(self, param_path):
        torch.save({
            'c_model': self.c_model.state_dict(),
        }, param_path)


trainer = CTrainer(c_model, trainer_kwargs={'max_epochs': 80})
trainer.fit()
trainer.save('./parameters/' + 'cross_validation_backbone.pth')

In [9]:
from model import SwinTransformer, SwinTransformer_D

c_model = SwinTransformer_D(in_chans=128,
                            num_classes=2,
                            embed_dim=96,
                            depths=(2, 2, 4, 2),
                            num_heads=(2, 2, 4, 6),
                            visual_mode=True
                            )
d_model = SwinTransformer(in_chans=128,
                          num_classes=2,
                          embed_dim=96,
                          depths=(2, 2, 4, 2),
                          num_heads=(2, 2, 4, 6),
                          )

In [None]:
# 模拟输入
import torch
input_tensor = torch.randn(2, 128, 9, 9)  # [batch, 1, 128, 7, 5]
output = d_model(input_tensor)

print("Input shape:", input_tensor.shape)
print(output)

In [None]:
import os
import torch.nn.functional as F

class GCTrainer():
    def __init__(self, c_model, g_model, trainer_kwargs={'max_epochs': 10}):
        super().__init__()
        self.c_model = c_model.cuda()
        self.g_model = g_model.cuda()

        self._loss_fn_ce = nn.CrossEntropyLoss()
        self._loss_fn_mse = nn.MSELoss()
        self._optimizer_c_model = torch.optim.Adam(c_model.parameters(),
                                                   lr=RECEIVED_PARAMS['c_lr'],
                                                   weight_decay=0.0005)

        self._trainer_kwargs = trainer_kwargs

        eeg_dataset = dataset
        train_dataset, val_dataset = train_test_split(eeg_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=16,
                                      shuffle=True,
                                      drop_last=False)
        val_dataloader = DataLoader(val_dataset,
                                    batch_size=16,
                                    shuffle=False,
                                    drop_last=False)

        self._train_dataloader = train_dataloader
        self._val_dataloader = val_dataloader

    def _accuracy(self, input, target):  # pylint: disable=redefined-builtin
        _, predict = torch.max(input.data, 1)
        correct = predict.eq(target.data).cpu().sum().item()
        return correct / input.size(0)

    def training_step_c_model(self, batch, batch_idx):
        for p in self.c_model.parameters():
            p.requires_grad = True

        self._optimizer_c_model.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        y_hat, x_feat = self.c_model(x)
        loss = self._loss_fn_ce(y_hat, y)

        aug_x, ratio = random_mask(x)
        aug_x = self.g_model(aug_x).detach()
        aug_y_hat, aug_x_feat = self.c_model(aug_x)

        loss += RECEIVED_PARAMS['weight_ssl'] * (
            (1 - ratio).squeeze() * F.mse_loss(
                x_feat, aug_x_feat, reduction='none').mean(dim=-1)).mean()

        loss.backward()
        self._optimizer_c_model.step()

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = self.validation_step_before_model(batch, batch_idx)
        y_hat, x_feat = self.c_model(x)
        return (y_hat.detach().cpu(), y.detach().cpu())

    def validation_step_before_model(self, batch, batch_idx):
        x, y = batch
        x, y = x.cuda(), y.cuda()
        return x, y

    def validation_epoch_end(self, outputs):
        # We might need dict metrics in future?
        y_hat, y = zip(*outputs)
        y_hat = torch.cat(y_hat, dim=0)
        y = torch.cat(y, dim=0)
        avg_acc = self._accuracy(y_hat, y)
        # logger.info('[VAL] Average ACC at epoch end is {}'.format(avg_acc))
        return {'val_acc': avg_acc}

    def _validate(self, epoch_idx=-1):
        validation_outputs = []
        for i, batch in enumerate(self._val_dataloader):
            validation_outputs.append(self.validation_step(batch, i))
        return self.validation_epoch_end(validation_outputs)

    def _train(self, epoch_idx=-1):
        """
        单独显示每个 epoch 的训练进度条。
        """
        pbar = tqdm(total=len(self._train_dataloader), desc=f"[TRAIN] Epoch {epoch_idx}")
        for i, batch in enumerate(self._train_dataloader):
            loss_c_model = self.training_step_c_model(batch, i)
            pbar.update(1)
            # 更新进度条的后缀信息
            pbar.set_postfix(ordered_dict={'loss_c_model': f'{loss_c_model.item():.3f}'})
        pbar.close()

    def fit(self) -> None:
        """
        按照每个 epoch 单独创建训练和验证进度条。
        """
        for epoch_idx in range(self._trainer_kwargs['max_epochs']):
            self._train(epoch_idx + 1)
                # 验证过程并获取验证结果
            val_metrics = self._validate(epoch_idx + 1)
            val_acc = val_metrics['val_acc']
            
            # 打印验证准确率
            print(f"[EPOCH {epoch_idx + 1}] Validation Accuracy: {val_acc:.3f}")

    def save(self, param_path):
        torch.save({
            'c_model': trainer.c_model.state_dict(),
        }, param_path)

    def load(self):
        gan_model_state_dict = torch.load(
            './parameters/cross_validation_proposed_pretrain.pth')
        self.g_model.load_state_dict(gan_model_state_dict['g_model'])

        if os.path.exists('./parameters/cross_validation_backbone' + '.pth'):
            c_model_state_dict = torch.load(
                './parameters/cross_validation_backbone'  +
                '.pth')
            self.c_model.load_state_dict(c_model_state_dict['c_model'])


trainer = GCTrainer(c_model,
                  g_model,
                  trainer_kwargs={'max_epochs': 100})
trainer.load()
trainer.fit()
trainer.save('./parameters/' +  'cross_validation_finetune.pth')

In [None]:
print(c_model)

In [14]:
g_model = Generator(in_channels=128, out_channels=128)

In [None]:
trainer = Trainer(g_model,
                  d_model,
                  trainer_kwargs={'max_epochs': 200})
trainer.fit()
trainer.save('./parameters/' + 'cross_validation_proposed_pretrain.pth')