# Transformer-based 脑肿瘤分割 - 训练Notebook

本notebook用于训练基于Transformer的医学图像分割模型.

## 支持的模型:
1. SwinUNet - 基于Swin Transformer的UNet
2. Swin-UNETR - MONAI的Swin-UNETR (3D)
3. UNETR - MONAI的UNETR (3D)

## 功能:
1. 从3D体积中提取数据
2. 加载和训练Transformer-based模型
3. 模型评估和可视化
4. 与之前的模型进行对比


## 1. 安装依赖和挂载Google Drive


In [None]:
# 挂载Google Drive
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# 安装必要的包
%pip install monai[all] -q
%pip install nibabel -q
%pip install timm -q  # Swin Transformer
%pip install einops -q
%pip install matplotlib -q


## 2. 导入库和配置参数


In [None]:
# 导入必要的库
import os
import glob
import re
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import nibabel as nib
from sklearn.model_selection import train_test_split
from collections import defaultdict
import matplotlib.pyplot as plt
from tqdm import tqdm
import json
from datetime import datetime

# MONAI导入
import monai
from monai.networks.nets import UNETR, SwinUNETR
from monai.losses import DiceLoss, DiceCELoss
from monai.metrics import DiceMetric
from monai.transforms import (
    LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd,
    ScaleIntensityRanged, CropForegroundd, Resized, ToTensord,
    Compose, RandRotate90d, RandFlipd, RandShiftIntensityd, MapTransform
)
from monai.data import Dataset as MonaiDataset, DataLoader as MonaiDataLoader
from monai.utils import set_determinism
from monai.inferers import sliding_window_inference

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

print(f"MONAI version: {monai.__version__}")
print(f"PyTorch version: {torch.__version__}")


In [None]:
# 数据路径配置
DRIVE_DATA_PATH = "/content/drive/MyDrive/data-brain-2024"
MODEL_SAVE_PATH = "/content/drive/MyDrive/brain-tumor-models-transformer"
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

# 训练参数
IMG_SIZE = 128  # 3D图像尺寸 (可以根据GPU内存调整)
BATCH_SIZE = 2  # 3D数据需要较小的batch size
LEARNING_RATE = 1e-4
NUM_EPOCHS = 50
VAL_INTERVAL = 1
NUM_CLASSES = 4  # 背景 + 3个肿瘤类别

# 模型选择: 'unetr', 'swin_unetr', 'swinunet'
MODEL_TYPE = 'unetr'  # 可以修改为 'swin_unetr' 或 'swinunet'

# 设置随机种子
set_determinism(seed=42)

