In [None]:
import os
import re
import numpy as np
import pandas as pd
import cv2
import imageio
import xml.etree.ElementTree as ET

import czifile
import tifffile as tiff

from skimage.feature import blob_log
from skimage.filters import gaussian, threshold_local, sobel
from skimage.morphology import (
    remove_small_objects, binary_opening, binary_closing, ball, binary_erosion
)
from scipy.ndimage import distance_transform_edt as edt
from skimage.measure import label, regionprops
from skimage.segmentation import watershed, find_boundaries
from scipy.ndimage import binary_fill_holes
import napari

# NEW imports for 3D textured surface
from scipy.ndimage import map_coordinates, gaussian_filter as ndi_gaussian_filter, gaussian_filter1d
from skimage.measure import marching_cubes

# ===============================
# USER CONFIGURATION
# ===============================
# Path to your file (.tif or .czi)
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/Airy scan_40A_UAS-TMEM-HA_CB_4h_1_051222.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/Airy scan_40A_UAS-TMEM-HA_CB_4h_2_051222.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/40A_UAS-TMEM1923x-HA x 71G10 40A MARCM_L3_1_Airy_010724.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/40A_UAS-TMEM1923x-HA x 71G10 40A MARCM_L3_2_Airy_010724.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/40A_UAS-TMEM192-3xHA x 40A 71G10 MARCM_around 12h - for quantification_4 Airy-CBs_300425.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/40A_UAS-TMEM192-3xHA x 40A 71G10 MARCM_around 12h - for quantification_3 Airy-CBs_300425-1.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/Airy scan_40A_UAS-TMEM-HA_CB_0h_1_051222.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/Airy scan_40A_UAS-TMEM-HA_CB_0h_2_051222.tif"

# ISP-MAPATZ
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/ISP MAPATZ 2025/40A_UAS-TMEM1923x-HA x 71G10 40A MARCM_L3_2_Airy_010724.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/ISP MAPATZ 2025/40A_UAS-TMEM192-3xHA x 40A 71G10 MARCM_around 12h - for quantification_3 Airy-CBs_300425.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/ISP MAPATZ 2025/TMEM-HA 6h CB airy_3_170722.tif"

# NEW
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/new/40A_UAS-TMEM1923x-HA x 71G10 40A MARCM_6h_1_airy_020724.tif"
file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/new/40A_UAS-TMEM192-3xHA x 40A 71G10 MARCM_around 12h - for quantification_5_Airy_CBs_040525.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/new/40A_UAS-TMEM192-3xHA x 40A 71G10 MARCM_around 12h - for quantification_6_Airy_CBs_060525.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/new/stock66 40A_UAS-TMEM192-3xHA x 71G10 40A MARCM_18h_Airy_1_281223.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/new/40A_UAS-TMEM192-3xHA 66 x 71G10-G4 MARCM_18h CB airyscan_1_050224.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/new/stock66 40A_UAS-TMEM192-3xHA x 71G10 40A MARCM_18h_Airy_4_281223.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/new/stock66 40A_UAS-TMEM192-3xHA x 71G10 40A MARCM_18h_Airy_5_281223.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/new/40A_UAS-TMEM1923x-HA x 71G10 40A MARCM_L3_4_airy_020724.tif"

# Correct microscope calibration (µm per pixel)
DEFAULT_VX_VY_UM = 0.04  # 0.04 before in OH-1,OH-2,6H-3,12H-3, L3 2
MAX_REASONABLE_VXY_UM = 0.5  # µm/px sanity limit

# --- classification controls for inside/outside ---
MARGIN_UM = 0.6        # soft band around neuron mask (µm)
OVERLAP_ALPHA = 0.3    # sphere-overlap fraction with neuron mask to count as inside
NEIGHBOR_MAX_VOX = 6   # search radius (voxels) for nearest nonzero cell label

# --- visualization-only filtering ---
VIZ_MIN_VOXELS = 20000  # hide cells smaller than this (visualization only)

# --- erosion boost while keeping automatic method ---
ERODE_MULT = 5         # multiply the auto-chosen erosion radius (clamped)

