In [12]:
import os
import csv
import numpy as np
import nibabel as nib
from pathlib import Path
import matplotlib.pyplot as plt
import torch
from scipy.ndimage import map_coordinates
from scipy.spatial.distance import directed_hausdorff
os.environ["XFORMERS_DISABLED"] = "1"

# --------------------------
# Path setup (your structure)
# --------------------------
PROJECT_ROOT = Path.cwd().parent
DATA_ROOT    = PROJECT_ROOT / "data"

DATA_RAS_DINO = DATA_ROOT / "ras_1mm_dinoreg"
DATA_COMPLETE = DATA_ROOT / "complete"
CSV_DIR       = DATA_ROOT / "csv"
FIG_DIR       = DATA_ROOT / "fig" / "dinoreg"
OUT_TRANSFORM = DATA_ROOT / "transforms_dinoreg"
OUT_WARP      = DATA_ROOT / "warp_dinoreg"

# Create dirs
for p in [FIG_DIR, OUT_TRANSFORM, OUT_WARP]:
    p.mkdir(exist_ok=True, parents=True)

PAIRS_CSV = CSV_DIR / "pairs_dinoreg.csv"

STRUCTURES = ["scapula_left", "scapula_right", "humerus_left", "humerus_right"]


In [6]:
#!pip install omegaconf==2.3.0
#!pip install hydra-core==1.3.2

Collecting omegaconf==2.3.0
  Downloading omegaconf-2.3.0-py3-none-any.whl (79 kB)
     ---------------------------------------- 0.0/79.5 kB ? eta -:--:--
     ----- ---------------------------------- 10.2/79.5 kB ? eta -:--:--
     ----- ---------------------------------- 10.2/79.5 kB ? eta -:--:--
     -------------- ----------------------- 30.7/79.5 kB 217.9 kB/s eta 0:00:01
     ------------------- ------------------ 41.0/79.5 kB 245.8 kB/s eta 0:00:01
     ----------------------------- -------- 61.4/79.5 kB 297.7 kB/s eta 0:00:01
     -------------------------------------- 79.5/79.5 kB 294.7 kB/s eta 0:00:00
Collecting antlr4-python3-runtime==4.9.*
  Downloading antlr4-python3-runtime-4.9.3.tar.gz (117 kB)
     ---------------------------------------- 0.0/117.0 kB ? eta -:--:--
     --- ------------------------------------ 10.2/117.0 kB ? eta -:--:--
     -------------------- ------------------ 61.4/117.0 kB 1.1 MB/s eta 0:00:01
     ---------------------------------- - 112.6/117.

In [8]:
#!pip install tensorboard

Collecting tensorboard
  Downloading tensorboard-2.20.0-py3-none-any.whl (5.5 MB)
     ---------------------------------------- 0.0/5.5 MB ? eta -:--:--
     ---------------------------------------- 0.0/5.5 MB ? eta -:--:--
     ---------------------------------------- 0.0/5.5 MB 165.2 kB/s eta 0:00:34
     ---------------------------------------- 0.0/5.5 MB 217.9 kB/s eta 0:00:26
     ---------------------------------------- 0.1/5.5 MB 363.1 kB/s eta 0:00:16
     - -------------------------------------- 0.2/5.5 MB 1.1 MB/s eta 0:00:05
     ---- ----------------------------------- 0.7/5.5 MB 2.8 MB/s eta 0:00:02
     --------- ------------------------------ 1.3/5.5 MB 4.7 MB/s eta 0:00:01
     ------------------- -------------------- 2.7/5.5 MB 8.3 MB/s eta 0:00:01
     ------------------------- -------------- 3.6/5.5 MB 9.5 MB/s eta 0:00:01
     ------------------------------ --------- 4.3/5.5 MB 10.1 MB/s eta 0:00:01
     ----------------------------------- ---- 4.9/5.5 MB 10.7 MB/s 

In [13]:
# DONT INSTALL XFORMERS!!!!!
#!pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118

Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting torch
  Obtaining dependency information for torch from https://download.pytorch.org/whl/cu118/torch-2.7.1%2Bcu118-cp310-cp310-win_amd64.whl.metadata
  Using cached https://download.pytorch.org/whl/cu118/torch-2.7.1%2Bcu118-cp310-cp310-win_amd64.whl.metadata (27 kB)
Using cached https://download.pytorch.org/whl/cu118/torch-2.7.1%2Bcu118-cp310-cp310-win_amd64.whl (2817.2 MB)
Using cached https://download.pytorch.org/whl/cu118/torch-2.7.1%2Bcu118-cp310-cp310-win_amd64.whl (2817.2 MB)
Installing collected packages: torch
  Attempting uninstall: torch
    Found existing installation: torch 2.9.1
    Uninstalling torch-2.9.1:
      Successfully uninstalled torch-2.9.1
Successfully installed torch-2.7.1+cu118


In [14]:
!pip uninstall xformers -y



In [15]:
import torch
print("Torch OK:", torch.__version__)

Torch OK: 2.7.1+cu118


In [16]:
import sys
sys.path.append(str(PROJECT_ROOT / "src"))

from dinoReg import dinoReg

xFormers NOT available â€” using torch attention instead


In [2]:
def dice_score(a, b):
    a = (a > 0).astype(np.uint8)
    b = (b > 0).astype(np.uint8)
    inter = np.sum(a * b)
    return 2 * inter / (np.sum(a) + np.sum(b) + 1e-6)

def compute_hd95(a, b):
    a_pts = np.transpose(np.nonzero(a))
    b_pts = np.transpose(np.nonzero(b))
    if len(a_pts)==0 or len(b_pts)==0:
        return np.nan
    d1 = directed_hausdorff(a_pts, b_pts)[0]
    d2 = directed_hausdorff(b_pts, a_pts)[0]
    return np.percentile([d1, d2], 95)

