In [1]:
from pathlib import Path
from scipy.ndimage import zoom
from scipy.ndimage import find_objects
import torchio as tio
import os
import glob
import re
from configparser import ConfigParser
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from typing import Dict, Tuple
import matplotlib.pyplot as plt
from collections import deque
from sklearn.model_selection import KFold
import math
from Fdataset import ACDCDataset, PairwiseAugmentor

# 配置参数
CLASS_MAP = {'NOR':0, 'DCM':1, 'HCM':2, 'MINF':3, 'RV':4}
TARGET_SHAPE = (200, 200, 80)
TARGET_SPACING = 1.25  # mm
AUG_FACTOR = 1  


  from .autonotebook import tqdm as notebook_tqdm


import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class ConvBlock3D(nn.Module):
    """3D卷积块（带SE注意力）"""
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, expansion=4, se_ratio=0.25):
        super().__init__()
        expanded = in_channels * expansion
        self.conv1 = nn.Conv3d(in_channels, expanded, 1, bias=False)
        self.bn1 = nn.BatchNorm3d(expanded)
        
        # 深度可分离卷积
        self.depthwise = nn.Conv3d(
            expanded, expanded, kernel_size, stride, 
            padding=kernel_size//2, groups=expanded, bias=False
        )
        self.bn2 = nn.BatchNorm3d(expanded)
        
        # SE注意力
        se_channels = max(1, int(expanded * se_ratio))
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Conv3d(expanded, se_channels, 1),
            nn.SiLU(),
            nn.Conv3d(se_channels, expanded, 1),
            nn.Sigmoid()
        )
        
        self.conv2 = nn.Conv3d(expanded, out_channels, 1, bias=False)
        self.bn3 = nn.BatchNorm3d(out_channels)
        self.act = nn.SiLU()
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm3d(out_channels)
            )

    def forward(self, x):
        residual = self.shortcut(x)
        x = self.act(self.bn1(self.conv1(x)))
        x = self.act(self.bn2(self.depthwise(x)))
        
        # 应用SE
        se = self.se(x)
        x = x * se
        
        x = self.bn3(self.conv2(x))
        return self.act(x + residual)

