In [4]:
import os
import json
from pathlib import Path

import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

from scipy.spatial import KDTree, cKDTree
from scipy.ndimage import affine_transform, distance_transform_edt, binary_erosion

# 可用：若后续需要用 SimpleITK 做其他注册，可直接调用
import SimpleITK as sitk

# ============================================================
# 0. 路径设置（与 preprocessing.ipynb 保持一致）
# ============================================================

PROJECT_ROOT = Path.cwd().parent
DATA_ROOT    = PROJECT_ROOT / "data"

DATA_RAS      = DATA_ROOT / "ras_1mm"
CSV_DIR       = DATA_ROOT / "csv"
FIG_DIR       = DATA_ROOT / "fig"
OUT_TRANSFORM = DATA_ROOT / "transforms_icp"
OUT_WARP      = DATA_ROOT / "warp_icp"

for p in [CSV_DIR, FIG_DIR, OUT_TRANSFORM, OUT_WARP]:
    p.mkdir(exist_ok=True, parents=True)

# ============================================================
# 1. 工具函数：加载、点云、ICP
# ============================================================

def load_nifti(path: Path):
    img = nib.load(str(path))
    data = img.get_fdata()
    affine = img.affine
    return img, data, affine

def mask_to_pointcloud_from_img(img, max_points=5000):
    """
    将三维二值 mask 转为世界坐标点云 (N,3)，用于 ICP。
    """
    data = img.get_fdata()
    mask = data > 0
    coords = np.array(np.where(mask)).T  # (N,3) in voxel index

    if coords.shape[0] == 0:
        return np.zeros((0, 3), dtype=float)

    ijk_h = np.c_[coords, np.ones(coords.shape[0])]
    xyz = (img.affine @ ijk_h.T).T[:, :3]  # world coords

    if xyz.shape[0] > max_points:
        idx = np.random.choice(xyz.shape[0], max_points, replace=False)
        xyz = xyz[idx]

    return xyz.astype(float)

def ICP_rigid(A, B, max_iter=50, tol=1e-5):
    """
    原始 ICP 实现（你成功的版本）：A -> B 的刚体变换。
    A, B: (N,3)
    返回：T (4x4), A_c (对齐后的 A)
    """
    if A.shape[0] == 0 or B.shape[0] == 0:
        return np.eye(4), A

    A_c = A.copy()
    prev_error = 1e10

    for i in range(max_iter):
        tree = KDTree(B)
        dist, idx = tree.query(A_c)
        B_match = B[idx]

        mu_A = A_c.mean(axis=0)
        mu_B = B_match.mean(axis=0)

        W = (A_c - mu_A).T @ (B_match - mu_B)
        U, S, Vt = np.linalg.svd(W)
        R = Vt.T @ U.T

        # reflection check
        if np.linalg.det(R) < 0:
            Vt[-1] *= -1
            R = Vt.T @ U.T

        t = mu_B - R @ mu_A

        A_c = (R @ A_c.T).T + t

        mean_error = dist.mean()
        if abs(prev_error - mean_error) < tol:
            break
        prev_error = mean_error

    T = np.eye(4)
    T[:3, :3] = R
    T[:3, 3] = t
    return T, A_c

# ============================================================
# 2. 标准 benchmark 的 Dice / HD95（在 mask 上计算）
# ============================================================

def compute_dice_from_mask(maskA_path, maskB_path):
    A = nib.load(str(maskA_path)).get_fdata() > 0
    B = nib.load(str(maskB_path)).get_fdata() > 0
    inter = np.logical_and(A, B).sum()
    denom = A.sum() + B.sum()
    if denom == 0:
        return 1.0
    return float(2.0 * inter / denom)

def compute_hd95_from_mask(maskA_path, maskB_path):
    imgA = nib.load(str(maskA_path))
    imgB = nib.load(str(maskB_path))

    A = imgA.get_fdata() > 0
    B = imgB.get_fdata() > 0

    # voxel spacing（用于把距离转成 mm）
    # 简单做法：spacing = 对角线的列范数
    affA = imgA.affine
    spacingA = np.sqrt((affA[:3, :3] ** 2).sum(axis=0))

    # 提取表面（边界）
    A_surf = A ^ binary_erosion(A)
    B_surf = B ^ binary_erosion(B)

    if not A_surf.any() or not B_surf.any():
        return float("nan")

    # 距离变换（在非前景上）
    dt_A = distance_transform_edt(~A, sampling=spacingA)
    dt_B = distance_transform_edt(~B, sampling=spacingA)

    d_AB = dt_B[A_surf]
    d_BA = dt_A[B_surf]

    d = np.concatenate([d_AB, d_BA])
    return float(np.percentile(d, 95))

