In [None]:
import os
import cv2
import numpy as np
import torch
import random
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, auc, roc_auc_score
from sklearn.utils import resample
import matplotlib.pyplot as plt
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from collections import Counter
from torch.utils.data.sampler import WeightedRandomSampler

# 配置参数 - 添加了测试集数据路径
class Config:
    train_data_path = "C:/Users/YK/Desktop/liver_train"  # 训练集路径
    test_data_path = "C:/Users/YK/Desktop/liver_test"    # 测试集路径
    mask_data_path = "C:/Users/YK/Desktop/liver-mask_pngs"
    output_dir = "./outputs(LF-liver)"
    seed = 42
    img_size = (224, 224)
    batch_size = 16
    num_workers = 0  # Windows系统下设置为0
    num_epochs = 120
    lr = 1e-4
    num_classes = 3
    val_size = 0.1   # 只保留验证集比例，移除测试集比例
    weight_decay = 1e-5
    n_bootstrap = 1000  # 用于计算置信区间的bootstrap抽样次数

# 设置随机种子，确保结果可复现
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(Config.seed)

# 创建输出目录
os.makedirs(f"{Config.output_dir}/models", exist_ok=True)
os.makedirs(f"{Config.output_dir}/results", exist_ok=True)
os.makedirs(f"{Config.output_dir}/results/train", exist_ok=True)
os.makedirs(f"{Config.output_dir}/results/test", exist_ok=True)

