In [None]:
import os, sys, csv, time, traceback, re
from pathlib import Path
from typing import Tuple, Dict, List, Optional
import numpy as np
import pandas as pd
import SimpleITK as sitk
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

IMAGES_DIR    = r"\image"
MASKS_DIR     = r"\mask"
OUT_MASKS_DIR = r"\split_masks"
QC_OUTPUT_DIR = r"\mask_QC_png"

LOG_CSV   = r"\mask_vertebra_split_log.csv"
ERROR_LOG = r"\mask_vertebra_split_error.log"

ENABLE_BRIDGE_THICKNESS_SPLIT = True
BRIDGE_MAX_THICKNESS_MM = 2.0
BRIDGE_GROW_BACK_MM     = 1.2
BRIDGE_MAX_GEO_ITERS    = 12

ENABLE_WATERSHED_SPLIT = True
WS_GAUSS_SIGMA_MM      = 1.0
WS_HMAX_MM             = 2.0
WS_MARK_WATERSHED_LINE = False
WS_FULLY_CONNECTED     = True

MIN_VOXELS = 300
MIN_LARGEST_SLICE_AREA = 150
MIN_BBOX_VOX = (5, 5, 2)
MIN_PHYS_VOL_MM3 = 0.0
FULLY_CONNECTED = False
MAX_COMPONENTS_PER_CASE = 1000

MORPH_OPEN_RADIUS = (1, 1, 0)
HOLE_FILL_MODE = "fillholes"
HOLE_RADIUS = (1, 1, 1)
MAJORITY_THRESHOLD = 1

ENABLE_QC = True
QC_MAX_CASES = None
QC_DPI = 160
SAGITTAL_TRANSPOSE = True
SAGITTAL_ROTATE_K = 1
SAGITTAL_FLIPUD   = False
SAGITTAL_FLIPLR   = False
QC_WINDOW_LEVEL = 500.0
QC_WINDOW_WIDTH = 2000.0

ALLOW_MASK_ONLY = True
VALID_EXTS = (".nii", ".nii.gz")
CHANNEL_REGEX = re.compile(r'_(\d{4})$')

FINAL_DILATE_MM = 0.8

def ts() -> str:
    import time as _t
    return _t.strftime("%Y-%m-%d %H:%M:%S")

def log_error(error_log_path: Path, key: str, err: str, detail: str = ""):
    error_log_path.parent.mkdir(parents=True, exist_ok=True)
    with open(error_log_path, "a", encoding="utf-8") as f:
        f.write("\n" + "="*80 + "\n")
        f.write(f"Time: {ts()}\nObject: {key}\nError: {err}\n")
        if detail:
            f.write("Details:\n" + detail + "\n")
        f.write("="*80 + "\n")

def split_name_ext_nii(p: Path) -> Tuple[str, str]:
    s = p.name
    if s.lower().endswith(".nii.gz"):
        return s[:-7], ".nii.gz"
    elif s.lower().endswith(".nii"):
        return s[:-4], ".nii"
    else:
        return p.stem, p.suffix

def is_nii_file(p: Path) -> bool:
    s = p.name.lower()
    return s.endswith(".nii") or s.endswith(".nii.gz")

def norm_id(name_no_ext: str) -> str:
    base = name_no_ext
    base = CHANNEL_REGEX.sub("", base)
    base = re.sub(r'(?:_seg|_mask|_label|_labels?|_gt)$', '', base, flags=re.IGNORECASE)
    base = re.sub(r'\.\d+$', '', base)
    return base.lower().strip()

def binarize_mask(msk: sitk.Image) -> sitk.Image:
    return sitk.Cast(sitk.NotEqual(msk, 0), sitk.sitkUInt8)

def fill_holes(binary: sitk.Image) -> sitk.Image:
    if HOLE_FILL_MODE.lower() == "fillholes":
        return sitk.BinaryFillhole(binary, True)
    vf = sitk.VotingBinaryHoleFillingImageFilter()
    vf.SetRadius(HOLE_RADIUS)
    vf.SetMajorityThreshold(MAJORITY_THRESHOLD)
    vf.SetBackgroundValue(0)
    vf.SetForegroundValue(1)
    return vf.Execute(binary)

