# DDM² : Inference

使用训练好的模型对测试集进行推理，生成去噪后的 CT 图像。

**输出**：
- `noisy_input.nii.gz` - 原始噪声输入
- `n2n_teacher.nii.gz` - Teacher N2N 结果
- `ddm2_first_step.nii.gz` - DDM² 第一步去噪
- `ddm2_final.nii.gz` - DDM² 最终去噪

In [None]:
#==============================================================================
# 基础路径配置
#==============================================================================

PROJECT_ROOT = "/host/c/Users/ROG/Documents/Github/DDM2_new"          # 项目根目录
CONFIG_FILE = "config/ct_denoise.json" # 配置文件路径
GPU_ID = "0"                           # 使用的GPU编号

#==============================================================================
# 模型 Checkpoint
#==============================================================================

# 设为 None 自动查找最新的 checkpoint
# 或指定具体路径，如 "experiments/ct_denoise_241231_120000/checkpoint/latest"
CHECKPOINT = None

#==============================================================================
# 推理设置
#==============================================================================

# 推理模式
INFERENCE_MODE = "batch"  # "single" = 单个患者, "batch" = 批量所有患者

# 单个患者模式时的患者索引 (volume_idx)
PATIENT_IDX = 0

# 输出目录
OUTPUT_DIR = "/host/d/file/pre/ddm2/pred_images"

# 保存哪些结果 
SAVE_NOISY = True        # 原始噪声输入
SAVE_N2N = True          # Teacher N2N 结果
SAVE_FIRST_STEP = True   # DDM² 第一步去噪结果
SAVE_FINAL = True        # DDM² 最终去噪结果

#==============================================================================
# 后处理设置
#==============================================================================

# 是否进行逆向直方图均衡化，将结果转换回原始 HU 空间
# True:  输出为原始 HU 值 (与输入 CT 在同一空间，可直接对比)
# False: 输出为 HE 空间的 HU 值 (与训练时的空间一致)
INVERSE_HE = True

In [None]:
# 初始化环境
import os
import sys
import json
import numpy as np
import torch
import nibabel as nib
from tqdm.auto import tqdm

os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
os.chdir(PROJECT_ROOT)
sys.path.insert(0, PROJECT_ROOT)

import data as Data
import model as Model
import core.logger as Logger

print(f"工作目录: {os.getcwd()}")
print(f"GPU: {GPU_ID}")

In [None]:
# 辅助函数

def find_latest_checkpoint(experiments_dir='experiments'):
    """自动查找最新的 checkpoint"""
    latest_dir = None
    latest_time = 0
    
    if not os.path.exists(experiments_dir):
        return None
    
    for d in os.listdir(experiments_dir):
        if d.startswith('ct_denoise'):
            ckpt_path = os.path.join(experiments_dir, d, 'checkpoint', 'latest_gen.pth')
            if os.path.exists(ckpt_path):
                mtime = os.path.getmtime(ckpt_path)
                if mtime > latest_time:
                    latest_time = mtime
                    latest_dir = os.path.join(experiments_dir, d, 'checkpoint', 'latest')
    
    return latest_dir


def inverse_histogram_equalization(img, bins, bins_mapped):
    """
    逆向 HE：HE 空间 → 原始 HU
    """
    if bins is None or bins_mapped is None:
        return img
    
    flat_img = img.flatten()
    bin_indices = np.digitize(flat_img, bins_mapped) - 1
    bin_indices = np.clip(bin_indices, 0, len(bins) - 1)
    original = bins[bin_indices]
    
    return original.reshape(img.shape).astype(np.float32)

In [None]:
# 加载配置和准备

with open(CONFIG_FILE, 'r') as f:
    opt = json.load(f)

HU_MIN = opt['datasets']['val'].get('HU_MIN', -1000.0)
HU_MAX = opt['datasets']['val'].get('HU_MAX', 2000.0)

# Teacher N2N 路径
TEACHER_N2N_ROOT = opt['datasets']['val'].get('teacher_n2n_root')
TEACHER_N2N_EPOCH = opt['datasets']['val'].get('teacher_n2n_epoch', 78)