# ============================================================
# 3. 使用 ICP 刚体变换 warp segmentation mask（NIfTI）
# ============================================================

def warp_mask_with_world_transform(mov_path, tgt_path, T_full, out_path):
    """
    mov_path: moving mask NIfTI (如 scapula_left)
    tgt_path: target mask NIfTI (如 scapula_right)
    T_full : 4x4，世界坐标刚体变换 (moving -> target)
    out_path: 保存 warp 后的 moving mask
    采用 nibabel + scipy.ndimage.affine_transform，严格按 affine 变换。
    """
    nii_mov = nib.load(str(mov_path))
    nii_tgt = nib.load(str(tgt_path))

    vol_mov = nii_mov.get_fdata()
    A_mov = nii_mov.affine  # 4x4
    A_tgt = nii_tgt.affine  # 4x4

    # 我们希望在 target 的体素网格上生成 warped moving：
    # 对于每个 target voxel index j：
    #   x_tgt = A_tgt @ [j,1]
    #   x_mov = T_full^{-1} @ x_tgt
    #   i_mov = A_mov^{-1} @ x_mov
    # affine_transform 需要 i_mov = M @ j + t
    T_inv = np.linalg.inv(T_full)
    V = np.linalg.inv(A_mov) @ T_inv @ A_tgt  # 4x4
    M = V[:3, :3]
    offs = V[:3, 3]

    shape_tgt = nii_tgt.shape

    warped = affine_transform(
        vol_mov,
        matrix=M,
        offset=offs,
        output_shape=shape_tgt,
        order=0  # 最近邻，保持 mask
    )

    nib.save(nib.Nifti1Image(warped, A_tgt), str(out_path))


# ============================================================
# 4. 2D overlay 可视化（warp 后与 target 对比）
# ============================================================

def plot_2d_overlay(pid, struct, warped_mask_path, tgt_mask_path, out_dir: Path):
    out_dir.mkdir(exist_ok=True, parents=True)
    out_path = out_dir / f"{pid}_{struct}_overlay.png"

    warped = nib.load(str(warped_mask_path)).get_fdata() > 0
    tgt    = nib.load(str(tgt_mask_path)).get_fdata() > 0

    # Axial / Coronal / Sagittal max projection
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))

    # Axial (Z)
    ax = axes[0]
    ax.imshow(tgt.max(axis=2).T, cmap="Blues", alpha=0.4, origin="lower")
    ax.imshow(warped.max(axis=2).T, cmap="Reds", alpha=0.4, origin="lower")
    ax.set_title("Axial"); ax.axis("off")

    # Coronal (Y)
    ax = axes[1]
    ax.imshow(tgt.max(axis=1).T, cmap="Blues", alpha=0.4, origin="lower")
    ax.imshow(warped.max(axis=1).T, cmap="Reds", alpha=0.4, origin="lower")
    ax.set_title("Coronal"); ax.axis("off")

    # Sagittal (X)
    ax = axes[2]
    ax.imshow(tgt.max(axis=0).T, cmap="Blues", alpha=0.4, origin="lower")
    ax.imshow(warped.max(axis=0).T, cmap="Reds", alpha=0.4, origin="lower")
    ax.set_title("Sagittal"); ax.axis("off")

    plt.suptitle(f"{pid} {struct} ICP warp vs target")
    plt.tight_layout()
    fig.savefig(out_path, dpi=200)
    plt.close(fig)

    return out_path


# ============================================================
# 5. 保存 ICP transform（JSON 版本，保持你之前的习惯）
# ============================================================

def save_transform_json(T_full, out_path: Path):
    out_path.parent.mkdir(exist_ok=True, parents=True)
    with open(out_path, "w") as f:
        json.dump({"T": T_full.tolist()}, f, indent=2)


# ============================================================
# 6. 单个 (pid, struct, repeat) 的完整流程：ICP + warp + Dice + HD95 + 可视化
# ============================================================

