# DDM² Stage 3: Diffusion Model Training

**直接调用 `train_diff_model.py`，自动读取 Stage 2 配置**

---

## 前提条件

- Stage 1 已完成（Noise Model checkpoint）
- Stage 2 已完成（config 中的 `stage2_file` 和 `use_random_num` 已自动更新）

## 1. 配置

In [1]:
import os
import sys
import json

# ============ 修改这里 ============
PROJECT_ROOT = '/host/c/Users/ROG/Documents/GitHub/DDM2'          # DDM² 项目根目录
CONFIG_FILE = 'config/ct_denoise.json'   # 配置文件（Stage 2 已自动更新）
# ==================================

os.chdir(PROJECT_ROOT)
sys.path.insert(0, PROJECT_ROOT)

config_path = os.path.join(PROJECT_ROOT, CONFIG_FILE)
with open(config_path, 'r') as f:
    config = json.load(f)

print(f"工作目录: {os.getcwd()}")
print(f"配置文件: {config_path}")

工作目录: /host/c/Users/ROG/Documents/GitHub/DDM2
配置文件: /host/c/Users/ROG/Documents/GitHub/DDM2/config/ct_denoise.json


## 2. 数据选择（可覆盖 Stage 2 设置）

In [2]:
# ================== 数据选择 ==================
# 这些设置默认从 Stage 2 继承，也可以在这里覆盖

# --- Batch 选择 ---
TRAIN_BATCHES = config['datasets']['train'].get('train_batches', [0, 1, 2, 3, 4])
VAL_BATCHES = config['datasets']['train'].get('val_batches', [5])

# --- Slice 范围 ---
valid_mask = config['datasets']['train'].get('valid_mask', [0, 100])
SLICE_START, SLICE_END = valid_mask[0], valid_mask[1]

# --- random_num（从 Stage 2 继承）---
USE_RANDOM_NUM = config['datasets']['train'].get('use_random_num', 0)

# --- 验证可视化用的 volume/slice ---
VAL_VOLUME_IDX = 0
VAL_SLICE_IDX = 25

# ================================================

# 更新配置
for phase in ['train', 'val']:
    config['datasets'][phase]['train_batches'] = TRAIN_BATCHES
    config['datasets'][phase]['val_batches'] = VAL_BATCHES
    config['datasets'][phase]['valid_mask'] = [SLICE_START, SLICE_END]
    config['datasets'][phase]['use_random_num'] = USE_RANDOM_NUM

config['datasets']['val']['val_volume_idx'] = VAL_VOLUME_IDX
config['datasets']['val']['val_slice_idx'] = VAL_SLICE_IDX

print("数据选择配置:")
print(f"  训练 batch: {TRAIN_BATCHES}")
print(f"  验证 batch: {VAL_BATCHES}")
print(f"  Slice 范围: [{SLICE_START}, {SLICE_END})")
print(f"  random_num: {USE_RANDOM_NUM} (从 Stage 2 继承)")
print(f"  验证 volume/slice: {VAL_VOLUME_IDX}/{VAL_SLICE_IDX}")

数据选择配置:
  训练 batch: [0, 1, 2, 3, 4]
  验证 batch: [5]
  Slice 范围: [0, 1000)
  random_num: 0 (从 Stage 2 继承)
  验证 volume/slice: 0/25


## 3. 训练参数

In [None]:
# ================== 训练参数 ==================

N_ITER = 100000          # 迭代次数
BATCH_SIZE = 1           # 批大小
LEARNING_RATE = 1e-4     # 学习率
VAL_FREQ = 5000         # 验证频率
SAVE_FREQ = 10000        # 保存频率

# 继续训练（None = 从头开始）
RESUME_STATE = None      # 或 'experiments/xxx/checkpoint/latest'

# ==============================================

config['train']['n_iter'] = N_ITER
config['train']['val_freq'] = VAL_FREQ
config['train']['save_checkpoint_freq'] = SAVE_FREQ
config['train']['optimizer']['lr'] = LEARNING_RATE
config['datasets']['train']['batch_size'] = BATCH_SIZE
config['path']['resume_state'] = RESUME_STATE

print("训练参数:")
print(f"  迭代次数: {N_ITER:,}")
print(f"  批大小: {BATCH_SIZE}")
print(f"  学习率: {LEARNING_RATE}")
print(f"  继续训练: {RESUME_STATE if RESUME_STATE else '从头开始'}")

训练参数:
  迭代次数: 100,000
  批大小: 1
  学习率: 0.0001
  继续训练: 从头开始


## 4. 检查依赖（自动读取）

In [4]:
print("=" * 60)
print("Stage 3 依赖检查")
print("=" * 60)

all_ok = True

# 1. Stage 2 文件
stage2_file = config.get('stage2_file')
print(f"\n[Stage 2 文件]")
print(f"  路径: {stage2_file}")
if stage2_file and os.path.exists(stage2_file):
    with open(stage2_file, 'r') as f:
        n = len(f.readlines())
    print(f"  状态: ✓ 存在 ({n} 样本)")
else:
    print(f"  状态: ✗ 不存在! 请先运行 Stage 2")
    all_ok = False

