In [None]:
fold_file = "./interpretability/interpretability_PDA_mixed5/fold_2_log.csv"
raw_vols_folder = "./data/raw/CPTAC_PDA_93_surv/cptacpda_93/CPTAC-PDA"
raw_vols_metadata = "./data/raw/CPTAC_PDA_93_surv/cptacpda_93/metadata.csv"
raw_segs_folder = "./data/raw/CPTAC_PDA_93_surv/Segmentations/CPTAC-PDA"
raw_segd_metadata = "./data/raw/CPTAC_PDA_93_surv/Segmentations/metadata.csv"
segs_csv = "./data/metadata_annotations/Metadata_Report_CPTAC-PDA_2023_07_14.csv"

In [1]:
#!/usr/bin/env python3
"""
Find which raw DICOM volumes produced a given set of pre-processed 224×224×66
arrays, by re-running the exact same preprocessing pipeline and comparing the
results voxel-wise.

For every (patient, index) entry in the fold CSV it logs **all** raw CT series
whose pre-processed array matches the reference within ε, and writes the
mapping to an output CSV.

Assumptions confirmed by the user
---------------------------------
• Raw volumes live in:  <raw_vols_folder>/<Patient>/<SeriesUID>/…/*.dcm
• Pre-processed arrays live in: <processed_folder>/<Patient>/<idx>.npy
• Target isotropic volume size = (224, 224, 224)
• Final depth  = 66 slices   (central padding/trim if needed)
• oversampling = False
"""

import argparse
import csv
import logging
import os
import re
import sys
import time
from datetime import datetime
from multiprocessing import Pool
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd

# --------------------------------------------------------------------------- #
#  Your existing utility functions (import exactly as in the original script)
# --------------------------------------------------------------------------- #
sys.path.insert(0, "../")  # Change if these utils live elsewhere
from util.data_util import (  # noqa: E402
    get_occupied_slices,
    load_single_volume,
    preprocess,
    remap_occupied_slices,
)

# --------------------------------------------------------------------------- #
#  Constants
# --------------------------------------------------------------------------- #
TARGET_SHAPE = [224, 224, 224]
FIX_DEPTH = 66


# --------------------------------------------------------------------------- #
#  Logging
# --------------------------------------------------------------------------- #
def setup_logging(log_dir: str = "./logs") -> str:
    os.makedirs(log_dir, exist_ok=True)
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    logfile = Path(log_dir) / f"match_log_{ts}.log"
    logging.basicConfig(
        filename=logfile,
        level=logging.DEBUG,
        format="%(asctime)s [%(levelname)s] %(message)s",
        filemode="w",
    )
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    console.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(message)s"))
    logging.getLogger().addHandler(console)
    logging.info("Logging initialised.")
    return str(logfile)


# --------------------------------------------------------------------------- #
#  Helper utilities
# --------------------------------------------------------------------------- #
def central_trim(indices: List[int], target_len: int) -> List[int]:
    """Trim a sorted list of slice indices to `target_len`, keeping the centre."""
    if len(indices) <= target_len:
        return indices
    mid = (indices[0] + indices[-1]) // 2
    half = target_len // 2
    start = mid - half
    end = start + target_len
    return list(range(start, end))


def enforce_depth(indices: List[int]) -> List[int]:
    """Pad or trim `indices` so that len(indices) == FIX_DEPTH."""
    if len(indices) == FIX_DEPTH:
        return indices

    padded = indices.copy()
    if len(indices) < FIX_DEPTH:
        left = indices[0] - 1
        right = indices[-1] + 1
        while len(padded) < FIX_DEPTH:
            if left >= 0:
                padded.insert(0, left)
                left -= 1
            if len(padded) < FIX_DEPTH and right < TARGET_SHAPE[2]:
                padded.append(right)
                right += 1
        return padded
    else:  # len(indices) > FIX_DEPTH
        return central_trim(indices, FIX_DEPTH)


def normalise_path(p: str) -> str:
    """Convert Windows back-slashes to ‘/’ for portability."""
    return p.replace("\\", "/") if os.name != "nt" else p


