In [1]:
import os
import torch
import numpy as np
import scipy.io as sio
from scipy.interpolate import interpn
import torch.nn.functional as F
from tqdm.notebook import tqdm
import glob
from monai.metrics import DiceMetric
from models import CycleTransMorph

In [2]:
DATA_DIR = "/mnt/hot/public/4DCT_datasets/DIR-Lab/all"
MODEL_PATH = "./model_runs/ctm_run_1/best_model.pth" 
IMG_SIZE = (128, 128, 128)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def load_dir_lab_case(mat_path):
    """
    Loads image, mask, and landmark data from a single .mat file.
    
    *** ASSUMPTION ***
    This function ASSUMES the following key names inside your .mat file:
    - 'image_T00': Inhale (moving) image (Z, Y, X)
    - 'image_T50': Exhale (fixed) image (Z, Y, X)
    - 'mask_T00': Inhale (moving) mask (Z, Y, X)
    - 'mask_T50': Exhale (fixed) mask (Z, Y, X)
    - 'landmarks_T00': Inhale landmarks (N, 3) in (x, y, z) mm
    - 'landmarks_T50': Exhale landmarks (N, 3) in (x, y, z) mm
    - 'spacing': Voxel spacing (3,) in (x_mm, y_mm, z_mm)
    - 'original_size': Original image size (3,) in (Z, Y, X)
    
    You MUST adapt these keys if your .mat files are structured differently.
    """
    data = sio.loadmat(mat_path)
    
    print(f"\nLoading {os.path.basename(mat_path)}. Found keys: {list(data.keys())}")
    
    try:
        inhale_img = data['T00'].astype(np.float32)
        exhale_img = data['T50'].astype(np.float32)
        inhale_mask = data['mask_T00'].astype(np.float32)
        exhale_mask = data['mask_T50'].astype(np.float32)
        
        inhale_lms = data['landmarks_T00'].astype(np.float32)
        exhale_lms = data['landmarks_T50'].astype(np.float32)
        
        spacing_xyz = data['spacing'].squeeze().astype(np.float32)
        original_size_zyx = data['original_size'].squeeze().astype(np.int32)
        
        # Normalize images
        inhale_img = (inhale_img - np.min(inhale_img)) / (np.max(inhale_img) - np.min(inhale_img))
        exhale_img = (exhale_img - np.min(exhale_img)) / (np.max(exhale_img) - np.min(exhale_img))

        return {
            "inhale_image": inhale_img,
            "exhale_image": exhale_img,
            "inhale_mask": inhale_mask,
            "exhale_mask": exhale_mask,
            "inhale_landmarks": inhale_lms,
            "exhale_landmarks": exhale_lms,
            "spacing_xyz": spacing_xyz, # (sx, sy, sz)
            "original_size_zyx": original_size_zyx # (oz, oy, ox)
        }
    
    except KeyError as e:
        print(f"\n--- FATAL ERROR ---")
        print(f"Could not find key {e} in {mat_path}.")
        print(f"Please edit Cell 5 to match your .mat file structure.")
        print(f"Available keys are: {list(data.keys())}")
        print("---------------------\n")
        raise

In [11]:
def preprocess_image(np_img, target_size):
    """
    Prepares a numpy image for the model (resize and add batch/channel dims).
    """
    img_tensor = torch.from_numpy(np_img).unsqueeze(0).unsqueeze(0) # (1, 1, D, H, W)
    # Resize using F.interpolate (must be 5D for 3D data)
    resized_tensor = F.interpolate(img_tensor, size=target_size, mode='trilinear', align_corners=False)
    return resized_tensor