print(f"数据路径: {DRIVE_DATA_PATH}")
print(f"模型保存路径: {MODEL_SAVE_PATH}")
print(f"图像尺寸: {IMG_SIZE}x{IMG_SIZE}x{IMG_SIZE}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"学习率: {LEARNING_RATE}")
print(f"模型类型: {MODEL_TYPE}")


## 3. 数据加载和预处理


In [None]:
# 自定义变换：合并多模态图像
class ConcatModalitiesd(MapTransform):
    """将多模态图像列表合并为多通道图像"""
    def __init__(self, keys, allow_missing_keys=False):
        super().__init__(keys, allow_missing_keys)

    def __call__(self, data):
        d = dict(data)
        for key in self.key_iterator(d):
            if isinstance(d[key], list):
                d[key] = np.concatenate(d[key], axis=0)
        return d

class RemapLabeld(MapTransform):
    """将标签中的值4映射到3"""
    def __init__(self, keys, allow_missing_keys=False):
        super().__init__(keys, allow_missing_keys)

    def __call__(self, data):
        d = dict(data)
        for key in self.key_iterator(d):
            if isinstance(d[key], torch.Tensor):
                d[key] = torch.where(d[key] == 4, torch.tensor(3, dtype=d[key].dtype, device=d[key].device), d[key])
            elif isinstance(d[key], np.ndarray):
                d[key] = np.where(d[key] == 4, 3, d[key])
        return d

def get_patient_groups(data_path):
    """获取所有患者的数据分组"""
    all_files = glob.glob(os.path.join(data_path, "*.nii"))
    patient_groups = defaultdict(lambda: defaultdict(dict))

    for file_path in all_files:
        filename = os.path.basename(file_path)
        match = re.match(r'BraTS-GLI-(\d+)-(\d+)-(t1n|t2f|t2w|t1c|seg)\.nii', filename)
        if match:
            patient_id = match.group(1)
            sequence_id = match.group(2)
            modality = match.group(3)
            patient_groups[patient_id][sequence_id][modality] = file_path

    complete_patients = {}
    for patient_id, sequences in patient_groups.items():
        for seq_id, modalities in sequences.items():
            if 't2f' in modalities and 't1c' in modalities and 'seg' in modalities:
                if patient_id not in complete_patients:
                    complete_patients[patient_id] = {}
                complete_patients[patient_id][seq_id] = modalities

    return complete_patients

def prepare_monai_data_list(patient_groups, patient_ids):
    """准备MONAI格式的数据字典列表"""
    data_list = []

    for patient_id in patient_ids:
        if patient_id not in patient_groups:
            continue

        for seq_id, modalities in patient_groups[patient_id].items():
            if 't2f' in modalities and 't1c' in modalities and 'seg' in modalities:
                data_dict = {
                    "image": [modalities['t2f'], modalities['t1c']],
                    "label": modalities['seg'],
                    "patient_id": patient_id,
                    "sequence_id": seq_id
                }
                data_list.append(data_dict)

    return data_list

# 获取所有患者数据
all_patient_groups = get_patient_groups(DRIVE_DATA_PATH)
patient_ids = list(all_patient_groups.keys())

print(f"找到 {len(patient_ids)} 个患者")
print(f"前5个患者ID: {patient_ids[:5]}")


In [None]:
# 数据划分（按患者ID，避免数据泄露）
all_data_list = prepare_monai_data_list(all_patient_groups, patient_ids)

unique_patient_ids = list(set([item['patient_id'] for item in all_data_list]))
train_patients, temp_patients = train_test_split(
    unique_patient_ids, test_size=0.3, random_state=42
)
val_patients, test_patients = train_test_split(
    temp_patients, test_size=0.5, random_state=42
)

# 根据患者ID划分数据
train_data_list = [item for item in all_data_list if item['patient_id'] in train_patients]
val_data_list = [item for item in all_data_list if item['patient_id'] in val_patients]
test_data_list = [item for item in all_data_list if item['patient_id'] in test_patients]

print(f"训练集: {len(train_data_list)} 个样本 ({len(train_patients)} 个患者)")
print(f"验证集: {len(val_data_list)} 个样本 ({len(val_patients)} 个患者)")
print(f"测试集: {len(test_data_list)} 个样本 ({len(test_patients)} 个患者)")
print(f"\n训练患者示例: {train_patients[:5]}")
print(f"验证患者示例: {val_patients[:3]}")
print(f"测试患者示例: {test_patients[:3]}")


In [None]:
# 训练时的数据变换
train_transforms = Compose([
    LoadImaged(keys=["image", "label"], image_only=False),
    EnsureChannelFirstd(keys=["image", "label"]),
    ConcatModalitiesd(keys=["image"]),
    RemapLabeld(keys=["label"]),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
    ScaleIntensityRanged(
        keys=["image"],
        a_min=0.0,
        a_max=1000.0,
        b_min=0.0,
        b_max=1.0,
        clip=True,
    ),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    Resized(keys=["image", "label"], spatial_size=(IMG_SIZE, IMG_SIZE, IMG_SIZE), mode=("trilinear", "nearest")),
    # 数据增强
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(0, 1)),
    RandFlipd(keys=["image", "label"], prob=0.5),
    RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
    ToTensord(keys=["image", "label"]),
])

# 验证/测试时的数据变换（无数据增强）
val_transforms = Compose([
    LoadImaged(keys=["image", "label"], image_only=False),
    EnsureChannelFirstd(keys=["image", "label"]),
    ConcatModalitiesd(keys=["image"]),
    RemapLabeld(keys=["label"]),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
    ScaleIntensityRanged(
        keys=["image"],
        a_min=0.0,
        a_max=1000.0,
        b_min=0.0,
        b_max=1.0,
        clip=True,
    ),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    Resized(keys=["image", "label"], spatial_size=(IMG_SIZE, IMG_SIZE, IMG_SIZE), mode=("trilinear", "nearest")),
    ToTensord(keys=["image", "label"]),
])

