# DDM² Stage 3: Diffusion Model Training

训练扩散去噪模型

In [None]:
#==============================================================================
# 路径配置 - 换电脑时修改这里
#==============================================================================

# 项目根目录
PROJECT_ROOT = "/host/d/ddm2"

# 数据路径
EXCEL_PATH = "/host/d/file/fixedCT_static_simulation_train_test_gaussian_local.xlsx"
DATA_ROOT = "/host/d/file/simulation/"

# Teacher N2N 预测结果
TEACHER_N2N_ROOT = "/host/d/file/pre/noise2noise/pred_images/"
TEACHER_N2N_EPOCH = 78

# 直方图均衡化文件
BINS_FILE = "/host/d/file/histogram_equalization/bins.npy"
BINS_MAPPED_FILE = "/host/d/file/histogram_equalization/bins_mapped.npy"

# 配置文件路径
CONFIG_FILE = f"{PROJECT_ROOT}/config/ct_denoise.json"

# Stage2 文件路径 (由01_stage2生成)
STAGE2_FILE = f"{PROJECT_ROOT}/experiments/ct_denoise_stage2/stage2_matched.txt"

# 实验名称 (输出目录会自动加时间戳)
EXPERIMENT_NAME = "ct_denoise"

#==============================================================================
# 训练参数
#==============================================================================

# 数据集划分
TRAIN_BATCHES = [0, 1, 2, 3, 4]  # 训练集
VAL_BATCHES = [5]                 # 验证集

# 切片范围
SLICE_RANGE = [30, 80]

# HU值范围
HU_MIN = -1000.0
HU_MAX = 2000.0

# 验证集采样
VAL_VOLUME_IDX = 8
VAL_SLICE_IDX = [25]

# 训练迭代次数
N_ITER = 100000
PRINT_FREQ = 100
VAL_FREQ = 1000
SAVE_FREQ = 10000

# 恢复训练 (设为checkpoint路径可恢复，如 "experiments/xxx/checkpoint/latest")
RESUME_STATE = None

#==============================================================================
# GPU配置
#==============================================================================

GPU_ID = "0"

In [None]:
import os
import sys
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
sys.path.insert(0, PROJECT_ROOT)
os.chdir(PROJECT_ROOT)

import json
import torch
import logging
from datetime import datetime

import data as Data
import model as Model
import core.logger as Logger
import core.metrics as Metrics

print(f"工作目录: {os.getcwd()}")
print(f"GPU: {GPU_ID}")

In [None]:
# 加载并更新配置
with open(CONFIG_FILE, 'r') as f:
    opt = json.load(f)

# 更新所有路径
for phase in ['train', 'val']:
    opt['datasets'][phase]['dataroot'] = EXCEL_PATH
    opt['datasets'][phase]['data_root'] = DATA_ROOT
    opt['datasets'][phase]['train_batches'] = TRAIN_BATCHES
    opt['datasets'][phase]['val_batches'] = VAL_BATCHES
    opt['datasets'][phase]['slice_range'] = SLICE_RANGE
    opt['datasets'][phase]['HU_MIN'] = HU_MIN
    opt['datasets'][phase]['HU_MAX'] = HU_MAX
    opt['datasets'][phase]['bins_file'] = BINS_FILE
    opt['datasets'][phase]['bins_mapped_file'] = BINS_MAPPED_FILE
    opt['datasets'][phase]['teacher_n2n_root'] = TEACHER_N2N_ROOT
    opt['datasets'][phase]['teacher_n2n_epoch'] = TEACHER_N2N_EPOCH

# 验证集采样设置
opt['datasets']['val']['val_volume_idx'] = VAL_VOLUME_IDX
opt['datasets']['val']['val_slice_idx'] = VAL_SLICE_IDX

# 训练参数
opt['train']['n_iter'] = N_ITER
opt['train']['print_freq'] = PRINT_FREQ
opt['train']['val_freq'] = VAL_FREQ
opt['train']['save_checkpoint_freq'] = SAVE_FREQ

# Stage2文件
opt['stage2_file'] = STAGE2_FILE

# 检查stage2文件
if not os.path.exists(STAGE2_FILE):
    raise FileNotFoundError(f"Stage2文件不存在: {STAGE2_FILE}\n请先运行 01_stage2_state_matching.ipynb")

