# HuggingFace 医学图像分割 - 测试Notebook

本notebook用于测试HuggingFace预训练模型的基本功能，不进行训练。

## 功能：
1. 数据加载和预处理测试
2. 模型创建和前向传播测试
3. 数据可视化
4. 基本预测测试


## 1. 安装依赖和挂载Google Drive


In [None]:
# 挂载Google Drive
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# 安装必要的包
%pip install segmentation-models-pytorch -q
%pip install nibabel -q
%pip install albumentations -q
%pip install matplotlib -q


In [None]:
# 导入必要的库
import os
import glob
import re
import numpy as np
import torch
import torch.nn as nn
import nibabel as nib
from collections import defaultdict
import matplotlib.pyplot as plt

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")


## 2. 配置参数


In [None]:
# 数据路径
DRIVE_DATA_PATH = "/content/drive/MyDrive/data-brain-2024"

# 测试参数
IMG_SIZE = 256
NUM_CLASSES = 4
SLICE_START = 22
NUM_SLICES = 100

# 模型参数
MODEL_TYPE = 'unet'
PRETRAINED_ENCODER = 'resnet34'

print(f"图像尺寸: {IMG_SIZE}x{IMG_SIZE}")
print(f"模型类型: {MODEL_TYPE}")
print(f"预训练编码器: {PRETRAINED_ENCODER}")


## 3. 数据加载函数


In [None]:
def get_patient_groups(data_path):
    """获取所有患者的数据分组"""
    all_files = glob.glob(os.path.join(data_path, "*.nii"))
    patient_groups = defaultdict(lambda: defaultdict(dict))

    for file_path in all_files:
        filename = os.path.basename(file_path)
        match = re.match(r'BraTS-GLI-(\d+)-(\d+)-(t1n|t2f|t2w|t1c|seg)\.nii', filename)
        if match:
            patient_id = match.group(1)
            sequence_id = match.group(2)
            modality = match.group(3)
            patient_groups[patient_id][sequence_id][modality] = file_path

    complete_patients = {}
    for patient_id, sequences in patient_groups.items():
        for seq_id, modalities in sequences.items():
            if 't2f' in modalities and 't1c' in modalities and 'seg' in modalities:
                if patient_id not in complete_patients:
                    complete_patients[patient_id] = {}
                complete_patients[patient_id][seq_id] = modalities

    return complete_patients

def load_nifti_volume(file_path):
    """加载NIfTI文件并返回numpy数组"""
    nii = nib.load(file_path)
    data = nii.get_fdata()
    return data

def extract_slices_from_volume(volume, start_idx=22, num_slices=100):
    """从3D体积中提取2D切片（沿z轴）"""
    depth = volume.shape[2]
    end_idx = min(start_idx + num_slices, depth)
    slices = volume[:, :, start_idx:end_idx]
    return slices

def normalize_slice(slice_data):
    """归一化单个切片"""
    slice_data = slice_data.astype(np.float32)
    max_val = np.max(slice_data)
    if max_val > 0:
        slice_data = slice_data / max_val
    return slice_data

def remap_labels(label_slice):
    """将标签值4映射到3"""
    label_slice = label_slice.astype(np.int64)
    label_slice[label_slice == 4] = 3
    return label_slice

# 获取所有患者数据
all_patient_groups = get_patient_groups(DRIVE_DATA_PATH)
patient_ids = list(all_patient_groups.keys())

print(f"找到 {len(patient_ids)} 个患者")
print(f"前5个患者ID: {patient_ids[:5]}")


## 4. 加载并可视化一个样本


In [None]:
# 选择一个患者进行测试
test_patient_id = patient_ids[0]
test_seq_id = list(all_patient_groups[test_patient_id].keys())[0]
test_modalities = all_patient_groups[test_patient_id][test_seq_id]

print(f"测试患者ID: {test_patient_id}")
print(f"测试序列ID: {test_seq_id}")
print(f"模态文件:")
for mod, path in test_modalities.items():
    print(f"  {mod}: {path}")

# 加载3D体积
t2f_volume = load_nifti_volume(test_modalities['t2f'])
t1c_volume = load_nifti_volume(test_modalities['t1c'])
seg_volume = load_nifti_volume(test_modalities['seg'])

print(f"\n体积形状:")
print(f"  FLAIR: {t2f_volume.shape}")
print(f"  T1CE: {t1c_volume.shape}")
print(f"  标签: {seg_volume.shape}")

