In [1]:
from pathlib import Path
from scipy.ndimage import zoom
from scipy.ndimage import find_objects
import torchio as tio
import os
import glob
import re
from configparser import ConfigParser
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from typing import Dict, Tuple
import matplotlib.pyplot as plt
from collections import deque
from sklearn.model_selection import KFold
import math
from dataset import ACDCDataset, PairwiseAugmentor

# 配置参数
CLASS_MAP = {'NOR':0, 'DCM':1, 'HCM':2, 'MINF':3, 'RV':4}
TARGET_SHAPE = (200, 200, 80)
TARGET_SPACING = 1.25  # mm
AUG_FACTOR = 1  


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class PrimaryCaps3D(nn.Module):
    def __init__(self, in_channels=256, caps_dim=4, kernel_size=3, stride=2):
        super().__init__()
        self.caps_dim = caps_dim
        self.out_caps = in_channels // caps_dim  # 主胶囊数量
        self.conv = nn.Conv3d(in_channels, self.out_caps * caps_dim, 
                            kernel_size, stride=stride, padding=1)
        
    def forward(self, x):
        batch_size = x.size(0)
        out = self.conv(x)  # [B, out_caps*caps_dim, D, H, W]
        
        # 重塑维度：分离胶囊维度和特征维度
        out = out.view(batch_size, self.out_caps, self.caps_dim,
                      out.size(-3), out.size(-2), out.size(-1))  # [B, out_caps, caps_dim, D, H, W]
        
        # 调整维度顺序并展平空间维度
        out = out.permute(0, 3, 4, 5, 1, 2).contiguous()  # [B, D, H, W, out_caps, caps_dim]
        spatial_flatten = out.size(1)*out.size(2)*out.size(3)  # p = D*H*W
        return out.view(batch_size, spatial_flatten, self.out_caps, self.caps_dim)  # [B, p, i, m]

class ConvCaps3D(nn.Module):
    def __init__(self, in_caps, out_caps, in_dim, out_dim, num_routing=3):
        super().__init__()
        self.num_routing = num_routing
        self.W = nn.Parameter(torch.Tensor(in_caps, out_caps, in_dim, out_dim))  # [i, j, m, n]
        nn.init.orthogonal_(self.W)

    def dynamic_routing(self, u):
        """动态路由机制
        Args:
            u: 输入胶囊 [B, p, i, m]
        Returns:
            v: 输出胶囊 [B, j, n]
        """
        batch_size, p = u.size(0), u.size(1)
        device = u.device
        
        # 计算预测向量 (公式1)
        u_hat = torch.einsum('bpim,ijmn->bpijn', u, self.W)  # [B, p, j, n]
        
        # 初始化路由logits (公式2)
        b = torch.zeros(batch_size, p, self.W.size(0), self.W.size(1)).to(device)  # [B, p, i, j]
        
        for _ in range(self.num_routing):
            # 空间位置维度独立计算耦合系数
            c = F.softmax(b, dim=-1)  # [B, p, i, j]
            
            # 加权求和 (公式3)
            s = torch.einsum('bpij,bpijn->bjn', c, u_hat)  # [B, j, n]
            
            # 非线性压缩 (公式4)
            v = self.squash(s)  # [B, j, n]
            
            # 路由协议更新 (仅在训练时更新)
            if self.training and _ < self.num_routing-1:
                delta_b = torch.einsum('bjn,bpijn->bpij', v, u_hat)
                b = b + delta_b
        
        return v

    def squash(self, input_tensor):
        norm = torch.norm(input_tensor, dim=-1, keepdim=True)  # [B, j, 1]
        scale = norm**2 / (1 + norm**2)  # 缩放系数
        return scale * input_tensor / (norm + 1e-8)  # [B, j, n]

    def forward(self, x):
        return self.dynamic_routing(x)