class EfficientNetV2_3D(nn.Module):
    """3D版EfficientNetV2（简化架构）"""
    def __init__(self, in_channels=1, num_classes=5):
        super().__init__()
        # 初始卷积层
        self.stem = nn.Sequential(
            nn.Conv3d(in_channels, 32, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm3d(32),
            nn.SiLU()
        )
        
        # MBConv模块序列
        self.blocks = nn.Sequential(
            ConvBlock3D(32, 24, 3, stride=1, expansion=1),
            ConvBlock3D(24, 48, 3, stride=2, expansion=4),
            ConvBlock3D(48, 64, 3, stride=2, expansion=4),
            ConvBlock3D(64, 128, 3, stride=2, expansion=4),
            ConvBlock3D(128, 160, 3, stride=1, expansion=6),
            ConvBlock3D(160, 256, 3, stride=2, expansion=6)
        )
        
        # 头部
        self.head = nn.Sequential(
            nn.Conv3d(256, 1280, 1, bias=False),
            nn.BatchNorm3d(1280),
            nn.SiLU(),
            nn.AdaptiveAvgPool3d(1),
            nn.Flatten()
        )
        self.classifier = nn.Linear(1280, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.blocks(x)
        x = self.head(x)
        return self.classifier(x)

class PatchEmbedding3D(nn.Module):
    """3D图像分块嵌入"""
    def __init__(self, patch_size=16, in_chans=1, embed_dim=768):
        super().__init__()
        self.proj = nn.Conv3d(
            in_chans, embed_dim, 
            kernel_size=patch_size, 
            stride=patch_size
        )
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.proj(x)  # [B, C, D, H, W]
        x = rearrange(x, 'b c d h w -> b (d h w) c')
        return self.norm(x)

class TransformerEncoder(nn.Module):
    """Transformer编码器层（3D）"""
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # 自注意力
        res = x
        x = self.norm1(x)
        x, _ = self.attn(x, x, x)
        x = res + x
        
        # MLP
        res = x
        x = self.norm2(x)
        x = self.mlp(x)
        return res + x

class VisionTransformer3D(nn.Module):
    """3D版Vision Transformer"""
    def __init__(self, in_chans=1, num_classes=5, 
                 patch_size=16, embed_dim=768, depth=12, 
                 num_heads=12, mlp_ratio=4.0):
        super().__init__()
        self.patch_embed = PatchEmbedding3D(patch_size, in_chans, embed_dim)
        
        # 位置编码
        num_patches = (200//patch_size) * (200//patch_size) * (80//patch_size)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Transformer编码器
        self.blocks = nn.Sequential(*[
            TransformerEncoder(embed_dim, num_heads, mlp_ratio)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        
        # 添加cls token
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x += self.pos_embed
        
        # Transformer编码
        x = self.blocks(x)
        x = self.norm(x)
        
        # 分类头
        return self.head(x[:, 0])

class EnsembleModel(nn.Module):
    def __init__(self, model1, model2):
        super().__init__()
        self.model1 = model1
        self.model2 = model2
        
    def forward(self, x):
        logits1 = self.model1(x)
        logits2 = self.model2(x)
        
        # 几何平均（概率空间）
        prob1 = F.softmax(logits1, dim=-1)
        prob2 = F.softmax(logits2, dim=-1)
        geom_prob = torch.sqrt(prob1 * prob2 + 1e-9)
        
        # 转换回logits形式（对数概率）
        return torch.log(geom_prob)  # CrossEntropyLoss需要logits输入

from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
import torch.multiprocessing as mp
import json

if __name__ == '__main__':  # 确保在主模块中设置
    mp.set_start_method('spawn', force=True)

start_fold = 0  # 可修改为需要开始的折数 (0-4)
results_file = '新-trans.json'
CUSTOM_PREFIX = "新-trans"

# 尝试加载已有的结果 - 添加空文件处理
fold_results = []
if os.path.exists(results_file):
    try:
        with open(results_file, 'r') as f:
            file_content = f.read().strip()
            if file_content:  # 检查文件是否非空
                fold_results = json.loads(file_content)
                print(f"Loaded existing results: {fold_results}")
            else:
                print("Results file exists but is empty. Starting fresh.")
    except json.JSONDecodeError:
        print("Warning: Results file contains invalid JSON. Starting fresh.")
        fold_results = []
else:
    print("No existing results file found. Starting fresh.")

# 训练流程修改
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds, all_labels = [], []
    
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device).to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        
        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())  # 使用extend代替append
        all_labels.extend(labels.cpu().numpy()) # 转换为numpy数组
    
    return running_loss/len(loader), accuracy_score(all_labels, all_preds)

def evaluate(model, loader, criterion, device):
    model.eval()
    val_loss = 0.0
    val_preds, val_labels = [], []
    
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device).to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)
            
            _, preds = torch.max(outputs, 1)
            val_preds.extend(preds.cpu().numpy())   # 修改为extend
            val_labels.extend(labels.cpu().numpy())  # 修改为extend
    
    return val_loss/len(loader.dataset), accuracy_score(val_labels, val_preds)


# 五折交叉验证修改版
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
all_cases = [d for d in Path('心力衰竭/database/training').glob('patient*') if d.is_dir()] + \
            [d for d in Path('心力衰竭/database/testing').glob('patient*') if d.is_dir()]
all_labels = []  # 存储每个病例的标签

# 收集每个病例的标签
for case in all_cases:
    # 创建临时数据集实例（不需要变换）
    _, label = ACDCDataset([case], phase='train')[0]
    all_labels.append(label)
fold_results = []

for fold, (train_val_idx, test_idx) in enumerate(kf.split(all_cases, all_labels)):
    print(f"\n=== Fold {fold+1}/5 ===")
    
    # 划分训练验证集和测试集
    train_val_cases = [all_cases[i] for i in train_val_idx]
    test_cases = [all_cases[i] for i in test_idx]
    
    # 从训练验证集中提取标签用于再分层
    train_val_labels = [all_labels[i] for i in train_val_idx]
    
    # 在训练验证集内部进行分层划分 (75%训练, 25%验证)
    sss = StratifiedShuffleSplit(n_splits=1, test_size=0.25, random_state=42)
    for train_idx, val_idx in sss.split(train_val_cases, train_val_labels):
        train_cases = [train_val_cases[i] for i in train_idx]
        val_cases = [train_val_cases[i] for i in val_idx]
    
    # 创建数据集
    train_dataset = ACDCDataset(train_cases, phase='train')
    val_dataset = ACDCDataset(val_cases, phase='val')    # 从训练集划分的验证集
    test_dataset = ACDCDataset(test_cases, phase='val')  # 独立测试集
    
    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=3)
    val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=1)
    test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)
    
    # 模型初始化（保持原有实现不变）
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    effnet = EfficientNetV2_3D()
    vit = VisionTransformer3D(patch_size=20, depth=6, embed_dim=512, num_heads=8)  # 简化ViT
    model = EnsembleModel(effnet, vit).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5, patience=5)
    
    # 初始化跟踪变量
    best_acc = 0.0  # 只跟踪最佳准确率
    best_loss = 10
    best_model_path = f"{CUSTOM_PREFIX}_fold{fold}_best.pth"
    final_model_path = f"{CUSTOM_PREFIX}_fold{fold}_last.pth"
    
    for epoch in range(100):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)  # 使用新验证集
        
        scheduler.step(val_acc)
        
        # 动态保存最佳模型（只保留最佳准确率版本）
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), best_model_path)
        
        
        # 早停判断（基于验证损失）
        if val_loss < best_loss:
            best_loss = val_loss
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= 10:
                print(f"Early stopping at epoch {epoch+1}")
                break
    
    model.load_state_dict(torch.load(best_model_path))
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    fold_results.append(test_acc)  # 记录测试集准确率
    print(f"Fold {fold+1} Test Accuracy: {test_acc:.2%}")    
    
    # 确保最终模型保存（即使早停也保存最后达到的epoch）
    torch.save(model.state_dict(), final_model_path)

    with open(results_file, 'w') as f:
        json.dump(fold_results, f)
    print(f"\nCurrent 5-Fold CV Results: {fold_results}")
    print(f"Average Accuracy: {np.mean(fold_results):.2%} (±{np.std(fold_results):.2%})")