# --------------------------------------------------------------------------- #
#  Worker
# --------------------------------------------------------------------------- #
def worker(task) -> List[Tuple]:
    """
    Parameters
    ----------
    task : dict with keys
        patient_id, idx, ref_np_path, seg_df (filtered for this patient),
        raw_meta_df, args
    Returns
    -------
    list of (patient, idx, raw_series_uid, raw_folder, delta)
        One tuple per successful match.
    """
    patient = task["patient_id"]
    idx = task["idx"]
    ref_np = task["ref_np_path"]
    seg_df = task["seg_df"]
    raw_meta = task["raw_meta_df"]
    args = task["args"]

    out_tuples = []

    try:
        ref_arr = np.load(ref_np)
    except FileNotFoundError:
        logging.warning(f"[{patient} | {idx}] reference array not found: {ref_np}")
        return out_tuples

    # =======================================================================
    #  For each segmentation object → find its referenced CT series
    # =======================================================================
    for _, seg_row in seg_df.iterrows():
        seg_folder_rel = normalise_path(seg_row["File Location"].split(".\\")[-1])
        seg_path = Path(args.raw_segs_folder) / seg_folder_rel
        seg_path = seg_path.resolve()

        ct_series_uid = seg_row["ReferencedSeriesInstanceUID"].strip()

        # -------------------------------------------------------------------
        #  Locate the raw CT folder from metadata
        # -------------------------------------------------------------------
        ct_meta_row = raw_meta[raw_meta["Series UID"] == ct_series_uid]
        if ct_meta_row.empty:
            logging.debug(
                f"[{patient} | {idx}] Series UID {ct_series_uid} not in raw metadata"
            )
            continue

        ct_folder_rel = normalise_path(ct_meta_row.iloc[0]["File Location"])
        ct_path = Path(args.raw_vols_folder) / Path(*Path(ct_folder_rel).parts[2:])
        ct_path = ct_path.resolve()

        if not ct_path.exists():
            # Fallback for possible “-NA” suffix stripping
            ct_path = Path(str(ct_path).replace("-NA", ""))
            if not ct_path.exists():
                logging.debug(
                    f"[{patient} | {idx}] CT folder missing: {ct_path} (after -NA fix)"
                )
                continue

        # -------------------------------------------------------------------
        #  Re-run preprocessing (exact replica of original pipeline)
        # -------------------------------------------------------------------
        vol, dim, dcm_slices, direction = load_single_volume(str(ct_path))
        if vol is None:
            logging.debug(f"[{patient} | {idx}] Empty volume at {ct_path}")
            continue

        if direction == "sagittal":
            vol = vol.transpose(1, 0, 2)
        elif direction == "coronal":
            vol = vol.transpose(2, 0, 1)

        occupied = get_occupied_slices(str(seg_path / os.listdir(seg_path)[0]),
                                       dcm_slices, direction)
        if not occupied:
            logging.debug(f"[{patient} | {idx}] No occupied slices for {ct_path}")
            continue

        vol, zoom = preprocess(vol, TARGET_SHAPE)
        occupied = remap_occupied_slices(occupied, zoom[0])
        occupied = enforce_depth(sorted(occupied))

        new_arr = vol[:, :, occupied]  # → (224,224,66)
        if new_arr.shape != ref_arr.shape:
            continue

        delta = float(np.abs(new_arr.astype(np.float32) -
                             ref_arr.astype(np.float32)).max())

        if delta <= args.epsilon:
            logging.info(
                f"[{patient} | {idx}] MATCH  "
                f"{ct_series_uid}  →  Δ={delta:.2e},  path={ct_path}"
            )
            out_tuples.append((patient, idx, ct_series_uid, str(ct_path), delta))
        else:
            logging.debug(
                f"[{patient} | {idx}] no match (Δ={delta:.3e})  series {ct_series_uid}"
            )

    return out_tuples