print("数据变换定义完成")


In [None]:
# 创建数据集和数据加载器
train_dataset = MonaiDataset(data=train_data_list, transform=train_transforms)
val_dataset = MonaiDataset(data=val_data_list, transform=val_transforms)
test_dataset = MonaiDataset(data=test_data_list, transform=val_transforms)

train_loader = MonaiDataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=torch.cuda.is_available()
)

val_loader = MonaiDataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    pin_memory=torch.cuda.is_available()
)

test_loader = MonaiDataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    pin_memory=torch.cuda.is_available()
)

print(f"训练集大小: {len(train_dataset)}")
print(f"验证集大小: {len(val_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
print(f"训练批次数: {len(train_loader)}")
print(f"验证批次数: {len(val_loader)}")


## 5. 创建模型


In [None]:
# 创建模型
in_channels = 2  # FLAIR + T1CE
out_channels = NUM_CLASSES

if MODEL_TYPE == 'unetr':
    model = UNETR(
        in_channels=in_channels,
        out_channels=out_channels,
        img_size=(IMG_SIZE, IMG_SIZE, IMG_SIZE),
        feature_size=16,
        hidden_size=768,
        mlp_dim=3072,
        num_heads=12,
        norm_name="instance",
        res_block=True,
        dropout_rate=0.0,
    )
    print("创建UNETR模型")
elif MODEL_TYPE == 'swin_unetr':
    model = SwinUNETR(
        img_size=(IMG_SIZE, IMG_SIZE, IMG_SIZE),
        in_channels=in_channels,
        out_channels=out_channels,
        feature_size=48,
        use_checkpoint=False,
    )
    print("创建Swin-UNETR模型")
elif MODEL_TYPE == 'swinunet':
    # SwinUNet需要单独实现或使用其他库
    # 这里提供一个简单的实现示例
    print("SwinUNet需要单独实现，暂时使用UNETR代替")
    model = UNETR(
        in_channels=in_channels,
        out_channels=out_channels,
        img_size=(IMG_SIZE, IMG_SIZE, IMG_SIZE),
        feature_size=16,
        hidden_size=768,
        mlp_dim=3072,
        num_heads=12,
        norm_name="instance",
        res_block=True,
        dropout_rate=0.0,
    )
else:
    raise ValueError(f"不支持的模型类型: {MODEL_TYPE}")

model = model.to(device)

# 计算模型参数数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"模型创建完成")
print(f"总参数数量: {total_params:,}")
print(f"可训练参数: {trainable_params:,}")

# 定义损失函数和优化器
loss_function = DiceCELoss(include_background=False, to_onehot_y=True, softmax=True)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6)

# 定义评估指标
dice_metric = DiceMetric(include_background=False, reduction="mean")

print("\n损失函数: Dice + CrossEntropy Loss")
print(f"优化器: AdamW (lr={LEARNING_RATE}, weight_decay=1e-4)")
print(f"学习率调度器: ReduceLROnPlateau")


## 6. 训练和验证函数


In [None]:
# 训练一个epoch
def train_epoch(model, loader, optimizer, loss_function, device, epoch):
    model.train()
    epoch_loss = 0
    step = 0

    for batch_data in tqdm(loader, desc=f"Epoch {epoch+1} Training"):
        step += 1
        inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if step % 10 == 0:
            print(f"  Step {step}/{len(loader)}, Loss: {loss.item():.4f}")

    epoch_loss /= len(loader)
    return epoch_loss

