In [12]:
import os
import numpy as np
import nibabel as nib
import json
import torch
import torch.nn as nn
import datetime
from typing import List
from torch.utils.data import Dataset, DataLoader
from collections import OrderedDict
import pandas as pd
from tensorboardX import SummaryWriter

In [5]:
# 获取子文件夹中的所有文件
def subfiles(folder: str, join: bool = True, prefix: str = None, suffix: str = None, sort: bool = True) -> List[str]:
    if join:
        l = os.path.join
    else:
        l = lambda x, y: y
    res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
           and (prefix is None or i.startswith(prefix))
           and (suffix is None or i.endswith(suffix))]
    if sort:
        res.sort()
    return res

# 获取 NIfTI 文件列表
def nifti_files(folder: str, join: bool = True, sort: bool = True) -> List[str]:
    return subfiles(folder, join=join, sort=sort, suffix='.nii.gz')

# 创建目录（如果不存在）
def maybe_mkdir_p(directory: str) -> None:
    os.makedirs(directory, exist_ok=True)

# 划分数据

In [12]:
# 获取数据集
ROOT_PATH = "D:/zlx/Medical_Image_Segmentation/data/patches/"
IMAGE_PATH = os.path.join(ROOT_PATH, "image")
LABEL_PATH = os.path.join(ROOT_PATH, "label")

# 获取数据集文件（按病人进行划分）
patients = sorted(set('_'.join(os.path.basename(f).split('_')[:2]) for f in nifti_files(IMAGE_PATH)))
image_files = nifti_files(IMAGE_PATH)
label_files = nifti_files(LABEL_PATH)

# 确保图像和标签数量匹配
if len(image_files) != len(label_files):
    raise ValueError("图像和标签文件数量不匹配，请检查数据。")

# 按病人划分数据集的函数
def split_dataset_by_patient(patients, image_files, label_files, num_folds=3, seed=42):
    # Adjust patient extraction to handle specific patient and patch naming conventions
    patient_ids = sorted(set('_'.join(os.path.basename(f).split('_')[:2]) for f in image_files))
    num_total = len(patients)
    indices = np.arange(num_total)
    np.random.seed(seed)
    np.random.shuffle(indices)

    fold_size = num_total // num_folds
    folds = []
    for i in range(num_folds):
        val_indices = indices[i * fold_size: (i + 1) * fold_size] if i != num_folds - 1 else indices[i * fold_size:]
        train_indices = np.setdiff1d(indices, val_indices)

        train_patients = [patient_ids[idx] for idx in train_indices]
        val_patients = [patient_ids[idx] for idx in val_indices]

        train_images = [f for f in image_files if any(patient in f for patient in train_patients)]
        train_labels = [f for f in label_files if any(patient in f for patient in train_patients)]
        val_images = [f for f in image_files if any(patient in f for patient in val_patients)]
        val_labels = [f for f in label_files if any(patient in f for patient in val_patients)]

        folds.append({
            'train_images': train_images,
            'train_labels': train_labels,
            'val_images': val_images,
            'val_labels': val_labels
        })
    return folds

# 设置交叉验证的折数
num_folds = 5
folds = split_dataset_by_patient(patients, image_files, label_files, num_folds=num_folds)

# 打印每一折的数据大小
for i, fold in enumerate(folds):
    print(f"折 {i + 1}:")
    print("训练集大小:", len(fold['train_images']))
    print("验证集大小:", len(fold['val_images']))

# 保存每一折的数据集列表到文件
OUTPUT_PATH = "D:/zlx/Medical_Image_Segmentation/data/splits/"
maybe_mkdir_p(OUTPUT_PATH)

for i, fold in enumerate(folds):
    split_file_path = os.path.join(OUTPUT_PATH, f"fold_{i + 1}.json")
    with open(split_file_path, 'w') as f:
        json.dump(fold, f, indent=4)
    print(f"保存折 {i + 1} 的数据列表到 {split_file_path}")