In [12]:
def warp_landmarks_dvf(landmarks_xyz, dvf_np_zyx, spacing_xyz, original_size_zyx, target_size_zyx):
    """
    Warps a list of landmarks (in mm) using the predicted DVF.
    
    Args:
        landmarks_xyz: (N, 3) numpy array of landmarks in (x, y, z) PHYSICAL (mm) coordinates.
        dvf_np_zyx: (3, D, H, W) numpy array (e.g., 3x128x128x128) from the model.
                      The values are displacements in RESIZED VOXEL space.
                      Channels are (disp_z, disp_y, disp_x).
        spacing_xyz: (3,) numpy array of voxel spacing (sx, sy, sz) in mm.
        original_size_zyx: (3,) numpy array of original image size (oz, oy, ox).
        target_size_zyx: (3,) tuple of model's image size (tz, ty, tx).
    """
    
    # Get component DVF fields
    dvf_z, dvf_y, dvf_x = dvf_np_zyx[0], dvf_np_zyx[1], dvf_np_zyx[2]
    
    # Create interpolation grids (for the 128^3 DVF)
    grid_z = np.arange(target_size_zyx[0])
    grid_y = np.arange(target_size_zyx[1])
    grid_x = np.arange(target_size_zyx[2])
    
    # Get physical size (mm) / target size (voxels)
    # This gives us (mm / resized_voxel)
    oz, oy, ox = original_size_zyx
    tz, ty, tx = target_size_zyx
    sx, sy, sz = spacing_xyz
    
    scale_x = (ox * sx) / tx 
    scale_y = (oy * sy) / ty
    scale_z = (oz * sz) / tz
    
    warped_landmarks = []
    
    for (lx, ly, lz) in landmarks_xyz:
        # 1. Convert landmark (mm) to original voxel coords
        orig_vx = lx / sx
        orig_vy = ly / sy
        orig_vz = lz / sz
        
        # 2. Convert original voxel coords to resized (128) voxel coords (query point)
        query_x = orig_vx * (tx - 1) / (ox - 1)
        query_y = orig_vy * (ty - 1) / (oy - 1)
        query_z = orig_vz * (tz - 1) / (oz - 1)
        query_point_zyx = (query_z, query_y, query_x)
        
        # 3. Interpolate displacement vector (in resized voxel units)
        disp_z = interpn((grid_z, grid_y, grid_x), dvf_z, query_point_zyx, method='linear', bounds_error=False, fill_value=0)
        disp_y = interpn((grid_z, grid_y, grid_x), dvf_y, query_point_zyx, method='linear', bounds_error=False, fill_value=0)
        disp_x = interpn((grid_z, grid_y, grid_x), dvf_x, query_point_zyx, method='linear', bounds_error=False, fill_value=0)
        
        # 4. Convert displacement vector from (resized_voxel) to (mm)
        disp_mm_x = disp_x * scale_x
        disp_mm_y = disp_y * scale_y
        disp_mm_z = disp_z * scale_z
        disp_vector_mm = np.array([disp_mm_x, disp_mm_y, disp_mm_z])
        
        # 5. Add mm displacement to original mm landmark
        warped_lm_mm = np.array([lx, ly, lz]) + disp_vector_mm
        warped_landmarks.append(warped_lm_mm)
        
    return np.array(warped_landmarks)

In [13]:
def get_jacobian_determinant(dvf_numpy):
    """
    Calculates the 3D Jacobian determinant of the transformation T(p) = p + DVF(p).
    DVF shape is (3, D, H, W) with channels (disp_z, disp_y, disp_x).
    """
    
    # DVF components
    dvf_z = dvf_numpy[0]
    dvf_y = dvf_numpy[1]
    dvf_x = dvf_numpy[2]
    
    # Get gradients for each component
    # np.gradient returns (grad_z, grad_y, grad_x)
    grad_ux = np.gradient(dvf_x) 
    dux_dz, dux_dy, dux_dx = grad_ux[0], grad_ux[1], grad_ux[2]
    
    grad_uy = np.gradient(dvf_y)
    duy_dz, duy_dy, duy_dx = grad_uy[0], grad_uy[1], grad_uy[2]

    grad_uz = np.gradient(dvf_z)
    duz_dz, duz_dy, duz_dx = grad_uz[0], grad_uz[1], grad_uz[2]
    
    # Build Jacobian determinant
    # J = I + grad(DVF)
    J_11 = 1 + dux_dx
    J_12 = dux_dy
    J_13 = dux_dz
    
    J_21 = duy_dx
    J_22 = 1 + duy_dy
    J_23 = duy_dz
    
    J_31 = duz_dx
    J_32 = duz_dy
    J_33 = 1 + duz_dz
    
    # Compute determinant
    det = J_11 * (J_22 * J_33 - J_23 * J_32) \
        - J_12 * (J_21 * J_33 - J_23 * J_31) \
        + J_13 * (J_21 * J_32 - J_22 * J_31)
        
    return det

In [14]:
model = CycleTransMorph(img_size=IMG_SIZE).to(DEVICE)

try:
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
except FileNotFoundError:
    print(f"--- FATAL ERROR: Model file not found ---")
    print(f"Could not find model at: {MODEL_PATH}")
    print(f"Please update the MODEL_PATH variable in Cell 3.")
    raise

model.eval()
print("Model loaded successfully.")

  model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))


Model loaded successfully.