# 2. random_num 设置
print(f"\n[random_num 设置]")
print(f"  值: {USE_RANDOM_NUM}")
if USE_RANDOM_NUM == 'both':
    print(f"  说明: 使用两个噪声实现 (N2N 模式)")
else:
    print(f"  说明: 只用 random_num={USE_RANDOM_NUM}")


# 3. 数据文件
excel = config['datasets']['train']['dataroot']
print(f"\n[Excel 数据]")
print(f"  路径: {excel}")
print(f"  状态: {'✓' if os.path.exists(excel) else '✗'}")
if not os.path.exists(excel):
    all_ok = False

print("\n" + "=" * 60)
if all_ok:
    print("✓ 所有依赖检查通过！")
else:
    print("✗ 请修复上述问题")
print("=" * 60)

Stage 3 依赖检查

[Stage 2 文件]
  路径: experiments/ct_denoise_teacher/stage2_matched.txt
  状态: ✓ 存在 (3450 样本)

[random_num 设置]
  值: 0
  说明: 只用 random_num=0

[Excel 数据]
  路径: /host/d/file/fixedCT_static_simulation_train_test_gaussian_local.xlsx
  状态: ✓

✓ 所有依赖检查通过！


## 5. 运行训练

In [5]:
import subprocess

# 保存临时配置
temp_config = os.path.join(PROJECT_ROOT, 'config', '_stage3_temp.json')
with open(temp_config, 'w') as f:
    json.dump(config, f, indent=4)

cmd = f"python train_diff_model.py -p train -c {temp_config}"
print(f"执行: {cmd}")
print("=" * 60)
print("开始训练...")
print("=" * 60)

执行: python train_diff_model.py -p train -c /host/c/Users/ROG/Documents/GitHub/DDM2/config/_stage3_temp.json
开始训练...


In [6]:
# 运行训练
process = subprocess.Popen(
    cmd, shell=True,
    stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
    text=True, bufsize=1
)
for line in process.stdout:
    print(line, end='')
process.wait()

print("\n" + "=" * 60)
print(f"返回码: {process.returncode}")

# 清理
if os.path.exists(temp_config):
    os.remove(temp_config)

PyTorch version: 2.0.1
CUDA version: 11.7
CUDA available: True
export CUDA_VISIBLE_DEVICES=0
26-01-02 20:09:53.954 - INFO:   name: ct_denoise_teacher
  phase: train
  gpu_ids: [0]
  path:[
    log: experiments/ct_denoise_teacher_260102_200953/logs
    tb_logger: experiments/ct_denoise_teacher_260102_200953/tb_logger
    results: experiments/ct_denoise_teacher_260102_200953/results
    checkpoint: experiments/ct_denoise_teacher_260102_200953/checkpoint
    resume_state: None
    experiments_root: experiments/ct_denoise_teacher_260102_200953
  ]
  datasets:[
    train:[
      name: ct
      dataroot: /host/d/file/fixedCT_static_simulation_train_test_gaussian_local.xlsx
      data_root: /host/d/file/simulation/
      train_batches: [0, 1, 2, 3, 4]
      val_batches: [5]
      valid_mask: [0, 1000]
      slice_range: [30, 80]
      phase: train
      padding: 3
      val_volume_idx: all
      val_slice_idx: all
      batch_size: 1
      in_channel: 1
      num_workers: 4
      use_shuffle:

KeyboardInterrupt: 

## 6. 查看训练状态

In [None]:
import glob

# 查找最新实验
exp_dirs = sorted(glob.glob('experiments/*'), key=os.path.getmtime, reverse=True)

if exp_dirs:
    latest = exp_dirs[0]
    print(f"最新实验: {latest}")
    
    # Checkpoints
    ckpt_dir = f"{latest}/checkpoint"
    if os.path.exists(ckpt_dir):
        ckpts = sorted(os.listdir(ckpt_dir))
        print(f"Checkpoints: {ckpts}")
    
    # 日志
    logs = glob.glob(f"{latest}/*.log") + glob.glob(f"{latest}/logs/*.log")
    if logs:
        print(f"\n最后 10 行日志:")
        with open(logs[0], 'r') as f:
            for line in f.readlines()[-10:]:
                print(f"  {line}", end='')

最新实验: experiments/ct_denoise_teacher_260102_175536
Checkpoints: []

最后 10 行日志:
    ]
    stage2_file: experiments/ct_denoise_teacher/stage2_matched.txt
    distributed: False
  
  26-01-02 17:55:39.185 - INFO: CT dataset [ct] is created. Size: 3450
  26-01-02 17:55:40.548 - INFO: CT dataset [ct] is created. Size: 0
  26-01-02 17:55:40.549 - INFO: Initial Dataset Finished
  26-01-02 17:55:42.574 - INFO: Initialization method [orthogonal]
  26-01-02 17:55:43.374 - INFO: [DDM2] is created.
  26-01-02 17:55:43.374 - INFO: Initial Model Finished


---

## 下一步

训练完成后，使用 `3_inference/inference_ddm2.ipynb` 进行推理。