def evaluate_icp_single(pid, struct, src_path: Path, tgt_path: Path, repeat_id: int):
    """
    src_path : 左侧 (moving)，如 scapula_left.nii.gz
    tgt_path : 右侧 (fixed)，如 scapula_right.nii.gz
    """
    print(f"  [{struct}] repeat={repeat_id}: {src_path.name} -> {tgt_path.name}")

    img_src, _, _ = load_nifti(src_path)
    img_tgt, _, _ = load_nifti(tgt_path)

    # 1) mask -> point cloud (世界坐标)
    pts_src = mask_to_pointcloud_from_img(img_src)
    pts_tgt = mask_to_pointcloud_from_img(img_tgt)

    if pts_src.shape[0] == 0 or pts_tgt.shape[0] == 0:
        print("   Skip: empty mask")
        return None

    # 2) 左 → 右：X 轴镜像（同你原来的成功代码）
    mirror3 = np.diag([-1.0, 1.0, 1.0])
    pts_src_m = pts_src @ mirror3

    # 3) 对镜像后的点云做 ICP
    T_icp, pts_aligned = ICP_rigid(pts_src_m, pts_tgt)

    # 4) 组合成原始 left → right 的刚体：T_full = T_icp ∘ mirror
    M4 = np.eye(4)
    M4[0, 0] = -1.0
    T_full = T_icp @ M4

    # 5) 保存 transform（JSON）
    tf_path = OUT_TRANSFORM / f"{pid}_{struct}_r{repeat_id}.json"
    save_transform_json(T_full, tf_path)

    # 6) 使用 T_full warp segmentation mask（benchmark需要）
    warped_mask_path = OUT_WARP / f"{pid}_{struct}_r{repeat_id}.nii.gz"
    warped_mask_path.parent.mkdir(exist_ok=True, parents=True)

    warp_mask_with_world_transform(
        mov_path=src_path,
        tgt_path=tgt_path,
        T_full=T_full,
        out_path=warped_mask_path
    )

    # 7) 计算 Dice / HD95（在 mask 上）
    dice = compute_dice_from_mask(warped_mask_path, tgt_path)
    hd95 = compute_hd95_from_mask(warped_mask_path, tgt_path)

    # 8) 保存 2D overlay 可视化
    plot_2d_overlay(pid, struct, warped_mask_path, tgt_path, FIG_DIR)

    return {
        "pid": pid,
        "struct": struct,
        "repeat": repeat_id,
        "dice": dice,
        "hd95": hd95,
    }


# ============================================================
# 7. 主函数：遍历所有病人 + scapula / humerus + 多次 repeat
# ============================================================

def main_icp_benchmark(n_repeats=3):
    rows = []

    patient_dirs = sorted([p for p in DATA_RAS.iterdir() if p.is_dir()])
    if not patient_dirs:
        print("No patients found in", DATA_RAS)
        return

    for pid_dir in patient_dirs:
        pid = pid_dir.name
        print(f"\n=== Patient {pid} ===")

        for struct in ["scapula", "humerus"]:
            src_path = pid_dir / f"{struct}_left.nii.gz"
            tgt_path = pid_dir / f"{struct}_right.nii.gz"

            if not src_path.exists() or not tgt_path.exists():
                print(f"  [{struct}] missing files, skip.")
                continue

            for r in range(n_repeats):
                try:
                    res = evaluate_icp_single(pid, struct, src_path, tgt_path, r)
                    if res is not None:
                        rows.append(res)
                except Exception as e:
                    print(f"  ERROR {pid} {struct} repeat={r}: {e}")

    if rows:
        import pandas as pd
        df = pd.DataFrame(rows)
        out_csv = CSV_DIR / "icp_benchmark_dice_hd95.csv"
        df.to_csv(out_csv, index=False)
        print("\nAll done. Results saved to:", out_csv)
    else:
        print("\nNo valid ICP results, nothing saved.")


# ============================================================
# 8. 运行
# ============================================================

#if __name__ == "__main__":



In [5]:
main_icp_benchmark(n_repeats=3)


=== Patient s0970 ===
  [scapula] repeat=0: scapula_left.nii.gz -> scapula_right.nii.gz
  [scapula] repeat=1: scapula_left.nii.gz -> scapula_right.nii.gz
  [scapula] repeat=2: scapula_left.nii.gz -> scapula_right.nii.gz
  [humerus] repeat=0: humerus_left.nii.gz -> humerus_right.nii.gz
  [humerus] repeat=1: humerus_left.nii.gz -> humerus_right.nii.gz
  [humerus] repeat=2: humerus_left.nii.gz -> humerus_right.nii.gz

=== Patient s1029 ===
  [scapula] repeat=0: scapula_left.nii.gz -> scapula_right.nii.gz
  [scapula] repeat=1: scapula_left.nii.gz -> scapula_right.nii.gz
  [scapula] repeat=2: scapula_left.nii.gz -> scapula_right.nii.gz
  [humerus] repeat=0: humerus_left.nii.gz -> humerus_right.nii.gz
  [humerus] repeat=1: humerus_left.nii.gz -> humerus_right.nii.gz
  [humerus] repeat=2: humerus_left.nii.gz -> humerus_right.nii.gz

=== Patient s1124 ===
  [scapula] repeat=0: scapula_left.nii.gz -> scapula_right.nii.gz
  [scapula] repeat=1: scapula_left.nii.gz -> scapula_right.nii.gz
  [scap