In [17]:
all_tre_errors = []
all_non_positive_jac = []
dice_metric = DiceMetric(include_background=True, reduction="mean_batch")

case_files = sorted(glob.glob(os.path.join(DATA_DIR, 'case*.mat')))

if not case_files:
    print(f"--- FATAL ERROR: No .mat files found ---")
    print(f"Could not find any 'case*.mat' files in: {DATA_DIR}")
    print(f"Please update the DATA_DIR variable in Cell 3.")
else:
    print(f"Found {len(case_files)} cases. Starting evaluation...")

    with torch.no_grad():
        for case_path in tqdm(case_files, desc="Processing DIR-Lab Cases"):
            # 1. Load Data
            data = load_dir_lab_case(case_path)
            
            # 2. Preprocess images and masks
            inhale_tensor = preprocess_image(data['inhale_image'], IMG_SIZE).to(DEVICE)
            exhale_tensor = preprocess_image(data['exhale_image'], IMG_SIZE).to(DEVICE)
            inhale_mask_tensor = preprocess_image(data['inhale_mask'], IMG_SIZE).to(DEVICE)
            exhale_mask_tensor = preprocess_image(data['exhale_mask'], IMG_SIZE).to(DEVICE)
            
            # 3. Run Model (Inhale -> Exhale)
            warped_inhale, dvf, svf = model(inhale_tensor, exhale_tensor)
            dvf_np = dvf.squeeze(0).cpu().numpy()
            
            # 4. Calculate TRE
            warped_lms = warp_landmarks_dvf(
                data['inhale_landmarks'],
                dvf_np,
                data['spacing_xyz'],
                data['original_size_zyx'],
                IMG_SIZE
            )
            tre_errors = np.sqrt(np.sum((warped_lms - data['exhale_landmarks']) ** 2, axis=1))
            all_tre_errors.extend(tre_errors)
            
            # 5. Calculate Jacobian
            jacobian_det = get_jacobian_determinant(dvf_np)
            all_non_positive_jac.append(np.sum(jacobian_det <= 0) / np.prod(jacobian_det.shape))
            
            # 6. Calculate DSC
            # Warp the inhale mask to the exhale space
            warped_inhale_mask = model.spatial_transformer(inhale_mask_tensor, dvf)
            
            # Binarize masks
            warped_mask_binary = (warped_inhale_mask > 0.5).float()
            exhale_mask_binary = (exhale_mask_tensor > 0.5).float()
            
            # Compute dice (MONAI expects one-hot, but binary (B,C,D,H,W) is fine)
            dice_metric(y_pred=warped_mask_binary, y=exhale_mask_binary)

    print("\nEvaluation loop complete.")

Found 10 cases. Starting evaluation...


Processing DIR-Lab Cases:   0%|          | 0/10 [00:00<?, ?it/s]


Loading case1.mat. Found keys: ['__header__', '__version__', '__globals__', 'T00', 'T10', 'T20', 'T30', 'T40', 'T50', 'T60', 'T70', 'T80', 'T90']

--- FATAL ERROR ---
Could not find key 'image_T00' in /mnt/hot/public/4DCT_datasets/DIR-Lab/all/case1.mat.
Please edit Cell 5 to match your .mat file structure.
Available keys are: ['__header__', '__version__', '__globals__', 'T00', 'T10', 'T20', 'T30', 'T40', 'T50', 'T60', 'T70', 'T80', 'T90']
---------------------



KeyError: 'image_T00'

In [None]:
if all_tre_errors:
    mean_tre = np.mean(all_tre_errors)
    std_tre = np.std(all_tre_errors)
    print(f"--- Target Registration Error (TRE) ---")
    print(f"Mean TRE:   {mean_tre:.4f} mm")
    print(f"Std TRE:    {std_tre:.4f} mm")
    print(f"(SOTA Target: < 1.5 mm)")
else:
    print("TRE calculation failed. Check data loading.")

if all_non_positive_jac:
    mean_jac = np.mean(all_non_positive_jac) * 100
    print(f"\n--- Jacobian Plausibility ---")
    print(f"Mean % Non-Positive Jacobians: {mean_jac:.6f} %")
    print(f"(Target: < 0.1 %)")
else:
    print("Jacobian calculation failed.")

try:
    mean_dice = dice_metric.aggregate().item()
    print(f"\n--- Dice Similarity Coefficient (DSC) ---")
    print(f"Mean DSC:   {mean_dice:.4f}")
    print(f"(Target: > 0.95)")
except Exception as e:
    print(f"\nDice calculation failed: {e}")