# 验证函数
def val_epoch(model, loader, loss_function, dice_metric, device):
    model.eval()
    val_loss = 0
    dice_metric.reset()

    with torch.no_grad():
        for batch_data in tqdm(loader, desc="Validation"):
            inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)

            # 使用滑动窗口推理（适合大图像）
            outputs = sliding_window_inference(
                inputs=inputs,
                roi_size=(IMG_SIZE, IMG_SIZE, IMG_SIZE),
                sw_batch_size=1,
                predictor=model,
                overlap=0.5,
                mode="gaussian",
            )

            loss = loss_function(outputs, labels)
            val_loss += loss.item()

            # 计算指标
            pred_one_hot = torch.softmax(outputs, dim=1)
            labels_one_hot = torch.zeros_like(pred_one_hot)
            labels_one_hot.scatter_(1, labels.long(), 1)

            dice_metric(y_pred=pred_one_hot, y=labels_one_hot)

    val_loss /= len(loader)
    dice_scores = dice_metric.aggregate()

    return val_loss, dice_scores

print("训练和验证函数定义完成")


## 7. 开始训练


In [None]:
# 训练历史
train_losses = []
val_losses = []
val_dice_scores = []
best_val_loss = float('inf')
best_dice_score = 0.0

# 检查是否有检查点
checkpoint_path = os.path.join(MODEL_SAVE_PATH, f"checkpoint_latest_{MODEL_TYPE}.pth")
start_epoch = 0

if os.path.exists(checkpoint_path):
    print(f"找到检查点: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    train_losses = checkpoint.get('train_losses', [])
    val_losses = checkpoint.get('val_losses', [])
    val_dice_scores = checkpoint.get('val_dice_scores', [])
    best_val_loss = checkpoint.get('best_val_loss', float('inf'))
    best_dice_score = checkpoint.get('best_dice_score', 0.0)
    print(f"从epoch {start_epoch} 恢复训练")
else:
    print("未找到检查点，从头开始训练")

print(f"\n开始训练，共 {NUM_EPOCHS} 个epoch")
print("=" * 60)


In [None]:
# 训练循环
for epoch in range(start_epoch, NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
    print("-" * 60)

    # 训练
    train_loss = train_epoch(model, train_loader, optimizer, loss_function, device, epoch)
    train_losses.append(train_loss)

    # 验证
    if (epoch + 1) % VAL_INTERVAL == 0:
        val_loss, dice_scores = val_epoch(model, val_loader, loss_function, dice_metric, device)
        val_losses.append(val_loss)
        val_dice_scores.append(dice_scores.mean().item())

        mean_dice = dice_scores.mean().item()

        print(f"\n验证结果:")
        print(f"  Loss: {val_loss:.4f}")
        print(f"  Dice系数: {mean_dice:.4f}")
        print(f"  Dice (各类别): {dice_scores.cpu().numpy()}")

        # 更新学习率
        scheduler.step(val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"  当前学习率: {current_lr:.6f}")

        # 保存最佳模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_path = os.path.join(MODEL_SAVE_PATH, f"best_model_{MODEL_TYPE}.pth")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_loss': val_loss,
                'dice_score': mean_dice,
            }, best_model_path)
            print(f"  保存最佳模型 (Loss: {val_loss:.4f}, Dice: {mean_dice:.4f})")

        if mean_dice > best_dice_score:
            best_dice_score = mean_dice

    # 保存检查点
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_dice_scores': val_dice_scores,
        'best_val_loss': best_val_loss,
        'best_dice_score': best_dice_score,
    }
    torch.save(checkpoint, checkpoint_path)

    print(f"训练Loss: {train_loss:.4f}")

print("\n" + "=" * 60)
print("训练完成！")
print(f"最佳验证Loss: {best_val_loss:.4f}")
print(f"最佳Dice系数: {best_dice_score:.4f}")


## 8. 可视化训练历史


In [None]:
# 绘制训练曲线
def plot_training_history(train_losses, val_losses, val_dice_scores, save_path=None):
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))

    # 损失曲线
    axes[0].plot(train_losses, label='Training Loss', color='blue')
    if val_losses:
        val_epochs = [i * VAL_INTERVAL for i in range(len(val_losses))]
        axes[0].plot(val_epochs, val_losses, label='Validation Loss', color='red', marker='o')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True)

    # Dice系数曲线
    if val_dice_scores:
        val_epochs = [i * VAL_INTERVAL for i in range(len(val_dice_scores))]
        axes[1].plot(val_epochs, val_dice_scores, label='Validation Dice Score', color='green', marker='o')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Dice Score')
        axes[1].set_title('Validation Dice Score')
        axes[1].legend()
        axes[1].grid(True)
        axes[1].set_ylim([0, 1])

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

