In [None]:
import os, glob, csv, math, warnings, re
import numpy as np
import SimpleITK as sitk
import cv2

try:
    from PIL import Image
    PIL_AVAILABLE = True
except Exception:
    PIL_AVAILABLE = False

IMG_DIR  = r"\image"
MASK_DIR = r"\splitmasks"
OUT_DIR  = r"\JPG_output"

RESAMPLE_MODE = "mpr_hq"

LAYER_OFFSETS = [0.0, +3.0, -3.0, +6.0, -6.0]

BASE_FOV_PX       = 900
CANVAS_SCALE_XY   = 3.0
PADDING_PIX       = 5
MIN_INPLANE_SP    = 0.25
MIN_THICKNESS_SP  = 0.25

USE_MPR_SUPERRES        = True
INPLANE_SUPERRES_FACTOR = 3.0

ORIENT_SUPERIOR_AT_TOP = True
ORIENT_ANTERIOR_AT_TOP = True

FINAL_RESIZE = None

APPLY_CLAHE     = True
CLAHE_CLIP      = 2.0
CLAHE_TILE      = (8, 8)

APPLY_UNSHARP   = True
UNSHARP1_RADIUS = 0.8
UNSHARP1_AMOUNT = 0.35
UNSHARP2_RADIUS = 1.6
UNSHARP2_AMOUNT = 0.15

SAVE_JPG_RGB24 = True
JPEG_QUALITY   = 100
SAVE_MASK_PNG  = False

HU_WIN_WIDTH = 1500
HU_WIN_LEVEL = 400

def ensure_dir(p): os.makedirs(p, exist_ok=True)
def read_img(p): return sitk.ReadImage(p)

def nii_stem(path: str) -> str:
    b = os.path.basename(path)
    if b.lower().endswith(".nii.gz"):
        return b[:-7]
    if b.lower().endswith(".nii"):
        return b[:-4]
    return os.path.splitext(b)[0]

def norm_key(s: str) -> str:
    return re.sub(r'[^0-9a-zA-Z]+', '', s).lower()

def mask_base_from_stem(mask_stem: str) -> str:
    return re.sub(r'_[vV]\d+$', '', mask_stem)

def edt_center(mask_sitk):
    m = sitk.Cast(mask_sitk>0, sitk.sitkUInt8)
    if int(sitk.GetArrayFromImage(m).sum()) == 0:
        return None
    dist = sitk.SignedMaurerDistanceMap(m, insideIsPositive=True, squaredDistance=False, useImageSpacing=True)
    arr  = sitk.GetArrayFromImage(dist)
    iz, iy, ix = np.unravel_index(int(np.argmax(arr)), arr.shape)
    idx = (int(ix), int(iy), int(iz))
    return np.array(mask_sitk.TransformIndexToPhysicalPoint(idx), float)

def dircols(img):
    D = np.array(img.GetDirection(), float).reshape(3,3)
    def n(v):
        v=np.array(v,float); s=np.linalg.norm(v)
        return v if s<1e-12 else v/s
    return n(D[:,0]), n(D[:,1]), n(D[:,2]), D

def spacing_along(u, spacing, D):
    sx, sy, sz = spacing
    di, dj, dk = D[:,0], D[:,1], D[:,2]
    return math.sqrt((np.dot(u,di)*sx)**2 + (np.dot(u,dj)*sy)**2 + (np.dot(u,dk)*sz)**2)

def make_ref(center, row, col, sp_xy, thick, w=900, h=900):
    r = row/ (np.linalg.norm(row)+1e-12)
    c = col/ (np.linalg.norm(col)+1e-12)
    n = np.cross(c, r); n = n/(np.linalg.norm(n)+1e-12)
    direction = np.array([c, r, n]).T.flatten().tolist()
    ref = sitk.Image(int(w), int(h), 1, sitk.sitkFloat32)
    ref.SetSpacing((sp_xy, sp_xy, thick))
    ref.SetDirection(direction)
    fovx = (w-1)*sp_xy; fovy=(h-1)*sp_xy
    origin = center - 0.5*fovx*c - 0.5*fovy*r
    ref.SetOrigin(tuple(origin.tolist()))
    return ref