def morph_open(binary: sitk.Image, radius_xyz: Tuple[int,int,int]) -> sitk.Image:
    rx, ry, rz = radius_xyz
    if max(rx, ry, rz) <= 0:
        return binary
    return sitk.BinaryMorphologicalOpening(binary, radius_xyz, sitk.sitkBall)

def connected_components(binary: sitk.Image, fully_connected: bool=False) -> sitk.Image:
    f = sitk.ConnectedComponentImageFilter()
    f.SetFullyConnected(fully_connected)
    cc = f.Execute(binary)
    return sitk.RelabelComponent(cc, True)

def label_stats(label_img: sitk.Image) -> Dict[int, Dict]:
    ls = sitk.LabelShapeStatisticsImageFilter(); ls.Execute(label_img)
    stats = {}
    for lab in ls.GetLabels():
        vox = int(ls.GetNumberOfPixels(lab))
        bb = ls.GetBoundingBox(lab)
        cx, cy, cz = ls.GetCentroid(lab)
        stats[int(lab)] = dict(
            voxels=vox,
            bbox_min_idx=(bb[0], bb[1], bb[2]),
            bbox_size=(bb[3], bb[4], bb[5]),
            bbox_max_idx=(bb[0]+bb[3]-1, bb[1]+bb[4]-1, bb[2]+bb[5]-1),
            centroid_phys=(cx, cy, cz),
        )
    return stats

def isolate_component(label_img: sitk.Image, label_value: int, ref_like: sitk.Image) -> sitk.Image:
    comp = sitk.Equal(label_img, label_value)
    comp = sitk.Cast(comp, sitk.sitkUInt8)
    comp.CopyInformation(ref_like)
    return comp

def binary_dilate_mm(binary: sitk.Image, mm: float) -> sitk.Image:
    if mm is None or mm <= 0:
        return binary
    sp = binary.GetSpacing()
    rx = int(np.ceil(mm / float(sp[0]))) if sp[0] > 0 else 1
    ry = int(np.ceil(mm / float(sp[1]))) if sp[1] > 0 else 1
    rz = int(np.ceil(mm / float(sp[2]))) if sp[2] > 0 else 1
    rx = max(rx, 1); ry = max(ry, 1); rz = max(rz, 1)
    dil = sitk.BinaryDilate(binary, (rx, ry, rz), sitk.sitkBall)
    dil = sitk.Cast(dil, sitk.sitkUInt8)
    dil.CopyInformation(binary)
    return dil

def build_index_by_id(folder: Path) -> Dict[str, List[Path]]:
    idx: Dict[str, List[Path]] = {}
    for p in folder.rglob("*"):
        if p.is_file() and is_nii_file(p):
            key = norm_id(split_name_ext_nii(p)[0])
            idx.setdefault(key, []).append(p)
    return idx

def choose_best_image(mask_base: str, candidates: List[Path]) -> Path:
    prefer = [c for c in candidates if split_name_ext_nii(c)[0].lower() == (mask_base.lower()+"_0000")]
    if prefer: return prefer[0]
    any0000 = [c for c in candidates if split_name_ext_nii(c)[0].endswith("_0000")]
    return any0000[0] if any0000 else candidates[0]

def list_pairs(images_dir: Path, masks_dir: Path, allow_mask_only: bool=True) -> List[Tuple[Optional[Path], Path]]:
    img_idx = build_index_by_id(images_dir)
    pairs: List[Tuple[Optional[Path], Path]] = []
    for mp in masks_dir.rglob("*"):
        if mp.is_file() and is_nii_file(mp):
            base, _ = split_name_ext_nii(mp)
            key = norm_id(base)
            if key in img_idx and img_idx[key]:
                pairs.append((choose_best_image(base, img_idx[key]), mp))
            elif allow_mask_only:
                pairs.append((None, mp))
    return pairs