折 1:
训练集大小: 5624
验证集大小: 1438
折 2:
训练集大小: 5564
验证集大小: 1498
折 3:
训练集大小: 5586
验证集大小: 1476
折 4:
训练集大小: 5799
验证集大小: 1263
折 5:
训练集大小: 5675
验证集大小: 1387
保存折 1 的数据列表到 D:/zlx/Medical_Image_Segmentation/data/splits/fold_1.json
保存折 2 的数据列表到 D:/zlx/Medical_Image_Segmentation/data/splits/fold_2.json
保存折 3 的数据列表到 D:/zlx/Medical_Image_Segmentation/data/splits/fold_3.json
保存折 4 的数据列表到 D:/zlx/Medical_Image_Segmentation/data/splits/fold_4.json
保存折 5 的数据列表到 D:/zlx/Medical_Image_Segmentation/data/splits/fold_5.json


In [14]:
# 查看保存的 JSON 文件内容
def view_split_file(fold_number):
    split_file_path = os.path.join(OUTPUT_PATH, f"fold_{fold_number}.json")
    if not os.path.exists(split_file_path):
        print(f"折 {fold_number} 的文件不存在。")
        return
    with open(split_file_path, 'r') as f:
        split_data = json.load(f)
        print(f"折 {fold_number} 数据集:")
        headers = ["训练集图像列表", "验证集图像列表"]
        print(f"{headers[0]:<40}{headers[1]:<40}")
        max_length = max(len(split_data['train_images']),  len(split_data['val_images']))
        for i in range(max_length):
            train_image = os.path.basename(split_data['train_images'][i]) if i < len(split_data['train_images']) else ""
            val_image = os.path.basename(split_data['val_images'][i]) if i < len(split_data['val_images']) else ""
            print(f"{train_image:<40}{val_image:<40}")


# 示例：查看折 1 的数据列表
view_split_file(1)

折 1 数据集:
训练集图像列表                                 验证集图像列表                                 
patient_02_patch_0_0_0.nii.gz           patient_01_patch_0_0_0.nii.gz           
patient_02_patch_0_0_128.nii.gz         patient_01_patch_0_0_128.nii.gz         
patient_02_patch_0_0_160.nii.gz         patient_01_patch_0_0_160.nii.gz         
patient_02_patch_0_0_192.nii.gz         patient_01_patch_0_0_192.nii.gz         
patient_02_patch_0_0_32.nii.gz          patient_01_patch_0_0_32.nii.gz          
patient_02_patch_0_0_64.nii.gz          patient_01_patch_0_0_64.nii.gz          
patient_02_patch_0_0_96.nii.gz          patient_01_patch_0_0_96.nii.gz          
patient_02_patch_0_128_0.nii.gz         patient_01_patch_0_128_0.nii.gz         
patient_02_patch_0_128_128.nii.gz       patient_01_patch_0_128_128.nii.gz       
patient_02_patch_0_128_160.nii.gz       patient_01_patch_0_128_160.nii.gz       
patient_02_patch_0_128_192.nii.gz       patient_01_patch_0_128_192.nii.gz       
patient_02_patch_0_

# 数据集类

In [17]:
# 数据集类定义
class MedicalImageDataset(Dataset):
    def __init__(self, image_paths: List[str], label_paths: List[str]):
        self.image_paths = image_paths
        self.label_paths = label_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label_path = self.label_paths[idx]

        # 读取图像和标签数据
        image = nib.load(image_path).get_fdata()
        label = nib.load(label_path).get_fdata()

        # 转换为 tensor
        image_tensor = torch.tensor(image, dtype=torch.float32).unsqueeze(0)  # 添加通道维度
        label_tensor = torch.tensor(label, dtype=torch.long).unsqueeze(0)

        return image_tensor, label_tensor

