In [3]:
#!/usr/bin/env python3
"""
preprocessing_dinoreg.py

DINO-Reg preprocessing with unified project directory layout and a GLOBAL TEMP folder.

Directory layout (consistent with icp_baseline):

PROJECT_ROOT/
    data/
        raw/
        ras_1mm/
        ras_1mm_dinoreg/
        TEMP/                   <-- unified temp folder
        csv/
        fig/
        transforms_icp/
        warp_icp/
"""

import logging
from pathlib import Path
import numpy as np
import SimpleITK as sitk
import nibabel as nib


logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


# -------------------------------------------------------------
# Global Directory Structure (consistent with icp_baseline)
# -------------------------------------------------------------
PROJECT_ROOT = Path.cwd().parent
DATA_ROOT    = PROJECT_ROOT / "data"

DATA_RAW        = DATA_ROOT / "raw"
DATA_RAS        = DATA_ROOT / "ras_1mm"
DATA_RAS_DINO   = DATA_ROOT / "ras_3mm_dinoreg"    # New DINO-Reg output
DATA_TEMP       = DATA_ROOT / "TEMP"               # <--- unified TEMP folder
DATA_COMPLETE = DATA_ROOT / "complete"
CSV_DIR         = DATA_ROOT / "csv"
FIG_DIR         = DATA_ROOT / "fig"
OUT_TRANSFORM   = DATA_ROOT / "transforms_icp"
OUT_WARP        = DATA_ROOT / "warp_icp"

for p in [CSV_DIR, FIG_DIR, OUT_TRANSFORM, OUT_WARP, DATA_RAS_DINO, DATA_TEMP]:
    p.mkdir(exist_ok=True, parents=True)

STRUCTURES = ["scapula_left", "scapula_right", "humerus_left", "humerus_right"]