class Caps3DNet(nn.Module):
    """改进后的三维胶囊网络"""
    def __init__(self, in_channels=1, num_classes=5):
        super().__init__()
        
        # 图像特征提取
        self.conv3d = nn.Sequential(
            nn.Conv3d(in_channels, 256, kernel_size=9, stride=3),
            nn.BatchNorm3d(256),
            nn.ReLU(),
            nn.Conv3d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.Dropout3d(0.2)
        )
        
        # 初级胶囊层 (输出[B, p=6*6*6=216, i=32, m=8])
        self.primary_caps = PrimaryCaps3D(in_channels=256, caps_dim=4)
        
        # 卷积胶囊层 (i=32 -> j=64)
        self.conv_caps1 = ConvCaps3D(
            in_caps=32,
            out_caps=64,
            in_dim=8,
            out_dim=16
        )
        
        # 数字胶囊层 (j=64 -> num_classes)
        self.digit_caps = ConvCaps3D(
            in_caps=64,
            out_caps=num_classes,
            in_dim=4,
            out_dim=8
        )
        
        # EF特征分支
        self.ef_fc = nn.Sequential(
            nn.Linear(1, 32),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # 分类器
        self.classifier = nn.Linear(num_classes*8 + 32, num_classes)

    def forward(self, x, ef):
        batch_size = x.size(0)
        
        # 图像特征处理
        x = self.conv3d(x)  # [B, 256, 6,6,6]
        x = self.primary_caps(x)  # [B, p=216, i=32, m=8]
        
        # 胶囊层处理
        caps_output = self.digit_caps(x)  # [B, num_classes, 16]
        
        # 展平胶囊输出
        img_features = caps_output.view(batch_size, -1)  # [B, num_classes*16]
        
        # EF特征处理
        ef_features = self.ef_fc(ef.unsqueeze(1))  # [B, 32]
        
        # 特征融合
        combined = torch.cat([img_features, ef_features], dim=1)  # [B, num_classes*16+32]
        
        return self.classifier(combined)

In [3]:

from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
import torch.multiprocessing as mp
import json

if __name__ == '__main__':  # 确保在主模块中设置
    mp.set_start_method('spawn', force=True)

start_fold = 0  # 可修改为需要开始的折数 (0-4)
results_file = '新-门控.json'
CUSTOM_PREFIX = "新-门控"

# 尝试加载已有的结果 - 添加空文件处理
fold_results = []
if os.path.exists(results_file):
    try:
        with open(results_file, 'r') as f:
            file_content = f.read().strip()
            if file_content:  # 检查文件是否非空
                fold_results = json.loads(file_content)
                print(f"Loaded existing results: {fold_results}")
            else:
                print("Results file exists but is empty. Starting fresh.")
    except json.JSONDecodeError:
        print("Warning: Results file contains invalid JSON. Starting fresh.")
        fold_results = []
else:
    print("No existing results file found. Starting fresh.")

# 训练流程修改
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds, all_labels = [], []
    
    for inputs, labels, ef in loader:
        inputs, labels, ef = inputs.to(device), labels.to(device), ef.float().to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs, ef)
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        
        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())  # 使用extend代替append
        all_labels.extend(labels.cpu().numpy()) # 转换为numpy数组
    
    return running_loss/len(loader), accuracy_score(all_labels, all_preds)

def evaluate(model, loader, criterion, device):
    model.eval()
    val_loss = 0.0
    val_preds, val_labels = [], []
    
    with torch.no_grad():
        for inputs, labels, ef in loader:
            inputs, labels, ef = inputs.to(device), labels.to(device), ef.float().to(device)
            
            outputs = model(inputs, ef)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)
            
            _, preds = torch.max(outputs, 1)
            val_preds.extend(preds.cpu().numpy())   # 修改为extend
            val_labels.extend(labels.cpu().numpy())  # 修改为extend
    
    return val_loss/len(loader.dataset), accuracy_score(val_labels, val_preds)


# 五折交叉验证修改版
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
all_cases = [d for d in Path('心力衰竭/database/training').glob('patient*') if d.is_dir()] + \
            [d for d in Path('心力衰竭/database/testing').glob('patient*') if d.is_dir()]
all_labels = []  # 存储每个病例的标签

# 收集每个病例的标签
for case in all_cases:
    # 创建临时数据集实例（不需要变换）
    _, label, _ = ACDCDataset([case], phase='train')[0]
    all_labels.append(label)
fold_results = []

