# DDM² Inference & Evaluation

1. 加载训练好的DDM²模型
2. 对测试数据进行推理
3. 量化评估 (MAE, SSIM, LPIPS)
4. 可视化结果

In [None]:
import os
import sys
import numpy as np
import torch
import nibabel as nib
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim

# ============ 项目路径 ============
PROJECT_ROOT = '/path/to/ddm2'  # 修改为你的项目路径
sys.path.insert(0, PROJECT_ROOT)

import data as Data
import model as Model
import core.logger as Logger

# LPIPS
try:
    import lpips
    HAS_LPIPS = True
    print("LPIPS 可用")
except ImportError:
    HAS_LPIPS = False
    print("LPIPS 不可用")

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

## 1. 配置

In [None]:
# ==================== 全部配置 ====================

config = {
    # ---------- 数据路径 ----------
    'data': {
        'dataroot': '/path/to/noise_data',
        'gt_root': '/path/to/gt_data',
        'n2n_root': '/path/to/n2n_output',
        'bins_file': '/path/to/bins.npy',
        'bins_mapped_file': '/path/to/bins_mapped.npy',
        'stage2_file': '/path/to/stage2_matched.txt',
    },
    
    # ---------- Checkpoint ----------
    'checkpoint': None,  # None = 自动查找最新, 或指定路径
    
    # ---------- 数据处理 ----------
    'preprocess': {
        'HU_MIN': -1000.0,
        'HU_MAX': 2000.0,
        'image_size': 256,
        'use_histogram_eq': True,
    },
    
    # ---------- 模型架构 ----------
    'model': {
        'in_channel': 2,
        'out_channel': 1,
        'inner_channel': 64,
        'channel_mults': [1, 2, 4, 8],
        'attn_res': [16],
        'res_blocks': 2,
        'dropout': 0.2,
        'image_size': 256,
    },
    
    # ---------- Beta Schedule ----------
    'beta_schedule': {
        'schedule': 'linear',
        'n_timestep': 2000,
        'linear_start': 1e-6,
        'linear_end': 1e-2,
    },
    
    # ---------- 推理参数 ----------
    'inference': {
        'val_volume_idx': 8,  # 测试患者, 'all' 或具体数字
        'val_slice_idx': 'all',
    },
    
    # ---------- 评估参数 ----------
    'eval': {
        'window': [0, 100],  # HU窗口
        'display_window': [0, 80],  # 显示窗口
    },
    
    # ---------- GPU ----------
    'gpu_id': 0,
}

# ===================================================