def split_by_bridge_thickness(binary: sitk.Image,
                              bridge_max_thickness_mm: float,
                              grow_back_mm: float,
                              max_iters: int = 12) -> Optional[sitk.Image]:
    if bridge_max_thickness_mm <= 0 or grow_back_mm <= 0:
        return None

    D = sitk.SignedMaurerDistanceMap(binary, True, False, True)

    T = float(bridge_max_thickness_mm) / 2.0
    thick = sitk.Cast(sitk.GreaterEqual(D, T), sitk.sitkUInt8)
    thick = sitk.And(thick, binary)

    cc = connected_components(thick, True)
    arr = sitk.GetArrayFromImage(cc)
    n_labels = int(arr.max())
    if n_labels < 2:
        return None

    allow = sitk.Cast(sitk.GreaterEqual(D, float(grow_back_mm)), sitk.sitkUInt8)
    allow = sitk.And(allow, binary)

    out = sitk.Image(binary.GetSize(), sitk.sitkUInt8); out.CopyInformation(binary)

    for lab in range(1, n_labels+1):
        seed = sitk.Equal(cc, lab)
        prev = None
        for _ in range(max_iters):
            seed = sitk.And(sitk.BinaryDilate(seed, (1,1,1), sitk.sitkBall), allow)
            if prev is not None:
                if np.array_equal(sitk.GetArrayFromImage(seed), sitk.GetArrayFromImage(prev)):
                    break
            prev = seed
        out = sitk.Or(out, seed)

    return sitk.Cast(out, sitk.sitkUInt8)

def watershed_split(binary: sitk.Image) -> sitk.Image:
    dist = sitk.SignedMaurerDistanceMap(binary, True, False, True)
    if WS_GAUSS_SIGMA_MM and WS_GAUSS_SIGMA_MM > 0:
        dist = sitk.SmoothingRecursiveGaussian(dist, WS_GAUSS_SIGMA_MM)

    if WS_HMAX_MM and WS_HMAX_MM > 0:
        dist_for_seeds = sitk.HMaxima(dist, WS_HMAX_MM)
        seeds_bin = sitk.RegionalMaxima(dist_for_seeds, True)
    else:
        seeds_bin = sitk.RegionalMaxima(dist, True)

    seed_lab_f = sitk.ConnectedComponentImageFilter()
    seed_lab_f.SetFullyConnected(WS_FULLY_CONNECTED)
    seeds_lab = seed_lab_f.Execute(sitk.Cast(seeds_bin, sitk.sitkUInt8))

    ws = sitk.MorphologicalWatershedFromMarkers(-dist, seeds_lab,
                                                WS_MARK_WATERSHED_LINE,
                                                WS_FULLY_CONNECTED)
    ws_masked = sitk.Mask(ws, binary)

    f2 = sitk.ConnectedComponentImageFilter()
    f2.SetFullyConnected(WS_FULLY_CONNECTED)
    ws_cc = f2.Execute(ws_masked > 0)
    ws_relab = sitk.RelabelComponent(ws_cc, True)
    ws_relab.CopyInformation(binary)
    return ws_relab

def _extract_sagittal(img: sitk.Image, x_index: int, transpose: bool=True) -> np.ndarray:
    arr = sitk.GetArrayFromImage(img)
    x_index = int(np.clip(x_index, 0, arr.shape[2]-1))
    sl = arr[:, :, x_index]
    return sl.T if transpose else sl

def _sag_idx_from_centroid(image_like: sitk.Image, centroid_phys) -> int:
    try:
        x_idx = image_like.TransformPhysicalPointToIndex(centroid_phys)[0]
    except Exception:
        x_idx = sitk.GetArrayFromImage(image_like).shape[2] // 2
    return int(np.clip(x_idx, 0, sitk.GetArrayFromImage(image_like).shape[2]-1))