# 输出结果（保持原有实现不变）
if os.path.exists(results_file):
    with open(results_file, 'r') as f:
        final_results = json.load(f)
print("\n=== Final Results ===")
print(f"5-Fold CV Results: {final_results}")
print(f"Average Accuracy: {np.mean(final_results):.2%} (±{np.std(final_results):.2%})")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SEBlock3D(nn.Module):
    def __init__(self, channels, reduction_ratio=16):
        super(SEBlock3D, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction_ratio),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction_ratio, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1, 1)
        return x * y


class FeatureExtractor3D(nn.Module):
    def __init__(self, in_channels=1):
        super(FeatureExtractor3D, self).__init__()
        # 输入: (batch, 1, 80, 200, 200)
        self.conv1 = nn.Sequential(
            nn.Conv3d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2)  # 输出: (40, 100, 100)
        )
        self.se1 = SEBlock3D(64)
        
        self.conv2 = nn.Sequential(
            nn.Conv3d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2)  # 输出: (20, 50, 50)
        )
        self.se2 = SEBlock3D(32)
        
        self.conv3 = nn.Sequential(
            nn.Conv3d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2)  # 输出: (10, 25, 25)
        )
        
        # 展平后的特征维度: 32 * 10 * 25 * 25 = 200,000
        self.flatten = nn.Flatten()
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.se1(x)
        x = self.conv2(x)
        x = self.se2(x)
        x = self.conv3(x)
        return self.flatten(x)  # 输出: (batch, 200,000)


# 模型1：多二分类器 (每个类别独立)
class MultiBinaryClassifier(nn.Module):
    def __init__(self, input_size, num_classes=3):
        super(MultiBinaryClassifier, self).__init__()
        self.classifiers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_size, 256),
                nn.ReLU(),
                nn.Linear(256, 1),
                nn.Sigmoid()
            ) for _ in range(num_classes)
        ])
    
    def forward(self, x):
        return torch.cat([cls(x) for cls in self.classifiers], dim=1)

# 模型2：标准多类分类器
class MultiClassClassifier(nn.Module):
    def __init__(self, input_size, num_classes=3):
        super(MultiClassClassifier, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )
    
    def forward(self, x):
        return self.net(x)