os.environ['CUDA_VISIBLE_DEVICES'] = str(config['gpu_id'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## 2. 加载 HE Bins

In [None]:
bins = None
bins_mapped = None

if config['preprocess']['use_histogram_eq']:
    bins_file = config['data']['bins_file']
    bins_mapped_file = config['data']['bins_mapped_file']
    
    if os.path.exists(bins_file) and os.path.exists(bins_mapped_file):
        bins = np.load(bins_file).astype(np.float32)
        bins_mapped = np.load(bins_mapped_file).astype(np.float32)
        print(f"HE bins 已加载")
        print(f"  bins: [{bins.min():.1f}, {bins.max():.1f}]")
        print(f"  bins_mapped: [{bins_mapped.min():.1f}, {bins_mapped.max():.1f}]")
    else:
        print(f"警告: HE bins文件不存在")

## 3. 查找 Checkpoint

In [None]:
def find_latest_checkpoint():
    """自动查找最新的checkpoint"""
    exp_path = os.path.join(PROJECT_ROOT, 'experiments')
    if not os.path.exists(exp_path):
        return None
    
    latest_dir = None
    latest_time = 0
    
    for d in os.listdir(exp_path):
        if d.startswith('ct_denoise_') and not d.endswith('_teacher'):
            ckpt = os.path.join(exp_path, d, 'checkpoint', 'latest_gen.pth')
            if os.path.exists(ckpt):
                mtime = os.path.getmtime(ckpt)
                if mtime > latest_time:
                    latest_time = mtime
                    latest_dir = os.path.join(exp_path, d, 'checkpoint', 'latest')
    
    return latest_dir

checkpoint_path = config['checkpoint']
if checkpoint_path is None:
    checkpoint_path = find_latest_checkpoint()

if checkpoint_path:
    print(f"Checkpoint: {checkpoint_path}")
    # 创建输出目录
    output_dir = checkpoint_path.replace('/checkpoint/latest', '/inference')
else:
    print("错误: 未找到checkpoint!")
    output_dir = os.path.join(PROJECT_ROOT, 'inference_results')

os.makedirs(output_dir, exist_ok=True)
print(f"输出目录: {output_dir}")

## 4. 加载模型

In [None]:
def create_model_opt(config, checkpoint_path):
    opt = {
        'name': 'ddm2',
        'gpu_ids': [0],
        'distributed': False,
        'phase': 'val',
        'path': {'resume_state': checkpoint_path},
        'model': {
            'which_model_G': 'sr3',
            'finetune_norm': False,
            'unet': config['model'],
            'beta_schedule': {'val': config['beta_schedule']},
        },
    }
    return Logger.dict_to_nonedict(opt)

opt_model = create_model_opt(config, checkpoint_path)
diffusion = Model.create_model(opt_model)
diffusion.set_new_noise_schedule(config['beta_schedule'], schedule_phase='val')

print("模型加载完成!")

## 5. 加载数据集

In [None]:
dataset_opt = {
    'name': 'CT_N2N',
    'mode': 'N2N',
    'dataroot': config['data']['dataroot'],
    'gt_root': config['data']['gt_root'],
    'n2n_root': config['data']['n2n_root'],
    'resolution': config['preprocess']['image_size'],
    'HU_MIN': config['preprocess']['HU_MIN'],
    'HU_MAX': config['preprocess']['HU_MAX'],
    'val_volume_idx': config['inference']['val_volume_idx'],
    'val_slice_idx': config['inference']['val_slice_idx'],
}

if config['preprocess']['use_histogram_eq']:
    dataset_opt['bins_file'] = config['data']['bins_file']
    dataset_opt['bins_mapped_file'] = config['data']['bins_mapped_file']

val_set = Data.create_dataset(dataset_opt, 'val', stage2_file=config['data']['stage2_file'])
print(f"数据集大小: {len(val_set)}")

## 6. 辅助函数

In [None]:
def inverse_histogram_equalization(img, bins, bins_mapped):
    """逆向HE: HE空间 → 原始HU"""
    if bins is None or bins_mapped is None:
        return img
    flat = img.flatten()
    indices = np.digitize(flat, bins_mapped) - 1
    indices = np.clip(indices, 0, len(bins) - 1)
    return bins[indices].reshape(img.shape).astype(np.float32)

def calc_metrics(img, ref, vmin, vmax, lpips_fn=None):
    """计算评估指标"""
    img_c = np.clip(img, vmin, vmax)
    ref_c = np.clip(ref, vmin, vmax)
    data_range = vmax - vmin
    
    mae = np.mean(np.abs(img_c - ref_c))
    mse = np.mean((img_c - ref_c) ** 2)
    psnr = 10 * np.log10(data_range ** 2 / mse) if mse > 0 else float('inf')
    ssim_val = ssim(img_c, ref_c, data_range=data_range)
    
    lpips_val = None
    if lpips_fn is not None:
        img_norm = (img_c - vmin) / data_range * 2 - 1
        ref_norm = (ref_c - vmin) / data_range * 2 - 1
        img_t = torch.from_numpy(np.stack([img_norm]*3)[None].astype(np.float32)).to(device)
        ref_t = torch.from_numpy(np.stack([ref_norm]*3)[None].astype(np.float32)).to(device)
        with torch.no_grad():
            lpips_val = lpips_fn(img_t, ref_t).item()
    
    return {'MAE': mae, 'PSNR': psnr, 'SSIM': ssim_val, 'LPIPS': lpips_val}

## 7. 推理

In [None]:
HU_MIN = config['preprocess']['HU_MIN']
HU_MAX = config['preprocess']['HU_MAX']

noisy_list = []
first_list = []
final_list = []

print(f"开始推理 {len(val_set)} slices...")

for idx in tqdm(range(len(val_set)), desc="Inference"):
    sample = val_set[idx]
    batch = {k: v.unsqueeze(0) if isinstance(v, torch.Tensor) else v for k, v in sample.items()}
    
    diffusion.feed_data(batch)
    diffusion.test(continous=True)
    visuals = diffusion.get_current_visuals()
    
    all_imgs = visuals['denoised'].numpy()
    
    # [-1, 1] → [0, 1]
    noisy = (all_imgs[0].squeeze() + 1) / 2
    first = (all_imgs[1].squeeze() + 1) / 2
    final = (all_imgs[-1].squeeze() + 1) / 2
    
    # [0, 1] → HU
    noisy_hu = noisy * (HU_MAX - HU_MIN) + HU_MIN
    first_hu = first * (HU_MAX - HU_MIN) + HU_MIN
    final_hu = final * (HU_MAX - HU_MIN) + HU_MIN
    
    # 逆向HE
    if bins is not None:
        noisy_hu = inverse_histogram_equalization(noisy_hu, bins, bins_mapped)
        first_hu = inverse_histogram_equalization(first_hu, bins, bins_mapped)
        final_hu = inverse_histogram_equalization(final_hu, bins, bins_mapped)
    
    noisy_list.append(noisy_hu)
    first_list.append(first_hu)
    final_list.append(final_hu)

print("推理完成!")

In [None]:
# 转为volume
noisy_vol = np.stack(noisy_list, axis=-1).astype(np.float32)
first_vol = np.stack(first_list, axis=-1).astype(np.float32)
final_vol = np.stack(final_list, axis=-1).astype(np.float32)

print(f"Volume shape: {first_vol.shape}")

# 保存
affine = np.eye(4)
nib.save(nib.Nifti1Image(first_vol, affine), os.path.join(output_dir, 'ddm2_first.nii.gz'))
nib.save(nib.Nifti1Image(final_vol, affine), os.path.join(output_dir, 'ddm2_final.nii.gz'))
print(f"\n结果已保存到: {output_dir}")

## 8. 量化评估

In [None]:
# 加载GT
gt_root = config['data']['gt_root']
gt_path = os.path.join(gt_root, 'gt_img.nii.gz')  # 根据实际路径调整

if os.path.exists(gt_path):
    gt_data = nib.load(gt_path).get_fdata().astype(np.float32)
    print(f"GT shape: {gt_data.shape}")
    HAS_GT = True
else:
    print(f"警告: GT不存在: {gt_path}")
    HAS_GT = False

In [None]:
if HAS_GT:
    lpips_fn = lpips.LPIPS(net='alex').to(device) if HAS_LPIPS else None
    eval_vmin, eval_vmax = config['eval']['window']
    
    results = {
        'Noisy': {'MAE': [], 'PSNR': [], 'SSIM': [], 'LPIPS': []},
        'DDM2_First': {'MAE': [], 'PSNR': [], 'SSIM': [], 'LPIPS': []},
        'DDM2_Final': {'MAE': [], 'PSNR': [], 'SSIM': [], 'LPIPS': []},
    }
    
    n_slices = min(first_vol.shape[-1], gt_data.shape[-1])
    
    for s in tqdm(range(n_slices), desc="Evaluating"):
        gt_s = gt_data[:, :, s]
        
        for name, vol in [('Noisy', noisy_vol), ('DDM2_First', first_vol), ('DDM2_Final', final_vol)]:
            m = calc_metrics(vol[:, :, s], gt_s, eval_vmin, eval_vmax, lpips_fn)
            for k, v in m.items():
                if v is not None:
                    results[name][k].append(v)
    
    # 打印结果
    print("\n" + "=" * 70)
    print(f"评估结果 ({n_slices} slices, HU窗口: [{eval_vmin}, {eval_vmax}])")
    print("-" * 70)
    print(f"{'Method':<15} {'MAE ↓':>10} {'PSNR ↑':>10} {'SSIM ↑':>10} {'LPIPS ↓':>10}")
    print("-" * 70)
    
    for method in ['Noisy', 'DDM2_First', 'DDM2_Final']:
        mae = np.mean(results[method]['MAE'])
        psnr = np.mean(results[method]['PSNR'])
        ssim_v = np.mean(results[method]['SSIM'])
        lpips_v = np.mean(results[method]['LPIPS']) if results[method]['LPIPS'] else float('nan')
        print(f"{method:<15} {mae:>10.4f} {psnr:>10.2f} {ssim_v:>10.4f} {lpips_v:>10.4f}")
    
    print("=" * 70)
    
    # 改进百分比
    if results['Noisy']['MAE']:
        print("\n改进 (vs Noisy):")
        for method in ['DDM2_First', 'DDM2_Final']:
            mae_imp = (np.mean(results['Noisy']['MAE']) - np.mean(results[method]['MAE'])) / np.mean(results['Noisy']['MAE']) * 100
            ssim_imp = (np.mean(results[method]['SSIM']) - np.mean(results['Noisy']['SSIM'])) / np.mean(results['Noisy']['SSIM']) * 100
            print(f"  {method}: MAE -{mae_imp:.1f}%, SSIM +{ssim_imp:.1f}%")

## 9. 可视化

In [None]:
vis_slice = first_vol.shape[-1] // 2
disp_vmin, disp_vmax = config['eval']['display_window']

if HAS_GT:
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    
    # 图像
    axes[0, 0].imshow(gt_data[:, :, vis_slice], cmap='gray', vmin=disp_vmin, vmax=disp_vmax)
    axes[0, 0].set_title('GT')
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(noisy_vol[:, :, vis_slice], cmap='gray', vmin=disp_vmin, vmax=disp_vmax)
    axes[0, 1].set_title('Noisy')
    axes[0, 1].axis('off')
    
    axes[0, 2].imshow(first_vol[:, :, vis_slice], cmap='gray', vmin=disp_vmin, vmax=disp_vmax)
    axes[0, 2].set_title('DDM² First')
    axes[0, 2].axis('off')
    
    axes[0, 3].imshow(final_vol[:, :, vis_slice], cmap='gray', vmin=disp_vmin, vmax=disp_vmax)
    axes[0, 3].set_title('DDM² Final')
    axes[0, 3].axis('off')
    
    # 差异图
    axes[1, 0].axis('off')
    
    for i, (name, vol) in enumerate([('Noisy', noisy_vol), ('First', first_vol), ('Final', final_vol)]):
        diff = vol[:, :, vis_slice] - gt_data[:, :, vis_slice]
        axes[1, i+1].imshow(diff, cmap='RdBu', vmin=-30, vmax=30)
        axes[1, i+1].set_title(f'{name} - GT\nMAE={np.mean(np.abs(diff)):.2f}')
        axes[1, i+1].axis('off')
else:
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    axes[0].imshow(noisy_vol[:, :, vis_slice], cmap='gray', vmin=disp_vmin, vmax=disp_vmax)
    axes[0].set_title('Noisy')
    axes[0].axis('off')
    
    axes[1].imshow(first_vol[:, :, vis_slice], cmap='gray', vmin=disp_vmin, vmax=disp_vmax)
    axes[1].set_title('DDM² First')
    axes[1].axis('off')
    
    axes[2].imshow(final_vol[:, :, vis_slice], cmap='gray', vmin=disp_vmin, vmax=disp_vmax)
    axes[2].set_title('DDM² Final')
    axes[2].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(output_dir, f'visualization_slice{vis_slice}.png'), dpi=150)
plt.show()

print(f"\n可视化已保存")

## 10. 统计信息

In [None]:
print("=" * 60)
print("Volume 统计 (HU)")
print("-" * 60)
print(f"{'Image':<15} {'Min':>10} {'Max':>10} {'Mean':>10} {'Std':>10}")
print("-" * 60)

for name, vol in [('Noisy', noisy_vol), ('DDM2 First', first_vol), ('DDM2 Final', final_vol)]:
    print(f"{name:<15} {vol.min():>10.1f} {vol.max():>10.1f} {vol.mean():>10.1f} {vol.std():>10.1f}")

if HAS_GT:
    print(f"{'GT':<15} {gt_data.min():>10.1f} {gt_data.max():>10.1f} {gt_data.mean():>10.1f} {gt_data.std():>10.1f}")

print("=" * 60)

## 完成

**输出文件:**
- `ddm2_first.nii.gz`: 第一步结果
- `ddm2_final.nii.gz`: 最终结果
- `visualization_sliceX.png`: 可视化