# DDM² Stage 2: State Matching

为每个训练样本找到最优的扩散时间步 t，使得 `x_t = √ᾱ_t * teacher + √(1-ᾱ_t) * noise` 最接近原始噪声图像。

修改下面的配置，然后运行所有cell。

In [1]:
#==============================================================================
# 基础路径配置
#==============================================================================

PROJECT_ROOT = "/host/c/Users/ROG/Documents/Github/DDM2_new"          # 项目根目录
CONFIG_FILE = "config/ct_denoise.json" # 配置文件路径 (相对于PROJECT_ROOT)
STAGE2_EXP_NAME = "ct_denoise_teacher" # Stage2输出目录名，结果保存在 experiments/{STAGE2_EXP_NAME}/
GPU_ID = "0"                           # 使用的GPU编号

#==============================================================================
# 数据路径
#==============================================================================

EXCEL_PATH = "/host/d/file/fixedCT_static_simulation_train_test_gaussian_local.xlsx"  # 数据索引Excel文件
DATA_ROOT = "/host/d/file/simulation/"                    # CT数据根目录
TEACHER_N2N_ROOT = "/host/d/file/pre/noise2noise/pred_images/"  # Teacher N2N预测结果目录
TEACHER_N2N_EPOCH = 78                                    # 使用Teacher N2N的哪个epoch
BINS_FILE = "/host/d/file/histogram_equalization/bins.npy"        # 直方图均衡化bins文件
BINS_MAPPED_FILE = "/host/d/file/histogram_equalization/bins_mapped.npy"  # 直方图均衡化映射文件

#==============================================================================
# 数据集划分
# 
# 数据组织结构:
#   - batch: 一组病例 (例如batch 0-4为训练集，batch 5为验证集)
#   - volume: 一个病例/case，即一个完整的3D CT扫描
#   - slice: volume中的一个2D切片
#
# 例如: volume_idx=3, slice_idx=45 表示第3个病例的第45层切片
#==============================================================================

TRAIN_BATCHES = [0, 1, 2, 3, 4]  # 训练集使用的batch编号
VAL_BATCHES = [5]                # 验证集使用的batch编号

SLICE_RANGE = [30, 80]  # 每个volume使用的切片范围 [start, end)
                        # 例如[30, 80]表示使用第30-79层，共50层
                        # 通常去掉头尾切片，因为边缘切片质量较差

# 验证集采样设置 (Stage2时建议用 'all' 匹配所有样本)
VAL_VOLUME_IDX = "all"  # 验证哪些volume: 'all'=所有, 或具体数字如 8 表示只用第8个case
VAL_SLICE_IDX = "all"   # 验证哪些slice: 'all'=所有, 或列表如 [25, 30] 表示只用这些层

#==============================================================================
# 预处理参数
#==============================================================================

HU_MIN = -1000.0  # CT值(HU)下限，低于此值会被clip
HU_MAX = 2000.0   # CT值(HU)上限，高于此值会被clip
                  # 典型值: 空气=-1000, 水=0, 骨骼=1000+
                  # 范围[-1000, 2000]覆盖大部分软组织和骨骼

HISTOGRAM_EQUALIZATION = True  # 是否启用直方图均衡化
                               # True: 增强对比度，让网络更容易学习
                               # False: 保持原始HU分布

In [2]:
# 准备工作：更新config + 创建目录
import os, json
os.chdir(PROJECT_ROOT)

with open(CONFIG_FILE, 'r') as f:
    config = json.load(f)

# 更新 datasets.train
config['datasets']['train']['dataroot'] = EXCEL_PATH
config['datasets']['train']['data_root'] = DATA_ROOT
config['datasets']['train']['teacher_n2n_root'] = TEACHER_N2N_ROOT
config['datasets']['train']['teacher_n2n_epoch'] = TEACHER_N2N_EPOCH
config['datasets']['train']['bins_file'] = BINS_FILE
config['datasets']['train']['bins_mapped_file'] = BINS_MAPPED_FILE
config['datasets']['train']['train_batches'] = TRAIN_BATCHES
config['datasets']['train']['val_batches'] = VAL_BATCHES
config['datasets']['train']['slice_range'] = SLICE_RANGE
config['datasets']['train']['HU_MIN'] = HU_MIN
config['datasets']['train']['HU_MAX'] = HU_MAX
config['datasets']['train']['histogram_equalization'] = HISTOGRAM_EQUALIZATION

