# DDM² Stage 3: Train Diffusion Model

使用Stage 2的匹配结果训练扩散模型。

训练目标: 给定噪声图像和时间步t，预测去噪后的图像。

确保Stage 2已完成，修改配置后运行。

In [None]:
#==============================================================================
# 基础路径配置
#==============================================================================

PROJECT_ROOT = "/host/c/Users/ROG/Documents/Github/DDM2_new"         # 项目根目录
CONFIG_FILE = "config/ct_denoise.json" # 配置文件路径 (相对于PROJECT_ROOT)
GPU_ID = "0"                           # 使用的GPU编号

#==============================================================================
# 数据路径
#==============================================================================

EXCEL_PATH = "/host/d/file/fixedCT_static_simulation_train_test_gaussian_local.xlsx"  # 数据索引Excel文件
DATA_ROOT = "/host/d/file/simulation/"                    # CT数据根目录
TEACHER_N2N_ROOT = "/host/d/file/pre/noise2noise/pred_images/"  # Teacher N2N预测结果目录
TEACHER_N2N_EPOCH = 78                                    # 使用Teacher N2N的哪个epoch
BINS_FILE = "/host/d/file/histogram_equalization/bins.npy"        # 直方图均衡化bins文件
BINS_MAPPED_FILE = "/host/d/file/histogram_equalization/bins_mapped.npy"  # 直方图均衡化映射文件

# Stage2 匹配文件 (由Stage2生成)
# 文件格式: 每行 volume_idx_slice_idx_t，例如 "3_45_156" 表示第3个case的第45层切片对应t=156
STAGE2_FILE = "experiments/ct_denoise_teacher/stage2_matched.txt"

#==============================================================================
# 数据集划分
#
# 数据组织结构:
#   - batch: 一组病例 (例如batch 0-4为训练集，batch 5为验证集)
#   - volume: 一个病例/case，即一个完整的3D CT扫描
#   - slice: volume中的一个2D切片
#
# 例如: volume_idx=3, slice_idx=45 表示第3个病例的第45层切片
#==============================================================================

TRAIN_BATCHES = [0, 1, 2, 3, 4]  # 训练集使用的batch编号
VAL_BATCHES = [5]                # 验证集使用的batch编号

SLICE_RANGE = [30, 80]  # 每个volume使用的切片范围 [start, end)
                        # 例如[30, 80]表示使用第30-79层，共50层

# 验证时的采样 (训练过程中验证用少量样本即可，节省时间)
VAL_VOLUME_IDX = "all"      # 验证用哪个volume(case): 数字表示具体case编号
VAL_SLICE_IDX = 25    # 验证用哪些slice: 列表如 [25] 表示只用第25层
                         # 训练时只需少量样本看趋势，完整评估在test阶段做

#==============================================================================
# 预处理参数
#==============================================================================

HU_MIN = -1000.0  # CT值(HU)下限，低于此值会被clip
HU_MAX = 2000.0   # CT值(HU)上限，高于此值会被clip
                  # 典型值: 空气=-1000, 水=0, 骨骼=1000+

HISTOGRAM_EQUALIZATION = True  # 是否启用直方图均衡化 (需与Stage2一致)

#==============================================================================
# 训练参数
#==============================================================================

N_ITER = 400000     # 总训练迭代次数 (iteration)
                    # 一个iteration = 处理一个batch
                    

VAL_FREQ = 1000      # 每多少iter验证一次
                    # 验证时会保存去噪结果图片到 results/ 目录

SAVE_FREQ = 10000   # 每多少iter保存一次checkpoint
                    # checkpoint保存在 experiments/{name}/checkpoint/

PRINT_FREQ = 100    # 每多少iter打印一次loss

LEARNING_RATE = 1e-4  # 学习率
                      

BATCH_SIZE = 1      # 每次训练的样本数
                    # 512x512图像较大，通常用1-4
                    # 增大batch需要更多显存

#==============================================================================
# 断点续训 (可选)
#==============================================================================

RESUME_STATE = None  # 断点续训: 设为checkpoint路径
                     # 例如: "experiments/ct_denoise/checkpoint/I100000_E29"
                     # None表示从头训练

In [2]:
# 更新config文件
import os, json
os.chdir(PROJECT_ROOT)

with open(CONFIG_FILE, 'r') as f:
    config = json.load(f)

# 更新 datasets.train
config['datasets']['train']['dataroot'] = EXCEL_PATH
config['datasets']['train']['data_root'] = DATA_ROOT
config['datasets']['train']['teacher_n2n_root'] = TEACHER_N2N_ROOT
config['datasets']['train']['teacher_n2n_epoch'] = TEACHER_N2N_EPOCH
config['datasets']['train']['bins_file'] = BINS_FILE
config['datasets']['train']['bins_mapped_file'] = BINS_MAPPED_FILE
config['datasets']['train']['train_batches'] = TRAIN_BATCHES
config['datasets']['train']['val_batches'] = VAL_BATCHES
config['datasets']['train']['slice_range'] = SLICE_RANGE
config['datasets']['train']['HU_MIN'] = HU_MIN
config['datasets']['train']['HU_MAX'] = HU_MAX
config['datasets']['train']['histogram_equalization'] = HISTOGRAM_EQUALIZATION
config['datasets']['train']['batch_size'] = BATCH_SIZE