plot_training_history(train_losses, val_losses, val_dice_scores,
                     save_path=os.path.join(MODEL_SAVE_PATH, f"training_history_{MODEL_TYPE}.png"))


## 9. 加载最佳模型并在测试集上评估


In [None]:
# 加载最佳模型
best_model_path = os.path.join(MODEL_SAVE_PATH, f"best_model_{MODEL_TYPE}.pth")
if os.path.exists(best_model_path):
    checkpoint = torch.load(best_model_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"加载最佳模型 (Epoch {checkpoint['epoch']}, Dice: {checkpoint['dice_score']:.4f})")
else:
    print("未找到最佳模型，使用当前模型")

# 在测试集上评估
print("\n在测试集上评估...")
test_loss, test_dice_scores = val_epoch(model, test_loader, loss_function, dice_metric, device)

print(f"\n测试集结果:")
print(f"  Loss: {test_loss:.4f}")
print(f"  平均Dice系数: {test_dice_scores.mean().item():.4f}")
print(f"  Dice系数 (各类别): {test_dice_scores.cpu().numpy()}")

# 保存测试结果
test_results = {
    'test_loss': test_loss,
    'test_dice_mean': test_dice_scores.mean().item(),
    'test_dice_per_class': test_dice_scores.cpu().numpy().tolist(),
    'model_type': MODEL_TYPE,
    'timestamp': datetime.now().isoformat(),
}

results_path = os.path.join(MODEL_SAVE_PATH, f"test_results_{MODEL_TYPE}.json")
with open(results_path, 'w') as f:
    json.dump(test_results, f, indent=2)
print(f"\n测试结果已保存到: {results_path}")


## 10. 训练结果分析


In [None]:
# 训练结果分析
print("=" * 80)
print("TRAINING RESULTS ANALYSIS REPORT")
print("=" * 80)

print("\n[1. Training Progress]")
print(f"   Completed Epochs: {len(train_losses)}/{NUM_EPOCHS}")
print(f"   Completion: {len(train_losses)/NUM_EPOCHS*100:.1f}%")

print("\n[2. Loss Function Analysis]")
if len(train_losses) > 0:
    print(f"   Training Loss:")
    print(f"      - Initial: {train_losses[0]:.4f}")
    print(f"      - Final: {train_losses[-1]:.4f}")
    improvement = train_losses[0] - train_losses[-1]
    improvement_pct = (improvement / train_losses[0]) * 100
    print(f"      - Improvement: {improvement:.4f} ({improvement_pct:.1f}%)")

if len(val_losses) > 0:
    print(f"\n   Validation Loss:")
    print(f"      - Initial: {val_losses[0]:.4f}")
    print(f"      - Final: {val_losses[-1]:.4f}")
    print(f"      - Best: {best_val_loss:.4f}")

print("\n[3. Dice Coefficient Analysis]")
if len(val_dice_scores) > 0:
    print(f"   Validation Dice Score:")
    print(f"      - Initial: {val_dice_scores[0]:.4f}")
    print(f"      - Final: {val_dice_scores[-1]:.4f}")
    print(f"      - Best: {best_dice_score:.4f}")
    print(f"      - Average: {np.mean(val_dice_scores):.4f}")

print("\n[4. Test Set Performance]")
if 'test_results' in locals():
    print(f"   Test Loss: {test_results['test_loss']:.4f}")
    print(f"   Test Dice Score: {test_results['test_dice_mean']:.4f}")

print("\n[5. Model Configuration]")
print(f"   Model Type: {MODEL_TYPE}")
print(f"   Image Size: {IMG_SIZE}x{IMG_SIZE}x{IMG_SIZE}")
print(f"   Batch Size: {BATCH_SIZE}")
print(f"   Learning Rate: {LEARNING_RATE}")

print("\n" + "=" * 80)
print("Analysis Complete!")
print("=" * 80)