def _apply_orientation(arr2d: np.ndarray) -> np.ndarray:
    out = arr2d
    if SAGITTAL_ROTATE_K in (1, 2, 3):
        out = np.rot90(out, SAGITTAL_ROTATE_K)
    return out

def _window_to_uint8(slice2d: np.ndarray, level: float, width: float) -> np.ndarray:
    lower = level - width / 2.0
    upper = level + width / 2.0
    sl = slice2d.astype(np.float32)
    sl = np.clip(sl, lower, upper)
    sl = (sl - lower) / (upper - lower + 1e-6)
    sl = np.clip(sl, 0, 1)
    return (sl * 255.0).astype(np.uint8)

def save_sagittal_png(out_png: Path, mask_after: sitk.Image, centroid_phys, image: Optional[sitk.Image],
                      transpose=True, dpi=160):
    x_idx = _sag_idx_from_centroid(mask_after, centroid_phys)
    m2d = _extract_sagittal(mask_after > 0, x_idx, transpose).astype(np.uint8)
    m2d = _apply_orientation(m2d)

    out_png.parent.mkdir(parents=True, exist_ok=True)
    if image is not None:
        im2d = _extract_sagittal(image, x_idx, transpose)
        im2d = _apply_orientation(im2d)
        im_show = _window_to_uint8(im2d, level=QC_WINDOW_LEVEL, width=QC_WINDOW_WIDTH)

        fig = plt.figure(figsize=(6, 6), dpi=dpi)
        ax = fig.add_subplot(1, 1, 1)
        ax.imshow(im_show, cmap="gray")
        m2d_masked = np.ma.masked_where(m2d == 0, m2d)
        ax.imshow(m2d_masked, cmap="Reds", alpha=0.7, interpolation="nearest")
        ax.axis("off")
        plt.savefig(str(out_png), bbox_inches="tight", pad_inches=0); plt.close(fig)
    else:
        fig = plt.figure(figsize=(6, 6), dpi=dpi)
        ax = fig.add_subplot(1, 1, 1)
        ax.imshow(m2d * 255, cmap="Reds")
        ax.axis("off")
        plt.savefig(str(out_png), bbox_inches="tight", pad_inches=0); plt.close(fig)