# 更新 datasets.val
config['datasets']['val']['dataroot'] = EXCEL_PATH
config['datasets']['val']['data_root'] = DATA_ROOT
config['datasets']['val']['teacher_n2n_root'] = TEACHER_N2N_ROOT
config['datasets']['val']['teacher_n2n_epoch'] = TEACHER_N2N_EPOCH
config['datasets']['val']['bins_file'] = BINS_FILE
config['datasets']['val']['bins_mapped_file'] = BINS_MAPPED_FILE
config['datasets']['val']['train_batches'] = VAL_BATCHES
config['datasets']['val']['val_batches'] = VAL_BATCHES
config['datasets']['val']['slice_range'] = SLICE_RANGE
config['datasets']['val']['val_volume_idx'] = VAL_VOLUME_IDX
config['datasets']['val']['val_slice_idx'] = VAL_SLICE_IDX
config['datasets']['val']['HU_MIN'] = HU_MIN
config['datasets']['val']['HU_MAX'] = HU_MAX
config['datasets']['val']['histogram_equalization'] = HISTOGRAM_EQUALIZATION

# 更新训练参数
config['train']['n_iter'] = N_ITER
config['train']['val_freq'] = VAL_FREQ
config['train']['save_checkpoint_freq'] = SAVE_FREQ
config['train']['print_freq'] = PRINT_FREQ
config['train']['optimizer']['lr'] = LEARNING_RATE

# 更新路径
config['stage2_file'] = STAGE2_FILE
config['path']['resume_state'] = RESUME_STATE

with open(CONFIG_FILE, 'w') as f:
    json.dump(config, f, indent=4)

print(f"Config已更新: {CONFIG_FILE}")
print(f"Stage2文件: {STAGE2_FILE}")
print(f"\n训练参数:")
print(f"  总迭代次数: {N_ITER:,}")
print(f"  验证频率: 每{VAL_FREQ}iter")
print(f"  保存频率: 每{SAVE_FREQ:,}iter")
print(f"  学习率: {LEARNING_RATE}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"\n断点续训: {RESUME_STATE if RESUME_STATE else '从头训练'}")

Config已更新: config/ct_denoise.json
Stage2文件: experiments/ct_denoise_teacher/stage2_matched.txt

训练参数:
  总迭代次数: 200,000
  验证频率: 每1000iter
  保存频率: 每10,000iter
  学习率: 0.0001
  Batch size: 1

断点续训: 从头训练


In [3]:
# 检查stage2文件是否存在
if os.path.exists(STAGE2_FILE):
    with open(STAGE2_FILE, 'r') as f:
        lines = f.readlines()
    print(f"✓ Stage2文件存在: {len(lines)} 个样本")
    
    # 预估训练epoch数
    samples_per_epoch = len(lines)
    total_epochs = N_ITER * BATCH_SIZE / samples_per_epoch
    print(f"  预计训练 {total_epochs:.1f} 个epoch")
else:
    print(f"⚠ Stage2文件不存在: {STAGE2_FILE}")
    print("请先运行 01_stage2_state_matching.ipynb")

✓ Stage2文件存在: 3500 个样本
  预计训练 57.1 个epoch


In [4]:
# 训练扩散模型
# 训练日志保存在 experiments/{name}/logs/
# 验证结果保存在 experiments/{name}/results/
# Checkpoint保存在 experiments/{name}/checkpoint/
!CUDA_VISIBLE_DEVICES={GPU_ID} python3 train_diff_model.py -p train -c {CONFIG_FILE}

11.7
export CUDA_VISIBLE_DEVICES=0
26-01-03 15:15:40.221 - INFO:   name: ct_denoise
  phase: train
  gpu_ids: [0]
  path:[
    log: experiments/ct_denoise_260103_151540/logs
    tb_logger: experiments/ct_denoise_260103_151540/tb_logger
    results: experiments/ct_denoise_260103_151540/results
    checkpoint: experiments/ct_denoise_260103_151540/checkpoint
    resume_state: None
    experiments: experiments
    experiments_root: experiments/ct_denoise_260103_151540
  ]
  datasets:[
    train:[
      name: ct
      dataroot: /host/d/file/fixedCT_static_simulation_train_test_gaussian_local.xlsx
      data_root: /host/d/file/simulation/
      train_batches: [0, 1, 2, 3, 4]
      val_batches: [5]
      valid_mask: [0, 1000]
      slice_range: [30, 80]
      phase: train
      padding: 3
      val_volume_idx: all
      val_slice_idx: all
      batch_size: 1
      in_channel: 1
      num_workers: 4
      use_shuffle: True
      image_size: 512
      lr_flip: 0.5
      HU_MIN: -1000.0
      HU