# 模型3：贝叶斯分类器 (使用MC Dropout模拟不确定性)
class BayesianClassifier(nn.Module):
    def __init__(self, input_size, num_classes=3):
        super(BayesianClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.dropout1 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(256, 64)
        self.dropout2 = nn.Dropout(p=0.5)
        self.fc3 = nn.Linear(64, num_classes)
    
    def forward(self, x, sample=True):
        x = F.relu(self.fc1(x))
        if sample: x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        if sample: x = self.dropout2(x)
        return self.fc3(x)


class VS_BEAM_3D(nn.Module):
    def __init__(self, in_channels=1, num_classes=5):
        super(VS_BEAM_3D, self).__init__()
        self.feature_extractor = FeatureExtractor3D(in_channels)
        input_size = 200000  # 32 * 10 * 25 * 25
        
        self.model1 = MultiBinaryClassifier(input_size, num_classes)
        self.model2 = MultiClassClassifier(input_size, num_classes)
        self.model3 = BayesianClassifier(input_size, num_classes)
        
    def forward(self, x):
        features = self.feature_extractor(x)
        
        # 三个模型的预测
        out1 = self.model1(features)  # [batch, num_classes]
        out2 = self.model2(features)  # [batch, num_classes]
        out3 = self.model3(features)  # [batch, num_classes]
        
        # 投票机制
        combined = torch.stack([out1, out2, out3], dim=2)  # [batch, classes, 3]
        votes = torch.argmax(combined, dim=1)  # [batch, 3]
        final = torch.mode(votes, dim=1).values
        
        return {
            'model1': out1,
            'model2': out2,
            'model3': out3,
            'final': final
        }


In [3]:
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
import torch.multiprocessing as mp
import json

if __name__ == '__main__':  # 确保在主模块中设置
    mp.set_start_method('spawn', force=True)

start_fold = 0  # 可修改为需要开始的折数 (0-4)
results_file = '新-classi.json'
CUSTOM_PREFIX = "新-classi"

# 尝试加载已有的结果 - 添加空文件处理
fold_results = []
if os.path.exists(results_file):
    try:
        with open(results_file, 'r') as f:
            file_content = f.read().strip()
            if file_content:  # 检查文件是否非空
                fold_results = json.loads(file_content)
                print(f"Loaded existing results: {fold_results}")
            else:
                print("Results file exists but is empty. Starting fresh.")
    except json.JSONDecodeError:
        print("Warning: Results file contains invalid JSON. Starting fresh.")
        fold_results = []
else:
    print("No existing results file found. Starting fresh.")

# 训练流程修改
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds, all_labels = [], []
    
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss1 = criterion(outputs['model1'], labels)
        loss2 = criterion(outputs['model2'], labels)
        loss3 = criterion(outputs['model3'], labels)
        loss = loss1 + loss2 + loss3
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        
        running_loss += loss.item()
        preds = outputs['final']
        all_preds.extend(preds.cpu().numpy())  # 使用extend代替append
        all_labels.extend(labels.cpu().numpy()) # 转换为numpy数组
    
    return running_loss/len(loader), accuracy_score(all_labels, all_preds)

def evaluate(model, loader, criterion, device):
    model.eval()
    val_loss = 0.0
    val_preds, val_labels = [], []
    
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss1 = criterion(outputs['model1'], labels)
            loss2 = criterion(outputs['model2'], labels)
            loss3 = criterion(outputs['model3'], labels)
            loss = loss1 + loss2 + loss3
            val_loss += loss.item() * inputs.size(0)
            
            preds = outputs['final']
            val_preds.extend(preds.cpu().numpy())   # 修改为extend
            val_labels.extend(labels.cpu().numpy())  # 修改为extend
    
    return val_loss/len(loader.dataset), accuracy_score(val_labels, val_preds)


# 五折交叉验证修改版
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
all_cases = [d for d in Path('心力衰竭/database/training').glob('patient*') if d.is_dir()] + \
            [d for d in Path('心力衰竭/database/testing').glob('patient*') if d.is_dir()]
all_labels = []  # 存储每个病例的标签

# 收集每个病例的标签
for case in all_cases:
    # 创建临时数据集实例（不需要变换）
    _, label = ACDCDataset([case], phase='train')[0]
    all_labels.append(label)
fold_results = []

for fold, (train_val_idx, test_idx) in enumerate(kf.split(all_cases, all_labels)):
    print(f"\n=== Fold {fold+1}/5 ===")
    
    # 划分训练验证集和测试集
    train_val_cases = [all_cases[i] for i in train_val_idx]
    test_cases = [all_cases[i] for i in test_idx]
    
    # 从训练验证集中提取标签用于再分层
    train_val_labels = [all_labels[i] for i in train_val_idx]
    
    # 在训练验证集内部进行分层划分 (75%训练, 25%验证)
    sss = StratifiedShuffleSplit(n_splits=1, test_size=0.25, random_state=42)
    for train_idx, val_idx in sss.split(train_val_cases, train_val_labels):
        train_cases = [train_val_cases[i] for i in train_idx]
        val_cases = [train_val_cases[i] for i in val_idx]
    
    # 创建数据集
    train_dataset = ACDCDataset(train_cases, phase='train')
    val_dataset = ACDCDataset(val_cases, phase='val')    # 从训练集划分的验证集
    test_dataset = ACDCDataset(test_cases, phase='val')  # 独立测试集
    
    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=3)
    val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=1)
    test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)
    
    # 模型初始化（保持原有实现不变）
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = VS_BEAM_3D().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5, patience=5)
    
    # 初始化跟踪变量
    best_acc = 0.0  # 只跟踪最佳准确率
    best_loss = 1000
    best_model_path = f"{CUSTOM_PREFIX}_fold{fold}_best.pth"
    final_model_path = f"{CUSTOM_PREFIX}_fold{fold}_last.pth"
    
    for epoch in range(100):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)  # 使用新验证集
        
        scheduler.step(val_acc)
        
        # 动态保存最佳模型（只保留最佳准确率版本）
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), best_model_path)
        
        
        # 早停判断（基于验证损失）
        if val_loss < best_loss:
            best_loss = val_loss
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= 10:
                print(f"Early stopping at epoch {epoch+1}")
                break
    
    model.load_state_dict(torch.load(best_model_path))
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    fold_results.append(test_acc)  # 记录测试集准确率
    print(f"Fold {fold+1} Test Accuracy: {test_acc:.2%}")    
    
    # 确保最终模型保存（即使早停也保存最后达到的epoch）
    torch.save(model.state_dict(), final_model_path)

    with open(results_file, 'w') as f:
        json.dump(fold_results, f)
    print(f"\nCurrent 5-Fold CV Results: {fold_results}")
    print(f"Average Accuracy: {np.mean(fold_results):.2%} (±{np.std(fold_results):.2%})")