# 加载 HE bins (用于逆向转换)
bins_file = opt['datasets']['val'].get('bins_file')
bins_mapped_file = opt['datasets']['val'].get('bins_mapped_file')
bins = None
bins_mapped = None
use_inverse_he = False

if INVERSE_HE and bins_file and bins_mapped_file:
    if os.path.exists(bins_file) and os.path.exists(bins_mapped_file):
        bins = np.load(bins_file).astype(np.float32)
        bins_mapped = np.load(bins_mapped_file).astype(np.float32)
        use_inverse_he = True
        print(f"✓ HE bins 已加载")
    else:
        print(f"⚠ HE bins 文件不存在，将跳过逆向 HE")

# 查找 checkpoint
checkpoint = CHECKPOINT
if checkpoint is None:
    checkpoint = find_latest_checkpoint()

if checkpoint is None:
    print("❌ 未找到 checkpoint！")
else:
    print(f"✓ Checkpoint: {checkpoint}")

print(f"\n配置:")
print(f"  HU 范围: [{HU_MIN}, {HU_MAX}]")
print(f"  逆向 HE: {use_inverse_he}")
print(f"  Teacher N2N: {TEACHER_N2N_ROOT}")
print(f"  推理模式: {INFERENCE_MODE}")
print(f"  输出目录: {OUTPUT_DIR}")

In [None]:
# 创建数据集
print("加载数据集...")

val_opt = opt['datasets']['val'].copy()

if INFERENCE_MODE == "single":
    val_opt['val_volume_idx'] = PATIENT_IDX
else:
    val_opt['val_volume_idx'] = 'all'

val_opt['val_slice_idx'] = 'all'

val_set = Data.create_dataset(val_opt, 'val', stage2_file=opt.get('stage2_file'))

# 获取患者列表
if hasattr(val_set, 'n2n_pairs'):
    n_patients = len(val_set.n2n_pairs)
    if INFERENCE_MODE == "single":
        patient_indices = [PATIENT_IDX] if PATIENT_IDX < n_patients else []
    else:
        patient_indices = list(range(n_patients))
else:
    patient_indices = list(set(s[0] for s in val_set.samples))

print(f"\n✓ 数据集加载完成")
print(f"  患者数量: {len(patient_indices)}")
print(f"  总样本数: {len(val_set)}")

In [None]:
# 加载模型
print("加载模型...")

opt_model = Logger.dict_to_nonedict(opt)
opt_model['path']['resume_state'] = checkpoint

diffusion = Model.create_model(opt_model)
diffusion.set_new_noise_schedule(
    opt_model['model']['beta_schedule']['val'], 
    schedule_phase='val'
)

print("✓ 模型加载完成")

In [None]:
# 加载 N2N 结果的辅助函数

def get_n2n_path(patient_id, patient_subid):
    """获取 Teacher N2N 结果路径"""
    if TEACHER_N2N_ROOT is None:
        return None
    
    pid_str = f"{int(patient_id):08d}"
    psid_str = f"{int(patient_subid):010d}"
    
    return os.path.join(
        TEACHER_N2N_ROOT,
        pid_str,
        psid_str,
        "random_0",
        f"epoch{TEACHER_N2N_EPOCH}",
        "pred_img.nii.gz"
    )


def load_n2n_volume(patient_id, patient_subid, slice_indices):
    """
    加载 Teacher N2N 结果
    
    Args:
        patient_id: 患者 ID
        patient_subid: 患者 SubID  
        slice_indices: 需要的 slice 索引列表
    
    Returns:
        n2n_slices: list of 2D arrays，或 None
    """
    pred_path = get_n2n_path(patient_id, patient_subid)
    if pred_path is None:
        return None
    
    npy_path = pred_path.replace('.nii.gz', '.npy')
    
    # 加载数据
    if os.path.exists(npy_path):
        data = np.load(npy_path)
    elif os.path.exists(pred_path):
        data = nib.load(pred_path).get_fdata()
    else:
        return None
    
    n2n_slices = []
    for s in slice_indices:
        if s < data.shape[2]:
            img = data[:, :, s].astype(np.float32)
            
            # 应用与训练相同的预处理
            # HE
            if val_set.histogram_equalization and val_set.bins is not None:
                from data.ct_dataset import apply_histogram_equalization
                img = apply_histogram_equalization(img, val_set.bins, val_set.bins_mapped)
            
            # HU clip + normalize to [0,1]
            img = np.clip(img, HU_MIN, HU_MAX)
            img = (img - HU_MIN) / (HU_MAX - HU_MIN)
            
            # 转回 HU (HE空间)
            img_hu = img * (HU_MAX - HU_MIN) + HU_MIN
            
            # 逆向 HE
            if use_inverse_he:
                img_hu = inverse_histogram_equalization(img_hu, bins, bins_mapped)
            
            n2n_slices.append(img_hu)
        else:
            n2n_slices.append(np.zeros((data.shape[0], data.shape[1]), dtype=np.float32))
    
    return n2n_slices

