# DDM² Stage 2: State Matching

对训练集所有数据进行状态匹配，完成后自动更新config文件

In [None]:
#==============================================================================
# 路径配置 - 换电脑时修改这里
#==============================================================================

# 项目根目录
PROJECT_ROOT = "/host/d/ddm2"

# 数据路径
EXCEL_PATH = "/host/d/file/fixedCT_static_simulation_train_test_gaussian_local.xlsx"
DATA_ROOT = "/host/d/file/simulation/"

# Teacher N2N 预测结果
TEACHER_N2N_ROOT = "/host/d/file/pre/noise2noise/pred_images/"
TEACHER_N2N_EPOCH = 78

# 直方图均衡化文件
BINS_FILE = "/host/d/file/histogram_equalization/bins.npy"
BINS_MAPPED_FILE = "/host/d/file/histogram_equalization/bins_mapped.npy"

# 配置文件路径
CONFIG_FILE = f"{PROJECT_ROOT}/config/ct_denoise.json"

# Stage2 输出目录
STAGE2_OUTPUT_DIR = f"{PROJECT_ROOT}/experiments/ct_denoise_stage2"
STAGE2_OUTPUT_FILE = "stage2_matched.txt"

#==============================================================================
# 训练参数
#==============================================================================

# 数据集划分
TRAIN_BATCHES = [0, 1, 2, 3, 4]  # 训练集
VAL_BATCHES = [5]                 # 验证集

# 切片范围
SLICE_RANGE = [30, 80]

# HU值范围
HU_MIN = -1000.0
HU_MAX = 2000.0

#==============================================================================
# GPU配置
#==============================================================================

GPU_ID = "0"

In [None]:
import os
import sys
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
sys.path.insert(0, PROJECT_ROOT)
os.chdir(PROJECT_ROOT)

import json
import torch
import numpy as np
from scipy.stats import norm
from functools import partial
from tqdm.auto import tqdm

import data as Data
import core.logger as Logger

print(f"工作目录: {os.getcwd()}")
print(f"GPU: {GPU_ID}")

In [None]:
# 加载并更新配置
with open(CONFIG_FILE, 'r') as f:
    opt = json.load(f)

# 更新所有路径
for phase in ['train', 'val']:
    opt['datasets'][phase]['dataroot'] = EXCEL_PATH
    opt['datasets'][phase]['data_root'] = DATA_ROOT
    opt['datasets'][phase]['train_batches'] = TRAIN_BATCHES
    opt['datasets'][phase]['val_batches'] = VAL_BATCHES
    opt['datasets'][phase]['slice_range'] = SLICE_RANGE
    opt['datasets'][phase]['HU_MIN'] = HU_MIN
    opt['datasets'][phase]['HU_MAX'] = HU_MAX
    opt['datasets'][phase]['bins_file'] = BINS_FILE
    opt['datasets'][phase]['bins_mapped_file'] = BINS_MAPPED_FILE
    opt['datasets'][phase]['teacher_n2n_root'] = TEACHER_N2N_ROOT
    opt['datasets'][phase]['teacher_n2n_epoch'] = TEACHER_N2N_EPOCH

# 设置stage2输出路径
stage2_file = os.path.join(STAGE2_OUTPUT_DIR, STAGE2_OUTPUT_FILE)
opt['stage2_file'] = stage2_file

opt = Logger.dict_to_nonedict(opt)
print(f"配置加载完成")
print(f"Stage2输出: {stage2_file}")

In [None]:
# 创建数据集
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

dataset_opt = dict(opt['datasets']['train'])
dataset_opt['use_shuffle'] = False
dataset_opt['batch_size'] = 1
dataset_opt['num_workers'] = 0
dataset_opt['lr_flip'] = 0.0

train_set = Data.create_dataset(dataset_opt, 'train')
train_loader = Data.create_dataloader(train_set, dataset_opt, 'val')
print(f'训练集大小: {len(train_set)}')

In [None]:
# Beta schedule
def _rev_warmup_beta(linear_start, linear_end, n_timestep, warmup_frac):
    betas = linear_start * np.ones(n_timestep, dtype=np.float64)
    warmup_time = int(n_timestep * warmup_frac)
    betas[n_timestep - warmup_time:] = np.linspace(
        linear_start, linear_end, warmup_time, dtype=np.float64)
    return betas

to_torch = partial(torch.tensor, dtype=torch.float32, device='cuda:0')
betas = _rev_warmup_beta(
    opt['noise_model']['beta_schedule']['linear_start'],
    opt['noise_model']['beta_schedule']['linear_end'],
    1000, 0.7)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
sqrt_alphas_cumprod_prev = to_torch(np.sqrt(np.append(1., alphas_cumprod)))
print("Beta schedule 计算完成")

In [None]:
# 状态匹配
os.makedirs(STAGE2_OUTPUT_DIR, exist_ok=True)
stage_file = open(stage2_file, 'w+')
idx = 0

for _, data in tqdm(enumerate(train_loader), total=len(train_loader), desc="State Matching"):
    idx += 1
    volume_idx, slice_idx = train_set.samples[idx - 1]
    
    if 'denoised' not in data:
        stage_file.write('%d_%d_%d\n' % (volume_idx, slice_idx, 500))
        continue
    
    denoised = data['denoised'].cuda()
    X = data['X'].cuda()
    
    min_lh, min_t, prev_diff = 999, -1, 999.
    for t in range(sqrt_alphas_cumprod_prev.shape[0]):
        noise = X - sqrt_alphas_cumprod_prev[t] * denoised
        noise = noise - torch.mean(noise)
        mu, std = norm.fit(noise.cpu().numpy())
        diff = np.abs((1 - sqrt_alphas_cumprod_prev[t]**2).sqrt().cpu().numpy() - std)
        if diff < min_lh:
            min_lh, min_t = diff, t
        if diff > prev_diff:
            break
        prev_diff = diff
    
    stage_file.write('%d_%d_%d\n' % (volume_idx, slice_idx, min_t))

stage_file.close()
print(f'\n状态匹配完成! 输出: {stage2_file}')

In [None]:
# 更新config文件中的stage2_file路径
with open(CONFIG_FILE, 'r') as f:
    config = json.load(f)

config['stage2_file'] = stage2_file

with open(CONFIG_FILE, 'w') as f:
    json.dump(config, f, indent=4)

print(f"已更新 {CONFIG_FILE}")
print(f"stage2_file = {stage2_file}")
print("\n可以运行 Stage 3 了!")