def save_overlay(fixed, warped_mask, png_path):
    mid = fixed.shape[0] // 2
    plt.figure(figsize=(6,6))
    plt.imshow(fixed[mid], cmap="gray")
    plt.imshow(warped_mask[mid], cmap="Reds", alpha=0.4)
    plt.axis("off")
    plt.savefig(png_path, dpi=200, bbox_inches="tight")
    plt.close()


In [36]:
# Import model from your DINO-Reg folder
#from dinoReg import dinoReg

def run_dinoreg_once(arr_mov, arr_fix, affine, configs):
    model = dinoReg(lr=configs['lr'], smooth_weight=configs['smooth_weight'], num_iter=configs['iter_smooth_num'], feat_size=configs['feature_size'])

    disp = model.case_inference(
        arr_mov,
        arr_fix,
        arr_mov.shape,
        affine,
        case_id="tmp",
        disp_init=None,
        grid_sp_adam=configs["fm_downsample"],
        DINOReg_useMask=False
    )
    return disp


In [35]:
configs = {
    'smooth_weight' : 2, #50
    'lr' : 3,
    'num_iter' : 1000,
    'fm_downsample' : 1,
    'feature_size' : (112,96),
        # 'feature_size' : (80,70),
        # 'feature_size' : (150,129),
    'useSavedPCA' : False,
    'DINOReg_useMask' : True,
    'window' : True,
    'convex' : False,
    'ztrans' : False,
    'iter_smooth_num': 5,
    'iter_smooth_kernel': 7,
    'final_upsample': 1,
    'mask': 'slice fill stack'
    }


REPEAT = 3

results_csv = CSV_DIR / "dinoreg_results.csv"
with open(results_csv, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["pid", "struct", "repeat", "dice", "hd95"])


In [20]:
pair_list = []
with open(PAIRS_CSV, "r") as f:
    reader = csv.reader(f)
    for row in reader:
        pair_list.append(row)

pair_list[:5]


[['s0970/scapula_left.nii.gz', 's0970/scapula_right.nii.gz'],
 ['s0970/humerus_left.nii.gz', 's0970/humerus_right.nii.gz'],
 ['s1029/scapula_left.nii.gz', 's1029/scapula_right.nii.gz'],
 ['s1029/humerus_left.nii.gz', 's1029/humerus_right.nii.gz'],
 ['s1124/scapula_left.nii.gz', 's1124/scapula_right.nii.gz']]

In [37]:
for moving_rel, fixed_rel in pair_list:
    subject = moving_rel.split("/")[0]
    struct_name = moving_rel.split("/")[1].replace(".nii.gz", "")

    print(f"=== {subject} / {struct_name} ===")

    moving_path = DATA_RAS_DINO / moving_rel
    fixed_path  = DATA_RAS_DINO / fixed_rel

    # Load CT ROIs
    mov_img = nib.load(str(moving_path))
    fix_img = nib.load(str(fixed_path))

    arr_mov = mov_img.get_fdata()
    arr_fix = fix_img.get_fdata()
    aff_mov = mov_img.affine

    # Load masks from DATA_COMPLETE
    seg_dir = DATA_COMPLETE / subject / "segmentations"
    mask_mov = nib.load(str(seg_dir / f"{struct_name}.nii.gz")).get_fdata()

    # Determine opposite side mask
    if "left" in struct_name:
        side2 = struct_name.replace("left", "right")
    else:
        side2 = struct_name.replace("right", "left")

    mask_fix = nib.load(str(seg_dir / f"{side2}.nii.gz")).get_fdata()


    # -----------------------
    # Repeat N times
    # -----------------------
    for r in range(REPEAT):
        print(f"  Run {r+1}/{REPEAT}")

        # Run DINO-Reg
        disp = run_dinoreg_once(arr_mov, arr_fix, aff_mov, configs)

        # Save disp
        disp_path = OUT_TRANSFORM / f"{subject}_{struct_name}_disp_r{r}.nii.gz"
        nib.save(nib.Nifti1Image(disp, aff_mov), str(disp_path))

        # Prepare grid for warping
        disp_ch = np.moveaxis(disp, 3, 0)
        D,H,W = arr_mov.shape
        grid = np.meshgrid(np.arange(D), np.arange(H), np.arange(W), indexing='ij')

        # Warp moving mask
        warped_mask = map_coordinates(mask_mov, grid + disp_ch, order=0)
        warp_path = OUT_WARP / f"{subject}_{struct_name}_maskWarp_r{r}.nii.gz"
        nib.save(nib.Nifti1Image(warped_mask, aff_mov), str(warp_path))

        # Save overlay
        overlay_png = FIG_DIR / f"{subject}_{struct_name}_r{r}.png"
        save_overlay(arr_fix, warped_mask, overlay_png)

        # Metrics
        dice = dice_score(mask_fix, warped_mask)
        hd95 = compute_hd95(mask_fix, warped_mask)

        print("   dice =", dice, "  hd95 =", hd95)

        # Append to results csv
        with open(results_csv, "a", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([subject, struct_name, r, dice, hd95])


=== s0970 / scapula_left ===
  Run 1/3
DINOv2 model found.
learning rate 3
preprocessed moving and fixed image, shape (184, 204, 247) (192, 211, 247)
14
112 96 247
resized input shape (1568, 1344, 247)


OutOfMemoryError: CUDA out of memory. Tried to allocate 6.90 GiB. GPU 0 has a total capacity of 6.00 GiB of which 0 bytes is free. Of the allocated memory 9.44 GiB is allocated by PyTorch, and 40.46 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)