# 自定义数据集类 - 保持不变
class SWEDataset(Dataset):
    def __init__(self, raw_path, mask_path, filenames, transform=None):
        self.raw_path = raw_path
        self.mask_path = mask_path
        self.filenames = filenames
        self.transform = transform
        self.label_map = {'F2': 0, 'F3': 1, 'F4': 2}  # 根据实际标签修改

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        # 加载原始图像和掩膜
        filename = self.filenames[idx]
        raw_img = cv2.cvtColor(cv2.imread(f"{self.raw_path}/{filename}"), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(f"{self.mask_path}/{filename}", cv2.IMREAD_GRAYSCALE)

        # 应用掩膜
        masked_img = cv2.bitwise_and(raw_img, raw_img, mask=mask)

        # 针对性预处理：去除噪声、增强对比度
        blurred = cv2.GaussianBlur(masked_img, (5, 5), 0)
        
        # 转换为LAB颜色空间进行CLAHE增强
        lab = cv2.cvtColor(blurred, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        cl = clahe.apply(l)
        limg = cv2.merge((cl, a, b))
        enhanced_img = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB)

        # 调整大小
        resized_img = cv2.resize(enhanced_img, Config.img_size)

        if self.transform:
            augmented = self.transform(image=resized_img)
            final_img = augmented['image']
        else:
            # 默认转换
            final_img = transforms.ToTensor()(resized_img)
            final_img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(final_img)

        # 提取标签 (假设文件名格式为"F0_001.jpg")
        label_key = filename.split('_')[0]
        if label_key not in self.label_map:
            raise ValueError(f"Unknown label '{label_key}' in filename: {filename}")
        
        label = self.label_map[label_key]

        return final_img, label

# 数据增强（使用 albumentations）- 保持不变
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

# 加载训练集文件
train_files = [f for f in os.listdir(Config.train_data_path) 
             if f.endswith(('.png', '.jpg', '.jpeg')) and 
             f.split('_')[0] in ['F2', 'F3', 'F4']]  # 确保只包含有效标签的文件

# 加载测试集文件
test_files = [f for f in os.listdir(Config.test_data_path) 
             if f.endswith(('.png', '.jpg', '.jpeg')) and 
             f.split('_')[0] in ['F2', 'F3', 'F4']]  # 确保只包含有效标签的文件

# 检查文件数量
if len(train_files) == 0:
    raise ValueError("No valid image files found in the training data path")

if len(test_files) == 0:
    raise ValueError("No valid image files found in the test data path")

# 从训练集中划分验证集
train_labels = [f.split('_')[0] for f in train_files]
train_files, val_files = train_test_split(train_files, test_size=Config.val_size, 
                                         random_state=Config.seed, stratify=train_labels)

print(f"Dataset sizes: Train={len(train_files)}, Val={len(val_files)}, Test={len(test_files)}")

# 处理数据不平衡 - 保持不变
train_labels = [SWEDataset(Config.train_data_path, Config.mask_data_path, train_files).label_map[f.split('_')[0]] for f in train_files]
class_counts = Counter(train_labels)
print(f"Class distribution in training set: {class_counts}")

class_weights = {cls: 1.0 / count for cls, count in class_counts.items()}
sample_weights = [class_weights[label] for label in train_labels]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

# 创建数据加载器 - 修改部分：使用不同的路径
train_dataset = SWEDataset(Config.train_data_path, Config.mask_data_path, train_files, transform=train_transform)
val_dataset = SWEDataset(Config.train_data_path, Config.mask_data_path, val_files, transform=val_transform)
test_dataset = SWEDataset(Config.test_data_path, Config.mask_data_path, test_files, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, sampler=sampler, 
                          num_workers=Config.num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=Config.batch_size, shuffle=False, 
                        num_workers=Config.num_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=Config.batch_size, shuffle=False, 
                         num_workers=Config.num_workers, pin_memory=True)

# 模型定义：使用efficientnet_b4 - 保持不变
class SWEClassifier(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        # 加载预训练的EfficientNet-B4
        self.backbone = models.efficientnet_b4(weights=models.EfficientNet_B4_Weights.DEFAULT)
        
        # 获取分类器的输入特征数
        in_features = self.backbone.classifier[1].in_features
        
        # 替换分类器
        self.backbone.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)

# 计算评估指标 - 保持不变
def calculate_metrics(labels, preds, probs, class_idx, num_classes):
    # 转换为二分类问题
    binary_labels = (np.array(labels) == class_idx).astype(int)
    binary_preds = (np.array(preds) == class_idx).astype(int)
    
    # 如果所有样本都属于同一类别，则无法计算ROC
    if len(np.unique(binary_labels)) < 2:
        return {
            'sensitivity': 0.0,
            'specificity': 0.0,
            'ppv': 0.0,
            'npv': 0.0,
            'lr_pos': 0.0,
            'lr_neg': 0.0,
            'auc': 0.0,
            'auc_ci_lower': 0.0,
            'auc_ci_upper': 0.0
        }
    
    binary_probs = np.array(probs)[:, class_idx]
    
    # 计算混淆矩阵元素
    cm = confusion_matrix(binary_labels, binary_preds)
    if cm.size == 1:  # 只有一个类别的情况
        if binary_labels.sum() == 0:  # 所有样本都是负类
            tn = cm[0, 0]
            fp = 0
            fn = 0
            tp = 0
        else:  # 所有样本都是正类
            tn = 0
            fp = 0
            fn = 0
            tp = cm[0, 0]
    else:
        tn, fp, fn, tp = cm.ravel()
    
    # 计算各项指标
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    ppv = tp / (tp + fp) if (tp + fp) > 0 else 0
    npv = tn / (tn + fn) if (tn + fn) > 0 else 0
    lr_pos = sensitivity / (1 - specificity) if (1 - specificity) > 0 else float('inf')
    lr_neg = (1 - sensitivity) / specificity if specificity > 0 else float('inf')
    
    # 计算AUC及其95%置信区间
    try:
        auc_score = roc_auc_score(binary_labels, binary_probs)
        
        # 使用bootstrap方法计算95%置信区间
        n_bootstraps = Config.n_bootstrap
        bootstrapped_scores = []
        
        rng = np.random.RandomState(Config.seed)
        for i in range(n_bootstraps):
            # 采样索引
            indices = rng.randint(0, len(binary_probs), len(binary_probs))
            
            # 检查引导样本中是否有两个类别
            if len(np.unique(binary_labels[indices])) < 2:
                continue
            
            score = roc_auc_score(binary_labels[indices], binary_probs[indices])
            bootstrapped_scores.append(score)
        
        if bootstrapped_scores:
            sorted_scores = np.array(bootstrapped_scores)
            sorted_scores.sort()
            confidence_lower = sorted_scores[int(0.025 * len(sorted_scores))]
            confidence_upper = sorted_scores[int(0.975 * len(sorted_scores))]
        else:
            confidence_lower = confidence_upper = auc_score
    except Exception as e:
        print(f"Error calculating AUC for class {class_idx}: {str(e)}")
        auc_score = 0.0
        confidence_lower = 0.0
        confidence_upper = 0.0
    
    return {
        'sensitivity': sensitivity,
        'specificity': specificity,
        'ppv': ppv,
        'npv': npv,
        'lr_pos': lr_pos,
        'lr_neg': lr_neg,
        'auc': auc_score,
        'auc_ci_lower': confidence_lower,
        'auc_ci_upper': confidence_upper
    }

# 评估模型并计算所有指标 - 保持不变
def evaluate_model_full(model, data_loader, dataset_name, device, num_classes):
    model.eval()
    
    all_labels = []
    all_preds = []
    all_probs = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(data_loader, desc=f"Evaluating {dataset_name}"):
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    # 计算总体准确率
    accuracy = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds, labels=range(num_classes))
    
    print(f"\n{dataset_name} Accuracy: {accuracy:.4f}")
    print(f"{dataset_name} Confusion Matrix:")
    print(cm)
    
    # 计算每个类别的详细指标
    metrics = {}
    for i in range(num_classes):
        class_metrics = calculate_metrics(all_labels, all_preds, all_probs, i, num_classes)
        metrics[i] = class_metrics
        
        print(f"\n{dataset_name} Class {i} Metrics:")
        print(f"Sensitivity (Recall): {class_metrics['sensitivity']:.4f}")
        print(f"Specificity: {class_metrics['specificity']:.4f}")
        print(f"PPV (Precision): {class_metrics['ppv']:.4f}")
        print(f"NPV: {class_metrics['npv']:.4f}")
        print(f"LR+: {class_metrics['lr_pos']:.4f}")
        print(f"LR-: {class_metrics['lr_neg']:.4f}")
        print(f"AUC: {class_metrics['auc']:.4f} (95% CI: {class_metrics['auc_ci_lower']:.4f}-{class_metrics['auc_ci_upper']:.4f})")
    
    # 绘制ROC曲线
    plt.figure(figsize=(10, 8))
    colors = ['blue', 'red', 'green', 'purple', 'orange']
    
    for i in range(num_classes):
        binary_labels = (np.array(all_labels) == i).astype(int)
        binary_probs = np.array(all_probs)[:, i]
        
        # 检查是否至少有两个类别
        if len(np.unique(binary_labels)) < 2:
            print(f"Skipping ROC for class {i} - only one class present")
            continue
            
        fpr, tpr, _ = roc_curve(binary_labels, binary_probs)
        roc_auc = auc(fpr, tpr)
        
        plt.plot(fpr, tpr, color=colors[i % len(colors)],
                 lw=2, label=f'Class {i} (AUC = {roc_auc:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('1 - Specificity')
    plt.ylabel('Sensitivity')
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.savefig(f"{Config.output_dir}/results/{dataset_name.lower()}_roc_curve.tif", dpi=300, bbox_inches='tight')
    plt.close()
    
    return metrics, accuracy, cm, np.array(all_labels), np.array(all_probs)

# 检查是否有可用的 GPU - 保持不变
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = SWEClassifier(num_classes=Config.num_classes)
model = model.to(device)

# 打印模型结构 - 保持不变
print("Model architecture:")
print(model)

# 优化器和损失函数 - 保持不变
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.AdamW(model.parameters(), lr=Config.lr, weight_decay=Config.weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=10, verbose=True
)

# 训练函数 - 保持不变
def train_model():
    best_acc = 0.0
    train_loss, val_loss = [], []
    train_acc, val_acc = [], []
    patience = 20  # 早停机制的耐心值
    no_improvement_epochs = 0  # 记录没有提升的周期数

    for epoch in range(Config.num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        loop = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{Config.num_epochs}')
        for inputs, labels in loop:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item() * inputs.size(0)
            
            # 更新进度条
            loop.set_postfix(loss=loss.item(), acc=correct/total)

        # 计算训练指标
        epoch_train_loss = running_loss / len(train_loader.dataset)
        epoch_train_acc = correct / total
        train_loss.append(epoch_train_loss)
        train_acc.append(epoch_train_acc)

        # 验证阶段
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                val_running_loss += loss.item() * inputs.size(0)

        # 计算验证指标
        epoch_val_loss = val_running_loss / len(val_loader.dataset)
        epoch_val_acc = val_correct / val_total
        val_loss.append(epoch_val_loss)
        val_acc.append(epoch_val_acc)
        
        # 更新学习率
        scheduler.step(epoch_val_acc)

        # 保存最佳模型
        if epoch_val_acc > best_acc:
            best_acc = epoch_val_acc
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': epoch_val_acc,
            }, f"{Config.output_dir}/models/best_model.pth")
            no_improvement_epochs = 0  # 验证集准确率提升，重置计数器
            print(f"New best model saved with val accuracy: {best_acc:.4f}")
        else:
            no_improvement_epochs += 1  # 验证集准确率未提升，计数器加1

        print(f"Epoch {epoch + 1}/{Config.num_epochs} | "
              f"Train Loss: {epoch_train_loss:.4f} Acc: {epoch_train_acc:.4f} | "
              f"Val Loss: {epoch_val_loss:.4f} Acc: {epoch_val_acc:.4f} | "
              f"LR: {optimizer.param_groups[0]['lr']:.2e}")

        # 早停机制
        if no_improvement_epochs >= patience:
            print(f"Early stopping at epoch {epoch + 1} due to no improvement in validation accuracy for {patience} epochs.")
            break

    # 绘制训练曲线
    plt.figure(figsize=(12, 10))
    
    plt.subplot(2, 1, 1)
    plt.plot(train_loss, label='Train Loss')
    plt.plot(val_loss, label='Val Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    plt.subplot(2, 1, 2)
    plt.plot(train_acc, label='Train Accuracy')
    plt.plot(val_acc, label='Val Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(f"{Config.output_dir}/results/training_curve.tif", dpi=300, bbox_inches='tight')
    plt.close()

    return train_loss, train_acc, val_loss, val_acc


# 执行训练与评估 - 保持不变
if __name__ == "__main__":
    print("开始训练模型...")
    train_loss, train_acc, val_loss, val_acc = train_model()
    
    print("\n加载最佳模型...")
    checkpoint = torch.load(f"{Config.output_dir}/models/best_model.pth")
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    print("\n评估训练集性能...")
    train_metrics, train_accuracy, train_cm, train_labels, train_probs = evaluate_model_full(
        model, train_loader, "Training Set", device, Config.num_classes)
    
    print("\n评估测试集性能...")
    test_metrics, test_accuracy, test_cm, test_labels, test_probs = evaluate_model_full(
        model, test_loader, "Test Set", device, Config.num_classes)
    
    # 保存评估结果到文件
    with open(f"{Config.output_dir}/results/metrics_summary.txt", 'w') as f:
        f.write("模型评估结果汇总\n")
        f.write("="*50 + "\n")
        f.write(f"训练集准确率: {train_accuracy:.4f}\n")
        f.write(f"测试集准确率: {test_accuracy:.4f}\n\n")
        
        f.write("训练集混淆矩阵:\n")
        f.write(f"{train_cm}\n\n")
        
        f.write("测试集混淆矩阵:\n")
        f.write(f"{test_cm}\n\n")
        
        for i in range(Config.num_classes):
            f.write(f"类别 {i} (F{i+2}) 指标\n")
            f.write("-"*30 + "\n")
            
            f.write("训练集:\n")
            f.write(f"Sensitivity: {train_metrics[i]['sensitivity']:.4f}\n")
            f.write(f"Specificity: {train_metrics[i]['specificity']:.4f}\n")
            f.write(f"PPV: {train_metrics[i]['ppv']:.4f}\n")
            f.write(f"NPV: {train_metrics[i]['npv']:.4f}\n")
            f.write(f"LR+: {train_metrics[i]['lr_pos']:.4f}\n")
            f.write(f"LR-: {train_metrics[i]['lr_neg']:.4f}\n")
            f.write(f"AUC: {train_metrics[i]['auc']:.4f} (95% CI: {train_metrics[i]['auc_ci_lower']:.4f}-{train_metrics[i]['auc_ci_upper']:.4f})\n\n")
            
            f.write("测试集:\n")
            f.write(f"Sensitivity: {test_metrics[i]['sensitivity']:.4f}\n")
            f.write(f"Specificity: {test_metrics[i]['specificity']:.4f}\n")
            f.write(f"PPV: {test_metrics[i]['ppv']:.4f}\n")
            f.write(f"NPV: {test_metrics[i]['npv']:.4f}\n")
            f.write(f"LR+: {test_metrics[i]['lr_pos']:.4f}\n")
            f.write(f"LR-: {test_metrics[i]['lr_neg']:.4f}\n")
            f.write(f"AUC: {test_metrics[i]['auc']:.4f} (95% CI: {test_metrics[i]['auc_ci_lower']:.4f}-{test_metrics[i]['auc_ci_upper']:.4f})\n\n")

    # 保存分类结果（二分类格式，适用于ROC曲线绘制）
    for i in range(Config.num_classes):
        # 创建当前类别的二分类标签（1为当前类别，0为其他类别）
        train_binary_labels = (train_labels == i).astype(int)
        test_binary_labels = (test_labels == i).astype(int)
    
        # 保存训练集结果
        np.save(f"{Config.output_dir}/results/train/F{i+2}_binary_true.npy", train_binary_labels)
        np.save(f"{Config.output_dir}/results/train/F{i+2}_proba.npy", train_probs[:, i])
    
        # 保存测试集结果
        np.save(f"{Config.output_dir}/results/test/F{i+2}_binary_true.npy", test_binary_labels)
        np.save(f"{Config.output_dir}/results/test/F{i+2}_proba.npy", test_probs[:, i])
        
    print("\n模型训练和评估完成!")
    print(f"结果已保存至: {Config.output_dir}/results")

In [None]:
import os
import cv2
import numpy as np
import torch
import random
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, auc, roc_auc_score
from sklearn.utils import resample
import matplotlib.pyplot as plt
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from collections import Counter
from torch.utils.data.sampler import WeightedRandomSampler

# 配置参数 - 添加了测试集数据路径
class Config:
    train_data_path = "C:/Users/YK/Desktop/spleen_train"  # 训练集路径
    test_data_path = "C:/Users/YK/Desktop/spleen_test"    # 测试集路径
    mask_data_path = "C:/Users/YK/Desktop/spleen-mask_pngs"
    output_dir = "./outputs(LF-spleen)"
    seed = 42
    img_size = (224, 224)
    batch_size = 16
    num_workers = 0  # Windows系统下设置为0
    num_epochs = 120
    lr = 1e-4
    num_classes = 3
    val_size = 0.1   # 只保留验证集比例，移除测试集比例
    weight_decay = 1e-5
    n_bootstrap = 1000  # 用于计算置信区间的bootstrap抽样次数

# 设置随机种子，确保结果可复现
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(Config.seed)

# 创建输出目录
os.makedirs(f"{Config.output_dir}/models", exist_ok=True)
os.makedirs(f"{Config.output_dir}/results", exist_ok=True)
os.makedirs(f"{Config.output_dir}/results/train", exist_ok=True)
os.makedirs(f"{Config.output_dir}/results/test", exist_ok=True)

# 自定义数据集类 - 保持不变
class SWEDataset(Dataset):
    def __init__(self, raw_path, mask_path, filenames, transform=None):
        self.raw_path = raw_path
        self.mask_path = mask_path
        self.filenames = filenames
        self.transform = transform
        self.label_map = {'F2': 0, 'F3': 1, 'F4': 2}  # 根据实际标签修改

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        # 加载原始图像和掩膜
        filename = self.filenames[idx]
        raw_img = cv2.cvtColor(cv2.imread(f"{self.raw_path}/{filename}"), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(f"{self.mask_path}/{filename}", cv2.IMREAD_GRAYSCALE)

        # 应用掩膜
        masked_img = cv2.bitwise_and(raw_img, raw_img, mask=mask)

        # 针对性预处理：去除噪声、增强对比度
        blurred = cv2.GaussianBlur(masked_img, (5, 5), 0)
        
        # 转换为LAB颜色空间进行CLAHE增强
        lab = cv2.cvtColor(blurred, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        cl = clahe.apply(l)
        limg = cv2.merge((cl, a, b))
        enhanced_img = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB)

        # 调整大小
        resized_img = cv2.resize(enhanced_img, Config.img_size)

        if self.transform:
            augmented = self.transform(image=resized_img)
            final_img = augmented['image']
        else:
            # 默认转换
            final_img = transforms.ToTensor()(resized_img)
            final_img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(final_img)

        # 提取标签 (假设文件名格式为"F0_001.jpg")
        label_key = filename.split('_')[0]
        if label_key not in self.label_map:
            raise ValueError(f"Unknown label '{label_key}' in filename: {filename}")
        
        label = self.label_map[label_key]

        return final_img, label

# 数据增强（使用 albumentations）- 保持不变
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

# 加载训练集文件
train_files = [f for f in os.listdir(Config.train_data_path) 
             if f.endswith(('.png', '.jpg', '.jpeg')) and 
             f.split('_')[0] in ['F2', 'F3', 'F4']]  # 确保只包含有效标签的文件

# 加载测试集文件
test_files = [f for f in os.listdir(Config.test_data_path) 
             if f.endswith(('.png', '.jpg', '.jpeg')) and 
             f.split('_')[0] in ['F2', 'F3', 'F4']]  # 确保只包含有效标签的文件

# 检查文件数量
if len(train_files) == 0:
    raise ValueError("No valid image files found in the training data path")

if len(test_files) == 0:
    raise ValueError("No valid image files found in the test data path")

# 从训练集中划分验证集
train_labels = [f.split('_')[0] for f in train_files]
train_files, val_files = train_test_split(train_files, test_size=Config.val_size, 
                                         random_state=Config.seed, stratify=train_labels)

print(f"Dataset sizes: Train={len(train_files)}, Val={len(val_files)}, Test={len(test_files)}")

# 处理数据不平衡 - 保持不变
train_labels = [SWEDataset(Config.train_data_path, Config.mask_data_path, train_files).label_map[f.split('_')[0]] for f in train_files]
class_counts = Counter(train_labels)
print(f"Class distribution in training set: {class_counts}")

class_weights = {cls: 1.0 / count for cls, count in class_counts.items()}
sample_weights = [class_weights[label] for label in train_labels]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

# 创建数据加载器 - 修改部分：使用不同的路径
train_dataset = SWEDataset(Config.train_data_path, Config.mask_data_path, train_files, transform=train_transform)
val_dataset = SWEDataset(Config.train_data_path, Config.mask_data_path, val_files, transform=val_transform)
test_dataset = SWEDataset(Config.test_data_path, Config.mask_data_path, test_files, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, sampler=sampler, 
                          num_workers=Config.num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=Config.batch_size, shuffle=False, 
                        num_workers=Config.num_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=Config.batch_size, shuffle=False, 
                         num_workers=Config.num_workers, pin_memory=True)

# 模型定义：使用efficientnet_b4 - 保持不变
class SWEClassifier(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        # 加载预训练的EfficientNet-B4
        self.backbone = models.efficientnet_b4(weights=models.EfficientNet_B4_Weights.DEFAULT)
        
        # 获取分类器的输入特征数
        in_features = self.backbone.classifier[1].in_features
        
        # 替换分类器
        self.backbone.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)

# 计算评估指标 - 保持不变
def calculate_metrics(labels, preds, probs, class_idx, num_classes):
    # 转换为二分类问题
    binary_labels = (np.array(labels) == class_idx).astype(int)
    binary_preds = (np.array(preds) == class_idx).astype(int)
    
    # 如果所有样本都属于同一类别，则无法计算ROC
    if len(np.unique(binary_labels)) < 2:
        return {
            'sensitivity': 0.0,
            'specificity': 0.0,
            'ppv': 0.0,
            'npv': 0.0,
            'lr_pos': 0.0,
            'lr_neg': 0.0,
            'auc': 0.0,
            'auc_ci_lower': 0.0,
            'auc_ci_upper': 0.0
        }
    
    binary_probs = np.array(probs)[:, class_idx]
    
    # 计算混淆矩阵元素
    cm = confusion_matrix(binary_labels, binary_preds)
    if cm.size == 1:  # 只有一个类别的情况
        if binary_labels.sum() == 0:  # 所有样本都是负类
            tn = cm[0, 0]
            fp = 0
            fn = 0
            tp = 0
        else:  # 所有样本都是正类
            tn = 0
            fp = 0
            fn = 0
            tp = cm[0, 0]
    else:
        tn, fp, fn, tp = cm.ravel()
    
    # 计算各项指标
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    ppv = tp / (tp + fp) if (tp + fp) > 0 else 0
    npv = tn / (tn + fn) if (tn + fn) > 0 else 0
    lr_pos = sensitivity / (1 - specificity) if (1 - specificity) > 0 else float('inf')
    lr_neg = (1 - sensitivity) / specificity if specificity > 0 else float('inf')
    
    # 计算AUC及其95%置信区间
    try:
        auc_score = roc_auc_score(binary_labels, binary_probs)
        
        # 使用bootstrap方法计算95%置信区间
        n_bootstraps = Config.n_bootstrap
        bootstrapped_scores = []
        
        rng = np.random.RandomState(Config.seed)
        for i in range(n_bootstraps):
            # 采样索引
            indices = rng.randint(0, len(binary_probs), len(binary_probs))
            
            # 检查引导样本中是否有两个类别
            if len(np.unique(binary_labels[indices])) < 2:
                continue
            
            score = roc_auc_score(binary_labels[indices], binary_probs[indices])
            bootstrapped_scores.append(score)
        
        if bootstrapped_scores:
            sorted_scores = np.array(bootstrapped_scores)
            sorted_scores.sort()
            confidence_lower = sorted_scores[int(0.025 * len(sorted_scores))]
            confidence_upper = sorted_scores[int(0.975 * len(sorted_scores))]
        else:
            confidence_lower = confidence_upper = auc_score
    except Exception as e:
        print(f"Error calculating AUC for class {class_idx}: {str(e)}")
        auc_score = 0.0
        confidence_lower = 0.0
        confidence_upper = 0.0
    
    return {
        'sensitivity': sensitivity,
        'specificity': specificity,
        'ppv': ppv,
        'npv': npv,
        'lr_pos': lr_pos,
        'lr_neg': lr_neg,
        'auc': auc_score,
        'auc_ci_lower': confidence_lower,
        'auc_ci_upper': confidence_upper
    }

# 评估模型并计算所有指标 - 保持不变
def evaluate_model_full(model, data_loader, dataset_name, device, num_classes):
    model.eval()
    
    all_labels = []
    all_preds = []
    all_probs = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(data_loader, desc=f"Evaluating {dataset_name}"):
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    # 计算总体准确率
    accuracy = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds, labels=range(num_classes))
    
    print(f"\n{dataset_name} Accuracy: {accuracy:.4f}")
    print(f"{dataset_name} Confusion Matrix:")
    print(cm)
    
    # 计算每个类别的详细指标
    metrics = {}
    for i in range(num_classes):
        class_metrics = calculate_metrics(all_labels, all_preds, all_probs, i, num_classes)
        metrics[i] = class_metrics
        
        print(f"\n{dataset_name} Class {i} Metrics:")
        print(f"Sensitivity (Recall): {class_metrics['sensitivity']:.4f}")
        print(f"Specificity: {class_metrics['specificity']:.4f}")
        print(f"PPV (Precision): {class_metrics['ppv']:.4f}")
        print(f"NPV: {class_metrics['npv']:.4f}")
        print(f"LR+: {class_metrics['lr_pos']:.4f}")
        print(f"LR-: {class_metrics['lr_neg']:.4f}")
        print(f"AUC: {class_metrics['auc']:.4f} (95% CI: {class_metrics['auc_ci_lower']:.4f}-{class_metrics['auc_ci_upper']:.4f})")
    
    # 绘制ROC曲线
    plt.figure(figsize=(10, 8))
    colors = ['blue', 'red', 'green', 'purple', 'orange']
    
    for i in range(num_classes):
        binary_labels = (np.array(all_labels) == i).astype(int)
        binary_probs = np.array(all_probs)[:, i]
        
        # 检查是否至少有两个类别
        if len(np.unique(binary_labels)) < 2:
            print(f"Skipping ROC for class {i} - only one class present")
            continue
            
        fpr, tpr, _ = roc_curve(binary_labels, binary_probs)
        roc_auc = auc(fpr, tpr)
        
        plt.plot(fpr, tpr, color=colors[i % len(colors)],
                 lw=2, label=f'Class {i} (AUC = {roc_auc:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('1 - Specificity')
    plt.ylabel('Sensitivity')
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.savefig(f"{Config.output_dir}/results/{dataset_name.lower()}_roc_curve.tif", dpi=300, bbox_inches='tight')
    plt.close()
    
    return metrics, accuracy, cm, np.array(all_labels), np.array(all_probs)

# 检查是否有可用的 GPU - 保持不变
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = SWEClassifier(num_classes=Config.num_classes)
model = model.to(device)

# 打印模型结构 - 保持不变
print("Model architecture:")
print(model)

# 优化器和损失函数 - 保持不变
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.AdamW(model.parameters(), lr=Config.lr, weight_decay=Config.weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=10, verbose=True
)

# 训练函数 - 保持不变
def train_model():
    best_acc = 0.0
    train_loss, val_loss = [], []
    train_acc, val_acc = [], []
    patience = 20  # 早停机制的耐心值
    no_improvement_epochs = 0  # 记录没有提升的周期数

    for epoch in range(Config.num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        loop = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{Config.num_epochs}')
        for inputs, labels in loop:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item() * inputs.size(0)
            
            # 更新进度条
            loop.set_postfix(loss=loss.item(), acc=correct/total)

        # 计算训练指标
        epoch_train_loss = running_loss / len(train_loader.dataset)
        epoch_train_acc = correct / total
        train_loss.append(epoch_train_loss)
        train_acc.append(epoch_train_acc)

        # 验证阶段
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                val_running_loss += loss.item() * inputs.size(0)

        # 计算验证指标
        epoch_val_loss = val_running_loss / len(val_loader.dataset)
        epoch_val_acc = val_correct / val_total
        val_loss.append(epoch_val_loss)
        val_acc.append(epoch_val_acc)
        
        # 更新学习率
        scheduler.step(epoch_val_acc)

        # 保存最佳模型
        if epoch_val_acc > best_acc:
            best_acc = epoch_val_acc
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': epoch_val_acc,
            }, f"{Config.output_dir}/models/best_model.pth")
            no_improvement_epochs = 0  # 验证集准确率提升，重置计数器
            print(f"New best model saved with val accuracy: {best_acc:.4f}")
        else:
            no_improvement_epochs += 1  # 验证集准确率未提升，计数器加1

        print(f"Epoch {epoch + 1}/{Config.num_epochs} | "
              f"Train Loss: {epoch_train_loss:.4f} Acc: {epoch_train_acc:.4f} | "
              f"Val Loss: {epoch_val_loss:.4f} Acc: {epoch_val_acc:.4f} | "
              f"LR: {optimizer.param_groups[0]['lr']:.2e}")

        # 早停机制
        if no_improvement_epochs >= patience:
            print(f"Early stopping at epoch {epoch + 1} due to no improvement in validation accuracy for {patience} epochs.")
            break

    # 绘制训练曲线
    plt.figure(figsize=(12, 10))
    
    plt.subplot(2, 1, 1)
    plt.plot(train_loss, label='Train Loss')
    plt.plot(val_loss, label='Val Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    plt.subplot(2, 1, 2)
    plt.plot(train_acc, label='Train Accuracy')
    plt.plot(val_acc, label='Val Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(f"{Config.output_dir}/results/training_curve.tif", dpi=300, bbox_inches='tight')
    plt.close()

    return train_loss, train_acc, val_loss, val_acc


# 执行训练与评估 - 保持不变
if __name__ == "__main__":
    print("开始训练模型...")
    train_loss, train_acc, val_loss, val_acc = train_model()
    
    print("\n加载最佳模型...")
    checkpoint = torch.load(f"{Config.output_dir}/models/best_model.pth")
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    print("\n评估训练集性能...")
    train_metrics, train_accuracy, train_cm, train_labels, train_probs = evaluate_model_full(
        model, train_loader, "Training Set", device, Config.num_classes)
    
    print("\n评估测试集性能...")
    test_metrics, test_accuracy, test_cm, test_labels, test_probs = evaluate_model_full(
        model, test_loader, "Test Set", device, Config.num_classes)
    
    # 保存评估结果到文件
    with open(f"{Config.output_dir}/results/metrics_summary.txt", 'w') as f:
        f.write("模型评估结果汇总\n")
        f.write("="*50 + "\n")
        f.write(f"训练集准确率: {train_accuracy:.4f}\n")
        f.write(f"测试集准确率: {test_accuracy:.4f}\n\n")
        
        f.write("训练集混淆矩阵:\n")
        f.write(f"{train_cm}\n\n")
        
        f.write("测试集混淆矩阵:\n")
        f.write(f"{test_cm}\n\n")
        
        for i in range(Config.num_classes):
            f.write(f"类别 {i} (F{i+2}) 指标\n")
            f.write("-"*30 + "\n")
            
            f.write("训练集:\n")
            f.write(f"Sensitivity: {train_metrics[i]['sensitivity']:.4f}\n")
            f.write(f"Specificity: {train_metrics[i]['specificity']:.4f}\n")
            f.write(f"PPV: {train_metrics[i]['ppv']:.4f}\n")
            f.write(f"NPV: {train_metrics[i]['npv']:.4f}\n")
            f.write(f"LR+: {train_metrics[i]['lr_pos']:.4f}\n")
            f.write(f"LR-: {train_metrics[i]['lr_neg']:.4f}\n")
            f.write(f"AUC: {train_metrics[i]['auc']:.4f} (95% CI: {train_metrics[i]['auc_ci_lower']:.4f}-{train_metrics[i]['auc_ci_upper']:.4f})\n\n")
            
            f.write("测试集:\n")
            f.write(f"Sensitivity: {test_metrics[i]['sensitivity']:.4f}\n")
            f.write(f"Specificity: {test_metrics[i]['specificity']:.4f}\n")
            f.write(f"PPV: {test_metrics[i]['ppv']:.4f}\n")
            f.write(f"NPV: {test_metrics[i]['npv']:.4f}\n")
            f.write(f"LR+: {test_metrics[i]['lr_pos']:.4f}\n")
            f.write(f"LR-: {test_metrics[i]['lr_neg']:.4f}\n")
            f.write(f"AUC: {test_metrics[i]['auc']:.4f} (95% CI: {test_metrics[i]['auc_ci_lower']:.4f}-{test_metrics[i]['auc_ci_upper']:.4f})\n\n")

    # 保存分类结果（二分类格式，适用于ROC曲线绘制）
    for i in range(Config.num_classes):
        # 创建当前类别的二分类标签（1为当前类别，0为其他类别）
        train_binary_labels = (train_labels == i).astype(int)
        test_binary_labels = (test_labels == i).astype(int)
    
        # 保存训练集结果
        np.save(f"{Config.output_dir}/results/train/F{i+2}_binary_true.npy", train_binary_labels)
        np.save(f"{Config.output_dir}/results/train/F{i+2}_proba.npy", train_probs[:, i])
    
        # 保存测试集结果
        np.save(f"{Config.output_dir}/results/test/F{i+2}_binary_true.npy", test_binary_labels)
        np.save(f"{Config.output_dir}/results/test/F{i+2}_proba.npy", test_probs[:, i])
        
    print("\n模型训练和评估完成!")
    print(f"结果已保存至: {Config.output_dir}/results")

In [None]:
import os
import cv2
import numpy as np
import torch
import random
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, auc, roc_auc_score
from sklearn.utils import resample
import matplotlib.pyplot as plt
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from collections import Counter
from torch.utils.data.sampler import WeightedRandomSampler
from torch.nn import functional as F

class Config:
    # 训练集路径
    train_liver_raw_data_path = "C:/Users/YK/Desktop/train_liver"
    train_liver_mask_data_path = "C:/Users/YK/Desktop/train_liver-mask_pngs"
    train_spleen_raw_data_path = "C:/Users/YK/Desktop/train_spleen"
    train_spleen_mask_data_path = "C:/Users/YK/Desktop/train_spleen-mask_pngs"
    
    # 测试集路径
    test_liver_raw_data_path = "C:/Users/YK/Desktop/test_liver"
    test_liver_mask_data_path = "C:/Users/YK/Desktop/test_liver-mask_pngs"
    test_spleen_raw_data_path = "C:/Users/YK/Desktop/test_spleen"
    test_spleen_mask_data_path = "C:/Users/YK/Desktop/test_spleen-mask_pngs"
    
    output_dir = "./outputs(LF-combine)"
    seed = 42
    img_size = (224, 224)
    batch_size = 16
    num_workers = 0  # Windows下建议设置为0
    num_epochs = 120
    lr = 1e-4
    num_classes = 3
    val_size = 0.1  # 仅保留验证集比例（从训练集中划分）
    weight_decay = 1e-5
    n_bootstrap = 1000

# 设置随机种子，确保结果可复现
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(Config.seed)

# 创建输出目录
os.makedirs(f"{Config.output_dir}/models", exist_ok=True)
os.makedirs(f"{Config.output_dir}/results", exist_ok=True)
os.makedirs(f"{Config.output_dir}/results/train", exist_ok=True)
os.makedirs(f"{Config.output_dir}/results/test", exist_ok=True)
os.makedirs(f"{Config.output_dir}/grad_cam", exist_ok=True)

# 自定义数据集类
class SWEDataset(Dataset):
    def __init__(self, liver_raw_path, liver_mask_path, spleen_raw_path, spleen_mask_path, filenames, transform=None):
        self.liver_raw_path = liver_raw_path
        self.liver_mask_path = liver_mask_path
        self.spleen_raw_path = spleen_raw_path
        self.spleen_mask_path = spleen_mask_path
        self.filenames = filenames
        self.transform = transform
        self.label_map = {'F2': 0, 'F3': 1, 'F4': 2}  # 根据实际标签修改

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        # 加载肝脏原始图像和掩膜
        filename = self.filenames[idx]
        liver_raw_img = cv2.imread(f"{self.liver_raw_path}/{filename}")
        liver_exists = liver_raw_img is not None
        if liver_exists:
            liver_raw_img = cv2.cvtColor(liver_raw_img, cv2.COLOR_BGR2RGB)
            liver_mask = cv2.imread(f"{self.liver_mask_path}/{filename}", cv2.IMREAD_GRAYSCALE)
            # 应用肝脏掩膜
            liver_masked_img = cv2.bitwise_and(liver_raw_img, liver_raw_img, mask=liver_mask)
            # 针对性预处理：去除噪声、增强对比度
            liver_blurred = cv2.GaussianBlur(liver_masked_img, (5, 5), 0)
            liver_clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
            liver_lab = cv2.cvtColor(liver_blurred, cv2.COLOR_RGB2LAB)
            liver_l, liver_a, liver_b = cv2.split(liver_lab)
            liver_cl = liver_clahe.apply(liver_l)
            liver_limg = cv2.merge((liver_cl, liver_a, liver_b))
            liver_masked_img = cv2.cvtColor(liver_limg, cv2.COLOR_LAB2RGB)
            # 调整大小
            liver_masked_img = cv2.resize(liver_masked_img, Config.img_size)
        else:
            liver_masked_img = np.zeros((Config.img_size[0], Config.img_size[1], 3), dtype=np.uint8)

        # 加载脾脏原始图像和掩膜
        spleen_img_path = f"{self.spleen_raw_path}/{filename}"
        spleen_raw_img = cv2.imread(spleen_img_path)
        spleen_exists = spleen_raw_img is not None
        if spleen_exists:
            spleen_raw_img = cv2.cvtColor(spleen_raw_img, cv2.COLOR_BGR2RGB)
            spleen_mask = cv2.imread(f"{self.spleen_mask_path}/{filename}", cv2.IMREAD_GRAYSCALE)
            # 应用脾脏掩膜
            spleen_masked_img = cv2.bitwise_and(spleen_raw_img, spleen_raw_img, mask=spleen_mask)
            # 针对性预处理：去除噪声、增强对比度
            spleen_blurred = cv2.GaussianBlur(spleen_masked_img, (5, 5), 0)
            spleen_clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
            spleen_lab = cv2.cvtColor(spleen_blurred, cv2.COLOR_RGB2LAB)
            spleen_l, spleen_a, spleen_b = cv2.split(spleen_lab)
            spleen_cl = spleen_clahe.apply(spleen_l)
            spleen_limg = cv2.merge((spleen_cl, spleen_a, spleen_b))
            spleen_masked_img = cv2.cvtColor(spleen_limg, cv2.COLOR_LAB2RGB)
            # 调整大小
            spleen_masked_img = cv2.resize(spleen_masked_img, Config.img_size)
        else:
            spleen_masked_img = np.zeros((Config.img_size[0], Config.img_size[1], 3), dtype=np.uint8)

        if self.transform:
            liver_augmented = self.transform(image=liver_masked_img)
            liver_masked_img = liver_augmented['image']
            spleen_augmented = self.transform(image=spleen_masked_img)
            spleen_masked_img = spleen_augmented['image']
        else:
            liver_masked_img = transforms.ToTensor()(liver_masked_img)
            liver_masked_img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(liver_masked_img)
            spleen_masked_img = transforms.ToTensor()(spleen_masked_img)
            spleen_masked_img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(spleen_masked_img)

        # 提取标签 (假设文件名格式为"F0_001.jpg")
        label = self.label_map[filename.split('_')[0]]

        return liver_masked_img, spleen_masked_img, label, liver_exists, spleen_exists

# 数据增强（使用 albumentations）
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

# 加载所有文件名并划分数据集
# 从训练集肝脏数据路径加载文件
train_files = [f for f in os.listdir(Config.train_liver_raw_data_path) 
             if f.endswith(('.png', '.jpg', '.jpeg')) and 
             f.split('_')[0] in ['F2', 'F3', 'F4']]  # 确保只包含有效标签的文件

# 从测试集肝脏数据路径加载文件
test_files = [f for f in os.listdir(Config.test_liver_raw_data_path) 
             if f.endswith(('.png', '.jpg', '.jpeg')) and 
             f.split('_')[0] in ['F2', 'F3', 'F4']]  # 确保只包含有效标签的文件

# 检查文件数量
if len(train_files) == 0:
    raise ValueError("No valid image files found in the training liver raw data path")

if len(test_files) == 0:
    raise ValueError("No valid image files found in the test liver raw data path")

# 从训练集中划分验证集
train_labels = [f.split('_')[0] for f in train_files]
train_files, val_files = train_test_split(train_files, test_size=Config.val_size, 
                                         random_state=Config.seed, stratify=train_labels)

print(f"Dataset sizes: Train={len(train_files)}, Val={len(val_files)}, Test={len(test_files)}")

# 处理数据不平衡
train_labels = [SWEDataset(Config.train_liver_raw_data_path, Config.train_liver_mask_data_path, 
                          Config.train_spleen_raw_data_path, Config.train_spleen_mask_data_path, 
                          train_files).label_map[f.split('_')[0]] for f in train_files]
class_counts = Counter(train_labels)
print(f"Class distribution in training set: {class_counts}")

class_weights = {cls: 1.0 / count for cls, count in class_counts.items()}
sample_weights = [class_weights[label] for label in train_labels]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

# 创建数据加载器
train_dataset = SWEDataset(Config.train_liver_raw_data_path, Config.train_liver_mask_data_path, 
                          Config.train_spleen_raw_data_path, Config.train_spleen_mask_data_path, 
                          train_files, transform=train_transform)
val_dataset = SWEDataset(Config.train_liver_raw_data_path, Config.train_liver_mask_data_path, 
                         Config.train_spleen_raw_data_path, Config.train_spleen_mask_data_path, 
                         val_files, transform=val_transform)
test_dataset = SWEDataset(Config.test_liver_raw_data_path, Config.test_liver_mask_data_path, 
                         Config.test_spleen_raw_data_path, Config.test_spleen_mask_data_path, 
                         test_files, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, sampler=sampler, 
                          num_workers=Config.num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=Config.batch_size, shuffle=False, 
                        num_workers=Config.num_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=Config.batch_size, shuffle=False, 
                         num_workers=Config.num_workers, pin_memory=True)

# SE模块
class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

# 模型定义：使用efficientnet_b4
class SWEClassifier(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        # 肝脏分支
        self.liver_backbone = models.efficientnet_b4(weights=models.EfficientNet_B4_Weights.DEFAULT)
        # 使用完整的特征提取器（包括最后一层）
        self.liver_features = self.liver_backbone.features
        self.liver_avgpool = nn.AdaptiveAvgPool2d(1)
        self.liver_se = SELayer(1792)  # EfficientNet-B4最终特征图通道数为1792
        
        # 脾脏分支
        self.spleen_backbone = models.efficientnet_b4(weights=models.EfficientNet_B4_Weights.DEFAULT)
        self.spleen_features = self.spleen_backbone.features
        self.spleen_avgpool = nn.AdaptiveAvgPool2d(1)
        self.spleen_se = SELayer(1792)
        
        # 分类器
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(1792 * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, liver_x, spleen_x, liver_exists, spleen_exists):
        # 肝脏特征提取
        liver_x = self.liver_features(liver_x)
        liver_x = self.liver_avgpool(liver_x)
        liver_x = self.liver_se(liver_x)
        liver_x = torch.flatten(liver_x, 1)
        
        # 脾脏特征提取
        spleen_x = self.spleen_features(spleen_x)
        spleen_x = self.spleen_avgpool(spleen_x)
        spleen_x = self.spleen_se(spleen_x)
        spleen_x = torch.flatten(spleen_x, 1)
        
        # 动态权重调整
        liver_weight = liver_exists.float().unsqueeze(1).to(liver_x.device)
        spleen_weight = spleen_exists.float().unsqueeze(1).to(spleen_x.device)
        
        # 应用存在性权重
        liver_x = liver_x * liver_weight
        spleen_x = spleen_x * spleen_weight
        
        # 拼接特征
        combined_x = torch.cat((liver_x, spleen_x), dim=1)
        
        return self.classifier(combined_x)

# 计算评估指标
def calculate_metrics(labels, preds, probs, class_idx, num_classes):
    # 转换为二分类问题
    binary_labels = (np.array(labels) == class_idx).astype(int)
    binary_preds = (np.array(preds) == class_idx).astype(int)
    
    # 如果所有样本都属于同一类别，则无法计算ROC
    if len(np.unique(binary_labels)) < 2:
        return {
            'sensitivity': 0.0,
            'specificity': 0.0,
            'ppv': 0.0,
            'npv': 0.0,
            'lr_pos': 0.0,
            'lr_neg': 0.0,
            'auc': 0.0,
            'auc_ci_lower': 0.0,
            'auc_ci_upper': 0.0
        }
    
    binary_probs = np.array(probs)[:, class_idx]
    
    # 计算混淆矩阵元素
    cm = confusion_matrix(binary_labels, binary_preds)
    if cm.size == 1:  # 只有一个类别的情况
        if binary_labels.sum() == 0:  # 所有样本都是负类
            tn = cm[0, 0]
            fp = 0
            fn = 0
            tp = 0
        else:  # 所有样本都是正类
            tn = 0
            fp = 0
            fn = 0
            tp = cm[0, 0]
    else:
        tn, fp, fn, tp = cm.ravel()
    
    # 计算各项指标
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    ppv = tp / (tp + fp) if (tp + fp) > 0 else 0
    npv = tn / (tn + fn) if (tn + fn) > 0 else 0
    lr_pos = sensitivity / (1 - specificity) if (1 - specificity) > 0 else float('inf')
    lr_neg = (1 - sensitivity) / specificity if specificity > 0 else float('inf')
    
    # 计算AUC及其95%置信区间
    try:
        auc_score = roc_auc_score(binary_labels, binary_probs)
        
        # 使用bootstrap方法计算95%置信区间
        n_bootstraps = Config.n_bootstrap
        bootstrapped_scores = []
        
        rng = np.random.RandomState(Config.seed)
        for i in range(n_bootstraps):
            # 采样索引
            indices = rng.randint(0, len(binary_probs), len(binary_probs))
            
            # 检查引导样本中是否有两个类别
            if len(np.unique(binary_labels[indices])) < 2:
                continue
            
            score = roc_auc_score(binary_labels[indices], binary_probs[indices])
            bootstrapped_scores.append(score)
        
        if bootstrapped_scores:
            sorted_scores = np.array(bootstrapped_scores)
            sorted_scores.sort()
            confidence_lower = sorted_scores[int(0.025 * len(sorted_scores))]
            confidence_upper = sorted_scores[int(0.975 * len(sorted_scores))]
        else:
            confidence_lower = confidence_upper = auc_score
    except Exception as e:
        print(f"Error calculating AUC for class {class_idx}: {str(e)}")
        auc_score = 0.0
        confidence_lower = 0.0
        confidence_upper = 0.0
    
    return {
        'sensitivity': sensitivity,
        'specificity': specificity,
        'ppv': ppv,
        'npv': npv,
        'lr_pos': lr_pos,
        'lr_neg': lr_neg,
        'auc': auc_score,
        'auc_ci_lower': confidence_lower,
        'auc_ci_upper': confidence_upper
    }

# 评估模型并计算所有指标
def evaluate_model_full(model, data_loader, dataset_name, device, num_classes):
    model.eval()
    
    all_labels = []
    all_preds = []
    all_probs = []
    
    with torch.no_grad():
        for liver_inputs, spleen_inputs, labels, liver_exists, spleen_exists in tqdm(
            data_loader, desc=f"Evaluating {dataset_name}"):
            
            liver_inputs, spleen_inputs, labels = liver_inputs.to(device), spleen_inputs.to(device), labels.to(device)
            liver_exists, spleen_exists = liver_exists.to(device), spleen_exists.to(device)
            
            outputs = model(liver_inputs, spleen_inputs, liver_exists, spleen_exists)
            
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    # 计算总体准确率
    accuracy = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds, labels=range(num_classes))
    
    print(f"\n{dataset_name} Accuracy: {accuracy:.4f}")
    print(f"{dataset_name} Confusion Matrix:")
    print(cm)
    
    # 计算每个类别的详细指标
    metrics = {}
    for i in range(num_classes):
        class_metrics = calculate_metrics(all_labels, all_preds, all_probs, i, num_classes)
        metrics[i] = class_metrics
        
        print(f"\n{dataset_name} Class {i} Metrics:")
        print(f"Sensitivity (Recall): {class_metrics['sensitivity']:.4f}")
        print(f"Specificity: {class_metrics['specificity']:.4f}")
        print(f"PPV (Precision): {class_metrics['ppv']:.4f}")
        print(f"NPV: {class_metrics['npv']:.4f}")
        print(f"LR+: {class_metrics['lr_pos']:.4f}")
        print(f"LR-: {class_metrics['lr_neg']:.4f}")
        print(f"AUC: {class_metrics['auc']:.4f} (95% CI: {class_metrics['auc_ci_lower']:.4f}-{class_metrics['auc_ci_upper']:.4f})")
    
    # 绘制ROC曲线
    plt.figure(figsize=(10, 8))
    colors = ['blue', 'red', 'green', 'purple', 'orange']
    
    for i in range(num_classes):
        binary_labels = (np.array(all_labels) == i).astype(int)
        binary_probs = np.array(all_probs)[:, i]
        
        # 检查是否至少有两个类别
        if len(np.unique(binary_labels)) < 2:
            print(f"Skipping ROC for class {i} - only one class present")
            continue
            
        fpr, tpr, _ = roc_curve(binary_labels, binary_probs)
        roc_auc = auc(fpr, tpr)
        
        plt.plot(fpr, tpr, color=colors[i % len(colors)],
                 lw=2, label=f'Class {i}')
    
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('1 - Specificity')
    plt.ylabel('Sensitivity')
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.savefig(f"{Config.output_dir}/results/{dataset_name.lower()}_roc_curve.tif", dpi=300, bbox_inches='tight')
    plt.close()
    
    return metrics, accuracy, cm, np.array(all_labels), np.array(all_probs)

# 检查是否有可用的 GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = SWEClassifier(num_classes=Config.num_classes)
model = model.to(device)

# 打印模型结构
print("Model architecture:")
print(model)

# 优化器和损失函数
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.AdamW(model.parameters(), lr=Config.lr, weight_decay=Config.weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=10, verbose=True
)

# 训练函数
def train_model():
    best_acc = 0.0
    train_loss, val_loss = [], []
    train_acc, val_acc = [], []
    patience = 30  # 早停机制的耐心值
    no_improvement_epochs = 0  # 记录没有提升的周期数

    for epoch in range(Config.num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        loop = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{Config.num_epochs}')
        for liver_inputs, spleen_inputs, labels, liver_exists, spleen_exists in loop:
            liver_inputs, spleen_inputs, labels = liver_inputs.to(device), spleen_inputs.to(device), labels.to(device)
            liver_exists, spleen_exists = liver_exists.to(device), spleen_exists.to(device)

            optimizer.zero_grad()
            outputs = model(liver_inputs, spleen_inputs, liver_exists, spleen_exists)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item() * liver_inputs.size(0)
            
            # 更新进度条
            loop.set_postfix(loss=loss.item(), acc=correct/total)

        # 计算训练指标
        epoch_train_loss = running_loss / len(train_loader.dataset)
        epoch_train_acc = correct / total
        train_loss.append(epoch_train_loss)
        train_acc.append(epoch_train_acc)

        # 验证阶段
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for liver_inputs, spleen_inputs, labels, liver_exists, spleen_exists in val_loader:
                liver_inputs, spleen_inputs, labels = liver_inputs.to(device), spleen_inputs.to(device), labels.to(device)
                liver_exists, spleen_exists = liver_exists.to(device), spleen_exists.to(device)
                outputs = model(liver_inputs, spleen_inputs, liver_exists, spleen_exists)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                val_running_loss += loss.item() * liver_inputs.size(0)

        # 计算验证指标
        epoch_val_loss = val_running_loss / len(val_loader.dataset)
        epoch_val_acc = val_correct / val_total
        val_loss.append(epoch_val_loss)
        val_acc.append(epoch_val_acc)
        
        # 更新学习率
        scheduler.step(epoch_val_acc)

        # 保存最佳模型
        if epoch_val_acc > best_acc:
            best_acc = epoch_val_acc
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': epoch_val_acc,
            }, f"{Config.output_dir}/models/best_model.pth")
            no_improvement_epochs = 0  # 验证集准确率提升，重置计数器
            print(f"New best model saved with val accuracy: {best_acc:.4f}")
        else:
            no_improvement_epochs += 1  # 验证集准确率未提升，计数器加1

        print(f"Epoch {epoch + 1}/{Config.num_epochs} | "
              f"Train Loss: {epoch_train_loss:.4f} Acc: {epoch_train_acc:.4f} | "
              f"Val Loss: {epoch_val_loss:.4f} Acc: {epoch_val_acc:.4f} | "
              f"LR: {optimizer.param_groups[0]['lr']:.2e}")

        # 早停机制
        if no_improvement_epochs >= patience:
            print(f"Early stopping at epoch {epoch + 1} due to no improvement in validation accuracy for {patience} epochs.")
            break

    # 绘制训练曲线
    plt.figure(figsize=(12, 10))
    
    plt.subplot(2, 1, 1)
    plt.plot(train_loss, label='Train Loss')
    plt.plot(val_loss, label='Val Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    plt.subplot(2, 1, 2)
    plt.plot(train_acc, label='Train Accuracy')
    plt.plot(val_acc, label='Val Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(f"{Config.output_dir}/results/training_curve.tif", dpi=300, bbox_inches='tight')
    plt.close()

    return train_loss, train_acc, val_loss, val_acc


# 执行训练与评估
if __name__ == "__main__":
    print("开始训练模型...")
    train_loss, train_acc, val_loss, val_acc = train_model()
    
    print("\n加载最佳模型...")
    checkpoint = torch.load(f"{Config.output_dir}/models/best_model.pth")
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    print("\n评估训练集性能...")
    train_metrics, train_accuracy, train_cm, train_labels, train_probs = evaluate_model_full(
        model, train_loader, "Training Set", device, Config.num_classes)
    
    print("\n评估测试集性能...")
    test_metrics, test_accuracy, test_cm, test_labels, test_probs = evaluate_model_full(
        model, test_loader, "Test Set", device, Config.num_classes)
    
    # 保存评估结果到文件
    with open(f"{Config.output_dir}/results/metrics_summary.txt", 'w') as f:
        f.write("模型评估结果汇总\n")
        f.write("="*50 + "\n")
        f.write(f"训练集准确率: {train_accuracy:.4f}\n")
        f.write(f"测试集准确率: {test_accuracy:.4f}\n\n")
        
        f.write("训练集混淆矩阵:\n")
        f.write(f"{train_cm}\n\n")
        
        f.write("测试集混淆矩阵:\n")
        f.write(f"{test_cm}\n\n")
        
        for i in range(Config.num_classes):
            f.write(f"类别 {i} (F{i+2}) 指标\n")
            f.write("-"*30 + "\n")
            
            f.write("训练集:\n")
            f.write(f"Sensitivity: {train_metrics[i]['sensitivity']:.4f}\n")
            f.write(f"Specificity: {train_metrics[i]['specificity']:.4f}\n")
            f.write(f"PPV: {train_metrics[i]['ppv']:.4f}\n")
            f.write(f"NPV: {train_metrics[i]['npv']:.4f}\n")
            f.write(f"LR+: {train_metrics[i]['lr_pos']:.4f}\n")
            f.write(f"LR-: {train_metrics[i]['lr_neg']:.4f}\n")
            f.write(f"AUC: {train_metrics[i]['auc']:.4f} (95% CI: {train_metrics[i]['auc_ci_lower']:.4f}-{train_metrics[i]['auc_ci_upper']:.4f})\n\n")
            
            f.write("测试集:\n")
            f.write(f"Sensitivity: {test_metrics[i]['sensitivity']:.4f}\n")
            f.write(f"Specificity: {test_metrics[i]['specificity']:.4f}\n")
            f.write(f"PPV: {test_metrics[i]['ppv']:.4f}\n")
            f.write(f"NPV: {test_metrics[i]['npv']:.4f}\n")
            f.write(f"LR+: {test_metrics[i]['lr_pos']:.4f}\n")
            f.write(f"LR-: {test_metrics[i]['lr_neg']:.4f}\n")
            f.write(f"AUC: {test_metrics[i]['auc']:.4f} (95% CI: {test_metrics[i]['auc_ci_lower']:.4f}-{test_metrics[i]['auc_ci_upper']:.4f})\n\n")
    
    # 保存分类结果（二分类格式，适用于ROC曲线绘制）
    for i in range(Config.num_classes):
        # 创建当前类别的二分类标签（1为当前类别，0为其他类别）
        train_binary_labels = (train_labels == i).astype(int)
        test_binary_labels = (test_labels == i).astype(int)
    
        # 保存训练集结果
        np.save(f"{Config.output_dir}/results/train/F{i+2}_binary_true.npy", train_binary_labels)
        np.save(f"{Config.output_dir}/results/train/F{i+2}_proba.npy", train_probs[:, i])
    
        # 保存测试集结果
        np.save(f"{Config.output_dir}/results/test/F{i+2}_binary_true.npy", test_binary_labels)
        np.save(f"{Config.output_dir}/results/test/F{i+2}_proba.npy", test_probs[:, i])
        
    print("\n模型训练和评估完成!")
    print(f"结果已保存至: {Config.output_dir}/results")
    