# ---------------------------
# Helper: refine radii by distance transform
# ---------------------------
def refine_radii_via_dt(img3d, blobs, win_px=25, bin_method="sauvola"):
    from skimage.filters import threshold_sauvola, threshold_local, threshold_otsu
    from skimage.morphology import remove_small_objects, disk, binary_opening
    from scipy.ndimage import distance_transform_edt as _edt
    from skimage.measure import label as _label

    if blobs is None or len(blobs) == 0:
        return blobs

    Z, H, W = img3d.shape
    out = blobs.copy().astype(np.float32)

    for i, (zc, yc, xc, _) in enumerate(out):
        z, y, x = int(round(zc)), int(round(yc)), int(round(xc))
        if not (0 <= z < Z and 0 <= y < H and 0 <= x < W):
            continue

        y1, y2 = max(0, y - win_px), min(H, y + win_px + 1)
        x1, x2 = max(0, x - win_px), min(W, x + win_px + 1)
        patch = img3d[z, y1:y2, x1:x2]

        if bin_method == "sauvola":
            ws = max(21, 2*(win_px//2)+1)
            thr = threshold_sauvola(patch, window_size=ws, k=0.2)
            bw = patch > thr
        elif bin_method == "local":
            ws = max(21, 2*(win_px//2)+1)
            thr = threshold_local(patch, block_size=ws, offset=-0.2*np.std(patch))
            bw = patch > thr
        else:
            try:
                thr = threshold_otsu(patch)
                bw = patch > thr
            except ValueError:
                continue

        bw = binary_opening(bw, footprint=disk(1))
        bw = remove_small_objects(bw, min_size=3, connectivity=2)

        yy, xx = y - y1, x - x1
        if not (0 <= yy < bw.shape[0] and 0 <= xx < bw.shape[1]) or not bw[yy, xx]:
            continue

        lab = _label(bw)
        lbl = lab[yy, xx]
        if lbl == 0:
            continue
        bw_obj = (lab == lbl)
        dt = _edt(bw_obj)
        r_px = float(dt[yy, xx])
        if r_px <= 0:
            continue
        out[i, 3] = r_px

    return out

# ---------------------------
# NEW: refine radii via radial intensity ("Gaussian bell", 50% height)
# ---------------------------
def refine_radii_via_radial_intensity(
    img3d,
    blobs,
    vx_um,
    vy_um,
    vz_um,
    max_radius_nm=400.0,
    dr_nm=10.0,
    min_drop_fraction=0.3
):
    """
    For each blob, estimate radius from the radial intensity profile:
    - Build I(r) in physical units (µm) up to max_radius_nm.
    - Smooth I(r) and find r where I(r) drops below 50% of max (first crossing).
    - Keep original radius if the profile is too flat / noisy.
    """
    if blobs is None or len(blobs) == 0:
        return blobs

    img = img3d.astype(np.float32)
    Z, Y, X = img.shape
    blobs_out = blobs.copy().astype(np.float32)

    # Convert nm -> µm
    max_r_um = max_radius_nm / 1000.0
    dr_um = dr_nm / 1000.0

    # radial bins in µm
    r_edges = np.arange(0.0, max_r_um + dr_um, dr_um)
    if r_edges.size < 2:
        return blobs_out
    r_centers = 0.5 * (r_edges[:-1] + r_edges[1:])  # ORIGINAL

    # for converting back to pixels in XY
    px_um_xy = float(np.sqrt(vx_um * vy_um))

    for i, (zc, yc, xc, r_px_init) in enumerate(blobs_out):
        z0 = int(round(zc))
        y0 = int(round(yc))
        x0 = int(round(xc))

        if not (0 <= z0 < Z and 0 <= y0 < Y and 0 <= x0 < X):
            continue

        # define local patch bounding the sphere of radius max_r_um
        rz = max(1, int(np.ceil(max_r_um / vz_um)))
        ry = max(1, int(np.ceil(max_r_um / vy_um)))
        rx = max(1, int(np.ceil(max_r_um / vx_um)))

        z1, z2 = max(0, z0 - rz), min(Z, z0 + rz + 1)
        y1, y2 = max(0, y0 - ry), min(Y, y0 + ry + 1)
        x1, x2 = max(0, x0 - rx), min(X, x0 + rx + 1)

        if z1 >= z2 or y1 >= y2 or x1 >= x2:
            continue

        patch = img[z1:z2, y1:y2, x1:x2]

        # build a radius map in physical units (µm)
        zz, yy, xx = np.mgrid[z1:z2, y1:y2, x1:x2]
        dz_um = (zz - z0) * vz_um
        dy_um = (yy - y0) * vy_um
        dx_um = (xx - x0) * vx_um
        r_um = np.sqrt(dz_um**2 + dy_um**2 + dx_um**2)

        # keep only voxels within max_r_um
        mask = (r_um <= max_r_um)
        if not np.any(mask):
            continue

        r_vals = r_um[mask].ravel()
        I_vals = patch[mask].ravel()

        # bin by radius
        bin_idx = np.digitize(r_vals, r_edges) - 1  # 0..len(r_edges)-2
        valid = (bin_idx >= 0) & (bin_idx < r_centers.size)
        if not np.any(valid):
            continue

        bin_idx = bin_idx[valid]
        I_vals = I_vals[valid]

        # average intensity per radial bin
        sums = np.bincount(bin_idx, weights=I_vals, minlength=r_centers.size)
        counts = np.bincount(bin_idx, minlength=r_centers.size)
        with np.errstate(invalid="ignore", divide="ignore"):
            prof = sums / np.maximum(counts, 1)

        # restrict to bins that actually got samples
        have = counts > 0
        if not np.any(have):
            continue

        r_prof = r_centers[have]
        I_prof = prof[have].astype(np.float32)

        # smooth the profile
        I_smooth = gaussian_filter1d(I_prof, sigma=1.0)

        I_max = float(I_smooth.max())
        if I_max <= 0:
            continue

        # make sure the profile actually drops significantly
        I_min = float(I_smooth.min())
        if (I_max - I_min) / max(I_max, 1e-9) < min_drop_fraction:
            continue

        # target = 50% height
        target = 0.5 * I_max

        # find first radius where intensity drops below 50% of peak
        below = np.where(I_smooth <= target)[0]
        if below.size == 0:
            continue

        j = below[0]
        r_half_um = float(r_prof[j])
        if r_half_um <= 0:
            continue

        # convert radius from µm back to pixels in XY plane
        r_half_px = r_half_um / max(px_um_xy, 1e-9)
        blobs_out[i, 3] = r_half_px

    return blobs_out

# ===============================
# AUTO MORPHOLOGY (no per-image tuning)
# ===============================
from scipy.ndimage import distance_transform_edt as _edt
from skimage.measure import label as _label, regionprops as _regionprops

def _equiv_radius_from_area(px_area):
    return float(np.sqrt(max(px_area, 1.0) / np.pi))

def _component_size_percentile(mask_bool, pct=0.2):
    lab = _label(mask_bool)
    sizes = [r.area for r in _regionprops(lab) if r.label != 0]
    if not sizes:
        return 0.0
    sizes = np.array(sorted(sizes))
    return float(np.percentile(sizes, pct * 100.0))

def auto_morphology_params(neuron_mask, vx_um, vy_um, vz_um,
                           p_open=0.8, p_close=0.6, p_erode=0.15,
                           min_r_open=1, min_r_close=1, max_r=12):
    din = _edt(neuron_mask)

    r_in_med = float(np.median(din[neuron_mask])) if np.any(neuron_mask) else 0.0
    small_px = _component_size_percentile(neuron_mask, pct=0.20)
    r_small = _equiv_radius_from_area(small_px)

    r_close = int(np.clip(round(p_close * r_in_med), min_r_close, max_r))
    r_open  = int(np.clip(round(p_open  * r_small),  min_r_open, max_r))
    r_erode = int(np.clip(round(p_erode * r_in_med), 0,            max_r))

    return max(r_open, 0), max(r_close, 0), max(r_erode, 0)

def apply_morphology_auto(neuron_mask, vx_um, vy_um, vz_um,
                          mode="dt", ch2_for_scoring=None,
                          area_stability=(0.85, 1.15)):
    r_open, r_close, r_erode = auto_morphology_params(neuron_mask, vx_um, vy_um, vz_um)

    # --- BOOST erosion automatically ---
    max_r = 12
    #---------------------------------------------
    #----------------------------------------------
    #---------------------------------------------------
    #r_erode = int(np.clip(round((r_erode * ERODE_MULT) + 8), 0, max_r))
    r_erode = int(np.clip(round(r_erode * ERODE_MULT), 0, max_r))
    #--------------------------------------------------
    #-----------------------------------------------------
    #-------------------------------------------------------------
    def _refine(mask, ro, rc, re):
        out = mask.copy()
        if ro > 0: out = binary_opening(out, ball(ro))
        if rc > 0: out = binary_closing(out, ball(rc))
        if re > 0: out = binary_erosion(out, ball(re))
        out = remove_small_objects(out, min_size=max(8, int(np.sum(out) * 1e-5)), connectivity=3)
        return out

    if mode == "dt" or ch2_for_scoring is None:
        refined = _refine(neuron_mask, r_open, r_close, r_erode)
        return refined, {'open': r_open, 'close': r_close, 'erode': r_erode}

    # grid search mode
    ch2 = ch2_for_scoring.astype(np.float32)
    ch2 = (ch2 - ch2.min()) / max(ch2.ptp(), 1e-9)
    edge = sobel(ch2)
    base_area = float(neuron_mask.sum())

    cand_ro = np.clip(np.array([r_open-1, r_open, r_open+1]), 0, max_r)
    cand_rc = np.clip(np.array([r_close-1, r_close, r_close+1]), 0, max_r)
    cand_re = np.clip(np.array([r_erode-1, r_erode, r_erode+1]), 0, max_r)

    best_score = -np.inf
    best_mask  = neuron_mask
    best_trip  = (r_open, r_close, r_erode)

    def _score(mask):
        bnd = find_boundaries(mask, mode='outer')
        align = float(edge[bnd].mean()) if np.any(bnd) else 0.0
        a = float(mask.sum())
        ratio = a / max(base_area, 1.0)
        stab = 0.0
        if ratio < area_stability[0] or ratio > area_stability[1]:
            stab = -abs(np.log(ratio))
        return 0.7 * align + 0.3 * (-abs(np.log(ratio)) + 0.1) + stab

    for ro in np.unique(cand_ro):
        for rc in np.unique(cand_rc):
            for re in np.unique(cand_re):
                m = _refine(neuron_mask, int(ro), int(rc), int(re))
                s = _score(m)
                if s > best_score:
                    best_score, best_mask, best_trip = s, m, (int(ro), int(rc), int(re))

    return best_mask, {'open': best_trip[0], 'close': best_trip[1], 'erode': best_trip[2]}

# ===============================
# Image loading and metadata
# ===============================
def _try_float(x):
    try:
        return float(x)
    except:
        return None

def _parse_ome_xml(xml_text):
    if not xml_text:
        return None
    def grab(attr):
        m = re.search(fr'PhysicalSize{attr}="([\d\.eE+-]+)"', xml_text)
        return float(m.group(1)) if m else None
    return grab("X"), grab("Y"), grab("Z")

def load_any(file_path):
    ext = os.path.splitext(file_path)[1].lower()
    
    if ext in (".tif", ".tiff"):
        with tiff.TiffFile(file_path) as tf:
            arr = tf.asarray()
            try:
                ome_xml = tf.ome_metadata
            except Exception:
                ome_xml = None
            vx_um = vy_um = vz_um = None
            if ome_xml:
                vx_um, vy_um, vz_um = _parse_ome_xml(ome_xml)
        img = np.squeeze(arr)
        if img.ndim == 4:
            if img.shape[0] == 2:        # (C, Z, Y, X)
                ch1, ch2 = img[0], img[1]
            elif img.shape[1] == 2:      # (Z, C, Y, X)
                ch1, ch2 = img[:, 0], img[:, 1]
            elif img.shape[-1] == 2:     # (Z, Y, X, C)
                ch1, ch2 = img[..., 0], img[..., 1]
            else:
                raise RuntimeError("Unexpected TIFF shape for 2 channels")
        else:
            raise RuntimeError("Unexpected TIFF shape")
        return ch1, ch2, (vx_um, vy_um, vz_um), {"type": "tiff"}

    elif ext == ".czi":
        # Minimal .czi support using czifile
        with czifile.CziFile(file_path) as cf:
            arr = cf.asarray()
            try:
                czi_xml = cf.metadata()  # XML string
            except Exception:
                czi_xml = None

        vx_um = vy_um = vz_um = None
        if czi_xml:
            vx_um, vy_um, vz_um = _parse_czi_scaling(czi_xml)

        img = np.squeeze(arr)
        if img.ndim == 4:
            if img.shape[0] == 2:        # (C, Z, Y, X)
                ch1, ch2 = img[0], img[1]
            elif img.shape[1] == 2:      # (Z, C, Y, X)
                ch1, ch2 = img[:, 0], img[:, 1]
            elif img.shape[-1] == 2:     # (Z, Y, X, C)
                ch1, ch2 = img[..., 0], img[..., 1]
            else:
                raise RuntimeError("Unexpected CZI shape for 2 channels")
        else:
            raise RuntimeError("Unexpected CZI shape")

        return ch1, ch2, (vx_um, vy_um, vz_um), {"type": "czi"}

    else:
        raise ValueError("Unsupported file format")

# (If you use .czi, make sure _parse_czi_scaling is defined somewhere above.)

# ==========================================
# Load image and scales
# ==========================================
img_ch1, img_ch2, (vx_um, vy_um, vz_um), meta = load_any(file_path)
if vx_um is None or vy_um is None:
    vx_um = vy_um = DEFAULT_VX_VY_UM
if vz_um is None:
    vz_um = vx_um

if vx_um > MAX_REASONABLE_VXY_UM:
    raise ValueError(f"XY pixel size too large: {vx_um} µm/px")

px_um_xy = float(np.sqrt(vx_um * vy_um))
voxel_um3 = vx_um * vy_um * vz_um
print(f"Voxel size (µm): X={vx_um}, Y={vy_um}, Z={vz_um}")

# ===== Aliases used throughout =====
image   = img_ch1       # lysosome channel (Ch1)
image_2 = img_ch2       # neuron / context channel (Ch2)

# ==========================================
# Lysosome detection and scaling
# ==========================================
image_smooth = gaussian(image, sigma=1.5)

blobs = blob_log(image_smooth, min_sigma=0.8, max_sigma=3.0, num_sigma=7, threshold=0.005)
if len(blobs) > 0:
    blobs[:, 3] *= np.sqrt(3)  # LoG radius correction (pixels)

    # First refine via distance transform
    blobs = refine_radii_via_dt(image_smooth, blobs)

    # Then refine radii via radial intensity ("Gaussian bell", 50% height)
    blobs = refine_radii_via_radial_intensity(
        image_smooth,
        blobs,
        vx_um,
        vy_um,
        vz_um,
        max_radius_nm=400.0,
        dr_nm=10.0,
        min_drop_fraction=0.3
    )

if len(blobs) > 0:
    z_um = blobs[:, 0] * vz_um
    y_um = blobs[:, 1] * vy_um
    x_um = blobs[:, 2] * vx_um
    radius_um = blobs[:, 3] * px_um_xy
    diameter_um = 2 * radius_um
    volume_um3 = (4/3) * np.pi * radius_um**3
    blob_ids = np.arange(1, len(blobs)+1, dtype=int)
else:
    z_um = y_um = x_um = radius_um = diameter_um = volume_um3 = np.array([])
    blob_ids = np.array([], dtype=int)

df = pd.DataFrame({
    "id": blob_ids,
    "z_um": z_um,
    "y_um": y_um,
    "x_um": x_um,
    "radius_um": radius_um,
    "diameter_um": diameter_um,
    "volume_um3": volume_um3,
})
df.to_csv("lysosome_blobs_regions.csv", index=False)
print("Saved: lysosome_blobs_regions.csv")

# === Optional: unique radii within ±5% and clamp to [0,0.4] µm ===
def _unique_radii_within_5pct(radius_series, low=0.0, high=0.4, max_frac=0.05):
    arr = radius_series.to_numpy().astype(float)
    uniq = np.empty_like(arr, dtype=float)
    used = set()
    for i, r in enumerate(arr):
        base = float(np.clip(r, low, high))
        dev = max(max_frac * max(abs(base), 1e-6), 1e-9)
        if base not in used:
            uniq[i] = base
            used.add(base)
            continue
        found = False
        step = dev / 20.0
        for k in range(1, 401):
            sgn = 1.0 if (k % 2 == 1) else -1.0
            cand = float(np.clip(base + sgn * min(dev, k * step), low, high))
            if abs(cand - base) <= dev and cand not in used:
                uniq[i] = cand
                used.add(cand)
                found = True
                break
        if not found:
            cand = float(np.clip(np.nextafter(base, high), low, high))
            uniq[i] = cand
            used.add(cand)
    return pd.Series(uniq, index=radius_series.index, name="radius_um")

if len(df) > 0:
    df_unique = df.copy()
    df_unique["radius_um"] = _unique_radii_within_5pct(df_unique["radius_um"], low=0.0, high=0.4, max_frac=0.05)
    df_unique["diameter_um"] = 2.0 * df_unique["radius_um"]
    df_unique["volume_um3"]  = (4.0/3.0) * np.pi * (df_unique["radius_um"] ** 3)
    df_unique.to_csv("lysosome_blobs_regions_unique_radius.csv", index=False)
    print("Saved: lysosome_blobs_regions_unique_radius.csv (all radii distinct, ±5% max change, clamped to [0.0, 0.4] µm)")
    df = df_unique.copy()

# ==========================================
# CH2 segmentation (CELL vs OUTSIDE)
# ==========================================
vol = image_2.astype(np.float32)
vmin, vmax = float(vol.min()), float(vol.max())
if vmax > vmin:
    vol = (vol - vmin) / (vmax - vmin)
else:
    vol[:] = 0.0
ch2 = gaussian(vol, sigma=1.5, preserve_range=True)

# Local threshold per z
neuron_mask = np.zeros_like(ch2, dtype=bool)
for z in range(ch2.shape[0]):
    R = ch2[z]
    t = threshold_local(R, block_size=301, offset=-0.2*np.std(R))
    neuron_mask[z] = R > t

# --- automatic morphology (choose mode: "dt" or "grid") ---
neuron_mask_auto, chosen = apply_morphology_auto(
    neuron_mask, vx_um, vy_um, vz_um, mode="dt"
)
print(f"[auto-morphology] radii -> open:{chosen['open']}  close:{chosen['close']}  erode:{chosen['erode']}")

# Continue downstream with the refined mask
neuron_mask = binary_fill_holes(neuron_mask_auto)

# Soma via distance + cleanup
dist = edt(neuron_mask)
cell_min_radius_vox = 1
cell_mask = (dist >= cell_min_radius_vox)
cell_mask &= neuron_mask
cell_mask = binary_fill_holes(binary_closing(binary_opening(cell_mask, ball(1)), ball(2)))

# Label and territories
body_lab = label(cell_mask, connectivity=3)
n_cells = int(body_lab.max())
print(f"Detected {n_cells} cells (soma).")

if n_cells > 0:
    dist_inside = edt(neuron_mask)
    cell_seg = watershed(-dist_inside, markers=body_lab, mask=neuron_mask)
else:
    cell_seg = np.zeros_like(neuron_mask, dtype=np.int32)

print("neuron voxels:", int(neuron_mask.sum()))
print("cell voxels:", int(cell_mask.sum()))

# ==========================================
# Visualization-only filtering (hide tiny cells) + serial IDs
# ==========================================
cell_seg_viz = cell_seg.copy()
cell_id_map_viz = {}  # mapping: original cell_id -> serial_id after filtering

if isinstance(cell_seg_viz, np.ndarray) and cell_seg_viz.max() > 0:
    # Remove tiny cells
    counts_viz = np.bincount(cell_seg_viz.ravel().astype(np.int64))
    tiny_labels = np.where(counts_viz < VIZ_MIN_VOXELS)[0]
    tiny_labels = tiny_labels[tiny_labels > 0]
    if tiny_labels.size > 0:
        tiny_mask = np.isin(cell_seg_viz, tiny_labels)
        cell_seg_viz[tiny_mask] = 0

    # Relabel remaining cells to serial IDs 1..N
    unique_labels = np.unique(cell_seg_viz)
    unique_labels = unique_labels[unique_labels > 0]

    if unique_labels.size > 0:
        new_seg = np.zeros_like(cell_seg_viz, dtype=np.int32)
        for new_id, old_id in enumerate(unique_labels, start=1):
            new_seg[cell_seg_viz == old_id] = new_id
            cell_id_map_viz[old_id] = new_id
        cell_seg_viz = new_seg

cell_mask_viz = (cell_seg_viz > 0)

# ==========================================
# Distance/overlap-aware classification helpers
# ==========================================
from scipy.ndimage import distance_transform_edt as _edt

dist_out_um = _edt(~neuron_mask, sampling=(vz_um, vy_um, vx_um)).astype(np.float32)
soft_cell_mask = neuron_mask | (dist_out_um <= MARGIN_UM)

def nearest_cell_label(z, y, x, max_r=NEIGHBOR_MAX_VOX):
    Z, Y, X = cell_seg.shape
    for r in range(1, max_r + 1):
        z1, z2 = max(0, z - r), min(Z, z + r + 1)
        y1, y2 = max(0, y - r), min(Y, y + r + 1)
        x1, x2 = max(0, x - r), min(X, x + r + 1)
        patch = cell_seg[z1:z2, y1:y2, x1:x2]
        lab = patch[patch > 0]
        if lab.size:
            return int(np.bincount(lab.ravel()).argmax())
    return 0

def sphere_overlap_fraction(zc_um, yc_um, xc_um, r_um, mask_bool):
    if r_um <= 0:
        return 0.0
    zc = int(round(zc_um / vz_um))
    yc = int(round(yc_um / vy_um))
    xc = int(round(xc_um / vx_um))

    rz = max(1, int(np.ceil(r_um / vz_um)))
    ry = max(1, int(np.ceil(r_um / vy_um)))
    rx = max(1, int(np.ceil(r_um / vx_um)))

    Z, Y, X = mask_bool.shape
    z1, z2 = max(0, zc - rz), min(Z, zc + rz + 1)
    y1, y2 = max(0, yc - ry), min(Y, yc + ry + 1)
    x1, x2 = max(0, xc - rx), min(X, xc + rx + 1)

    if z1 >= z2 or y1 >= y2 or x1 >= x2:
        return 0.0

    zz, yy, xx = np.mgrid[z1:z2, y1:y2, x1:x2]
    dz = (zz - zc) * vz_um
    dy = (yy - yc) * vy_um
    dx = (xx - xc) * vx_um
    sphere = (dz*dz + dy*dy + dx*dx) <= (r_um * r_um)

    if not np.any(sphere):
        return 0.0

    in_mask = mask_bool[z1:z2, y1:y2, x1:x2] & sphere
    return float(in_mask.sum()) / float(sphere.sum())

# ==========================================
# Map lysosomes to (cell/outside) with per-cell IDs (robust)
# ==========================================
location_ch2 = []
cell_id_list = []

if len(df) > 0:
    Z, Y, X = neuron_mask.shape
    for (zc_um, yc_um, xc_um, r_um) in df[["z_um","y_um","x_um","radius_um"]].to_numpy():
        zz = int(round(zc_um / vz_um))
        yy = int(round(yc_um / vy_um))
        xx = int(round(xc_um / vx_um))

        inside_hard = (0 <= zz < Z and 0 <= yy < Y and 0 <= xx < X and neuron_mask[zz, yy, xx])
        inside_soft = (0 <= zz < Z and 0 <= yy < Y and 0 <= xx < X and soft_cell_mask[zz, yy, xx])

        is_inside = bool(inside_hard)
        if not is_inside and inside_soft:
            is_inside = True
        if not is_inside:
            frac = sphere_overlap_fraction(zc_um, yc_um, xc_um, r_um, neuron_mask)
            if frac >= OVERLAP_ALPHA:
                is_inside = True

        if is_inside:
            cid = 0
            if 0 <= zz < Z and 0 <= yy < Y and 0 <= xx < X:
                cid = int(cell_seg[zz, yy, xx]) if cell_seg[zz, yy, xx] != 0 else nearest_cell_label(zz, yy, xx)
            location_ch2.append("cell")
            cell_id_list.append(cid)
        else:
            location_ch2.append("outside")
            cell_id_list.append(0)

    df["location_ch2"] = location_ch2
    df["cell_id_ch2"]  = cell_id_list

    # NEW: map original cell IDs to serial viz IDs (0 if filtered out / no cell)
    if 'cell_id_map_viz' in locals() and isinstance(cell_id_map_viz, dict):
        df["cell_id_ch2_viz"] = (
            df["cell_id_ch2"]
            .map(cell_id_map_viz)
            .fillna(0)
            .astype(int)
        )
    else:
        df["cell_id_ch2_viz"] = 0

    df.groupby("location_ch2").size().reset_index(name="count") \
      .to_csv("lysosome_counts_cell_vs_outside.csv", index=False)

    (df[df["location_ch2"] == "cell"]
        .groupby("cell_id_ch2").size()
        .reset_index(name="count")
        .to_csv("lysosome_counts_by_cell.csv", index=False))

    # --- Per-cell serial lysosome IDs (1..N per cell) ---
    df["lys_id_in_cell"] = 0  # 0 for outside / unknown
    mask_in = (df["location_ch2"] == "cell") & (df["cell_id_ch2"] > 0)

    # Deterministic sort so numbering is stable (z, y, x)
    order_cols = ["cell_id_ch2", "z_um", "y_um", "x_um"]
    df_sorted = df.loc[mask_in].sort_values(order_cols).copy()

    # Number 1..N inside each cell
    df.loc[df_sorted.index, "lys_id_in_cell"] = (
        df_sorted.groupby("cell_id_ch2").cumcount().to_numpy() + 1
    ).astype(int)

    # Export full per-lysosome table
    df.to_csv("lysosomes_with_cell_vs_outside.csv", index=False)

    # Summary: how many lysosomes per cell (N is just max serial)
    lys_serial_counts = (
        df[df["lys_id_in_cell"] > 0]
        .groupby("cell_id_ch2")["lys_id_in_cell"]
        .max()
        .reset_index()
        .rename(columns={"lys_id_in_cell": "lysosomes_in_cell"})
    )
    lys_serial_counts.to_csv("lysosome_counts_by_cell_serial.csv", index=False)
    print("Saved: lysosomes_with_cell_vs_outside.csv (now includes lys_id_in_cell)")
    print("Saved: lysosome_counts_by_cell_serial.csv")

print(
    "Saved: lysosome_counts_cell_vs_outside.csv, "
    "lysosome_counts_by_cell.csv, "
    "lysosomes_with_cell_vs_outside.csv, "
    "lysosome_counts_by_cell_serial.csv"
)

# Include zero-count cells
try:
    if isinstance(cell_seg, np.ndarray) and cell_seg.max() > 0:
        all_cells = pd.DataFrame({
            "cell_id_ch2": np.arange(1, int(cell_seg.max()) + 1, dtype=int)
        })

        if len(df) > 0 and "location_ch2" in df and "cell_id_ch2" in df:
            lys_counts_nonzero = (
                df[df["location_ch2"] == "cell"]
                .groupby("cell_id_ch2")
                .size()
                .reset_index(name="count")
            )
        else:
            lys_counts_nonzero = pd.DataFrame(columns=["cell_id_ch2", "count"])

        lys_counts_all = (
            all_cells.merge(lys_counts_nonzero, on="cell_id_ch2", how="left")
                     .fillna({"count": 0})
        )
        lys_counts_all["count"] = lys_counts_all["count"].astype(int)

        lys_counts_all.to_csv("lysosome_counts_by_cell.csv", index=False)
        print("Updated: lysosome_counts_by_cell.csv now includes cells with 0 lysosomes.")
except Exception as e:
    print("Could not expand lysosome_counts_by_cell with zero-count cells:", e)

# ==========================================
# Per-cell (Ch2) volumes (µm^3)
# ==========================================
cell_volume_df = pd.DataFrame(columns=["cell_id_ch2", "voxel_count", "volume_um3"])
if isinstance(cell_seg, np.ndarray) and cell_seg.max() > 0:
    counts = np.bincount(cell_seg.ravel().astype(np.int64))
    cell_ids = np.arange(1, counts.size, dtype=int)
    voxels = counts[1:].astype(np.int64)
    vol_um3 = voxels.astype(float) * voxel_um3

    cell_volume_df = pd.DataFrame({
        "cell_id_ch2": cell_ids,
        "voxel_count": voxels,
        "volume_um3": vol_um3
    })
    cell_volume_df.to_csv("cell_volumes_ch2.csv", index=False)
    print("Saved: cell_volumes_ch2.csv")

    try:
        if len(df) > 0 and "location_ch2" in df and "cell_id_ch2" in df:
            lys_counts = (df[df["location_ch2"] == "cell"]
                          .groupby("cell_id_ch2")
                          .size()
                          .reset_index(name="lysosome_count"))
            merged = (cell_volume_df
                      .merge(lys_counts, on="cell_id_ch2", how="left")
                      .fillna({"lysosome_count": 0}))
            merged.to_csv("cell_metrics_ch2.csv", index=False)
            print("Saved: cell_metrics_ch2.csv")
    except Exception as e:
        print("Merge with lysosome counts failed:", e)

# ==========================================
# Napari visualization layers (add AFTER segmentation exists)
# ==========================================
viewer = napari.Viewer()

# --- Hover-only labels helper (FIXED) ---
def attach_hover_tooltip(points_layer, viewer, template,
                         color='white', size=10, anchor='upper_left'):
    """
    Show a label only for the point under the cursor.
    """
    n = len(points_layer.data)
    strings = [''] * n
    points_layer.text = {'string': strings, 'color': color, 'size': size, 'anchor': anchor}
    state = {'last': None}

    def _on_move(layer, event):
        try:
            idx = layer.get_value(
                event.position,
                view_direction=viewer.camera.view_direction,
                dims_displayed=viewer.dims.displayed,
                world=True
            )
        except TypeError:
            idx = layer.get_value(
                event.position,
                view_direction=viewer.camera.view_direction,
                dims_displayed=viewer.dims.displayed
            )
        if isinstance(idx, (list, tuple)):
            idx = idx[0] if idx else None

        if idx == state['last']:
            return

        if state['last'] is not None and 0 <= state['last'] < n:
            strings[state['last']] = ''

        if idx is not None and 0 <= idx < n:
            props = {k: v[idx] for k, v in points_layer.properties.items()} if points_layer.properties else {}
            try:
                strings[idx] = template.format(**props)
            except Exception:
                strings[idx] = ''

        state['last'] = idx
        points_layer.text = {'string': strings, 'color': color, 'size': size, 'anchor': anchor}

    points_layer.mouse_move_callbacks.append(_on_move)

# base images
viewer.add_image(img_ch2, name="Ch2 raw")
viewer.add_image(img_ch1, name="Ch1 raw")

# cell mask and segmentation layers (VISUALIZATION-ONLY FILTER APPLIED)
cell_layer = viewer.add_labels(cell_mask_viz.astype(np.uint8), name='Cell (Ch2, viz)', opacity=0.35)
try:
    cell_layer.color = {1: (0.0, 1.0, 0.0, 1.0)}  # green
except Exception:
    pass
cell_layer.blending = 'translucent_no_depth'

try:
    cellid_layer = viewer.add_labels(
        cell_seg_viz.astype(np.uint16),
        name='Cell ID (Ch2, viz)',
        opacity=0.25
    )
    cellid_layer.blending = 'translucent_no_depth'

    boundaries = find_boundaries(cell_seg_viz, connectivity=1, mode='outer')
    viewer.add_image(
        boundaries.astype(np.uint8),
        name='Cell ID boundaries (viz)',
        blending='additive',
        contrast_limits=(0, 1),
        colormap='magenta',
        opacity=0.6
    )
except Exception:
    pass

# Lysosome points overlay (show ONLY lysosomes inside cells) + HOVER tooltips (2D)
if len(df) > 0 and "location_ch2" in df and "cell_id_ch2" in df:
    in_cell_mask = (df["location_ch2"].to_numpy() == "cell")
    if in_cell_mask.any():
        blobs_cell = np.stack([
            df.loc[in_cell_mask, "z_um"].to_numpy() / vz_um,
            df.loc[in_cell_mask, "y_um"].to_numpy() / vy_um,
            df.loc[in_cell_mask, "x_um"].to_numpy() / vx_um
        ], axis=1)
        radii_vox = (df.loc[in_cell_mask, "radius_um"].to_numpy() / (np.sqrt(vx_um * vy_um) + 1e-12))
        pts = viewer.add_points(
            blobs_cell,
            size=np.clip(radii_vox * 2, 2, None),
            name='Lysosomes (cell only)'
        )
        try:
            pts.face_color = [0.0, 1.0, 1.0, 1.0]       # cyan
            pts.edge_color = 'black'
            pts.edge_width = 0.3

            pts.properties = {
                'lys_in_cell': df.loc[in_cell_mask, 'lys_id_in_cell'].to_numpy(),
                'cell':        df.loc[in_cell_mask, 'cell_id_ch2_viz'].to_numpy(),
                'diameter_um': df.loc[in_cell_mask, 'diameter_um'].to_numpy()
            }
            attach_hover_tooltip(
                pts, viewer,
                template='C:{cell}  ID:{lys_in_cell}  d:{diameter_um:.3f}µm',
                color='yellow', size=10, anchor='upper_left'
            )
        except Exception:
            pass

# Show ALL lysosomes + HOVER tooltips
if len(df) > 0:
    try:
        blobs_all = np.stack([
            (df["z_um"].to_numpy() / vz_um),
            (df["y_um"].to_numpy() / vy_um),
            (df["x_um"].to_numpy() / vx_um)
        ], axis=1)
        radii_all_vox = df["radius_um"].to_numpy() / (np.sqrt(vx_um * vy_um) + 1e-12)
        pts_all = viewer.add_points(
            blobs_all,
            size=np.clip(radii_all_vox * 2, 2, None),
            name='Lysosomes (all)'
        )
        pts_all.face_color = [1.0, 1.0, 1.0, 1.0]  # white
        pts_all.edge_color = 'black'
        pts_all.edge_width = 0.3

        pts_all.properties = {
            'lys_in_cell': df.get('lys_id_in_cell', pd.Series(np.zeros(len(df), dtype=int))).to_numpy(),
            'cell':        df.get('cell_id_ch2_viz', pd.Series(np.zeros(len(df), dtype=int))).to_numpy(),
            'where':       df.get('location_ch2', pd.Series(['unknown']*len(df))).to_numpy(),
            'diameter_um': df['diameter_um'].to_numpy()
        }
        attach_hover_tooltip(
            pts_all, viewer,
            template='C:{cell}  ID:{lys_in_cell}  d:{diameter_um:.3f}µm  {where}',
            color='white', size=11, anchor='upper_left'
        )
    except Exception:
        pass

# ==========================================
# 3D textured surface of the cell + lysosomes in 3D
# ==========================================
def add_textured_cell_surface(viewer, cell_mask, texture_vol,
                              iso_level=0.5, step_size=2,
                              smooth_sigma_vox=1.0,
                              opacity=0.75, colormap='turbo'):
    mask_f = cell_mask.astype(np.float32)
    if smooth_sigma_vox and smooth_sigma_vox > 0:
        mask_f = ndi_gaussian_filter(mask_f, smooth_sigma_vox)

    verts, faces, normals, vals = marching_cubes(
        volume=mask_f,
        level=iso_level,
        spacing=(vz_um, vy_um, vx_um),
        step_size=step_size,
        allow_degenerate=False
    )

    z_vox = verts[:, 0] / vz_um
    y_vox = verts[:, 1] / vy_um
    x_vox = verts[:, 2] / vx_um
    tex = map_coordinates(texture_vol.astype(np.float32), [z_vox, y_vox, x_vox],
                          order=1, mode='nearest')

    tmin, tmax = float(tex.min()), float(tex.max())
    if tmax > tmin:
        tex = (tex - tmin) / (tmax - tmin)
    else:
        tex = np.zeros_like(tex, dtype=np.float32)

    surf = viewer.add_surface((verts, faces, tex), name='Cell surface (textured)')
    surf.colormap = colormap
    surf.shading  = 'smooth'
    surf.opacity  = opacity
    return surf

# ensure 3D
viewer.dims.ndisplay = 3

# soften Ch2 for nicer texture and add the textured surface (use viz mask)
ch2_tex = gaussian(image_2.astype(np.float32), sigma=1.0, preserve_range=True)
_ = add_textured_cell_surface(
    viewer,
    cell_mask=cell_mask_viz,
    texture_vol=ch2_tex,
    iso_level=0.5,
    step_size=2,
    smooth_sigma_vox=1.0,
    opacity=0.75,
    colormap='turbo'
)

# emphasize lysosomes INSIDE the cell in 3D (explicit 3D layer) + HOVER tooltips
if len(df) > 0:
    in_mask  = (df["location_ch2"].to_numpy() == "cell")
    if in_mask.any():
        pts_zyx_um = np.stack([
            df.loc[in_mask, "z_um"].to_numpy(),
            df.loc[in_mask, "y_um"].to_numpy(),
            df.loc[in_mask, "x_um"].to_numpy()
        ], axis=1)

        radii_vox = (df.loc[in_mask, "radius_um"].to_numpy() / (np.sqrt(vx_um * vy_um) + 1e-12))
        sizes_3d = np.clip(radii_vox * 2.5, 2.0, None)

        pts3d = viewer.add_points(
            np.stack([pts_zyx_um[:, 0] / vz_um,
                      pts_zyx_um[:, 1] / vy_um,
                      pts_zyx_um[:, 2] / vx_um], axis=1),
            name='Lysosomes (inside, 3D)',
            size=sizes_3d
        )
        try:
            pts3d.face_color = [1.0, 1.0, 0.0, 1.0]   # bright yellow
            pts3d.edge_color = 'black'
            pts3d.edge_width = 0.2
        except Exception:
            pass

        pts3d.properties = {
            'cell_id':     df.loc[in_mask, 'cell_id_ch2_viz'].to_numpy(),
            'lys_in_cell': df.loc[in_mask, 'lys_id_in_cell'].to_numpy(),
            'diameter_um': df.loc[in_mask, 'diameter_um'].to_numpy()
        }
        attach_hover_tooltip(
            pts3d, viewer,
            template='C:{cell_id}  ID:{lys_in_cell}  d:{diameter_um:.3f}µm',
            color='white', size=10, anchor='upper_left'
        )

# optional: set a nice camera pose
try:
    viewer.camera.zoom = 1.2
except Exception:
    pass

# ==========================================
# Optional fused 2D video (detect ALL lysosomes: inside + outside)
# Draw per-slice cross-sections for each 3D lysosome sphere
# Uses viz mask for the green overlay (small components hidden), but
# still draws ALL lysosomes regardless of location/classification.
# ==========================================
img_norm_2 = (ch2 * 255).astype(np.uint8)
frames_fused = []
Z = img_norm_2.shape[0]

px_um_xy = float(np.sqrt(vx_um * vy_um))  # XY voxel size

for z in range(Z):
    base = cv2.cvtColor(img_norm_2[z], cv2.COLOR_GRAY2BGR)

    # Use visualization mask for green overlay
    cell = (cell_mask_viz[z].astype(np.uint8) * 255)
    overlay = base.copy()
    overlay[..., 1] = np.maximum(overlay[..., 1], cell)
    overlay = cv2.addWeighted(base, 1.0, overlay, 0.35, 0.0)

    drew_any = False

    # draw from df: per-slice sphere cross-sections
    if 'df' in locals() and isinstance(df, pd.DataFrame) and len(df) > 0:
        dfv = df[
            np.isfinite(df["z_um"]) &
            np.isfinite(df["y_um"]) &
            np.isfinite(df["x_um"]) &
            np.isfinite(df["radius_um"])
        ]
        if not dfv.empty:
            zc = (dfv["z_um"].to_numpy() / vz_um).astype(float)
            yc = (dfv["y_um"].to_numpy() / vy_um).astype(float)
            xc = (dfv["x_um"].to_numpy() / vx_um).astype(float)
            r_um = dfv["radius_um"].to_numpy().astype(float)

            dz_vox = np.abs(zc - z)
            dz_um = dz_vox * vz_um
            hits = dz_um <= r_um

            if np.any(hits):
                r_proj_um = np.sqrt(np.clip(r_um[hits]**2 - dz_um[hits]**2, 0.0, None))
                r_proj_vox = r_proj_um / max(px_um_xy, 1e-12)

                ys = np.rint(yc[hits]).astype(int)
                xs = np.rint(xc[hits]).astype(int)

                H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
                min_radius_px = 3
                thickness = 2

                for y, x, rpv in zip(ys, xs, r_proj_vox):
                    rr = int(max(min_radius_px, round(rpv)))
                    if 0 <= y < H and 0 <= x < W and rr > 0:
                        cv2.circle(overlay, (x, y), rr, (0, 0, 0),
                                   thickness + 2, lineType=cv2.LINE_AA)
                        cv2.circle(overlay, (x, y), rr, (255, 255, 0),
                                   thickness, lineType=cv2.LINE_AA)
                drew_any = True

    # fallback: LoG detections
    if (not drew_any) and ('blobs' in locals()) and (blobs is not None) and (len(blobs) > 0):
        z_blobs = blobs[np.abs(blobs[:, 0] - z) < 0.5]
        H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
        min_radius_px = 3
        thickness = 2
        for b in z_blobs:
            y, x = int(round(b[1])), int(round(b[2]))
            r = int(max(min_radius_px, round(b[3])))
            if 0 <= y < H and 0 <= x < W and r > 0:
                cv2.circle(overlay, (x, y), r, (0, 0, 0),
                           thickness + 2, lineType=cv2.LINE_AA)
                cv2.circle(overlay, (x, y), r, (255, 255, 0),
                           thickness, lineType=cv2.LINE_AA)

    cv2.putText(overlay, "FUSED (all lysosomes + viz mask)", (10, 22),
                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)

    frames_fused.append(overlay)

# Save video (fallback to GIF if mp4 encoder unavailable)
try:
    imageio.mimsave('ch2_fused_cell.mp4', frames_fused, fps=8, format='FFMPEG')
    print("Saved: ch2_fused_cell.mp4")
except TypeError:
    imageio.mimsave('ch2_fused_cell.gif', frames_fused, fps=8)
    print("Saved: ch2_fused_cell.gif")

# ==========================================
# Optional fused 2D video (RAW uses ALL image: Ch1+Ch2) with perfect sync
# ==========================================
def _norm_u8_stack(vol):
    vmin, vmax = float(vol.min()), float(vol.max())
    if vmax > vmin:
        out = (np.clip((vol - vmin) / (vmax - vmin), 0, 1) * 255.0).astype(np.uint8)
    else:
        out = np.zeros_like(vol, dtype=np.uint8)
    return out

# Per-volume normalization
ch1_u8 = _norm_u8_stack(img_ch1.astype(np.float32))
ch2_u8 = _norm_u8_stack(ch2.astype(np.float32) * 255.0 / max(1.0, ch2.max()))

frames_raw = []
frames_fused_all = []
frames_side_by_side = []

Z = ch2_u8.shape[0]
px_um_xy = float(np.sqrt(vx_um * vy_um))

for z in range(Z):
    # RAW frame (both channels): BGR composition (Ch1->magenta, Ch2->green)
    b = ch1_u8[z]
    g = ch2_u8[z]
    r = ch1_u8[z]
    base = np.dstack([b, g, r])

    # FUSED overlay: use viz mask for context
    cell = (cell_mask_viz[z].astype(np.uint8) * 255)
    overlay = base.copy()
    overlay[..., 1] = np.maximum(overlay[..., 1], cell)
    overlay = cv2.addWeighted(base, 1.0, overlay, 0.35, 0.0)

    drew_any = False

    # Draw ALL lysosomes from df as per-slice cross-sections
    if 'df' in locals() and isinstance(df, pd.DataFrame) and len(df) > 0:
        dfv = df[
            np.isfinite(df["z_um"]) &
            np.isfinite(df["y_um"]) &
            np.isfinite(df["x_um"]) &
            np.isfinite(df["radius_um"])
        ]
        if not dfv.empty:
            zc = (dfv["z_um"].to_numpy() / vz_um).astype(float)
            yc = (dfv["y_um"].to_numpy() / vy_um).astype(float)
            xc = (dfv["x_um"].to_numpy() / vx_um).astype(float)
            r_um = dfv["radius_um"].to_numpy().astype(float)

            dz_vox = np.abs(zc - z)
            dz_um = dz_vox * vz_um
            hits = dz_um <= r_um

            if np.any(hits):
                r_proj_um = np.sqrt(np.clip(r_um[hits]**2 - dz_um[hits]**2, 0.0, None))
                r_proj_vox = r_proj_um / max(px_um_xy, 1e-12)

                ys = np.rint(yc[hits]).astype(int)
                xs = np.rint(xc[hits]).astype(int)

                H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
                min_radius_px = 3
                thickness = 2

                for y, x, rpv in zip(ys, xs, r_proj_vox):
                    rr = int(max(min_radius_px, round(rpv)))
                    if 0 <= y < H and 0 <= x < W and rr > 0:
                        cv2.circle(overlay, (x, y), rr, (0, 0, 0),
                                   thickness + 2, lineType=cv2.LINE_AA)
                        cv2.circle(overlay, (x, y), rr, (255, 255, 0),
                                   thickness, lineType=cv2.LINE_AA)
                drew_any = True

    # Fallback: LoG detections
    if (not drew_any) and ('blobs' in locals()) and (blobs is not None) and (len(blobs) > 0):
        z_blobs = blobs[np.abs(blobs[:, 0] - z) < 0.5]
        H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
        min_radius_px = 3
        thickness = 2
        for b_ in z_blobs:
            y, x = int(round(b_[1])), int(round(b_[2]))
            rpx = int(max(min_radius_px, round(b_[3])))
            if 0 <= y < H and 0 <= x < W and rpx > 0:
                cv2.circle(overlay, (x, y), rpx, (0, 0, 0),
                           thickness + 2, lineType=cv2.LINE_AA)
                cv2.circle(overlay, (x, y), rpx, (255, 255, 0),
                           thickness, lineType=cv2.LINE_AA)

    cv2.putText(base, "RAW (Ch1+Ch2)", (10, 22),
                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.putText(overlay, "FUSED (all lysosomes + viz mask)", (10, 22),
                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)

    frames_raw.append(base)
    frames_fused_all.append(overlay)
    frames_side_by_side.append(cv2.hconcat([base, overlay]))

# --- Consistent save with identical timing for all videos ---
FPS = 8

def _save_video_sync(basename, frames):
    try:
        with imageio.get_writer(f"{basename}.mp4",
                                fps=FPS,
                                format="FFMPEG",
                                codec="libx264",
                                macro_block_size=None) as w:
            for fr in frames:
                w.append_data(fr)
        print(f"Saved: {basename}.mp4 @ {FPS} fps")
    except Exception:
        imageio.mimsave(f"{basename}.gif", frames, duration=1.0/FPS)
        print(f"Saved: {basename}.gif @ {FPS} fps equivalent")

_save_video_sync('ch2_fused_all_viz', frames_fused_all)
_save_video_sync('ch2_raw', frames_raw)
_save_video_sync('ch2_raw_and_fused_all_viz', frames_side_by_side)

# ==========================================
# Run viewer
# ==========================================
napari.run()

  "cipher": algorithms.TripleDES,
  "class": algorithms.Blowfish,
  "class": algorithms.TripleDES,


Voxel size (µm): X=0.04, Y=0.04, Z=0.04