# 提取一个切片
slice_idx = 50  # 选择中间的一个切片
t2f_slice = t2f_volume[:, :, slice_idx]
t1c_slice = t1c_volume[:, :, slice_idx]
seg_slice = seg_volume[:, :, slice_idx]

# 归一化
t2f_slice_norm = normalize_slice(t2f_slice)
t1c_slice_norm = normalize_slice(t1c_slice)
seg_slice_remap = remap_labels(seg_slice)

print(f"\n切片形状:")
print(f"  FLAIR: {t2f_slice_norm.shape}")
print(f"  T1CE: {t1c_slice_norm.shape}")
print(f"  标签: {seg_slice_remap.shape}")
print(f"  标签值范围: {seg_slice_remap.min()} - {seg_slice_remap.max()}")
print(f"  标签值分布: {np.bincount(seg_slice_remap.flatten())}")


In [None]:
# 可视化原始切片
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(t2f_slice_norm, cmap='gray')
axes[0].set_title(f'FLAIR (切片 {slice_idx})')
axes[0].axis('off')

axes[1].imshow(t1c_slice_norm, cmap='gray')
axes[1].set_title(f'T1CE (切片 {slice_idx})')
axes[1].axis('off')

axes[2].imshow(seg_slice_remap, cmap='tab10', vmin=0, vmax=3)
axes[2].set_title(f'标签 (切片 {slice_idx})')
axes[2].axis('off')

plt.tight_layout()
plt.show()


## 5. 测试数据预处理（Albumentations）


In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

# 创建数据增强管道
transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=[0.5, 0.5], std=[0.5, 0.5]),  # 归一化到[-1, 1]
    ToTensorV2()
])

label_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE, interpolation=0),  # 最近邻插值
    ToTensorV2()
])

# 准备2通道图像
image_2ch = np.stack([t2f_slice_norm, t1c_slice_norm], axis=0)  # (2, H, W)
image_2ch = np.transpose(image_2ch, (1, 2, 0))  # (H, W, 2)

# 应用变换
transformed = transform(image=image_2ch)
image_tensor = transformed['image']  # (2, H, W)

label_transformed = label_transform(image=seg_slice_remap)
label_tensor = label_transformed['image'].squeeze(0).long()  # (H, W)

print(f"预处理后的形状:")
print(f"  图像: {image_tensor.shape}")
print(f"  标签: {label_tensor.shape}")
print(f"  图像值范围: [{image_tensor.min():.3f}, {image_tensor.max():.3f}]")
print(f"  标签值范围: {label_tensor.min()} - {label_tensor.max()}")


## 6. 创建并测试模型


In [None]:
import segmentation_models_pytorch as smp

# 创建UNet模型（使用预训练编码器）
def create_model(model_type='unet', encoder_name='resnet34', num_classes=4, in_channels=2):
    """创建分割模型"""
    if model_type == 'unet':
        model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights='imagenet',  # 使用ImageNet预训练权重
            in_channels=in_channels,
            classes=num_classes,
            activation=None,  # 使用logits
        )
    elif model_type == 'fpn':
        model = smp.FPN(
            encoder_name=encoder_name,
            encoder_weights='imagenet',
            in_channels=in_channels,
            classes=num_classes,
            activation=None,
        )
    elif model_type == 'linknet':
        model = smp.Linknet(
            encoder_name=encoder_name,
            encoder_weights='imagenet',
            in_channels=in_channels,
            classes=num_classes,
            activation=None,
        )
    else:
        raise ValueError(f"不支持的模型类型: {model_type}")
    
    return model

# 创建模型
print("正在创建模型...")
model = create_model(
    model_type=MODEL_TYPE,
    encoder_name=PRETRAINED_ENCODER,
    num_classes=NUM_CLASSES,
    in_channels=2  # FLAIR + T1CE
)

model = model.to(device)
model.eval()

# 打印模型信息
print(f"\n模型信息:")
print(f"  模型类型: {MODEL_TYPE}")
print(f"  编码器: {PRETRAINED_ENCODER}")
print(f"  输入通道: 2 (FLAIR + T1CE)")
print(f"  输出类别: {NUM_CLASSES}")

# 计算参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n参数量:")
print(f"  总参数量: {total_params / 1e6:.2f}M")
print(f"  可训练参数量: {trainable_params / 1e6:.2f}M")