# 更新 datasets.val
config['datasets']['val']['dataroot'] = EXCEL_PATH
config['datasets']['val']['data_root'] = DATA_ROOT
config['datasets']['val']['teacher_n2n_root'] = TEACHER_N2N_ROOT
config['datasets']['val']['teacher_n2n_epoch'] = TEACHER_N2N_EPOCH
config['datasets']['val']['bins_file'] = BINS_FILE
config['datasets']['val']['bins_mapped_file'] = BINS_MAPPED_FILE
config['datasets']['val']['train_batches'] = VAL_BATCHES  # val用val_batches
config['datasets']['val']['val_batches'] = VAL_BATCHES
config['datasets']['val']['slice_range'] = SLICE_RANGE
config['datasets']['val']['val_volume_idx'] = VAL_VOLUME_IDX
config['datasets']['val']['val_slice_idx'] = VAL_SLICE_IDX
config['datasets']['val']['HU_MIN'] = HU_MIN
config['datasets']['val']['HU_MAX'] = HU_MAX
config['datasets']['val']['histogram_equalization'] = HISTOGRAM_EQUALIZATION

# 更新 stage2_file 路径
config['stage2_file'] = f"experiments/{STAGE2_EXP_NAME}/stage2_matched.txt"

with open(CONFIG_FILE, 'w') as f:
    json.dump(config, f, indent=4)

# 创建输出目录
os.makedirs(f"experiments/{STAGE2_EXP_NAME}", exist_ok=True)

print(f"Config已更新: {CONFIG_FILE}")
print(f"输出目录: experiments/{STAGE2_EXP_NAME}")
print(f"Train batches: {TRAIN_BATCHES}, Val batches: {VAL_BATCHES}")
print(f"Slice range: {SLICE_RANGE}, HU: [{HU_MIN}, {HU_MAX}]")

Config已更新: config/ct_denoise.json
输出目录: experiments/ct_denoise_teacher
Train batches: [0, 1, 2, 3, 4], Val batches: [5]
Slice range: [30, 80], HU: [-1000.0, 2000.0]


In [3]:
# 运行State Matching (train + val)
# 输出文件格式: volume_idx_slice_idx_t (每行一个样本)
!CUDA_VISIBLE_DEVICES={GPU_ID} python3 match_state.py -p all -c {CONFIG_FILE}

export CUDA_VISIBLE_DEVICES=0
26-01-03 12:52:26.722 - INFO: [Stage 2] Markov chain state matching (using teacher N2N)!
[train] Histogram equalization enabled:
    bins: /host/d/file/histogram_equalization/bins.npy (shape: (2301,))
    bins_mapped: /host/d/file/histogram_equalization/bins_mapped.npy (shape: (2301,))
Found 69 N2N pairs
[Slice Detection] Noise: 100 slices, Teacher: 50 slices
[Slice Detection] Best offset: 30, correlation: 0.9998
[OK] Slice offset verified: 30
[train] CTDataset: pairs=69, slices=50, samples=3450
[train] Noise slice_range: [30, 80)
[train] HU range: [-1000.0, 2000.0]
[train] Histogram equalization: True
[train] Using teacher N2N from: /host/d/file/pre/noise2noise/pred_images/
26-01-03 12:52:32.871 - INFO: CT dataset [ct] is created. Size: 3450
[val] Histogram equalization enabled:
    bins: /host/d/file/histogram_equalization/bins.npy (shape: (2301,))
    bins_mapped: /host/d/file/histogram_equalization/bins_mapped.npy (shape: (2301,))
Found 14 N2N pairs
[S

In [4]:
# 验证结果 + 统计t值分布
import glob

stage2_file = f"experiments/{STAGE2_EXP_NAME}/stage2_matched.txt"

if os.path.exists(stage2_file):
    with open(stage2_file, 'r') as f:
        lines = f.readlines()
    
    print(f"✓ Stage2完成: {stage2_file}")
    print(f"  总样本数: {len(lines)}")
    
    # 统计t值分布
    t_values = [int(line.strip().split('_')[-1]) for line in lines]
    print(f"  t值范围: [{min(t_values)}, {max(t_values)}]")
    print(f"  t值均值: {sum(t_values)/len(t_values):.1f}")
    
    # t值越大说明噪声越多，需要更多去噪步骤
    print(f"\n  t < 100 (低噪声): {sum(1 for t in t_values if t < 100)} 样本")
    print(f"  100 <= t < 300: {sum(1 for t in t_values if 100 <= t < 300)} 样本")
    print(f"  t >= 300 (高噪声): {sum(1 for t in t_values if t >= 300)} 样本")
else:
    print(f"⚠ 未找到stage2文件: {stage2_file}")

✓ Stage2完成: experiments/ct_denoise_teacher/stage2_matched.txt
  总样本数: 3500
  t值范围: [16, 72]
  t值均值: 33.4

  t < 100 (低噪声): 3500 样本
  100 <= t < 300: 0 样本
  t >= 300 (高噪声): 0 样本