def resample(mov, ref, is_label=False, mode="mpr_nearest"):
    rs = sitk.ResampleImageFilter()
    rs.SetReferenceImage(ref)
    rs.SetTransform(sitk.Transform())
    if is_label:
        rs.SetInterpolator(sitk.sitkNearestNeighbor)
        rs.SetDefaultPixelValue(0)
    else:
        if mode == "mpr_nearest":
            rs.SetInterpolator(sitk.sitkNearestNeighbor)
        elif mode == "mpr_hq":
            rs.SetInterpolator(sitk.sitkLanczosWindowedSinc)
        else:
            rs.SetInterpolator(sitk.sitkNearestNeighbor)
        rs.SetDefaultPixelValue(-1024)
    return rs.Execute(mov)

def win_to_u8(arr, ww=1500, wl=400):
    lo, hi = wl-ww/2.0, wl+ww/2.0
    arr = np.clip(arr, lo, hi)
    return ((arr - lo)/(hi-lo+1e-8)*255.0).astype(np.uint8)

def clahe_enhance(gray8, clip=2.0, tile=(8,8)):
    clahe = cv2.createCLAHE(clipLimit=float(clip), tileGridSize=tuple(tile))
    return clahe.apply(gray8)

def _gauss_safe_sigma(r): return max(0.3, float(r))

def unsharp(img_u8, radius=1.0, amount=0.25):
    if amount <= 0: return img_u8
    f = img_u8.astype(np.float32)
    blur = cv2.GaussianBlur(f, (0,0), sigmaX=_gauss_safe_sigma(radius), sigmaY=_gauss_safe_sigma(radius))
    out  = cv2.addWeighted(f, 1+amount, blur, -amount, 0)
    return np.clip(out, 0, 255).astype(np.uint8)

def enhance_pipeline(img8):
    if APPLY_CLAHE:
        img8 = clahe_enhance(img8, CLAHE_CLIP, CLAHE_TILE)
    if APPLY_UNSHARP:
        img8 = unsharp(img8, UNSHARP1_RADIUS, UNSHARP1_AMOUNT)
        img8 = unsharp(img8, UNSHARP2_RADIUS, UNSHARP2_AMOUNT)
    return img8

def to_rgb3_u8(gray2d_u8):
    g = np.squeeze(gray2d_u8)
    if g.ndim != 2:
        raise ValueError(f"to_rgb3_u8 expects 2D, got {g.shape}")
    if g.dtype != np.uint8:
        g = np.clip(g, 0, 255).astype(np.uint8)
    rgb = np.stack([g, g, g], axis=-1)
    return np.ascontiguousarray(rgb)

def save_jpg_rgb(rgb_u8, path, quality=100):
    if PIL_AVAILABLE:
        try:
            Image.fromarray(rgb_u8, mode="RGB").save(path, format='JPEG',
                                                    quality=int(quality),
                                                    subsampling=0, optimize=True)
            return
        except Exception as e:
            warnings.warn(f"PIL JPG save failed {path}: {e}")
    ok, buf = cv2.imencode('.jpg', np.ascontiguousarray(rgb_u8),
                           [cv2.IMWRITE_JPEG_QUALITY, int(quality)])
    if ok: buf.tofile(path)

def save_mask_gray_png(gray, path):
    g = np.squeeze(gray)
    if g.ndim != 2: return
    if g.dtype != np.uint8:
        g = np.clip(g, 0, 255).astype(np.uint8)
    ok, buf = cv2.imencode('.png', np.ascontiguousarray(g))
    if ok: buf.tofile(path)

def same_space(a: sitk.Image, b: sitk.Image) -> bool:
    return (np.allclose(a.GetSpacing(),   b.GetSpacing()) and
            np.allclose(a.GetOrigin(),    b.GetOrigin())  and
            np.allclose(a.GetDirection(), b.GetDirection()))

def resample_like(moving: sitk.Image, ref: sitk.Image, is_label: bool) -> sitk.Image:
    rs = sitk.ResampleImageFilter()
    rs.SetReferenceImage(ref)
    rs.SetTransform(sitk.Transform())
    rs.SetInterpolator(sitk.sitkNearestNeighbor if is_label else sitk.sitkLinear)
    rs.SetDefaultPixelValue(0 if is_label else -1024)
    return rs.Execute(moving)

def mm_tag(d_mm: float) -> str:
    v = int(round(abs(d_mm)))
    if v == 0: return "0mm"
    return ("p{}mm".format(v) if d_mm > 0 else "m{}mm".format(v))