In [None]:
# 推理函数

def inference_patient(patient_idx):
    """
    对单个患者进行推理
    
    Returns:
        dict: 包含 noisy, n2n, first, final volumes 和 patient info
    """
    if patient_idx >= len(val_set.n2n_pairs):
        return None
    
    pair = val_set.n2n_pairs[patient_idx]
    patient_id = pair['patient_id']
    patient_subid = pair['patient_subid']
    
    # 获取该患者的所有 samples
    patient_samples = [(i, s) for i, s in enumerate(val_set.samples) if s[0] == patient_idx]
    
    if len(patient_samples) == 0:
        return None
    
    noisy_results = []
    first_results = []
    final_results = []
    slice_indices = []
    
    for sample_idx, (vol_idx, slice_idx) in tqdm(patient_samples, 
                                                  desc=f"Patient {patient_idx}", 
                                                  leave=False):
        slice_indices.append(slice_idx)
        sample = val_set[sample_idx]
        
        # 准备 batch
        batch = {k: v.unsqueeze(0) if isinstance(v, torch.Tensor) else v 
                 for k, v in sample.items()}
        
        # 推理
        diffusion.feed_data(batch)
        diffusion.test(continous=True)
        visuals = diffusion.get_current_visuals()
        
        # 提取结果
        all_imgs = visuals['denoised'].numpy()
        
        # 从 [-1, 1] 转换到 [0, 1]
        noisy = (all_imgs[0].squeeze() + 1) / 2
        first = (all_imgs[1].squeeze() + 1) / 2
        final = (all_imgs[-1].squeeze() + 1) / 2
        
        # 转换到 HU 值
        noisy_hu = noisy * (HU_MAX - HU_MIN) + HU_MIN
        first_hu = first * (HU_MAX - HU_MIN) + HU_MIN
        final_hu = final * (HU_MAX - HU_MIN) + HU_MIN
        
        # 逆向 HE
        if use_inverse_he:
            noisy_hu = inverse_histogram_equalization(noisy_hu, bins, bins_mapped)
            first_hu = inverse_histogram_equalization(first_hu, bins, bins_mapped)
            final_hu = inverse_histogram_equalization(final_hu, bins, bins_mapped)
        
        noisy_results.append(noisy_hu)
        first_results.append(first_hu)
        final_results.append(final_hu)
    
    # 加载 N2N 结果
    n2n_results = load_n2n_volume(patient_id, patient_subid, slice_indices)
    
    # 堆叠成 3D volume
    result = {
        'patient_id': patient_id,
        'patient_subid': patient_subid,
        'noisy': np.stack(noisy_results, axis=-1).astype(np.float32),
        'first': np.stack(first_results, axis=-1).astype(np.float32),
        'final': np.stack(final_results, axis=-1).astype(np.float32),
    }
    
    if n2n_results is not None:
        result['n2n'] = np.stack(n2n_results, axis=-1).astype(np.float32)
    else:
        result['n2n'] = None
    
    # 获取 affine
    result['affine'] = np.eye(4)
    noise_path = pair['noise_0']
    if hasattr(val_set, '_fix_path'):
        noise_path = val_set._fix_path(noise_path)
    if os.path.exists(noise_path):
        try:
            result['affine'] = nib.load(noise_path).affine
        except:
            pass
    
    return result