In [25]:
# 示例：使用数据集类
fold_number = 1
json_path = "D:/zlx/Medical_Image_Segmentation/data/splits/"
split_file_path = os.path.join(json_path, f"fold_{fold_number}.json")
if os.path.exists(split_file_path):
    with open(split_file_path, 'r') as f:
        split_data = json.load(f)
        train_dataset = MedicalImageDataset(split_data['train_images'], split_data['train_labels'])
        val_dataset = MedicalImageDataset(split_data['val_images'], split_data['val_labels'])
        
        # 创建数据加载器
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False)
        
        # 打印一些示例数据
        for images, labels in train_loader:
            print("图像批次形状:", images.shape)
            print("标签批次形状:", labels.shape)
            print("图像文件名:", [train_dataset.image_paths[idx] for idx in range(len(images))])
            print("标签文件名:", [train_dataset.label_paths[idx] for idx in range(len(labels))])
            break

图像批次形状: torch.Size([4, 1, 128, 128, 64])
标签批次形状: torch.Size([4, 1, 128, 128, 64])
图像文件名: ['D:/zlx/Medical_Image_Segmentation/data/patches/image\\patient_02_patch_0_0_0.nii.gz', 'D:/zlx/Medical_Image_Segmentation/data/patches/image\\patient_02_patch_0_0_128.nii.gz', 'D:/zlx/Medical_Image_Segmentation/data/patches/image\\patient_02_patch_0_0_160.nii.gz', 'D:/zlx/Medical_Image_Segmentation/data/patches/image\\patient_02_patch_0_0_192.nii.gz']
标签文件名: ['D:/zlx/Medical_Image_Segmentation/data/patches/label\\patient_02_label_patch_0_0_0.nii.gz', 'D:/zlx/Medical_Image_Segmentation/data/patches/label\\patient_02_label_patch_0_0_128.nii.gz', 'D:/zlx/Medical_Image_Segmentation/data/patches/label\\patient_02_label_patch_0_0_160.nii.gz', 'D:/zlx/Medical_Image_Segmentation/data/patches/label\\patient_02_label_patch_0_0_192.nii.gz']


# 模型

