In [None]:
import os
import random
import time
from glob import glob

import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from tqdm import tqdm

from monai.data import Dataset, CacheDataset
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, ConcatItemsd, ToTensord
)
from monai.utils import set_determinism

from Net import CascadeModel3

import logging
logging.disable(logging.WARNING)

In [None]:
def prepare(in_dir, pixdim=(1.0, 1.0, 1.0), batchsize=5, cache=False):
    set_determinism(seed=42)

    # 构建文件路径
    path_dict = {
        "ptv": "PTV", "bld_pos": "Bladder_pos", "fhl_pos": "Femoral_head_l_pos",
        "fhr_pos": "Femoral_head_r_pos", "si_pos": "Small_intestine_pos",
        "dose": "Dose", "flu": "Fluence", "eptv": "expand_PTV"
    }

    path_train = {
        key: sorted(glob(os.path.join(in_dir, f"TrainData/{subdir}", "*.nii.gz")))
        for key, subdir in path_dict.items()
    }

    # 构建样本
    train_files = [
        {key: path_train[key][i] for key in path_train}
        for i in range(len(path_train["ptv"]))
    ]

    # 定义预处理管道
    train_transforms = Compose([
        LoadImaged(keys=list(path_dict.keys())),
        EnsureChannelFirstd(keys=list(path_dict.keys())),
        ConcatItemsd(
            keys=["ptv", "bld_pos", "fhl_pos", "fhr_pos", "si_pos"],
            name="bev", dim=0
        ),
        ToTensord(keys=["bev", "dose", "flu", "eptv"]),
    ])

    # 训练验证划分
    train_split = int(len(train_files) * 72 / 81)
    train_files_split = train_files[:train_split]
    val_files_split = train_files[train_split:]

    # 返回加载器
    dataset_cls = CacheDataset if cache else Dataset
    train_ds = dataset_cls(data=train_files_split, transform=train_transforms, cache_rate=1.0)
    val_ds = dataset_cls(data=val_files_split, transform=train_transforms, cache_rate=1.0)

    train_loader = DataLoader(train_ds, batch_size=batchsize)
    val_loader = DataLoader(val_ds, batch_size=batchsize)

    return train_loader, val_loader

In [None]:
def train(model, data_in, loss_fn, optimizer, max_epochs, model_dir, test_interval=1, device=torch.device("cuda:0")):
    train_loader, test_loader = data_in

    best_flu_metric = float("inf")
    best_loss = float("inf")

    save_train_loss, save_test_loss = [], []
    save_train_dose_metric, save_train_flu_metric = [], []
    save_test_dose_metric, save_test_flu_metric = [], []

    for epoch in range(max_epochs):
        print(f"\n{'-'*60}\nEpoch {epoch+1}/{max_epochs}")

        model.train()
        epoch_train_loss, epoch_dose_metric, epoch_flu_metric = 0, 0, 0

        for batch in train_loader:
            bev = batch["bev"].to(device)
            flu = batch["flu"].to(device)
            dose = batch["dose"].to(device)
            eptv = batch["eptv"].to(device)

            optimizer.zero_grad()
            pred_dose, pred_flu = model(bev, eptv)

            dose_loss = loss_fn(pred_dose, dose)
            flu_loss = loss_fn(pred_flu, flu)
            total_loss = 5 * dose_loss + flu_loss

            total_loss.backward()
            optimizer.step()

            epoch_train_loss += total_loss.item()
            epoch_dose_metric += dose_loss.item()
            epoch_flu_metric += flu_loss.item()

        # 日志记录
        steps = len(train_loader)
        epoch_train_loss /= steps
        epoch_dose_metric /= steps
        epoch_flu_metric /= steps

        print(f"Train Loss: {epoch_train_loss:.4f}")
        print(f"Dose Metric: {epoch_dose_metric:.4f}")
        print(f"Fluence Metric: {epoch_flu_metric:.4f}")

        save_train_loss.append(epoch_train_loss)
        save_train_dose_metric.append(epoch_dose_metric)
        save_train_flu_metric.append(epoch_flu_metric)

        np.save(os.path.join(model_dir, 'train_loss.npy'), save_train_loss)
        np.save(os.path.join(model_dir, 'train_dose_metric.npy'), save_train_dose_metric)
        np.save(os.path.join(model_dir, 'train_fluence_metric.npy'), save_train_flu_metric)

        # 验证阶段
        if (epoch + 1) % test_interval == 0:
            model.eval()
            val_loss, val_dose_metric, val_flu_metric = 0, 0, 0

            with torch.no_grad():
                for batch in test_loader:
                    bev = batch["bev"].to(device)
                    flu = batch["flu"].to(device)
                    dose = batch["dose"].to(device)
                    eptv = batch["eptv"].to(device)

                    pred_dose, pred_flu = model(bev, eptv)

                    d_loss = loss_fn(pred_dose, dose)
                    f_loss = loss_fn(pred_flu, flu)
                    t_loss = 5 * d_loss + f_loss

                    val_loss += t_loss.item()
                    val_dose_metric += d_loss.item()
                    val_flu_metric += f_loss.item()

            steps = len(test_loader)
            val_loss /= steps
            val_dose_metric /= steps
            val_flu_metric /= steps

            print(f"Test Loss: {val_loss:.4f}")
            print(f"Dose Metric: {val_dose_metric:.4f}")
            print(f"Fluence Metric: {val_flu_metric:.4f}")

            save_test_loss.append(val_loss)
            save_test_dose_metric.append(val_dose_metric)
            save_test_flu_metric.append(val_flu_metric)

            np.save(os.path.join(model_dir, 'test_loss.npy'), save_test_loss)
            np.save(os.path.join(model_dir, 'test_dose_metric.npy'), save_test_dose_metric)
            np.save(os.path.join(model_dir, 'test_fluence_metric.npy'), save_test_flu_metric)

            if val_flu_metric < best_flu_metric:
                best_flu_metric = val_flu_metric
                torch.save(model.state_dict(), os.path.join(model_dir, "best_model.pth"))

            if val_loss < best_loss:
                best_loss = val_loss
                torch.save(model.state_dict(), os.path.join(model_dir, "best_loss_model.pth"))

    print("\n✅ Training Complete.")

In [None]:
def get_random_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
# 配置路径和参数
data_dir = './data'
model_dir = '../model'
device = torch.device("cuda:0")

# 设置种子
get_random_seed(831)

# 准备数据与模型
data_in = prepare(data_dir, batchsize=3, cache=True)
model = CascadeModel3().to(device)
loss_fn = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, amsgrad=True)

# 启动训练
train(model, data_in, loss_fn, optimizer, max_epochs=200, model_dir=model_dir)