def save_results(result, output_dir):
    """
    保存推理结果为 nii.gz 文件
    """
    patient_id = result['patient_id']
    patient_subid = result['patient_subid']
    affine = result['affine']
    
    pid_str = f"{int(patient_id):08d}" if isinstance(patient_id, (int, float)) else str(patient_id)
    psid_str = f"{int(patient_subid):010d}" if isinstance(patient_subid, (int, float)) else str(patient_subid)
    
    output_subdir = os.path.join(output_dir, pid_str, psid_str)
    os.makedirs(output_subdir, exist_ok=True)
    
    saved_files = []
    
    if SAVE_NOISY:
        path = os.path.join(output_subdir, 'noisy_input.nii.gz')
        nib.save(nib.Nifti1Image(result['noisy'], affine), path)
        saved_files.append('noisy_input')
    
    if SAVE_N2N and result['n2n'] is not None:
        path = os.path.join(output_subdir, 'n2n_teacher.nii.gz')
        nib.save(nib.Nifti1Image(result['n2n'], affine), path)
        saved_files.append('n2n_teacher')
    
    if SAVE_FIRST_STEP:
        path = os.path.join(output_subdir, 'ddm2_first_step.nii.gz')
        nib.save(nib.Nifti1Image(result['first'], affine), path)
        saved_files.append('ddm2_first_step')
    
    if SAVE_FINAL:
        path = os.path.join(output_subdir, 'ddm2_final.nii.gz')
        nib.save(nib.Nifti1Image(result['final'], affine), path)
        saved_files.append('ddm2_final')
    
    return saved_files

In [None]:
# 执行推理
print(f"\n开始推理 ({len(patient_indices)} 个患者)...")
print("=" * 60)

os.makedirs(OUTPUT_DIR, exist_ok=True)

all_stats = []

for patient_idx in tqdm(patient_indices, desc="Total Progress"):
    result = inference_patient(patient_idx)
    
    if result is None:
        print(f"  跳过 Patient {patient_idx} (无数据)")
        continue
    
    # 保存结果
    saved = save_results(result, OUTPUT_DIR)
    
    # 记录统计信息
    stats = {
        'patient_id': result['patient_id'],
        'shape': result['final'].shape,
        'noisy_mean': result['noisy'].mean(),
        'n2n_mean': result['n2n'].mean() if result['n2n'] is not None else None,
        'first_mean': result['first'].mean(),
        'final_mean': result['final'].mean(),
    }
    all_stats.append(stats)
    
    n2n_str = f"{result['n2n'].mean():.0f}" if result['n2n'] is not None else "N/A"
    tqdm.write(f"  ✓ Patient {result['patient_id']}: shape={result['final'].shape}, "
               f"saved=[{', '.join(saved)}]")

print("\n" + "=" * 60)
print(f"✓ 推理完成!")
print(f"  处理患者数: {len(all_stats)}")
print(f"  输出目录: {OUTPUT_DIR}")

In [None]:
# 统计信息汇总
if len(all_stats) > 0:
    print("\n统计信息汇总 (Mean HU)")
    print("=" * 80)
    print(f"{'Patient ID':<12} {'Shape':<18} {'Noisy':>10} {'N2N':>10} {'First':>10} {'Final':>10}")
    print("-" * 80)
    
    for s in all_stats:
        n2n_str = f"{s['n2n_mean']:.1f}" if s['n2n_mean'] is not None else "N/A"
        print(f"{str(s['patient_id']):<12} {str(s['shape']):<18} "
              f"{s['noisy_mean']:>10.1f} {n2n_str:>10} "
              f"{s['first_mean']:>10.1f} {s['final_mean']:>10.1f}")
    
    print("=" * 80)
    print("\n输出文件说明:")
    print("  - noisy_input.nii.gz    : 原始噪声CT输入")
    print("  - n2n_teacher.nii.gz    : Teacher N2N 去噪结果")
    print("  - ddm2_first_step.nii.gz: DDM² 第一步去噪 (粗去噪)")
    print("  - ddm2_final.nii.gz     : DDM² 最终去噪 (细去噪)")