def _export_block(a_ct, a_ms, out_dir, mask_stem, plane_tag, offset_tag):
    if a_ms.max() == 0:
        return False
    rows = np.any(a_ms>0, axis=1)
    cols = np.any(a_ms>0, axis=0)
    rmin, rmax = np.where(rows)[0][[0,-1]]
    cmin, cmax = np.where(cols)[0][[0,-1]]
    rmin=max(0,rmin-PADDING_PIX); cmin=max(0,cmin-PADDING_PIX)
    rmax=min(a_ct.shape[0]-1,rmax+PADDING_PIX); cmax=min(a_ct.shape[1]-1,cmax+PADDING_PIX)
    crop = a_ct[rmin:rmax+1, cmin:cmax+1]

    img8 = win_to_u8(crop, HU_WIN_WIDTH, HU_WIN_LEVEL)
    img8 = enhance_pipeline(img8)

    if FINAL_RESIZE is not None:
        img8 = cv2.resize(img8, (int(FINAL_RESIZE[0]), int(FINAL_RESIZE[1])), interpolation=cv2.INTER_LANCZOS4)

    rgb = to_rgb3_u8(img8)
    if SAVE_JPG_RGB24:
        save_jpg_rgb(rgb, os.path.join(out_dir, f"{mask_stem}_{plane_tag}_{offset_tag}_rgb.jpg"),
                     quality=JPEG_QUALITY)

    if SAVE_MASK_PNG:
        m8 = (a_ms[rmin:rmax+1, cmin:cmax+1] > 0).astype(np.uint8) * 255
        if FINAL_RESIZE is not None:
            m8 = cv2.resize(m8, (int(FINAL_RESIZE[0]), int(FINAL_RESIZE[1])), interpolation=cv2.INTER_NEAREST)
        save_mask_gray_png(m8, os.path.join(out_dir, f"{mask_stem}_{plane_tag}_{offset_tag}_mask.png"))
    return True

def process_one_ct_mask(ct_path, mask_path, out_case_dir, mask_stem, csv_rows):
    ensure_dir(out_case_dir)

    ct = read_img(ct_path)
    ms = read_img(mask_path)
    if not same_space(ms, ct):
        ms = resample_like(ms, ct, is_label=True)

    center = edt_center(ms)
    if center is None:
        print(f"[WARN] {mask_stem}: mask is all zeros, skipped")
        return

    i_dir, j_dir, k_dir, D = dircols(ct)
    spacing = np.array(ct.GetSpacing(), float)

    def inplane_sp(row_vec, col_vec):
        base_row = spacing_along(row_vec, spacing, D)
        base_col = spacing_along(col_vec, spacing, D)
        native = min(base_row, base_col)
        sp_xy = native / INPLANE_SUPERRES_FACTOR if USE_MPR_SUPERRES else native
        return max(sp_xy, MIN_INPLANE_SP)

    W = int(BASE_FOV_PX * CANVAS_SCALE_XY)
    H = int(BASE_FOV_PX * CANVAS_SCALE_XY)

    if RESAMPLE_MODE == "axial_only":
        row = (-j_dir) if ORIENT_ANTERIOR_AT_TOP else j_dir
        col = i_dir
        norm = k_dir
        sp_xy = inplane_sp(row, col)
        thk   = max(spacing_along(norm, spacing, D), MIN_THICKNESS_SP)
        for d_mm in LAYER_OFFSETS:
            ref  = make_ref(center + d_mm*norm, row, col, sp_xy, thk, W, H)
            ct2d = resample(ct, ref, is_label=False, mode="mpr_nearest")
            ms2d = resample(ms, ref, is_label=True,  mode="mpr_nearest")
            a_ct = sitk.GetArrayFromImage(ct2d)[0]
            a_ms = sitk.GetArrayFromImage(ms2d)[0].astype(np.uint8)
            tag  = mm_tag(d_mm)
            if _export_block(a_ct, a_ms, out_case_dir, mask_stem, "ax", tag):
                csv_rows.append([nii_stem(ct_path), ct_path, mask_path, mask_stem,
                                 "ax", f"{d_mm/(thk if thk>0 else 1.0):.4f}", f"{d_mm:.4f}",
                                 os.path.join(out_case_dir, f"{mask_stem}_ax_{tag}"), RESAMPLE_MODE])
        return

    planes = {
        "sag": ((-k_dir if ORIENT_SUPERIOR_AT_TOP else k_dir),  j_dir,  i_dir),
        "cor": ((-k_dir if ORIENT_SUPERIOR_AT_TOP else k_dir),  i_dir,  j_dir),
    }

    for name, (row, col, norm) in planes.items():
        sp_xy = inplane_sp(row, col)
        thk   = max(spacing_along(norm, spacing, D), MIN_THICKNESS_SP)
        for d_mm in LAYER_OFFSETS:
            ref  = make_ref(center + d_mm*norm, row, col, sp_xy, thk, W, H)
            ct2d = resample(ct, ref, is_label=False, mode=RESAMPLE_MODE)
            ms2d = resample(ms, ref, is_label=True,  mode=RESAMPLE_MODE)
            a_ct = sitk.GetArrayFromImage(ct2d)[0]
            a_ms = sitk.GetArrayFromImage(ms2d)[0].astype(np.uint8)
            tag  = mm_tag(d_mm)
            if _export_block(a_ct, a_ms, out_case_dir, mask_stem, name, tag):
                csv_rows.append([nii_stem(ct_path), ct_path, mask_path, mask_stem,
                                 name, f"{d_mm/(thk if thk>0 else 1.0):.4f}", f"{d_mm:.4f}",
                                 os.path.join(out_case_dir, f"{mask_stem}_{name}_{tag}"), RESAMPLE_MODE])