# -------------------------------------------------------------
# Preprocessor
# -------------------------------------------------------------
class DinoRegPreprocessor:
    def __init__(self, target_spacing=(3.0, 3.0, 3.0), margin_mm=50):
        self.target_spacing = target_spacing
        self.margin_mm = margin_mm

    # ---------------------------------------------------------
    # IO
    # ---------------------------------------------------------
    def load_image(self, path: Path) -> sitk.Image:
        if not path.exists():
            raise FileNotFoundError(str(path))
        return sitk.ReadImage(str(path))

    # ---------------------------------------------------------
    # CT × mask
    # ---------------------------------------------------------
    def extract_bone(self, ct_img, mask_img):
        ct = sitk.GetArrayFromImage(ct_img)
        mask = sitk.GetArrayFromImage(mask_img)
        bone = ct * (mask > 0)
        out = sitk.GetImageFromArray(bone)
        out.CopyInformation(ct_img)
        return out

    # ---------------------------------------------------------
    # ROI based on mask bounding box
    # ---------------------------------------------------------
    def crop_roi(self, img, mask_img):

        mask_arr = sitk.GetArrayFromImage(mask_img)
        coords = np.where(mask_arr > 0)

        if coords[0].size == 0:
            raise RuntimeError("Mask is empty!")

    # cast to pure Python int
        min_z = int(coords[0].min())
        max_z = int(coords[0].max())
        min_y = int(coords[1].min())
        max_y = int(coords[1].max())
        min_x = int(coords[2].min())
        max_x = int(coords[2].max())

        spacing = img.GetSpacing()

        # margin in voxels
        margin_x = int(self.margin_mm / spacing[0])
        margin_y = int(self.margin_mm / spacing[1])
        margin_z = int(self.margin_mm / spacing[2])

        # compute start/end in Python int only
        start_x = int(max(0, min_x - margin_x))
        start_y = int(max(0, min_y - margin_y))
        start_z = int(max(0, min_z - margin_z))

        end_x = int(min(img.GetSize()[0], max_x + margin_x + 1))
        end_y = int(min(img.GetSize()[1], max_y + margin_y + 1))
        end_z = int(min(img.GetSize()[2], max_z + margin_z + 1))

        # size is also pure Python int
        size_x = int(end_x - start_x)
        size_y = int(end_y - start_y)
        size_z = int(end_z - start_z)

        extractor = sitk.ExtractImageFilter()
        extractor.SetIndex([start_x, start_y, start_z])
        extractor.SetSize([size_x, size_y, size_z])

        return extractor.Execute(img)

    # ---------------------------------------------------------
    # Convert to RAS using nibabel.as_closest_canonical
    # **now uses GLOBAL DATA_TEMP instead of subject tmp folder**
    # ---------------------------------------------------------
    def convert_to_ras(self, img, subject_id: str, struct_name: str):
    # 创建专属临时目录
        tmp_dir = DATA_TEMP / subject_id / struct_name
        tmp_dir.mkdir(parents=True, exist_ok=True)

        tmp_path = tmp_dir / "ras_tmp.nii.gz"

    # 写入中间文件
        sitk.WriteImage(img, str(tmp_path))

    # nibabel 转换到 RAS
        nib_img = nib.load(str(tmp_path))
        ras_img = nib.as_closest_canonical(nib_img)
        nib.save(ras_img, str(tmp_path))

    # 再读回 SimpleITK
        ras_sitk = sitk.ReadImage(str(tmp_path))

    # 清理文件
        try:
            for f in tmp_dir.glob("*"):
                f.unlink()
            tmp_dir.rmdir()
        except:
            pass

        return ras_sitk


    # ---------------------------------------------------------
    # Resample to target spacing
    # ---------------------------------------------------------
    import numpy as np
    from scipy.ndimage import zoom

    def resample(self, img):
        spacing = img.GetSpacing()
        size = img.GetSize()

        new_size = [
            int(size[i] * (spacing[i] / self.target_spacing[i])) for i in range(3)
        ]

        resampler = sitk.ResampleImageFilter()
        resampler.SetInterpolator(sitk.sitkLinear)
        resampler.SetOutputSpacing(self.target_spacing)
        resampler.SetOutputOrigin(img.GetOrigin())
        resampler.SetOutputDirection(img.GetDirection())
        resampler.SetSize(new_size)

        return resampler.Execute(img)

    def resize_to_fixed_size(self,img: sitk.Image,target_size=(128, 128, 128),is_label=False,):

        original_size = img.GetSize()
        original_spacing = img.GetSpacing()

        new_size = target_size
        new_spacing = [original_spacing[i] * (original_size[i] / new_size[i]) for i in range(3)]

        resampler = sitk.ResampleImageFilter()
        resampler.SetSize(new_size)
        resampler.SetOutputSpacing(new_spacing)
        resampler.SetOutputOrigin(img.GetOrigin())
        resampler.SetOutputDirection(img.GetDirection())

        if is_label:
            resampler.SetInterpolator(sitk.sitkNearestNeighbor)
        else:
            resampler.SetInterpolator(sitk.sitkLinear)

        return resampler.Execute(img)

    # ---------------------------------------------------------
    # Intensity normalization
    # ---------------------------------------------------------
    def normalize(self, img):
        arr = sitk.GetArrayFromImage(img)
        arr = np.clip(arr, -1000, 2000)
        arr = (arr - arr.mean()) / (arr.std() + 1e-5)

        out = sitk.GetImageFromArray(arr)
        out.CopyInformation(img)
        return out

    # ---------------------------------------------------------
    # Process ONE structure
    # ---------------------------------------------------------
    def process_structure(self, subject_id: str, ct_path: Path, mask_path: Path):
        logger.info(f"[{subject_id}] Processing {mask_path.name}")
        mask_name = mask_path.name.replace(".nii.gz", "")

        ct_img = self.load_image(ct_path)
        mask_img = self.load_image(mask_path)

        #bone_img = self.extract_bone(ct_img, mask_img)
        roi_img = self.crop_roi(ct_img, mask_img)
        
        ras_img = self.convert_to_ras(roi_img, subject_id, mask_name)
        # now uses DATA_TEMP
        resampled = self.resample(ras_img)
        resized = self.resize_to_fixed_size(resampled,target_size=(128, 128, 128),is_label=False,)
        final_img = self.normalize(resampled)

        # write output
        subj_out_dir = DATA_RAS_DINO / subject_id
        subj_out_dir.mkdir(exist_ok=True, parents=True)

        
        out_path = subj_out_dir / f"{mask_name}.nii.gz"
        sitk.WriteImage(final_img, str(out_path))

        logger.info(f"[{subject_id}] Saved: {out_path}")
        return out_path.relative_to(DATA_RAS_DINO)

    # ---------------------------------------------------------
    # Process ONE subject
    # ---------------------------------------------------------
    def process_subject(self, subject_dir: Path):
        subject_id = subject_dir.name

        ct_path = subject_dir / "ct.nii.gz"
        seg_dir = subject_dir / "segmentations"

        if not ct_path.exists() or not seg_dir.exists():
            logger.warning(f"[{subject_id}] Missing CT or segmentations")
            return []

        processed = {}

        for struct in STRUCTURES:
            mask_path = seg_dir / f"{struct}.nii.gz"
            if mask_path.exists():
                rel = self.process_structure(subject_id, ct_path, mask_path)
                processed[struct] = rel

        # generate left → right pairs
        pairs = []
        if "scapula_left" in processed and "scapula_right" in processed:
            pairs.append((processed["scapula_left"], processed["scapula_right"]))
        if "humerus_left" in processed and "humerus_right" in processed:
            pairs.append((processed["humerus_left"], processed["humerus_right"]))

        return pairs

    # ---------------------------------------------------------
    # Batch processing
    # ---------------------------------------------------------
    def batch(self):
        logger.info("Starting DINO-Reg preprocessing ...")

        subject_dirs = sorted([p for p in DATA_COMPLETE.glob("s*") if p.is_dir()])
        logger.info(f"Found {len(subject_dirs)} subjects")

        all_pairs = []

        for subj in subject_dirs:
            pairs = self.process_subject(subj)
            all_pairs.extend(pairs)

        # write pairs.csv
        pairs_csv = CSV_DIR / "pairs_dinoreg.csv"
        with pairs_csv.open("w") as f:
            for m, r in all_pairs:
                f.write(f"{m.as_posix()},{r.as_posix()}\n")

        logger.info(f"Written pairs_dinoreg.csv → {pairs_csv}")

        structures_csv = CSV_DIR / "structures_dinoreg.csv"
        with structures_csv.open("w") as f:
            f.write("1\n")

        logger.info(f"Written structures_dinoreg.csv → {structures_csv}")