# 输出结果（保持原有实现不变）
if os.path.exists(results_file):
    with open(results_file, 'r') as f:
        final_results = json.load(f)
print("\n=== Final Results ===")
print(f"5-Fold CV Results: {final_results}")
print(f"Average Accuracy: {np.mean(final_results):.2%} (±{np.std(final_results):.2%})")

No existing results file found. Starting fresh.

=== Fold 1/5 ===
Early stopping at epoch 11


  model.load_state_dict(torch.load(best_model_path))


Fold 1 Test Accuracy: 55.00%

Current 5-Fold CV Results: [0.55]
Average Accuracy: 55.00% (±0.00%)

=== Fold 2/5 ===
Early stopping at epoch 13


  model.load_state_dict(torch.load(best_model_path))


Fold 2 Test Accuracy: 50.00%

Current 5-Fold CV Results: [0.55, 0.5]
Average Accuracy: 52.50% (±2.50%)

=== Fold 3/5 ===
Early stopping at epoch 12


  model.load_state_dict(torch.load(best_model_path))


Fold 3 Test Accuracy: 61.67%

Current 5-Fold CV Results: [0.55, 0.5, 0.6166666666666667]
Average Accuracy: 55.56% (±4.78%)

=== Fold 4/5 ===
Early stopping at epoch 13


  model.load_state_dict(torch.load(best_model_path))


Fold 4 Test Accuracy: 58.33%

Current 5-Fold CV Results: [0.55, 0.5, 0.6166666666666667, 0.5833333333333334]
Average Accuracy: 56.25% (±4.31%)

=== Fold 5/5 ===
Early stopping at epoch 11


  model.load_state_dict(torch.load(best_model_path))


Fold 5 Test Accuracy: 43.33%

Current 5-Fold CV Results: [0.55, 0.5, 0.6166666666666667, 0.5833333333333334, 0.43333333333333335]
Average Accuracy: 53.67% (±6.45%)

=== Final Results ===
5-Fold CV Results: [0.55, 0.5, 0.6166666666666667, 0.5833333333333334, 0.43333333333333335]
Average Accuracy: 53.67% (±6.45%)