for fold, (train_val_idx, test_idx) in enumerate(kf.split(all_cases, all_labels)):
    print(f"\n=== Fold {fold+1}/5 ===")
    
    # 划分训练验证集和测试集
    train_val_cases = [all_cases[i] for i in train_val_idx]
    test_cases = [all_cases[i] for i in test_idx]
    
    # 从训练验证集中提取标签用于再分层
    train_val_labels = [all_labels[i] for i in train_val_idx]
    
    # 在训练验证集内部进行分层划分 (75%训练, 25%验证)
    sss = StratifiedShuffleSplit(n_splits=1, test_size=0.25, random_state=42)
    for train_idx, val_idx in sss.split(train_val_cases, train_val_labels):
        train_cases = [train_val_cases[i] for i in train_idx]
        val_cases = [train_val_cases[i] for i in val_idx]
    
    # 创建数据集
    train_dataset = ACDCDataset(train_cases, phase='train')
    val_dataset = ACDCDataset(val_cases, phase='val')    # 从训练集划分的验证集
    test_dataset = ACDCDataset(test_cases, phase='val')  # 独立测试集
    
    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=3)
    val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=1)
    test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)
    
    # 模型初始化（保持原有实现不变）
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Caps3DNet().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5, patience=5)
    
    # 初始化跟踪变量
    best_acc = 0.0  # 只跟踪最佳准确率
    best_loss = 10
    best_model_path = f"{CUSTOM_PREFIX}_fold{fold}_best.pth"
    final_model_path = f"{CUSTOM_PREFIX}_fold{fold}_last.pth"
    
    for epoch in range(100):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)  # 使用新验证集
        
        scheduler.step(val_acc)
        
        # 动态保存最佳模型（只保留最佳准确率版本）
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), best_model_path)
        
        
        # 早停判断（基于验证损失）
        if val_loss < best_loss:
            best_loss = val_loss
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= 10:
                print(f"Early stopping at epoch {epoch+1}")
                break
    
    model.load_state_dict(torch.load(best_model_path))
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    fold_results.append(test_acc)  # 记录测试集准确率
    print(f"Fold {fold+1} Test Accuracy: {test_acc:.2%}")    
    
    # 确保最终模型保存（即使早停也保存最后达到的epoch）
    torch.save(model.state_dict(), final_model_path)

    with open(results_file, 'w') as f:
        json.dump(fold_results, f)
    print(f"\nCurrent 5-Fold CV Results: {fold_results}")
    print(f"Average Accuracy: {np.mean(fold_results):.2%} (±{np.std(fold_results):.2%})")

# 输出结果（保持原有实现不变）
if os.path.exists(results_file):
    with open(results_file, 'r') as f:
        final_results = json.load(f)
print("\n=== Final Results ===")
print(f"5-Fold CV Results: {final_results}")
print(f"Average Accuracy: {np.mean(final_results):.2%} (±{np.std(final_results):.2%})")

No existing results file found. Starting fresh.

=== Fold 1/5 ===
Early stopping at epoch 80


  model.load_state_dict(torch.load(best_model_path))


Fold 1 Test Accuracy: 66.67%

Current 5-Fold CV Results: [0.6666666666666666]
Average Accuracy: 66.67% (±0.00%)

=== Fold 2/5 ===
Early stopping at epoch 23


  model.load_state_dict(torch.load(best_model_path))


Fold 2 Test Accuracy: 53.33%

Current 5-Fold CV Results: [0.6666666666666666, 0.5333333333333333]
Average Accuracy: 60.00% (±6.67%)

=== Fold 3/5 ===
Early stopping at epoch 33


  model.load_state_dict(torch.load(best_model_path))


Fold 3 Test Accuracy: 51.67%

Current 5-Fold CV Results: [0.6666666666666666, 0.5333333333333333, 0.5166666666666667]
Average Accuracy: 57.22% (±6.71%)

=== Fold 4/5 ===
Early stopping at epoch 31


  model.load_state_dict(torch.load(best_model_path))


Fold 4 Test Accuracy: 58.33%

Current 5-Fold CV Results: [0.6666666666666666, 0.5333333333333333, 0.5166666666666667, 0.5833333333333334]
Average Accuracy: 57.50% (±5.83%)

=== Fold 5/5 ===
Early stopping at epoch 36


  model.load_state_dict(torch.load(best_model_path))


Fold 5 Test Accuracy: 56.67%

Current 5-Fold CV Results: [0.6666666666666666, 0.5333333333333333, 0.5166666666666667, 0.5833333333333334, 0.5666666666666667]
Average Accuracy: 57.33% (±5.23%)

=== Final Results ===
5-Fold CV Results: [0.6666666666666666, 0.5333333333333333, 0.5166666666666667, 0.5833333333333334, 0.5666666666666667]
Average Accuracy: 57.33% (±5.23%)