# -------------------------------------------------------------
# Main
# -------------------------------------------------------------
#if __name__ == "__main__":



In [4]:
pre = DinoRegPreprocessor()
pre.batch()

2025-12-14 15:29:47,808 - INFO - Starting DINO-Reg preprocessing ...
2025-12-14 15:29:47,811 - INFO - Found 4 subjects
2025-12-14 15:29:47,815 - INFO - [s0970] Processing scapula_left.nii.gz
2025-12-14 15:29:49,012 - INFO - [s0970] Saved: C:\Users\lenovo\Documents\registration_project\data\ras_3mm_dinoreg\s0970\scapula_left.nii.gz
2025-12-14 15:29:49,029 - INFO - [s0970] Processing scapula_right.nii.gz
2025-12-14 15:29:50,169 - INFO - [s0970] Saved: C:\Users\lenovo\Documents\registration_project\data\ras_3mm_dinoreg\s0970\scapula_right.nii.gz
2025-12-14 15:29:50,183 - INFO - [s0970] Processing humerus_left.nii.gz
2025-12-14 15:29:51,437 - INFO - [s0970] Saved: C:\Users\lenovo\Documents\registration_project\data\ras_3mm_dinoreg\s0970\humerus_left.nii.gz
2025-12-14 15:29:51,460 - INFO - [s0970] Processing humerus_right.nii.gz
2025-12-14 15:29:52,635 - INFO - [s0970] Saved: C:\Users\lenovo\Documents\registration_project\data\ras_3mm_dinoreg\s0970\humerus_right.nii.gz
2025-12-14 15:29:52,6

In [None]:
# Manual check if the ROI complete or not
# check magrin =70, for scapula, including both scapula + humerus
# magrin = 50, for scapula, 50 stil including both scapula + humerus