# --------------------------------------------------------------------------- #
#  Main
# --------------------------------------------------------------------------- #
def main(argv=None):
    parser = argparse.ArgumentParser(description="Match raw CT volumes to "
                                                 "pre-processed 224×224×66 arrays")
    parser.add_argument("--fold_file", required=True)
    parser.add_argument("--processed_folder", required=True)
    parser.add_argument("--raw_vols_folder", required=True)
    parser.add_argument("--raw_vols_metadata", required=True)
    parser.add_argument("--raw_segs_folder", required=True)
    parser.add_argument("--segs_csv", required=True)

    parser.add_argument("--workers", type=int, default=4)
    parser.add_argument("--epsilon", type=float, default=1e-6)
    parser.add_argument("--out_csv", default="matched_raw_volumes.csv")

    args = parser.parse_args(argv)
    log_file = setup_logging()

    start = time.time()

    # ------------------------------------------------------------------ #
    #  Load driver tables
    # ------------------------------------------------------------------ #
    fold_df = pd.read_csv(args.fold_file)
    raw_meta_df = pd.read_csv(args.raw_vols_metadata)
    segs_df = pd.read_csv(args.segs_csv)

    # Ensure we have the columns we need
    segs_df = segs_df[segs_df["Annotation Type"] == "Segmentation"]

    logging.info(f"Fold rows: {len(fold_df)}")
    logging.info(f"Segmentation objects: {len(segs_df)}")

    # ------------------------------------------------------------------ #
    #  Build job list
    # ------------------------------------------------------------------ #
    jobs = []
    for _, row in fold_df.iterrows():
        patient = row["patient_id"]
        # index = number before '_embeddings.npy'
        m = re.search(r"(\d+)_embeddings\.npy$", row["ct_path"])
        if not m:
            logging.warning(f"Could not parse index from ct_path: {row['ct_path']}")
            continue
        idx = int(m.group(1))

        ref_np = Path(args.processed_folder) / patient / f"{idx}.npy"
        patient_segs = segs_df[segs_df["PatientID"] == patient]

        if patient_segs.empty:
            logging.debug(f"No segmentation rows for patient {patient}")
            continue

        jobs.append(
            dict(
                patient_id=patient,
                idx=idx,
                ref_np_path=str(ref_np),
                seg_df=patient_segs.copy(),
                raw_meta_df=raw_meta_df,
                args=args,
            )
        )

    # ------------------------------------------------------------------ #
    #  Parallel processing
    # ------------------------------------------------------------------ #
    logging.info(f"Launching pool with {args.workers} workers …")
    results = []
    with Pool(args.workers) as pool:
        for res in pool.imap_unordered(worker, jobs):
            results.extend(res)

    # ------------------------------------------------------------------ #
    #  Persist matches
    # ------------------------------------------------------------------ #
    if results:
        with open(args.out_csv, "w", newline="") as f:
            w = csv.writer(f)
            w.writerow(
                ["patient_id", "idx", "raw_series_uid", "raw_folder", "delta"]
            )
            w.writerows(results)
        logging.info(f"Wrote {len(results)} matches → {args.out_csv}")
    else:
        logging.warning("No matches were found!")

    elapsed = time.time() - start
    logging.info(f"Done in {time.strftime('%H:%M:%S', time.gmtime(elapsed))}")
    logging.info(f"Full log: {log_file}")


if __name__ == "__main__":
    main()


usage: ipykernel_launcher.py [-h] --fold_file FOLD_FILE --processed_folder
                             PROCESSED_FOLDER --raw_vols_folder
                             RAW_VOLS_FOLDER --raw_vols_metadata
                             RAW_VOLS_METADATA --raw_segs_folder
                             RAW_SEGS_FOLDER --segs_csv SEGS_CSV
                             [--workers WORKERS] [--epsilon EPSILON]
                             [--out_csv OUT_CSV]
ipykernel_launcher.py: error: the following arguments are required: --processed_folder, --raw_vols_folder, --raw_vols_metadata, --raw_segs_folder, --segs_csv


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