def build_ct_map(img_dir):
    exts = ("*.nii", "*.nii.gz")
    paths = []
    for e in exts:
        paths += glob.glob(os.path.join(img_dir, e))
    ct_map = {}
    for p in paths:
        stem = nii_stem(p)
        key  = norm_key(stem)
        if key in ct_map:
            print(f"[WARN] CT prefix conflict (multiple CTs share the same normalized key): {stem} | existing: {ct_map[key][1]}")
            continue
        ct_map[key] = (p, stem)
    return ct_map

def iter_masks(mask_dir):
    paths = []
    for e in ("*.nii", "*.nii.gz"):
        paths += glob.glob(os.path.join(mask_dir, e))
    for p in sorted(paths):
        stem = nii_stem(p)
        if re.search(r'_[vV]\d+$', stem) is None:
            pass
        yield p, stem

def main():
    if not os.path.isdir(IMG_DIR):
        print(f"[ERROR] Image directory does not exist: {IMG_DIR}")
        return
    if not os.path.isdir(MASK_DIR):
        print(f"[ERROR] Mask directory does not exist: {MASK_DIR}")
        return
    ensure_dir(OUT_DIR)

    index_csv = os.path.join(OUT_DIR, "index_generated_views.csv")
    rows = [["case_id","img_path","mask_path","mask_stem","plane",
             "offset_layers","offset_mm","out_stem","mode"]]

    ct_map = build_ct_map(IMG_DIR)
    if not ct_map:
        print(f"[ERROR] No .nii/.nii.gz images found in {IMG_DIR}")
        return
    print(f"[INFO] Number of indexed CTs: {len(ct_map)}")

    total_masks = 0
    matched_masks = 0
    for mask_path, mask_stem in iter_masks(MASK_DIR):
        total_masks += 1
        base = mask_base_from_stem(mask_stem)
        key  = norm_key(base)
        if key not in ct_map:
            print(f"[WARN] No matching CT found for mask: mask={mask_stem}  base={base}")
            continue
        ct_path, ct_stem = ct_map[key]
        matched_masks += 1

        out_case_dir = os.path.join(OUT_DIR, ct_stem)
        print(f"[INFO] {ct_stem} <- {mask_stem}   mode={RESAMPLE_MODE}  offsets_mm={LAYER_OFFSETS}  "
              f"canvas×={CANVAS_SCALE_XY}  superres×={INPLANE_SUPERRES_FACTOR}  superior_on_top={ORIENT_SUPERIOR_AT_TOP}")
        try:
            process_one_ct_mask(ct_path, mask_path, out_case_dir, mask_stem, rows)
        except Exception as e:
            print(f"[ERROR] Processing failed: CT={ct_stem}  MASK={mask_stem}  Reason: {e}")

    with open(index_csv, "w", newline="", encoding="utf-8-sig") as f:
        csv.writer(f).writerows(rows)
    print(f"[DONE] Finished: found {total_masks} masks, matched {matched_masks}; index saved: {index_csv}")

if __name__ == "__main__":
    main()