In [24]:
class Unet_module(nn.Module):
    
    def __init__(self, kernel_size, channel_list, down_up='down'):
        super(Unet_module, self).__init__()
        self.conv1 = nn.Conv3d(channel_list[0], channel_list[1], kernel_size, 1, (kernel_size - 1) // 2)
        self.conv2 = nn.Conv3d(channel_list[1], channel_list[2], kernel_size, 1, (kernel_size - 1) // 2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.bn1 = nn.BatchNorm3d(channel_list[1])
        self.bn2 = nn.BatchNorm3d(channel_list[2])

        if down_up == 'down':
            self.sample = nn.MaxPool3d(2, 2)
        else:
            self.sample = nn.Sequential(nn.ConvTranspose3d(channel_list[2], channel_list[2], kernel_size,
                                                        2, (kernel_size - 1) // 2, 1),
                                        nn.ReLU())
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.bn1(x)

        x = self.conv2(x)
        x = self.relu2(x)
        x = self.bn2(x)

        next_layer = self.sample(x)

        return next_layer, x

class UNet(nn.Module):

    def __init__(self, kernel_size, in_channel=1, out_channel=2):
        super(UNet, self).__init__()

        self.encoder1 = Unet_module(kernel_size, (in_channel, 32, 64))
        self.encoder2 = Unet_module(kernel_size, (64, 64, 128))
        self.encoder3 = Unet_module(kernel_size, (128,128,256))

        self.decoder1 = Unet_module(kernel_size, (256, 256, 512), down_up='up')
        self.decoder2 = Unet_module(kernel_size, (768, 256, 256), down_up='up')
        self.decoder3 = Unet_module(kernel_size, (384, 128, 128), down_up='up')
        self.decoder4 = Unet_module(kernel_size, (192, 64, 64), down_up='up')

        self.last_conv = nn.Conv3d(64, out_channel, 1, 1, bias=False)

    def forward(self, x):
        # print('input', x.shape)
        # 下采样路径
        x, skip1 = self.encoder1(x)
        x, skip2 = self.encoder2(x)
        x, skip3 = self.encoder3(x)

        # 上采样路径
        x, _ = self.decoder1(x)
        x = torch.cat([x, skip3], dim=1)
        x, _ = self.decoder2(x)
        x = torch.cat([x, skip2], dim=1)
        x, _ = self.decoder3(x)
        x = torch.cat([x, skip1], dim=1)
        _, x = self.decoder4(x)

        # 最终卷积层
        output = self.last_conv(x)
        
        return output

In [30]:
# 查看模型结构的函数
def print_model_summary(model):
    print("模型结构:")
    for name, module in model.named_children():
        print(f"{name}: {module}")
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"总参数数量: {total_params}")
          
model = UNet(kernel_size=3)
print_model_summary(model)

模型结构:
encoder1: Unet_module(
  (conv1): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv2): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (relu1): ReLU()
  (relu2): ReLU()
  (bn1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (sample): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
encoder2: Unet_module(
  (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv2): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (relu1): ReLU()
  (relu2): ReLU()
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (sample): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

In [33]:
def hook_fn(module, input, output):
    print(f"{module.__class__.__name__} output shape: {output.shape}")

model = UNet(kernel_size=3)

# 注册钩子到每个卷积层
for name, layer in model.named_modules():
    if isinstance(layer, (nn.Conv3d, nn.ConvTranspose3d)):
        layer.register_forward_hook(hook_fn)

# 随机生成输入，测试钩子功能
dummy_input = torch.randn(1, 1, 128, 128, 64)  # 批次大小为1，通道为1，空间尺寸为64x64x64
output = model(dummy_input)

Conv3d output shape: torch.Size([1, 32, 128, 128, 64])
Conv3d output shape: torch.Size([1, 64, 128, 128, 64])
Conv3d output shape: torch.Size([1, 64, 64, 64, 32])
Conv3d output shape: torch.Size([1, 128, 64, 64, 32])
Conv3d output shape: torch.Size([1, 128, 32, 32, 16])
Conv3d output shape: torch.Size([1, 256, 32, 32, 16])
Conv3d output shape: torch.Size([1, 256, 16, 16, 8])
Conv3d output shape: torch.Size([1, 512, 16, 16, 8])
ConvTranspose3d output shape: torch.Size([1, 512, 32, 32, 16])
Conv3d output shape: torch.Size([1, 256, 32, 32, 16])
Conv3d output shape: torch.Size([1, 256, 32, 32, 16])
ConvTranspose3d output shape: torch.Size([1, 256, 64, 64, 32])
Conv3d output shape: torch.Size([1, 128, 64, 64, 32])
Conv3d output shape: torch.Size([1, 128, 64, 64, 32])
ConvTranspose3d output shape: torch.Size([1, 128, 128, 128, 64])
Conv3d output shape: torch.Size([1, 64, 128, 128, 64])
Conv3d output shape: torch.Size([1, 64, 128, 128, 64])
ConvTranspose3d output shape: torch.Size([1, 64, 256

# 模型初值化

In [22]:
# 权重初始化函数（使用Kaiming初始化）
def weights_init_kaiming(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)

In [26]:
# 在模型初始化之后，应用初始化函数
model = UNet(kernel_size=3)
model.apply(weights_init_kaiming)

# 检查模型的初始化情况
for name, param in model.named_parameters():
    if 'weight' in name:
        print(f"{name} - mean: {param.mean().item():.4f}, std: {param.std().item():.4f}")
    if 'bias' in name and param is not None:
        print(f"{name} - bias mean: {param.mean().item():.4f}")

encoder1.conv1.weight - mean: 0.0056, std: 0.2731
encoder1.conv1.bias - bias mean: -0.0085
encoder1.conv2.weight - mean: -0.0001, std: 0.0479
encoder1.conv2.bias - bias mean: 0.0011
encoder1.bn1.weight - mean: 0.9974, std: 0.0247
encoder1.bn1.bias - bias mean: 0.0000
encoder1.bn2.weight - mean: 0.9990, std: 0.0203
encoder1.bn2.bias - bias mean: 0.0000
encoder2.conv1.weight - mean: -0.0001, std: 0.0340
encoder2.conv1.bias - bias mean: -0.0000
encoder2.conv2.weight - mean: 0.0001, std: 0.0341
encoder2.conv2.bias - bias mean: -0.0019
encoder2.bn1.weight - mean: 1.0046, std: 0.0228
encoder2.bn1.bias - bias mean: 0.0000
encoder2.bn2.weight - mean: 0.9997, std: 0.0219
encoder2.bn2.bias - bias mean: 0.0000
encoder3.conv1.weight - mean: -0.0000, std: 0.0240
encoder3.conv1.bias - bias mean: -0.0011
encoder3.conv2.weight - mean: 0.0000, std: 0.0240
encoder3.conv2.bias - bias mean: 0.0001
encoder3.bn1.weight - mean: 0.9963, std: 0.0202
encoder3.bn1.bias - bias mean: 0.0000
encoder3.bn2.weight - m

# Loss

In [None]:
class SoftDiceLoss(nn.Module):

    def __init__(self, activation: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1.,
                 ddp: bool = True):
        """
        初始化SoftDiceLoss类
        :param activation: 可选的激活函数（如softmax），用于将模型输出转换为概率
        :param batch_dice: 是否在批次级别计算Dice系数，默认为False
        :param do_bg: 是否计算背景类的Dice系数，默认为True
        :param smooth: 平滑因子，防止除零错误，默认为1.0
        :param ddp: 是否使用分布式数据并行计算，默认为True
        """
        super(SoftDiceLoss, self).__init__()

        self.do_bg = do_bg
        self.batch_dice = batch_dice
        self.activation = activation
        self.smooth = smooth
        self.ddp = ddp

    def forward(self, x, y, loss_mask=None):
        """
        前向传播方法
         :param x: 模型的输出，通常是logits
         :param y: 真实标签
         :param loss_mask: 可选的损失掩码，用于指定有效区域
         :return: 计算得到的Dice损失
        """
        if self.activation is not None:
            x = self.activation(x)
        # 确定需要进行求和的轴，通常是空间维度
        axes = tuple(range(2, x.ndim))

        with torch.no_grad():
            # 如果模型输出和真实标签的维度不一致，则调整真实标签的形状
            if x.ndim != y.ndim:
                #将y变形为（b, c(设为1), x, y(,z)）
                y = y.view((y.shape[0], 1, *y.shape[1:]))

            if x.shape == y.shape:

                y_onehot = y
            else:
                y_onehot = torch.zeros(x.shape, device=x.device, dtype=torch.bool)
                y_onehot.scatter_(1, y.long(), 1)

            if not self.do_bg:
                y_onehot = y_onehot[:, 1:]

            sum_gt = y_onehot.sum(axes) if loss_mask is None else (y_onehot * loss_mask).sum(axes)

        if not self.do_bg:
            x = x[:, 1:]

        if loss_mask is None:
            intersect = (x * y_onehot).sum(axes)
            sum_pred = x.sum(axes)
        else:
            intersect = (x * y_onehot * loss_mask).sum(axes)
            sum_pred = (x * loss_mask).sum(axes)

        if self.batch_dice:
            if self.ddp:
                intersect = AllGatherGrad.apply(intersect).sum(0)
                sum_pred = AllGatherGrad.apply(sum_pred).sum(0)
                sum_gt = AllGatherGrad.apply(sum_gt).sum(0)

            intersect = intersect.sum(0)
            sum_pred = sum_pred.sum(0)
            sum_gt = sum_gt.sum(0)
            
        # 计算每个类别的 Dice 系数
        dc_per_class = (2 * intersect + self.smooth) / (torch.clip(sum_gt + sum_pred + self.smooth, min=1e-8))
        
        # 计算总的平均 Dice 系数
        dc = dc_per_class.mean()

        return 1 - dc, 1 - dc_per_class

# 评估指标

In [6]:
def calculate_dice(pred, target, smooth=1e-8):
    intersection = (pred * target).sum(dim=(2, 3, 4))
    pred_sum = pred.sum(dim=(2, 3, 4))
    target_sum = target.sum(dim=(2, 3, 4))
    dice = (2 * intersection + smooth) / (pred_sum + target_sum + smooth)
    return dice

def calculate_iou(pred, target, smooth=1e-8):
    intersection = (pred * target).sum(dim=(2, 3, 4))
    union = pred.sum(dim=(2, 3, 4)) + target.sum(dim=(2, 3, 4)) - intersection
    iou = intersection / (union + smooth)
    return iou

def calculate_precision(pred, target, smooth=1e-8):
    tp = (pred * target).sum(dim=(2, 3, 4))
    fp = pred.sum(dim=(2, 3, 4)) - tp
    precision = tp / (tp + fp + smooth)
    return precision

def calculate_recall(pred, target, smooth=1e-8):
    tp = (pred * target).sum(dim=(2, 3, 4))
    fn = target.sum(dim=(2, 3, 4)) - tp
    recall = tp / (tp + fn + smooth)
    return recall

# 综合评估指标计算函数
def calculate_metrics(pred, target, threshold=0.5, activation=F.sig):
    # 使用 Softmax 激活函数
    pred = F.softmax(pred, dim=1)

    # 将预测概率二值化
    pred_bin = (pred > threshold).float()

    # 分别计算各个指标
    dice = calculate_dice(pred_bin, target)
    iou = calculate_iou(pred_bin, target)
    precision = calculate_precision(pred_bin, target)
    recall = calculate_recall(pred_bin, target)

    # 返回每个类别和总体的指标
    return {
        "dice": dice.mean().item(),
        "iou": iou.mean().item(),
        "precision": precision.mean().item(),
        "recall": recall.mean().item(),
        "dice_per_class": dice.tolist(),
        "iou_per_class": iou.tolist(),
        "precision_per_class": precision.tolist(),
        "recall_per_class": recall.tolist()
    }

# 日志

In [8]:
class Train_Logger():
    def __init__(self, save_path, save_name):
        self.log = None
        self.summary = None
        self.save_path = save_path
        self.save_name = save_name

    def update(self, epoch, train_log, val_log):
        item = OrderedDict({'epoch': epoch})
        item.update(train_log)
        item.update(val_log)
        # 打印训练和验证日志
        print("[0;33mTrain:[0m", train_log)
        print("[0;33mValid:[0m", val_log)
        self.update_csv(item)
        self.update_tensorboard(item)

    def update_csv(self, item):
        tmp = pd.DataFrame(item, index=[0])
        if self.log is not None:
            self.log = pd.concat([self.log, tmp], ignore_index=True)
        else:
            self.log = tmp
        self.log.to_csv('%s/%s.csv' % (self.save_path, self.save_name), index=False)

    def update_tensorboard(self, item):
        if self.summary is None:
            self.summary = SummaryWriter('%s/' % self.save_path)
        epoch = item['epoch']
        for key, value in item.items():
            if key != 'epoch':
                self.summary.add_scalar(key, value, epoch)
        with open(self.log_file, 'a') as f:
            for i, metrics in enumerate(metrics_per_class):
                f.write(f"Class {i} - epoch {epoch}: Dice: {metrics['dice']:.4f}, IoU: {metrics['iou']:.4f}, Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}")

# 训练

In [14]:
# 模型训练函数
def train(data_path, num_epochs=10, learning_rate=0.001, batch_size=4, log_path="logs", model_save_path='saved_models'):
    
    #配置设备
    device_ids = [0, 1]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)
    
    # 设置数据集路径
    image_path = os.path.join(data_path, "image")
    label_path = os.path.join(data_path, "label")

    # 加载数据集
    train_dataset = MedicalImageDataset(image_path, label_path, mode="train")
    val_dataset = MedicalImageDataset(image_path, label_path, mode="val")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # 创建模型
    model = UNet(kernel_size=3, in_channel=1, out_channel=args.n_classes).to(device)
    model.apply(weights_init_kaiming)
    model = nn.DataParallel(model, device_ids=device_ids)

    # 初始化日志记录器
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    log_name = f"training_log_{current_time}"
    logger = Train_Logger(save_path=log_path, save_name=log_name)

    # 定义损失函数和优化器
    criterion = SoftDiceLoss(activation=torch.sigmoid)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    best_dice = 0.0
    # 训练循环
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss, class_loss = criterion(outputs, labels.float())
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        train_log = {'train_loss': epoch_loss}

        # 评估模型在验证集上的表现
        model.eval()
        val_metrics = {"dice": 0, "iou": 0, "precision": 0, "recall": 0, "val_loss": 0}
        with torch.no_grad():
            for val_images, val_labels in val_loader:
                val_images, val_labels = val_images.to(device), val_labels.to(device)
                val_outputs = model(val_images)
                loss = criterion(val_outputs, val_labels.float())
                val_loss, val_class_loss = criterion(val_outputs, val_labels.float())
                val_metrics['val_loss'] += val_loss.item() * val_images.size(0)
                metrics = calculate_metrics(val_outputs, val_labels.float())
                for key in ['dice', 'iou', 'precision', 'recall']:
                    val_metrics[key] += metrics[key] * val_images.size(0)

        # 计算平均评估指标
        val_metrics['val_loss'] /= len(val_loader.dataset)
        for key in ['dice', 'iou', 'precision', 'recall']:
            val_metrics[key] /= len(val_loader.dataset)
        print(f"Validation Metrics - Loss: {val_metrics['val_loss']:.4f}, Dice: {val_metrics['dice']:.4f}, IoU: {val_metrics['iou']:.4f}, Precision: {val_metrics['precision']:.4f}, Recall: {val_metrics['recall']:.4f}")

        # 保存最佳模型
        if val_metrics['dice'] > best_dice:
            best_dice = val_metrics['dice']
            best_model_path = os.path.join(model_save_path, 'best_model.pth')  # 保存模型的路径
            torch.save(model.state_dict(), best_model_path)
            print(f"Best model saved with Dice: {best_dice:.4f}")

        # 更新日志
        val_log = {key: val_metrics[key] for key in ['val_loss', 'dice', 'iou', 'precision', 'recall']}
        logger.update(epoch + 1, train_log, val_log)
        
        # 保存每个类别的损失到 CSV
        for i in range(len(class_loss)):
            class_log = OrderedDict({
                'epoch': epoch + 1,
                'class': i,
                'train_class_loss': class_loss[i].item(),
                'val_class_loss': val_class_loss[i].item() if i < len(val_class_loss) else None
            })
            logger.update_csv(class_log)
        
        # 保存每个类别的指标到 CSV
        for i in range(len(metrics['dice_per_class'])):
            class_log = OrderedDict({
                'epoch': epoch + 1,
                'class': i,
                'dice': metrics['dice_per_class'][i],
                'iou': metrics['iou_per_class'][i],
                'precision': metrics['precision_per_class'][i],
                'recall': metrics['recall_per_class'][i]
            })
            logger.update_csv(class_log)

    print("训练完成!")

In [None]:
data_path = "D:/zlx/Medical_Image_Segmentation/data/patches"
model_save_path = 'saved_models'  # 你可以设置保存最佳模型的路径
log_path =
train(data_path=data_path, num_epochs=500, learning_rate=0.001, batch_size=4, log_path=log_path, model_save_path=model_save_path)