In [None]:
# 测试前向传播
print("\n测试前向传播...")
with torch.no_grad():
    # 添加batch维度
    test_input = image_tensor.unsqueeze(0).to(device)  # (1, 2, H, W)
    print(f"输入形状: {test_input.shape}")
    
    # 前向传播
    test_output = model(test_input)
    print(f"输出形状: {test_output.shape}")
    print(f"期望输出形状: (1, {NUM_CLASSES}, {IMG_SIZE}, {IMG_SIZE})")
    
    # 检查输出
    print(f"\n输出统计:")
    print(f"  值范围: [{test_output.min().item():.3f}, {test_output.max().item():.3f}]")
    print(f"  均值: {test_output.mean().item():.3f}")
    print(f"  标准差: {test_output.std().item():.3f}")
    
    # 应用softmax获取概率
    probs = torch.softmax(test_output, dim=1)
    pred_classes = torch.argmax(probs, dim=1).squeeze(0)  # (H, W)
    
    print(f"\n预测结果:")
    print(f"  预测类别形状: {pred_classes.shape}")
    print(f"  预测类别值范围: {pred_classes.min().item()} - {pred_classes.max().item()}")
    print(f"  预测类别分布: {torch.bincount(pred_classes.flatten())}")
    
print("\n✅ 模型前向传播测试成功！")


## 7. 可视化预测结果


In [None]:
# 可视化输入、真实标签和预测结果
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# 第一行：输入图像
axes[0, 0].imshow(image_tensor[0].cpu().numpy(), cmap='gray')
axes[0, 0].set_title('FLAIR (预处理后)')
axes[0, 0].axis('off')

axes[0, 1].imshow(image_tensor[1].cpu().numpy(), cmap='gray')
axes[0, 1].set_title('T1CE (预处理后)')
axes[0, 1].axis('off')

axes[0, 2].axis('off')  # 空白

# 第二行：标签和预测
axes[1, 0].imshow(label_tensor.cpu().numpy(), cmap='tab10', vmin=0, vmax=3)
axes[1, 0].set_title('真实标签')
axes[1, 0].axis('off')

axes[1, 1].imshow(pred_classes.cpu().numpy(), cmap='tab10', vmin=0, vmax=3)
axes[1, 1].set_title('预测结果（未训练模型）')
axes[1, 1].axis('off')

# 显示每个类别的概率图
class_probs = probs.squeeze(0)  # (C, H, W)
for c in range(1, NUM_CLASSES):  # 跳过背景
    axes[1, 2].imshow(class_probs[c].cpu().numpy(), cmap='hot', vmin=0, vmax=1)
    axes[1, 2].set_title(f'类别 {c} 概率图')
    axes[1, 2].axis('off')
    break  # 只显示第一个非背景类别

plt.tight_layout()
plt.show()

print("\n注意：这是未训练模型的预测结果，仅用于测试模型是否能正常运行。")
print("实际训练后，预测结果应该会更接近真实标签。")


## 8. 测试损失函数


In [None]:
# 测试损失函数
class DiceLoss(nn.Module):
    def __init__(self, num_classes=4, smooth=1e-6):
        super().__init__()
        self.num_classes = num_classes
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred = torch.softmax(pred, dim=1)
        target_one_hot = torch.zeros_like(pred)
        target_one_hot.scatter_(1, target.unsqueeze(1), 1)
        
        dice_scores = []
        for c in range(1, self.num_classes):  # 跳过背景
            pred_c = pred[:, c]
            target_c = target_one_hot[:, c]
            intersection = (pred_c * target_c).sum()
            union = pred_c.sum() + target_c.sum()
            dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
            dice_scores.append(dice)
        
        dice_loss = 1.0 - torch.stack(dice_scores).mean()
        return dice_loss

# 测试损失计算
dice_loss = DiceLoss(num_classes=NUM_CLASSES)
ce_loss = nn.CrossEntropyLoss()

# 准备输入
test_output_batch = test_output  # (1, C, H, W)
test_label_batch = label_tensor.unsqueeze(0).to(device)  # (1, H, W)

# 计算损失
dice = dice_loss(test_output_batch, test_label_batch)
ce = ce_loss(test_output_batch, test_label_batch)
combined = 0.5 * dice + 0.5 * ce

print("损失函数测试:")
print(f"  Dice Loss: {dice.item():.4f}")
print(f"  CrossEntropy Loss: {ce.item():.4f}")
print(f"  组合损失: {combined.item():.4f}")

print("\n✅ 所有测试完成！模型和数据管道工作正常。")
print("可以开始训练了！")