def main():
    images_dir, masks_dir = Path(IMAGES_DIR), Path(MASKS_DIR)
    out_msk_dir, qc_dir = Path(OUT_MASKS_DIR), Path(QC_OUTPUT_DIR)
    log_csv, err_log = Path(LOG_CSV), Path(ERROR_LOG)

    for p in [out_msk_dir, qc_dir, log_csv.parent, err_log.parent]:
        p.mkdir(parents=True, exist_ok=True)
    with open(err_log, "w", encoding="utf-8") as f:
        f.write(f"Error Log - Generated at: {ts()}\n")

    print("="*80)
    print(" Multi-vertebra mask → single-vertebra instance splitting (bridge thickness + watershed fallback; QC sagittal PNG)")
    print("="*80)
    print(f"[Path] images: {images_dir}")
    print(f"[Path] masks : {masks_dir}")
    print(f"[Out ] masks : {out_msk_dir}")
    print(f"[QC  ] out   : {qc_dir}")
    print(f"[Bridge] ENABLE={ENABLE_BRIDGE_THICKNESS_SPLIT}, MAX_THICK={BRIDGE_MAX_THICKNESS_MM}mm, GROW_BACK≥{BRIDGE_GROW_BACK_MM}mm")
    print(f"[Fallback] watershed: GAUSS={WS_GAUSS_SIGMA_MM}mm, HMAX={WS_HMAX_MM}mm")
    print(f"[Thresh] vox≥{MIN_VOXELS}, area2D≥{MIN_LARGEST_SLICE_AREA}, bbox≥{MIN_BBOX_VOX}, vol_mm3≥{MIN_PHYS_VOL_MM3}")
    print(f"[Morph] OPEN_RADIUS={MORPH_OPEN_RADIUS}, holefill={HOLE_FILL_MODE}")
    print(f"[Orient] ROTATE_K={SAGITTAL_ROTATE_K}, TRANSPOSE={SAGITTAL_TRANSPOSE}")
    print(f"[Window] WL={QC_WINDOW_LEVEL}, WW={QC_WINDOW_WIDTH}")
    print(f"[Switch] ENABLE_QC={ENABLE_QC}, ALLOW_MASK_ONLY={ALLOW_MASK_ONLY}")

    if not images_dir.exists() or not masks_dir.exists():
        print("[Error] Input directory does not exist"); sys.exit(1)

    pairs = list_pairs(images_dir, masks_dir, allow_mask_only=ALLOW_MASK_ONLY)
    if QC_MAX_CASES is not None: pairs = pairs[:QC_MAX_CASES]
    print(f"[Info] Total cases to process: {len(pairs)}")

    rows = []; total_cases = total_instances = skipped_cases = error_cases = 0
    for idx, (img_path, msk_path) in enumerate(pairs, 1):
        base, ext = split_name_ext_nii(msk_path)
        print("\n" + "-"*80); print(f"[{idx}/{len(pairs)}] Processing: {base}{ext}")
        try:
            img = sitk.ReadImage(str(img_path)) if img_path is not None else None
            msk = sitk.ReadImage(str(msk_path))

            bin_msk = binarize_mask(msk)
            bin_clean = morph_open(bin_msk, MORPH_OPEN_RADIUS)
            filled_global = fill_holes(bin_clean)

            if np.count_nonzero(sitk.GetArrayFromImage(filled_global)) == 0:
                note = "Mask is all zeros (still empty after opening/hole filling), skipped"
                print("  [Skip] " + note)
                rows.append({"case_name": base, "instance_index": "", "voxels": 0,
                             "bbox_min_x":"", "bbox_min_y":"", "bbox_min_z":"",
                             "bbox_max_x":"", "bbox_max_y":"", "bbox_max_z":"",
                             "centroid_x":"", "centroid_y":"", "centroid_z":"",
                             "out_image":"", "out_mask":"", "note": note})
                skipped_cases += 1; continue

            label_img = None
            if ENABLE_BRIDGE_THICKNESS_SPLIT:
                new_mask = split_by_bridge_thickness(
                    filled_global,
                    bridge_max_thickness_mm=BRIDGE_MAX_THICKNESS_MM,
                    grow_back_mm=BRIDGE_GROW_BACK_MM,
                    max_iters=BRIDGE_MAX_GEO_ITERS
                )
                if new_mask is not None:
                    label_img = connected_components(new_mask, True)

            if label_img is None:
                if ENABLE_WATERSHED_SPLIT:
                    label_img = watershed_split(filled_global)
                else:
                    label_img = connected_components(filled_global, FULLY_CONNECTED)

            stats = label_stats(label_img)

            spacing = msk.GetSpacing()
            vox_mm3 = float(spacing[0]*spacing[1]*spacing[2]) if spacing else 1.0
            arr_label = sitk.GetArrayFromImage(label_img)

            keep_labels = []
            for lab, info in stats.items():
                vox = info["voxels"]
                if vox < MIN_VOXELS: continue
                sx, sy, sz = info["bbox_size"]
                if (sx < MIN_BBOX_VOX[0]) or (sy < MIN_BBOX_VOX[1]) or (sz < MIN_BBOX_VOX[2]): continue
                comp_mask = (arr_label == lab)
                largest_slice_area = int(comp_mask.reshape(comp_mask.shape[0], -1).sum(axis=1).max())
                if largest_slice_area < MIN_LARGEST_SLICE_AREA: continue
                if MIN_PHYS_VOL_MM3 > 0.0 and vox * vox_mm3 < MIN_PHYS_VOL_MM3: continue
                keep_labels.append(lab)

            if len(keep_labels) > MAX_COMPONENTS_PER_CASE:
                keep_labels = keep_labels[:MAX_COMPONENTS_PER_CASE]

            if not keep_labels:
                note = "No connected components meet the thresholds"
                print("  [Skip] " + note)
                rows.append({"case_name": base, "instance_index": "", "voxels": 0,
                             "bbox_min_x":"", "bbox_min_y":"", "bbox_min_z":"",
                             "bbox_max_x":"", "bbox_max_y":"", "bbox_max_z":"",
                             "centroid_x":"", "centroid_y":"", "centroid_z":"",
                             "out_image":"", "out_mask":"", "note": note})
                skipped_cases += 1; continue

            print(f"  Total labels: {len(stats)}, kept: {len(keep_labels)}")

            for k, lab in enumerate(keep_labels, 1):
                out_msk = out_msk_dir / f"{base}__v{k}{ext}"
                comp_after = isolate_component(label_img, lab, msk)

                if FINAL_DILATE_MM and FINAL_DILATE_MM > 0:
                    comp_after = binary_dilate_mm(comp_after, FINAL_DILATE_MM)

                sitk.WriteImage(comp_after, str(out_msk))

                info = stats[int(lab)]
                (minx, miny, minz) = info["bbox_min_idx"]
                (maxx, maxy, maxz) = info["bbox_max_idx"]
                (cx, cy, cz) = info["centroid_phys"]

                if ENABLE_QC:
                    sag_png = qc_dir / f"{base}__v{k}_sag.png"
                    try:
                        save_sagittal_png(sag_png, comp_after, (cx, cy, cz), img,
                                          transpose=SAGITTAL_TRANSPOSE, dpi=QC_DPI)
                        print(f"  [QC] Saved sagittal: {sag_png}")
                    except Exception as e:
                        print(f"  [QC Error] {e}")
                        log_error(err_log, f"{base}__v{k} SAG", str(e), traceback.format_exc())

                rows.append({
                    "case_name": base, "instance_index": k, "voxels": info["voxels"],
                    "bbox_min_x": minx, "bbox_min_y": miny, "bbox_min_z": minz,
                    "bbox_max_x": maxx, "bbox_max_y": maxy, "bbox_max_z": maxz,
                    "centroid_x": round(cx, 3), "centroid_y": round(cy, 3), "centroid_z": round(cz, 3),
                    "out_image": "", "out_mask": str(out_msk),
                    "note": f"{'bridge' if ENABLE_BRIDGE_THICKNESS_SPLIT else 'watershed'}; fill={HOLE_FILL_MODE}{'; no_image_found' if img is None else ''}"
                })
                total_instances += 1

            total_cases += 1

        except Exception as e:
            print(f"  [Error] {e}")
            log_error(err_log, f"{base}{ext}", str(e), traceback.format_exc())
            rows.append({"case_name": base, "instance_index": "", "voxels": 0,
                         "bbox_min_x":"", "bbox_min_y":"", "bbox_min_z":"",
                         "bbox_max_x":"", "bbox_max_y":"", "bbox_max_z":"",
                         "centroid_x":"", "centroid_y":"", "centroid_z":"",
                         "out_image":"", "out_mask":"", "note": f"Error: {e}"})
            error_cases += 1

    with open(LOG_CSV, "w", newline="", encoding="utf-8-sig") as f:
        writer = csv.DictWriter(f, fieldnames=[
            "case_name","instance_index","voxels",
            "bbox_min_x","bbox_min_y","bbox_min_z",
            "bbox_max_x","bbox_max_y","bbox_max_z",
            "centroid_x","centroid_y","centroid_z",
            "out_image","out_mask","note"
        ])
        writer.writeheader(); writer.writerows(rows)

    print("\n" + "="*80)
    print("Done")
    print("="*80)
    print(f"  Successful cases : {total_cases}")
    print(f"  Skipped cases    : {skipped_cases}")
    print(f"  Error cases      : {error_cases}")
    print(f"  Total instances  : {total_instances}")
    print(f"  Log CSV          : {LOG_CSV}")
    print(f"  QC PNG dir       : {QC_OUTPUT_DIR}")
    if error_cases:
        print(f"  Error log        : {ERROR_LOG}")

if __name__ == "__main__":
    main()