print(f"Stage2文件: {STAGE2_FILE} (共 {sum(1 for _ in open(STAGE2_FILE))} 行)")

In [None]:
# 设置实验目录
ts = datetime.now().strftime('%y%m%d_%H%M%S')
exp_root = f"{PROJECT_ROOT}/experiments/{EXPERIMENT_NAME}_{ts}"

opt['name'] = EXPERIMENT_NAME
opt['path'] = {
    'log': f"{exp_root}/logs",
    'results': f"{exp_root}/results",
    'checkpoint': f"{exp_root}/checkpoint",
    'resume_state': RESUME_STATE,
    'experiments_root': exp_root
}

for p in [opt['path']['log'], opt['path']['results'], opt['path']['checkpoint']]:
    os.makedirs(p, exist_ok=True)

opt = Logger.dict_to_nonedict(opt)
print(f"实验目录: {exp_root}")

In [None]:
# 设置日志
Logger.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
logger = logging.getLogger('base')

# CUDA
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

# 数据集
train_set = Data.create_dataset(opt['datasets']['train'], 'train', stage2_file=STAGE2_FILE)
train_loader = Data.create_dataloader(train_set, opt['datasets']['train'], 'train')
val_set = Data.create_dataset(opt['datasets']['val'], 'val', stage2_file=STAGE2_FILE)
val_loader = Data.create_dataloader(val_set, opt['datasets']['val'], 'val')

logger.info(f'训练集: {len(train_set)}, 验证集: {len(val_set)}')

In [None]:
# 创建模型
diffusion = Model.create_model(opt)
logger.info('模型初始化完成')

# 打印参数量
total_params = sum(p.numel() for p in diffusion.netG.parameters())
print(f"模型参数量: {total_params:,}")

current_step = diffusion.begin_step
current_epoch = diffusion.begin_epoch
n_iter = opt['train']['n_iter']

diffusion.set_new_noise_schedule(opt['model']['beta_schedule']['train'], schedule_phase='train')
print(f"训练目标: {n_iter} iterations")

In [None]:
# 训练循环
from tqdm.auto import tqdm

pbar = tqdm(total=n_iter - current_step, desc="Training")

try:
    while current_step < n_iter:
        current_epoch += 1
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > n_iter:
                break
            
            diffusion.feed_data(train_data)
            diffusion.optimize_parameters()
            pbar.update(1)
            
            if current_step % opt['train']['print_freq'] == 0:
                logs = diffusion.get_current_log()
                logger.info(f'<epoch:{current_epoch}, iter:{current_step}> l_pix: {logs["l_pix"]:.4e}')
                pbar.set_postfix({'loss': f"{logs['l_pix']:.4e}"})
            
            if current_step % opt['train']['val_freq'] == 0:
                result_path = f"{opt['path']['results']}/{current_epoch}"
                os.makedirs(result_path, exist_ok=True)
                diffusion.set_new_noise_schedule(opt['model']['beta_schedule']['val'], schedule_phase='val')
                for idx, val_data in enumerate(val_loader):
                    diffusion.feed_data(val_data)
                    diffusion.test(continous=True)
                    visuals = diffusion.get_current_visuals()
                    Metrics.save_img(Metrics.tensor2img(visuals['denoised']), f'{result_path}/{current_step}_{idx}_denoised.png')
                    Metrics.save_img(Metrics.tensor2img(visuals['X']), f'{result_path}/{current_step}_{idx}_input.png')
                diffusion.set_new_noise_schedule(opt['model']['beta_schedule']['train'], schedule_phase='train')
                logger.info(f'验证完成，结果保存在: {result_path}')
            
            if current_step % opt['train']['save_checkpoint_freq'] == 0:
                logger.info('保存checkpoint...')
                diffusion.save_network(current_epoch, current_step, save_last_only=True)

except KeyboardInterrupt:
    logger.info('训练中断，保存模型...')
    diffusion.save_network(current_epoch, current_step, save_last_only=True)
finally:
    pbar.close()

# 保存最终模型
diffusion.save_network(current_epoch, current_step, save_last_only=True)
logger.info(f'训练完成! 模型保存在: {opt["path"]["checkpoint"]}')