In [None]:
import os
import re
import numpy as np
import pandas as pd
import cv2
import imageio
import xml.etree.ElementTree as ET
from datetime import datetime

import czifile
import tifffile as tiff

from skimage.feature import blob_log
from skimage.filters import gaussian, threshold_local
from skimage.morphology import (
    remove_small_objects, binary_opening, binary_closing, ball, binary_erosion
)
from scipy.ndimage import distance_transform_edt as edt
from skimage.measure import label
from skimage.segmentation import watershed
from scipy.ndimage import binary_fill_holes

import napari
from scipy.ndimage import gaussian_filter1d
import colorsys

# ===============================
# USER MODE
# ===============================
# For neurites-only datasets (no soma), set True (recommended for your new image).
# For soma datasets, set False to use the original soma/watershed logic.
NEURITE_MODE = True

# Export full-size overlay RGB stack + mp4 (recommended)
EXPORT_FULLSIZE_OVERLAY = True

# OpenCV colors are BGR (not RGB)
LYS_EDGE_BGR = (0, 0, 0)          # black outline
LYS_MAGENTA_BGR = (255, 0, 255)   # magenta

# ===============================
# GUI (single unified interface)
# Advanced settings ONLY:
#   - MARGIN_UM (µm)
#   - OVERLAP_ALPHA (0..1)
#   - NEIGHBOR_MAX_VOX (voxels)
#   - VIZ_MIN_VOXELS (voxels)
# ===============================
import tkinter as tk
from tkinter import ttk, filedialog, messagebox


def get_user_config_gui(
    # basic defaults
    default_vxy_um=0.04,
    default_vz_um=None,          # if None and Z missing -> fallback to XY
    default_erode_mult=1.0,
    default_blob_threshold=0.001,

    # advanced defaults (ONLY the 4 you want)
    default_margin_um=0.5,          # µm (neurite-friendly default)
    default_overlap_alpha=0.4,      # unitless (0..1)
    default_neighbor_max_vox=6,     # voxels
    default_viz_min_voxels=200,     # voxels (neurite-friendly default)

    # fixed defaults (not shown in GUI)
    default_max_reasonable_vxy_um=0.5,
    default_ch1_smooth_sigma=1.0,
    default_blob_min_sigma=0.8,
    default_blob_max_sigma=3.0,
    default_blob_num_sigma=10,
    default_radial_max_radius_nm=300.0,
    default_radial_dr_nm=10.0,
    default_radial_min_drop_fraction=0.5,

    # neurite-friendly threshold defaults (fixed, not shown)
    default_ch2_smooth_sigma=0.9,
    default_thresh_block_size=151,
    default_thresh_offset_std_mult=0.25,

    default_video_fps=8,
    default_launch_viewer=True,
    default_generate_videos=True,
):
    cfg = {"ok": False}

    root = tk.Tk()
    root.title("Lysosome + Neurite Segmentation (GUI)")
    root.resizable(False, False)

    file_var = tk.StringVar(value="")
    out_var = tk.StringVar(value="")

    erode_var = tk.StringVar(value=str(default_erode_mult))
    blob_var  = tk.StringVar(value=str(default_blob_threshold))

    vxy_override_var = tk.StringVar(value="")  # blank = use metadata/default
    vz_override_var  = tk.StringVar(value="")  # blank = use metadata/default

    show_adv = tk.BooleanVar(value=False)

    margin_var   = tk.StringVar(value=str(default_margin_um))
    overlap_var  = tk.StringVar(value=str(default_overlap_alpha))
    neighbor_var = tk.StringVar(value=str(default_neighbor_max_vox))
    vizmin_var   = tk.StringVar(value=str(default_viz_min_voxels))

    fps_var = tk.StringVar(value=str(default_video_fps))

    launch_viewer_var = tk.BooleanVar(value=bool(default_launch_viewer))
    gen_videos_var    = tk.BooleanVar(value=bool(default_generate_videos))

    def _suggest_output_dir(fp):
        if not fp:
            return ""
        raw_dir = os.path.dirname(fp)
        raw_base = os.path.splitext(os.path.basename(fp))[0]
        stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        return os.path.join(raw_dir, f"{raw_base}_outputs_{stamp}")

    def browse_file():
        fp = filedialog.askopenfilename(
            title="Select image file",
            filetypes=[("Image files", "*.tif *.tiff *.czi"), ("All files", "*.*")],
        )
        if fp:
            file_var.set(fp)
            if not out_var.get().strip():
                out_var.set(_suggest_output_dir(fp))

    def browse_output_dir():
        d = filedialog.askdirectory(title="Select output folder")
        if d:
            fp = file_var.get().strip()
            if fp:
                raw_base = os.path.splitext(os.path.basename(fp))[0]
                stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                out_var.set(os.path.join(d, f"{raw_base}_outputs_{stamp}"))
            else:
                out_var.set(d)

    def _err(msg):
        messagebox.showerror("Invalid input", msg)
        raise ValueError(msg)

    def _float_required(s, name):
        s = (s or "").strip()
        if s == "":
            _err(f"{name} is required.")
        try:
            return float(s.replace(",", "."))
        except Exception:
            _err(f"{name} must be a number (got: {s})")

    def _float_optional(s, name):
        s = (s or "").strip()
        if s == "":
            return None
        try:
            return float(s.replace(",", "."))
        except Exception:
            _err(f"{name} must be a number (got: {s})")

    def _int_required(s, name):
        s = (s or "").strip()
        if s == "":
            _err(f"{name} is required.")
        try:
            return int(float(s.replace(",", ".")))
        except Exception:
            _err(f"{name} must be an integer (got: {s})")

    def _toggle_adv():
        if show_adv.get():
            adv_frame.grid()
        else:
            adv_frame.grid_remove()

    def run_clicked():
        fp = file_var.get().strip()
        if not fp:
            _err("Please select a file.")
        if not os.path.isfile(fp):
            _err("Selected file does not exist.")

        outd = out_var.get().strip() or _suggest_output_dir(fp)

        erode = _float_required(erode_var.get(), "ERODE_MULT")
        blobt = _float_required(blob_var.get(), "blob_log threshold")
        vxy_ov = _float_optional(vxy_override_var.get(), "XY override (µm/px)")
        vz_ov  = _float_optional(vz_override_var.get(), "Z override (µm/slice)")
        fps    = _int_required(fps_var.get(), "video FPS")

        margin  = _float_required(margin_var.get(), "MARGIN_UM (µm)")
        overlap = _float_required(overlap_var.get(), "OVERLAP_ALPHA (0..1)")
        neigh   = _int_required(neighbor_var.get(), "NEIGHBOR_MAX_VOX (voxels)")
        vizmin  = _int_required(vizmin_var.get(), "VIZ_MIN_VOXELS (voxels)")

        if not (0.0 <= overlap <= 1.0):
            _err("OVERLAP_ALPHA must be between 0 and 1.")

        cfg.update({
            "ok": True,
            "file_path": fp,
            "output_dir": outd,

            "ERODE_MULT": float(erode),
            "BLOB_THRESHOLD": float(blobt),
            "vxy_override": None if vxy_ov is None else float(vxy_ov),
            "vz_override": None if vz_ov is None else float(vz_ov),

            "DEFAULT_VX_VY_UM": float(default_vxy_um),
            "DEFAULT_VZ_UM": None if default_vz_um is None else float(default_vz_um),

            "MAX_REASONABLE_VXY_UM": float(default_max_reasonable_vxy_um),

            # Advanced only
            "MARGIN_UM": float(margin),
            "OVERLAP_ALPHA": float(overlap),
            "NEIGHBOR_MAX_VOX": int(neigh),
            "VIZ_MIN_VOXELS": int(vizmin),

            # Fixed defaults (hidden)
            "CH1_SMOOTH_SIGMA": float(default_ch1_smooth_sigma),
            "BLOB_MIN_SIGMA": float(default_blob_min_sigma),
            "BLOB_MAX_SIGMA": float(default_blob_max_sigma),
            "BLOB_NUM_SIGMA": int(default_blob_num_sigma),

            "RADIAL_MAX_RADIUS_NM": float(default_radial_max_radius_nm),
            "RADIAL_DR_NM": float(default_radial_dr_nm),
            "RADIAL_MIN_DROP_FRACTION": float(default_radial_min_drop_fraction),

            "CH2_SMOOTH_SIGMA": float(default_ch2_smooth_sigma),
            "THRESH_BLOCK_SIZE": int(default_thresh_block_size),
            "THRESH_OFFSET_STD_MULT": float(default_thresh_offset_std_mult),

            "VIDEO_FPS": int(fps),
            "LAUNCH_VIEWER": bool(launch_viewer_var.get()),
            "GENERATE_VIDEOS": bool(gen_videos_var.get()),
        })

        root.destroy()

    def cancel_clicked():
        root.destroy()

    root.protocol("WM_DELETE_WINDOW", cancel_clicked)

    pad = {"padx": 10, "pady": 6}
    frm = ttk.Frame(root)
    frm.grid(row=0, column=0, sticky="nsew", **pad)

    r = 0
    ttk.Label(frm, text="Image file:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=file_var, width=60).grid(row=r, column=1, sticky="we")
    ttk.Button(frm, text="Browse...", command=browse_file).grid(row=r, column=2, sticky="e")
    r += 1

    ttk.Label(frm, text="Output folder:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=out_var, width=60).grid(row=r, column=1, sticky="we")
    ttk.Button(frm, text="Browse...", command=browse_output_dir).grid(row=r, column=2, sticky="e")
    r += 1

    ttk.Separator(frm).grid(row=r, column=0, columnspan=3, sticky="we", pady=8)
    r += 1

    ttk.Label(frm, text="ERODE_MULT:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=erode_var, width=20).grid(row=r, column=1, sticky="w")
    ttk.Label(frm, text="unitless").grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Label(frm, text="blob_log threshold:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=blob_var, width=20).grid(row=r, column=1, sticky="w")
    ttk.Label(frm, text="unitless").grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Separator(frm).grid(row=r, column=0, columnspan=3, sticky="we", pady=8)
    r += 1

    ttk.Label(frm, text="XY override (µm/px):").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=vxy_override_var, width=20).grid(row=r, column=1, sticky="w")
    ttk.Label(frm, text="blank = metadata; if missing uses default").grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Label(frm, text="Z override (µm/slice):").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=vz_override_var, width=20).grid(row=r, column=1, sticky="w")
    ttk.Label(frm, text="blank = metadata; if missing uses fallback").grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Separator(frm).grid(row=r, column=0, columnspan=3, sticky="we", pady=8)
    r += 1

    ttk.Checkbutton(frm, text="Launch Napari viewer", variable=launch_viewer_var)\
        .grid(row=r, column=0, columnspan=2, sticky="w")
    ttk.Checkbutton(frm, text="Generate videos", variable=gen_videos_var)\
        .grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Label(frm, text="Video FPS:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=fps_var, width=20).grid(row=r, column=1, sticky="w")
    ttk.Label(frm, text="frames/sec").grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Separator(frm).grid(row=r, column=0, columnspan=3, sticky="we", pady=8)
    r += 1

    ttk.Checkbutton(frm, text="Show advanced settings", variable=show_adv, command=_toggle_adv)\
        .grid(row=r, column=0, columnspan=3, sticky="w")
    r += 1

    adv_frame = ttk.LabelFrame(frm, text="Advanced")
    adv_frame.grid(row=r, column=0, columnspan=3, sticky="we", pady=6)
    adv_frame.grid_remove()

    rr = 0
    def add_row(label_txt, var, hint):
        nonlocal rr
        ttk.Label(adv_frame, text=label_txt).grid(row=rr, column=0, sticky="w", padx=8, pady=3)
        ttk.Entry(adv_frame, textvariable=var, width=18).grid(row=rr, column=1, sticky="w", padx=8, pady=3)
        ttk.Label(adv_frame, text=hint).grid(row=rr, column=2, sticky="w", padx=8, pady=3)
        rr += 1

    add_row("MARGIN_UM:", margin_var, "µm (soft band around mask)")
    add_row("OVERLAP_ALPHA:", overlap_var, "0..1 (sphere overlap fraction)")
    add_row("NEIGHBOR_MAX_VOX:", neighbor_var, "voxels (ID search radius)")
    add_row("VIZ_MIN_VOXELS:", vizmin_var, "voxels (hide small components)")

    r += 1
    ttk.Separator(frm).grid(row=r, column=0, columnspan=3, sticky="we", pady=8)
    r += 1

    btns = ttk.Frame(frm)
    btns.grid(row=r, column=0, columnspan=3, sticky="e")
    ttk.Button(btns, text="Cancel", command=cancel_clicked).grid(row=0, column=0, padx=6)
    ttk.Button(btns, text="Run", command=run_clicked).grid(row=0, column=1, padx=6)

    root.mainloop()

    if not cfg.get("ok"):
        raise SystemExit("Cancelled.")
    return cfg


# ===============================
# Metadata parsing
# ===============================
def _parse_ome_xml(xml_text):
    if not xml_text:
        return None, None, None

    def grab(attr):
        m = re.search(fr'PhysicalSize{attr}="([\d\.eE+-]+)"', xml_text)
        return float(m.group(1)) if m else None

    return grab("X"), grab("Y"), grab("Z")


def _parse_czi_scaling(czi_text):
    """
    Robust CZI scaling parser:
    - Works when values are nested (<Value>...</Value>) or attributes
    - Works with namespaces
    - Converts meters -> µm (typical CZI)
    """
    if not czi_text:
        return None, None, None

    if isinstance(czi_text, (bytes, bytearray)):
        czi_text = czi_text.decode("utf-8", errors="ignore")

    czi_text = czi_text.replace("\x00", "")

    def _to_float(s):
        if s is None:
            return None
        try:
            return float(str(s).strip().replace(",", "."))
        except Exception:
            return None

    def _to_um(val, unit_hint=None):
        if val is None:
            return None
        if unit_hint:
            u = str(unit_hint).strip().lower()
            if u in ("m", "meter", "metre", "meters", "metres"):
                return val * 1e6
            if u in ("µm", "um", "micron", "microns", "micrometer", "micrometre"):
                return val
            if u in ("nm", "nanometer", "nanometre", "nanometers", "nanometres"):
                return val / 1000.0

        # heuristic: CZI values often in meters (~1e-8 to 1e-6)
        if val < 1e-3:
            return val * 1e6
        # sometimes already in µm
        if val < 10:
            return val
        # sometimes in nm
        if val < 1e5:
            return val / 1000.0
        return None

    try:
        root = ET.fromstring(czi_text)
    except Exception:
        def _grab(axis):
            mm = re.search(
                rf'<Distance[^>]*Id="{axis}"[^>]*>.*?<Value>\s*([0-9eE\+\-\.]+)\s*</Value>',
                czi_text,
                flags=re.IGNORECASE | re.DOTALL,
            )
            return _to_float(mm.group(1)) if mm else None

        return _to_um(_grab("X")), _to_um(_grab("Y")), _to_um(_grab("Z"))

    sx = sy = sz = None
    for d in root.findall(".//{*}Distance"):
        axis = d.attrib.get("Id") or d.attrib.get("id") or d.attrib.get("Axis") or d.attrib.get("axis")
        if not axis:
            continue
        axis = axis.upper()
        unit = d.attrib.get("Unit") or d.attrib.get("unit")

        valf = _to_float(d.attrib.get("Value") or d.attrib.get("value"))

        if valf is None:
            v_el = d.find(".//{*}Value")
            if v_el is not None and v_el.text:
                valf = _to_float(v_el.text)

        if valf is None:
            for child in d.iter():
                if child is d:
                    continue
                if str(child.tag).lower().endswith("value"):
                    valf = _to_float(child.attrib.get("Value") or child.attrib.get("value")) or _to_float(child.text)
                    if valf is not None:
                        break

        val_um = _to_um(valf, unit_hint=unit)
        if val_um is None:
            continue

        if axis == "X":
            sx = val_um
        elif axis == "Y":
            sy = val_um
        elif axis == "Z":
            sz = val_um

    return sx, sy, sz


def load_any(file_path):
    ext = os.path.splitext(file_path)[1].lower()

    if ext in (".tif", ".tiff"):
        with tiff.TiffFile(file_path) as tf:
            arr = tf.asarray()
            try:
                ome_xml = tf.ome_metadata
            except Exception:
                ome_xml = None
            vx_um = vy_um = vz_um = None
            if ome_xml:
                vx_um, vy_um, vz_um = _parse_ome_xml(ome_xml)

        img = np.squeeze(arr)
        if img.ndim == 4:
            if img.shape[0] == 2:
                ch1, ch2 = img[0], img[1]
            elif img.shape[1] == 2:
                ch1, ch2 = img[:, 0], img[:, 1]
            elif img.shape[-1] == 2:
                ch1, ch2 = img[..., 0], img[..., 1]
            else:
                raise RuntimeError("Unexpected TIFF shape for 2 channels")
        else:
            raise RuntimeError("Unexpected TIFF shape")
        return ch1, ch2, (vx_um, vy_um, vz_um), {"type": "tiff"}

    if ext == ".czi":
        with czifile.CziFile(file_path) as cf:
            arr = cf.asarray()
            try:
                czi_xml = cf.metadata()
            except Exception:
                czi_xml = None

        vx_um = vy_um = vz_um = None
        if czi_xml:
            vx_um, vy_um, vz_um = _parse_czi_scaling(czi_xml)

        img = np.squeeze(arr)
        if img.ndim == 4:
            if img.shape[0] == 2:
                ch1, ch2 = img[0], img[1]
            elif img.shape[1] == 2:
                ch1, ch2 = img[:, 0], img[:, 1]
            elif img.shape[-1] == 2:
                ch1, ch2 = img[..., 0], img[..., 1]
            else:
                raise RuntimeError("Unexpected CZI shape for 2 channels")
        else:
            raise RuntimeError("Unexpected CZI shape")

        return ch1, ch2, (vx_um, vy_um, vz_um), {"type": "czi"}

    raise ValueError("Unsupported file format")


# ===============================
# Helper: refine radii by distance transform
# ===============================
def refine_radii_via_dt(img3d, blobs, win_px=40, bin_method="sauvola"):
    from skimage.filters import threshold_sauvola, threshold_local, threshold_otsu
    from skimage.morphology import remove_small_objects, disk, binary_opening
    from scipy.ndimage import distance_transform_edt as _edt
    from skimage.measure import label as _label

    if blobs is None or len(blobs) == 0:
        return blobs

    Z, H, W = img3d.shape
    out = blobs.copy().astype(np.float32)

    for i, (zc, yc, xc, _) in enumerate(out):
        z, y, x = int(round(zc)), int(round(yc)), int(round(xc))
        if not (0 <= z < Z and 0 <= y < H and 0 <= x < W):
            continue

        y1, y2 = max(0, y - win_px), min(H, y + win_px + 1)
        x1, x2 = max(0, x - win_px), min(W, x + win_px + 1)
        patch = img3d[z, y1:y2, x1:x2]

        if bin_method == "sauvola":
            ws = max(11, 2 * (win_px // 2) + 1)
            thr = threshold_sauvola(patch, window_size=ws, k=0.4)
            bw = patch > thr
        elif bin_method == "local":
            ws = max(11, 2 * (win_px // 2) + 1)
            thr = threshold_local(patch, block_size=ws, offset=-0.4 * np.std(patch))
            bw = patch > thr
        else:
            try:
                thr = threshold_otsu(patch)
                bw = patch > thr
            except ValueError:
                continue

        bw = binary_opening(bw, footprint=disk(1))
        bw = remove_small_objects(bw, min_size=3, connectivity=2)

        yy, xx = y - y1, x - x1
        if not (0 <= yy < bw.shape[0] and 0 <= xx < bw.shape[1]) or not bw[yy, xx]:
            continue

        lab = _label(bw)
        lbl = lab[yy, xx]
        if lbl == 0:
            continue
        bw_obj = (lab == lbl)
        dt = _edt(bw_obj)
        r_px = float(dt[yy, xx])
        if r_px <= 0:
            continue
        out[i, 3] = r_px

    return out


def refine_radii_via_radial_intensity(
    img3d,
    blobs,
    vx_um,
    vy_um,
    vz_um,
    max_radius_nm=300.0,
    dr_nm=10.0,
    min_drop_fraction=0.5,
):
    if blobs is None or len(blobs) == 0:
        return blobs

    img = img3d.astype(np.float32)
    Z, Y, X = img.shape
    blobs_out = blobs.copy().astype(np.float32)

    max_r_um = max_radius_nm / 1000.0
    dr_um = dr_nm / 1000.0

    r_edges = np.arange(0.0, max_r_um + dr_um, dr_um)
    if r_edges.size < 2:
        return blobs_out
    r_centers = 0.5 * (r_edges[:-1] + r_edges[1:])

    px_um_xy = float(np.sqrt(vx_um * vy_um))

    for i, (zc, yc, xc, r_px_init) in enumerate(blobs_out):
        z0 = int(round(zc))
        y0 = int(round(yc))
        x0 = int(round(xc))

        if not (0 <= z0 < Z and 0 <= y0 < Y and 0 <= x0 < X):
            continue

        rz = max(1, int(np.ceil(max_r_um / vz_um)))
        ry = max(1, int(np.ceil(max_r_um / vy_um)))
        rx = max(1, int(np.ceil(max_r_um / vx_um)))

        z1, z2 = max(0, z0 - rz), min(Z, z0 + rz + 1)
        y1, y2 = max(0, y0 - ry), min(Y, y0 + ry + 1)
        x1, x2 = max(0, x0 - rx), min(X, x0 + rx + 1)
        if z1 >= z2 or y1 >= y2 or x1 >= x2:
            continue

        patch = img[z1:z2, y1:y2, x1:x2]

        zz, yy, xx = np.mgrid[z1:z2, y1:y2, x1:x2]
        dz_um = (zz - z0) * vz_um
        dy_um = (yy - y0) * vy_um
        dx_um = (xx - x0) * vx_um
        r_um = np.sqrt(dz_um**2 + dy_um**2 + dx_um**2)

        mask = (r_um <= max_r_um)
        if not np.any(mask):
            continue

        r_vals = r_um[mask].ravel()
        I_vals = patch[mask].ravel()

        bin_idx = np.digitize(r_vals, r_edges) - 1
        valid = (bin_idx >= 0) & (bin_idx < r_centers.size)
        if not np.any(valid):
            continue

        bin_idx = bin_idx[valid]
        I_vals = I_vals[valid]

        sums = np.bincount(bin_idx, weights=I_vals, minlength=r_centers.size)
        counts = np.bincount(bin_idx, minlength=r_centers.size)
        with np.errstate(invalid="ignore", divide="ignore"):
            prof = sums / np.maximum(counts, 1)

        have = counts > 0
        if not np.any(have):
            continue

        r_prof = r_centers[have]
        I_prof = prof[have].astype(np.float32)

        I_smooth = gaussian_filter1d(I_prof, sigma=1.0)
        I_max = float(I_smooth.max())
        if I_max <= 0:
            continue
        I_min = float(I_smooth.min())

        if (I_max - I_min) / max(I_max, 1e-9) < min_drop_fraction:
            continue

        I_half = I_min + 0.5 * (I_max - I_min)

        peak_idx = int(np.argmax(I_smooth))
        n_bins = len(I_smooth)

        left_idx = peak_idx
        while left_idx > 0 and I_smooth[left_idx] >= I_half:
            left_idx -= 1
        if left_idx < peak_idx and I_smooth[left_idx] < I_half:
            left_idx += 1

        right_idx = peak_idx
        while right_idx < n_bins - 1 and I_smooth[right_idx] >= I_half:
            right_idx += 1
        if right_idx > peak_idx and I_smooth[right_idx] < I_half:
            right_idx -= 1

        if right_idx <= left_idx:
            continue

        radius_um = 0.5 * (float(r_prof[right_idx]) - float(r_prof[left_idx]))
        if radius_um <= 0:
            continue

        r_fwhm_px = radius_um / max(px_um_xy, 1e-9)
        blobs_out[i, 3] = max(float(r_px_init), float(r_fwhm_px))

    return blobs_out


# ===============================
# Distance/overlap-aware classification helpers
# ===============================
from scipy.ndimage import distance_transform_edt as _edt


def nearest_cell_label(cell_seg, z, y, x, max_r=12):
    Z, Y, X = cell_seg.shape
    for r in range(1, max_r + 1):
        z1, z2 = max(0, z - r), min(Z, z + r + 1)
        y1, y2 = max(0, y - r), min(Y, y + r + 1)
        x1, x2 = max(0, x - r), min(X, x + r + 1)
        patch = cell_seg[z1:z2, y1:y2, x1:x2]
        lab = patch[patch > 0]
        if lab.size:
            return int(np.bincount(lab.ravel()).argmax())
    return 0


def sphere_overlap_fraction(zc_um, yc_um, xc_um, r_um, mask_bool, vx_um, vy_um, vz_um):
    if r_um <= 0:
        return 0.0

    zc = int(round(zc_um / vz_um))
    yc = int(round(yc_um / vy_um))
    xc = int(round(xc_um / vx_um))

    rz = max(1, int(np.ceil(r_um / vz_um)))
    ry = max(1, int(np.ceil(r_um / vy_um)))
    rx = max(1, int(np.ceil(r_um / vx_um)))

    Z, Y, X = mask_bool.shape
    z1, z2 = max(0, zc - rz), min(Z, zc + rz + 1)
    y1, y2 = max(0, yc - ry), min(Y, yc + ry + 1)
    x1, x2 = max(0, xc - rx), min(X, xc + rx + 1)
    if z1 >= z2 or y1 >= y2 or x1 >= x2:
        return 0.0

    zz, yy, xx = np.mgrid[z1:z2, y1:y2, x1:x2]
    dz = (zz - zc) * vz_um
    dy = (yy - yc) * vy_um
    dx = (xx - xc) * vx_um
    sphere = (dz*dz + dy*dy + dx*dx) <= (r_um * r_um)

    if not np.any(sphere):
        return 0.0

    in_mask = mask_bool[z1:z2, y1:y2, x1:x2] & sphere
    return float(in_mask.sum()) / float(sphere.sum())


# ===============================
# Full-size overlay exporter (RGB stack + MP4)
# ===============================
def _norm_u8_stack(vol):
    vmin, vmax = float(vol.min()), float(vol.max())
    if vmax > vmin:
        out = (np.clip((vol - vmin) / (vmax - vmin), 0, 1) * 255.0).astype(np.uint8)
    else:
        out = np.zeros_like(vol, dtype=np.uint8)
    return out


def make_label_colormap(n_labels, seed_hue=0.13):
    colors = np.zeros((n_labels + 1, 3), dtype=np.uint8)
    if n_labels <= 0:
        return colors
    for i in range(1, n_labels + 1):
        h = (seed_hue + (i - 1) / max(n_labels, 1)) % 1.0
        s = 1.0
        v = 1.0
        r, g, b = colorsys.hsv_to_rgb(h, s, v)
        colors[i] = (int(255 * r), int(255 * g), int(255 * b))
    return colors


def export_fullsize_overlay_stack(
    img_ch1,
    img_ch2_raw,
    cell_seg_viz,
    df,
    vx_um, vy_um, vz_um,
    output_dir,
    alpha_labels=0.45,
    draw_only_inside=True,
    fps=8,
    basename="FULLSIZE_overlay_ID_Lysosomes_MAGENTA",
):
    os.makedirs(output_dir, exist_ok=True)

    ch1_u8 = _norm_u8_stack(img_ch1.astype(np.float32))
    ch2_u8 = _norm_u8_stack(img_ch2_raw.astype(np.float32))

    Z, H, W = ch2_u8.shape
    n_labels = int(cell_seg_viz.max()) if isinstance(cell_seg_viz, np.ndarray) else 0
    cmap = make_label_colormap(n_labels, seed_hue=0.13)

    use_df = None
    if isinstance(df, pd.DataFrame) and len(df) > 0 and {"z_um", "y_um", "x_um", "radius_um"}.issubset(df.columns):
        if draw_only_inside and "location_ch2" in df.columns:
            use_df = df[df["location_ch2"] == "cell"].copy()
        else:
            use_df = df.copy()
        use_df = use_df[
            np.isfinite(use_df["z_um"]) &
            np.isfinite(use_df["y_um"]) &
            np.isfinite(use_df["x_um"]) &
            np.isfinite(use_df["radius_um"])
        ].copy()

    px_um_xy = float(np.sqrt(vx_um * vy_um))
    frames = np.zeros((Z, H, W, 3), dtype=np.uint8)

    for z in range(Z):
        # Background (same style as your RAW composite): B=Ch1, G=Ch2, R=Ch1
        base = np.dstack([ch1_u8[z], ch2_u8[z], ch1_u8[z]]).astype(np.float32)

        # Colored IDs
        lab2d = cell_seg_viz[z].astype(np.int32)
        lab_rgb = cmap[lab2d].astype(np.float32)

        mask = (lab2d > 0)[..., None].astype(np.float32)
        out = base * (1.0 - alpha_labels * mask) + lab_rgb * (alpha_labels * mask)

        # Lysosomes (MAGENTA)
        if use_df is not None and len(use_df) > 0:
            zc = (use_df["z_um"].to_numpy() / vz_um).astype(float)
            yc = (use_df["y_um"].to_numpy() / vy_um).astype(float)
            xc = (use_df["x_um"].to_numpy() / vx_um).astype(float)
            r_um = use_df["radius_um"].to_numpy().astype(float)

            dz_um = np.abs(zc - z) * vz_um
            hits = dz_um <= r_um
            if np.any(hits):
                r_proj_um = np.sqrt(np.clip(r_um[hits]**2 - dz_um[hits]**2, 0.0, None))
                r_proj_px = r_proj_um / max(px_um_xy, 1e-12)
                ys = np.rint(yc[hits]).astype(int)
                xs = np.rint(xc[hits]).astype(int)

                out_u8 = np.clip(out, 0, 255).astype(np.uint8)
                for y, x, rp in zip(ys, xs, r_proj_px):
                    rr = int(max(3, round(rp)))
                    if 0 <= y < H and 0 <= x < W and rr > 0:
                        cv2.circle(out_u8, (x, y), rr, LYS_EDGE_BGR, 4, lineType=cv2.LINE_AA)
                        cv2.circle(out_u8, (x, y), rr, LYS_MAGENTA_BGR, 2, lineType=cv2.LINE_AA)
                        cv2.circle(out_u8, (x, y), 1,  LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)
                out = out_u8.astype(np.float32)

        frames[z] = np.clip(out, 0, 255).astype(np.uint8)

    # Save RGB TIFF stack
    tiff_path = os.path.join(output_dir, f"{basename}.tif")
    tiff.imwrite(tiff_path, frames, photometric="rgb")
    print("Saved full-size RGB TIFF stack:", tiff_path)

    # Save MP4 (fallback GIF)
    mp4_path = os.path.join(output_dir, f"{basename}.mp4")
    try:
        with imageio.get_writer(
            mp4_path,
            fps=int(fps),
            format="FFMPEG",
            codec="libx264",
            macro_block_size=None
        ) as w:
            for fr in frames:
                w.append_data(fr)
        print("Saved full-size MP4:", mp4_path)
    except Exception as e:
        gif_path = os.path.join(output_dir, f"{basename}.gif")
        imageio.mimsave(gif_path, list(frames), fps=int(fps))
        print("FFMPEG failed, saved GIF instead:", gif_path, "Error:", e)

    return tiff_path, mp4_path


# ===============================
# MAIN
# ===============================
DEFAULT_VX_VY_UM = 0.04
DEFAULT_VZ_UM = None  # if None and Z missing -> fallback to XY

cfg = get_user_config_gui(
    default_vxy_um=DEFAULT_VX_VY_UM,
    default_vz_um=DEFAULT_VZ_UM,
    default_erode_mult=1.0,
    default_blob_threshold=0.001,
)

file_path = cfg["file_path"]
output_dir = cfg["output_dir"]
os.makedirs(output_dir, exist_ok=True)
os.chdir(output_dir)

print("Selected file:", file_path)
print("Outputs will be saved to:", output_dir)

ERODE_MULT = cfg["ERODE_MULT"]
BLOB_THRESHOLD = cfg["BLOB_THRESHOLD"]
vxy_override = cfg["vxy_override"]
vz_override = cfg["vz_override"]

MAX_REASONABLE_VXY_UM = cfg["MAX_REASONABLE_VXY_UM"]
MARGIN_UM = cfg["MARGIN_UM"]
OVERLAP_ALPHA = cfg["OVERLAP_ALPHA"]
NEIGHBOR_MAX_VOX = cfg["NEIGHBOR_MAX_VOX"]
VIZ_MIN_VOXELS = cfg["VIZ_MIN_VOXELS"]

CH1_SMOOTH_SIGMA = cfg["CH1_SMOOTH_SIGMA"]
BLOB_MIN_SIGMA = cfg["BLOB_MIN_SIGMA"]
BLOB_MAX_SIGMA = cfg["BLOB_MAX_SIGMA"]
BLOB_NUM_SIGMA = cfg["BLOB_NUM_SIGMA"]

RADIAL_MAX_RADIUS_NM = cfg["RADIAL_MAX_RADIUS_NM"]
RADIAL_DR_NM = cfg["RADIAL_DR_NM"]
RADIAL_MIN_DROP_FRACTION = cfg["RADIAL_MIN_DROP_FRACTION"]

CH2_SMOOTH_SIGMA = cfg["CH2_SMOOTH_SIGMA"]
THRESH_BLOCK_SIZE = cfg["THRESH_BLOCK_SIZE"]
THRESH_OFFSET_STD_MULT = cfg["THRESH_OFFSET_STD_MULT"]

FPS = cfg["VIDEO_FPS"]
LAUNCH_VIEWER = cfg["LAUNCH_VIEWER"]
GENERATE_VIDEOS = cfg["GENERATE_VIDEOS"]

# Load data + metadata
img_ch1, img_ch2, (vx_um, vy_um, vz_um), meta = load_any(file_path)
print(f"[metadata] vx_um={vx_um}  vy_um={vy_um}  vz_um={vz_um}")

# Apply overrides / fallbacks
if vxy_override is not None:
    vx_um = vy_um = float(vxy_override)
else:
    if vx_um is None or vy_um is None:
        vx_um = vy_um = float(cfg["DEFAULT_VX_VY_UM"])

if vz_override is not None:
    vz_um = float(vz_override)
else:
    if vz_um is None:
        if cfg["DEFAULT_VZ_UM"] is not None:
            vz_um = float(cfg["DEFAULT_VZ_UM"])
        else:
            vz_um = float(vx_um)  # fallback only if Z missing

if vx_um > MAX_REASONABLE_VXY_UM:
    raise ValueError(f"XY pixel size too large: {vx_um} µm/px")

px_um_xy = float(np.sqrt(vx_um * vy_um))
px_um = px_um_xy * 0.55
voxel_um3 = vx_um * vy_um * vz_um
print(f"Voxel size (µm): X={vx_um}, Y={vy_um}, Z={vz_um}")

# ===== Aliases =====
image = img_ch1
image_2 = img_ch2

# ==========================================
# Lysosome detection (Ch1)
# ==========================================
image_smooth = gaussian(image, sigma=CH1_SMOOTH_SIGMA)

blobs = blob_log(
    image_smooth,
    min_sigma=BLOB_MIN_SIGMA,
    max_sigma=BLOB_MAX_SIGMA,
    num_sigma=BLOB_NUM_SIGMA,
    threshold=BLOB_THRESHOLD
)

if len(blobs) > 0:
    blobs[:, 3] *= np.sqrt(3)  # LoG radius correction (pixels)
    blobs = refine_radii_via_dt(image_smooth, blobs)
    blobs = refine_radii_via_radial_intensity(
        image_smooth,
        blobs,
        vx_um, vy_um, vz_um,
        max_radius_nm=RADIAL_MAX_RADIUS_NM,
        dr_nm=RADIAL_DR_NM,
        min_drop_fraction=RADIAL_MIN_DROP_FRACTION
    )

# Peak intensity in raw 16-bit Ch1
peak_gray = np.zeros(len(blobs), dtype=np.uint16)
Z0, Y0, X0 = image.shape
rad = 1

for i, (zc, yc, xc, _) in enumerate(blobs):
    zc_i = int(round(zc))
    yc_i = int(round(yc))
    xc_i = int(round(xc))

    z1, z2 = max(0, zc_i - rad), min(Z0, zc_i + rad + 1)
    y1, y2 = max(0, yc_i - rad), min(Y0, yc_i + rad + 1)
    x1, x2 = max(0, xc_i - rad), min(X0, xc_i + rad + 1)

    peak_gray[i] = np.max(image[z1:z2, y1:y2, x1:x2]).astype(np.uint16)

if len(blobs) > 0:
    z_um = blobs[:, 0] * vz_um
    y_um = blobs[:, 1] * vy_um
    x_um = blobs[:, 2] * vx_um
    radius_um = blobs[:, 3] * px_um
    diameter_um = 2 * radius_um
    volume_um3 = (4/3) * np.pi * radius_um**3
    blob_ids = np.arange(1, len(blobs) + 1, dtype=int)
else:
    z_um = y_um = x_um = radius_um = diameter_um = volume_um3 = np.array([])
    blob_ids = np.array([], dtype=int)

df = pd.DataFrame({
    "id": blob_ids,
    "z_um": z_um,
    "y_um": y_um,
    "x_um": x_um,
    "radius_um": radius_um,
    "diameter_um": diameter_um,
    "volume_um3": volume_um3,
    "peak_gray": peak_gray,
})
df.to_csv("lysosome_blobs_regions.csv", index=False)
print("Saved: lysosome_blobs_regions.csv")

# ==========================================
# CH2 segmentation (neurites mask)
# ==========================================
vol = image_2.astype(np.float32)
vmin, vmax = float(vol.min()), float(vol.max())
if vmax > vmin:
    vol = (vol - vmin) / (vmax - vmin)
else:
    vol[:] = 0.0

ch2 = gaussian(vol, sigma=CH2_SMOOTH_SIGMA, preserve_range=True)

neuron_mask = np.zeros_like(ch2, dtype=bool)
for z in range(ch2.shape[0]):
    R = ch2[z]
    t = threshold_local(R, block_size=THRESH_BLOCK_SIZE, offset=-THRESH_OFFSET_STD_MULT * np.std(R))
    neuron_mask[z] = R > t

if NEURITE_MODE:
    # neurite-safe cleanup: connect gaps, remove small speckles; avoid erosion/opening
    neuron_mask = binary_closing(neuron_mask, ball(1))
    neuron_mask = remove_small_objects(neuron_mask, min_size=200, connectivity=3)
    # NOTE: do NOT fill holes globally for neurites (can fill loops)
else:
    # fallback to soma-style refinement if you ever switch back
    neuron_mask = binary_fill_holes(neuron_mask)

print("neurite voxels:", int(neuron_mask.sum()))

# ---- ID segmentation ----
if NEURITE_MODE:
    # label connected neurite networks (often 1)
    cell_seg = label(neuron_mask, connectivity=3).astype(np.int32)
    cell_mask = neuron_mask.copy()
    print("Detected components (neurite networks):", int(cell_seg.max()))
else:
    # Original soma logic (kept for compatibility)
    dist = edt(neuron_mask)
    cell_min_radius_vox = 1
    cell_mask = (dist >= cell_min_radius_vox)
    cell_mask &= neuron_mask
    cell_mask = binary_fill_holes(binary_closing(binary_opening(cell_mask, ball(1)), ball(2)))

    body_lab = label(cell_mask, connectivity=3)
    n_cells = int(body_lab.max())
    print(f"Detected {n_cells} cells (soma).")

    if n_cells > 0:
        dist_inside = edt(neuron_mask)
        cell_seg = watershed(-dist_inside, markers=body_lab, mask=neuron_mask)
    else:
        cell_seg = np.zeros_like(neuron_mask, dtype=np.int32)

print("ID voxels:", int((cell_seg > 0).sum()))

# ==========================================
# Visualization-only filtering (hide tiny components) + serial IDs
# ==========================================
cell_seg_viz = cell_seg.copy()
cell_id_map_viz = {}

if isinstance(cell_seg_viz, np.ndarray) and cell_seg_viz.max() > 0:
    counts_viz = np.bincount(cell_seg_viz.ravel().astype(np.int64))
    tiny_labels = np.where(counts_viz < VIZ_MIN_VOXELS)[0]
    tiny_labels = tiny_labels[tiny_labels > 0]
    if tiny_labels.size > 0:
        tiny_mask = np.isin(cell_seg_viz, tiny_labels)
        cell_seg_viz[tiny_mask] = 0

    unique_labels = np.unique(cell_seg_viz)
    unique_labels = unique_labels[unique_labels > 0]

    if unique_labels.size > 0:
        new_seg = np.zeros_like(cell_seg_viz, dtype=np.int32)
        for new_id, old_id in enumerate(unique_labels, start=1):
            new_seg[cell_seg_viz == old_id] = new_id
            cell_id_map_viz[old_id] = new_id
        cell_seg_viz = new_seg

cell_mask_viz = (cell_seg_viz > 0)

# ==========================================
# Classify lysosomes: inside vs outside neurite mask, and assign ID
# ==========================================
dist_out_um = _edt(~neuron_mask, sampling=(vz_um, vy_um, vx_um)).astype(np.float32)
soft_cell_mask = neuron_mask | (dist_out_um <= MARGIN_UM)

location_ch2 = []
cell_id_list = []

if len(df) > 0:
    Z, Y, X = neuron_mask.shape
    for (zc_um, yc_um, xc_um, r_um) in df[["z_um", "y_um", "x_um", "radius_um"]].to_numpy():
        zz = int(round(zc_um / vz_um))
        yy = int(round(yc_um / vy_um))
        xx = int(round(xc_um / vx_um))

        inside_hard = (0 <= zz < Z and 0 <= yy < Y and 0 <= xx < X and neuron_mask[zz, yy, xx])
        inside_soft = (0 <= zz < Z and 0 <= yy < Y and 0 <= xx < X and soft_cell_mask[zz, yy, xx])

        is_inside = bool(inside_hard)
        if not is_inside and inside_soft:
            is_inside = True

        if not is_inside:
            frac = sphere_overlap_fraction(zc_um, yc_um, xc_um, r_um, neuron_mask, vx_um, vy_um, vz_um)
            if frac >= OVERLAP_ALPHA:
                is_inside = True

        if is_inside:
            cid = 0
            if 0 <= zz < Z and 0 <= yy < Y and 0 <= xx < X:
                if cell_seg[zz, yy, xx] != 0:
                    cid = int(cell_seg[zz, yy, xx])
                else:
                    cid = nearest_cell_label(cell_seg, zz, yy, xx, max_r=NEIGHBOR_MAX_VOX)
            location_ch2.append("cell")
            cell_id_list.append(cid)
        else:
            location_ch2.append("outside")
            cell_id_list.append(0)

    df["location_ch2"] = location_ch2
    df["cell_id_ch2"] = cell_id_list
    df["cell_id_ch2_viz"] = df["cell_id_ch2"].map(cell_id_map_viz).fillna(0).astype(int) if isinstance(cell_id_map_viz, dict) else 0

    # Serial lysosome IDs within each ID component
    df["lys_id_in_cell"] = 0
    mask_in = (df["location_ch2"] == "cell") & (df["cell_id_ch2"] > 0)
    df_sorted = df.loc[mask_in].sort_values(["cell_id_ch2", "z_um", "y_um", "x_um"]).copy()
    df.loc[df_sorted.index, "lys_id_in_cell"] = (df_sorted.groupby("cell_id_ch2").cumcount().to_numpy() + 1).astype(int)

    df.to_csv("lysosomes_with_cell_vs_outside.csv", index=False)
    print("Saved: lysosomes_with_cell_vs_outside.csv")

    df.groupby("location_ch2").size().reset_index(name="count").to_csv("lysosome_counts_cell_vs_outside.csv", index=False)
    (df[df["location_ch2"] == "cell"]
        .groupby("cell_id_ch2").size()
        .reset_index(name="count")
        .to_csv("lysosome_counts_by_cell.csv", index=False))
    print("Saved: lysosome_counts_cell_vs_outside.csv, lysosome_counts_by_cell.csv")

# Per-ID volumes (µm^3)
cell_volume_df = pd.DataFrame(columns=["cell_id_ch2", "voxel_count", "volume_um3"])
if isinstance(cell_seg, np.ndarray) and cell_seg.max() > 0:
    counts = np.bincount(cell_seg.ravel().astype(np.int64))
    cell_ids = np.arange(1, counts.size, dtype=int)
    voxels = counts[1:].astype(np.int64)
    vol_um3 = voxels.astype(float) * voxel_um3

    cell_volume_df = pd.DataFrame({
        "cell_id_ch2": cell_ids,
        "voxel_count": voxels,
        "volume_um3": vol_um3
    })
    cell_volume_df.to_csv("cell_volumes_ch2.csv", index=False)
    print("Saved: cell_volumes_ch2.csv")

    try:
        if len(df) > 0 and "location_ch2" in df and "cell_id_ch2" in df:
            lys_counts = (
                df[df["location_ch2"] == "cell"]
                .groupby("cell_id_ch2").size()
                .reset_index(name="lysosome_count")
            )
        else:
            lys_counts = pd.DataFrame(columns=["cell_id_ch2", "lysosome_count"])

        merged = cell_volume_df.merge(lys_counts, on="cell_id_ch2", how="left").fillna({"lysosome_count": 0})
        merged.to_csv("cell_metrics_ch2.csv", index=False)
        print("Saved: cell_metrics_ch2.csv")
    except Exception as e:
        print("Merge with lysosome counts failed:", e)

# ==========================================
# Export full-size overlay (RAW + colored IDs + MAGENTA lysosomes)
# ==========================================
if EXPORT_FULLSIZE_OVERLAY:
    export_fullsize_overlay_stack(
        img_ch1=img_ch1,
        img_ch2_raw=img_ch2,
        cell_seg_viz=cell_seg_viz,
        df=df,
        vx_um=vx_um, vy_um=vy_um, vz_um=vz_um,
        output_dir=output_dir,
        alpha_labels=0.45,
        draw_only_inside=True,   # True = only inside neurites; set False for all
        fps=FPS,
        basename="FULLSIZE_overlay_ID_Lysosomes_MAGENTA"
    )

# ==========================================
# Videos (MAGENTA lysosomes)
# ==========================================
if GENERATE_VIDEOS:
    img_norm_2 = (ch2 * 255).astype(np.uint8)
    frames_fused = []
    Z = img_norm_2.shape[0]
    px_um_xy = float(np.sqrt(vx_um * vy_um))

    for z in range(Z):
        base = cv2.cvtColor(img_norm_2[z], cv2.COLOR_GRAY2BGR)

        mask_u8 = (cell_mask_viz[z].astype(np.uint8) * 255)
        overlay = base.copy()
        overlay[..., 1] = np.maximum(overlay[..., 1], mask_u8)
        overlay = cv2.addWeighted(base, 1.0, overlay, 0.35, 0.0)

        drew_any = False

        if isinstance(df, pd.DataFrame) and len(df) > 0:
            dfv = df[
                np.isfinite(df["z_um"]) &
                np.isfinite(df["y_um"]) &
                np.isfinite(df["x_um"]) &
                np.isfinite(df["radius_um"])
            ]
            if not dfv.empty:
                zc = (dfv["z_um"].to_numpy() / vz_um).astype(float)
                yc = (dfv["y_um"].to_numpy() / vy_um).astype(float)
                xc = (dfv["x_um"].to_numpy() / vx_um).astype(float)
                r_um = dfv["radius_um"].to_numpy().astype(float)

                dz_um = np.abs(zc - z) * vz_um
                hits = dz_um <= r_um

                if np.any(hits):
                    r_proj_um = np.sqrt(np.clip(r_um[hits]**2 - dz_um[hits]**2, 0.0, None))
                    r_proj_px = r_proj_um / max(px_um_xy, 1e-12)

                    ys = np.rint(yc[hits]).astype(int)
                    xs = np.rint(xc[hits]).astype(int)

                    H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
                    thickness = 2

                    for y, x, rpv in zip(ys, xs, r_proj_px):
                        rr = int(max(3, round(rpv)))
                        if 0 <= y < H and 0 <= x < W and rr > 0:
                            cv2.circle(overlay, (x, y), rr, LYS_EDGE_BGR, thickness + 2, lineType=cv2.LINE_AA)
                            cv2.circle(overlay, (x, y), rr, LYS_MAGENTA_BGR, thickness, lineType=cv2.LINE_AA)
                            cv2.circle(overlay, (x, y), 1,  LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)
                    drew_any = True

        # Fallback: draw raw blobs if df empty
        if (not drew_any) and (blobs is not None) and (len(blobs) > 0):
            z_blobs = blobs[np.abs(blobs[:, 0] - z) < 0.5]
            H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
            thickness = 2
            for b in z_blobs:
                y, x = int(round(b[1])), int(round(b[2]))
                r = int(max(3, round(b[3])))
                if 0 <= y < H and 0 <= x < W and r > 0:
                    cv2.circle(overlay, (x, y), r, LYS_EDGE_BGR, thickness + 2, lineType=cv2.LINE_AA)
                    cv2.circle(overlay, (x, y), r, LYS_MAGENTA_BGR, thickness, lineType=cv2.LINE_AA)
                    cv2.circle(overlay, (x, y), 1, LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)

        cv2.putText(overlay, "FUSED (mask + MAGENTA lysosomes)", (10, 22),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)

        frames_fused.append(overlay)

    try:
        imageio.mimsave("ch2_fused_cell_magenta.mp4", frames_fused, fps=int(FPS), format="FFMPEG")
        print("Saved: ch2_fused_cell_magenta.mp4")
    except TypeError:
        imageio.mimsave("ch2_fused_cell_magenta.gif", frames_fused, fps=int(FPS))
        print("Saved: ch2_fused_cell_magenta.gif")

    # RAW+FUSED side-by-side
    ch1_u8 = _norm_u8_stack(img_ch1.astype(np.float32))
    ch2_u8 = _norm_u8_stack(ch2.astype(np.float32) * 255.0 / max(1.0, ch2.max()))

    frames_raw = []
    frames_fused_all = []
    frames_side_by_side = []

    Z = ch2_u8.shape[0]
    px_um_xy = float(np.sqrt(vx_um * vy_um))

    for z in range(Z):
        b = ch1_u8[z]
        g = ch2_u8[z]
        r = ch1_u8[z]
        base = np.dstack([b, g, r])

        mask_u8 = (cell_mask_viz[z].astype(np.uint8) * 255)
        overlay = base.copy()
        overlay[..., 1] = np.maximum(overlay[..., 1], mask_u8)
        overlay = cv2.addWeighted(base, 1.0, overlay, 0.35, 0.0)

        drew_any = False

        if isinstance(df, pd.DataFrame) and len(df) > 0:
            dfv = df[
                np.isfinite(df["z_um"]) &
                np.isfinite(df["y_um"]) &
                np.isfinite(df["x_um"]) &
                np.isfinite(df["radius_um"])
            ]
            if not dfv.empty:
                zc = (dfv["z_um"].to_numpy() / vz_um).astype(float)
                yc = (dfv["y_um"].to_numpy() / vy_um).astype(float)
                xc = (dfv["x_um"].to_numpy() / vx_um).astype(float)
                r_um = dfv["radius_um"].to_numpy().astype(float)

                dz_um = np.abs(zc - z) * vz_um
                hits = dz_um <= r_um

                if np.any(hits):
                    r_proj_um = np.sqrt(np.clip(r_um[hits]**2 - dz_um[hits]**2, 0.0, None))
                    r_proj_px = r_proj_um / max(px_um_xy, 1e-12)

                    ys = np.rint(yc[hits]).astype(int)
                    xs = np.rint(xc[hits]).astype(int)

                    H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
                    thickness = 2

                    for y, x, rpv in zip(ys, xs, r_proj_px):
                        rr = int(max(3, round(rpv)))
                        if 0 <= y < H and 0 <= x < W and rr > 0:
                            cv2.circle(overlay, (x, y), rr, LYS_EDGE_BGR, thickness + 2, lineType=cv2.LINE_AA)
                            cv2.circle(overlay, (x, y), rr, LYS_MAGENTA_BGR, thickness, lineType=cv2.LINE_AA)
                            cv2.circle(overlay, (x, y), 1,  LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)
                    drew_any = True

        if (not drew_any) and (blobs is not None) and (len(blobs) > 0):
            z_blobs = blobs[np.abs(blobs[:, 0] - z) < 0.5]
            H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
            thickness = 2
            for b_ in z_blobs:
                y, x = int(round(b_[1])), int(round(b_[2]))
                rpx = int(max(3, round(b_[3])))
                if 0 <= y < H and 0 <= x < W and rpx > 0:
                    cv2.circle(overlay, (x, y), rpx, LYS_EDGE_BGR, thickness + 2, lineType=cv2.LINE_AA)
                    cv2.circle(overlay, (x, y), rpx, LYS_MAGENTA_BGR, thickness, lineType=cv2.LINE_AA)
                    cv2.circle(overlay, (x, y), 1, LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)

        cv2.putText(base, "RAW (Ch1+Ch2)", (10, 22),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(overlay, "FUSED (mask + MAGENTA lysosomes)", (10, 22),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)

        frames_raw.append(base)
        frames_fused_all.append(overlay)
        frames_side_by_side.append(cv2.hconcat([base, overlay]))

    def _save_video_sync(basename, frames):
        try:
            with imageio.get_writer(
                f"{basename}.mp4",
                fps=int(FPS),
                format="FFMPEG",
                codec="libx264",
                macro_block_size=None
            ) as w:
                for fr in frames:
                    w.append_data(fr)
            print(f"Saved: {basename}.mp4 @ {int(FPS)} fps")
        except Exception:
            imageio.mimsave(f"{basename}.gif", frames, duration=1.0 / max(int(FPS), 1))
            print(f"Saved: {basename}.gif @ {int(FPS)} fps equivalent")

    _save_video_sync("ch2_raw", frames_raw)
    _save_video_sync("ch2_fused_all_viz_magenta", frames_fused_all)
    _save_video_sync("ch2_raw_and_fused_all_viz_magenta", frames_side_by_side)

# ==========================================
# Napari visualization
# ==========================================
if LAUNCH_VIEWER:
    viewer = napari.Viewer()
    viewer.dims.ndisplay = 3

    viewer.add_image(img_ch2, name="Ch2 raw")
    viewer.add_image(img_ch1, name="Ch1 raw")

    mask_layer = viewer.add_labels(
        cell_mask_viz.astype(np.uint8),
        name="Neurite mask (viz)" if NEURITE_MODE else "Cell mask (viz)",
        opacity=0.35
    )
    mask_layer.blending = "translucent_no_depth"

    id_layer = viewer.add_labels(
        cell_seg_viz.astype(np.uint16),
        name="ID (viz)",
        opacity=0.25
    )
    id_layer.blending = "translucent_no_depth"

    if len(df) > 0 and "location_ch2" in df:
        in_mask = (df["location_ch2"].to_numpy() == "cell")
        if np.any(in_mask):
            pts_zyx = np.stack([
                (df.loc[in_mask, "z_um"].to_numpy() / vz_um),
                (df.loc[in_mask, "y_um"].to_numpy() / vy_um),
                (df.loc[in_mask, "x_um"].to_numpy() / vx_um),
            ], axis=1)

            radii_vox = df.loc[in_mask, "radius_um"].to_numpy() / (np.sqrt(vx_um * vy_um) + 1e-12)
            sizes = np.clip(radii_vox * 2, 2, None).astype(np.float32)

            pts = viewer.add_points(pts_zyx, size=sizes, name="Lysosomes (inside 3D)")
            pts.face_color = [1.0, 0.0, 1.0, 1.0]  # MAGENTA in napari
            pts.edge_color = "black"
            pts.edge_width = 0.3

            cell_ids = df.loc[in_mask, "cell_id_ch2_viz"].to_numpy().astype(int) if "cell_id_ch2_viz" in df.columns else np.zeros(int(np.sum(in_mask)), dtype=int)
            lys_ids  = df.loc[in_mask, "lys_id_in_cell"].to_numpy().astype(int) if "lys_id_in_cell" in df.columns else np.zeros(int(np.sum(in_mask)), dtype=int)
            diams    = df.loc[in_mask, "diameter_um"].to_numpy().astype(float)

            info = np.array(
                [f"ID:{c}  Ly:{l}  Diam:{d:.3f}µm" for c, l, d in zip(cell_ids, lys_ids, diams)],
                dtype=object
            )
            pts.properties = {"info": info}

    keep = {"Lysosomes (inside 3D)", "ID (viz)", "Neurite mask (viz)", "Cell mask (viz)", "Ch1 raw", "Ch2 raw"}
    for lyr in list(viewer.layers):
        if lyr.name not in keep:
            viewer.layers.remove(lyr)

    try:
        viewer.camera.zoom = 1.2
    except Exception:
        pass

    napari.run()

In [None]:
#################last version###################

In [None]:
import os
import re
import numpy as np
import pandas as pd
import cv2
import imageio
import xml.etree.ElementTree as ET
from datetime import datetime

import czifile
import tifffile as tiff

from skimage.feature import blob_log
from skimage.filters import gaussian, threshold_local
from skimage.morphology import (
    remove_small_objects, binary_opening, binary_closing, ball, binary_erosion
)
from scipy.ndimage import distance_transform_edt as edt
from skimage.measure import label
from skimage.segmentation import watershed
from scipy.ndimage import binary_fill_holes

import napari
from scipy.ndimage import gaussian_filter1d
import colorsys

# ===============================
# USER MODE
# ===============================
# For neurites-only datasets (no soma), set True (recommended for your new image).
# For soma datasets, set False to use the original soma/watershed logic.
NEURITE_MODE = True

# Export full-size overlay RGB stack + mp4 (recommended)
EXPORT_FULLSIZE_OVERLAY = True

# OpenCV colors are BGR (not RGB)
LYS_EDGE_BGR = (0, 0, 0)          # black outline
LYS_MAGENTA_BGR = (255, 0, 255)   # magenta

# ===============================
# GUI (single unified interface)
# Advanced settings ONLY:
#   - MARGIN_UM (µm)
#   - OVERLAP_ALPHA (0..1)
#   - VIZ_MIN_VOXELS (voxels)
# ===============================
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, simpledialog


def get_user_config_gui(
    # basic defaults
    default_vxy_um=0.04,
    default_vz_um=None,          # if None and Z missing -> fallback to XY
    default_erode_mult=1.0,
    default_blob_threshold=0.001,

    # advanced defaults (ONLY the 3 you want)
    default_margin_um=0.5,          # µm (neurite-friendly default)
    default_overlap_alpha=0.4,      # unitless (0..1)
    default_neighbor_max_vox=6,     # voxels (hidden; fixed default)
    default_viz_min_voxels=200,     # voxels (neurite-friendly default)

    # fixed defaults (not shown in GUI)
    default_max_reasonable_vxy_um=0.5,
    default_ch1_smooth_sigma=1.0,
    default_blob_min_sigma=0.8,
    default_blob_max_sigma=3.0,
    default_blob_num_sigma=10,
    default_radial_max_radius_nm=300.0,
    default_radial_dr_nm=10.0,
    default_radial_min_drop_fraction=0.5,

    # neurite-friendly threshold defaults (fixed, not shown)
    default_ch2_smooth_sigma=0.9,
    default_thresh_block_size=151,
    default_thresh_offset_std_mult=0.25,

    default_video_fps=8,
    default_launch_viewer=True,
    default_generate_videos=True,
):
    cfg = {"ok": False}

    root = tk.Tk()
    root.title("Lysosome + Neurite Segmentation (GUI)")
    root.resizable(False, False)

    file_var = tk.StringVar(value="")
    out_var = tk.StringVar(value="")

    erode_var = tk.StringVar(value=str(default_erode_mult))
    blob_var  = tk.StringVar(value=str(default_blob_threshold))

    show_adv = tk.BooleanVar(value=False)

    margin_var   = tk.StringVar(value=str(default_margin_um))
    overlap_var  = tk.StringVar(value=str(default_overlap_alpha))
    vizmin_var   = tk.StringVar(value=str(default_viz_min_voxels))

    fps_var = tk.StringVar(value=str(default_video_fps))

    launch_viewer_var = tk.BooleanVar(value=bool(default_launch_viewer))
    gen_videos_var    = tk.BooleanVar(value=bool(default_generate_videos))

    def _suggest_output_dir(fp):
        if not fp:
            return ""
        raw_dir = os.path.dirname(fp)
        raw_base = os.path.splitext(os.path.basename(fp))[0]
        stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        return os.path.join(raw_dir, f"{raw_base}_outputs_{stamp}")

    def browse_file():
        fp = filedialog.askopenfilename(
            title="Select image file",
            filetypes=[("Image files", "*.tif *.tiff *.czi"), ("All files", "*.*")],
        )
        if fp:
            file_var.set(fp)
            if not out_var.get().strip():
                out_var.set(_suggest_output_dir(fp))

    def browse_output_dir():
        d = filedialog.askdirectory(title="Select output folder")
        if d:
            fp = file_var.get().strip()
            if fp:
                raw_base = os.path.splitext(os.path.basename(fp))[0]
                stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                out_var.set(os.path.join(d, f"{raw_base}_outputs_{stamp}"))
            else:
                out_var.set(d)

    def _err(msg):
        messagebox.showerror("Invalid input", msg)
        raise ValueError(msg)

    def _float_required(s, name):
        s = (s or "").strip()
        if s == "":
            _err(f"{name} is required.")
        try:
            return float(s.replace(",", "."))
        except Exception:
            _err(f"{name} must be a number (got: {s})")

    def _float_optional(s, name):
        s = (s or "").strip()
        if s == "":
            return None
        try:
            return float(s.replace(",", "."))
        except Exception:
            _err(f"{name} must be a number (got: {s})")

    def _int_required(s, name):
        s = (s or "").strip()
        if s == "":
            _err(f"{name} is required.")
        try:
            return int(float(s.replace(",", ".")))
        except Exception:
            _err(f"{name} must be an integer (got: {s})")

    def _toggle_adv():
        if show_adv.get():
            adv_frame.grid()
        else:
            adv_frame.grid_remove()

    def run_clicked():
        fp = file_var.get().strip()
        if not fp:
            _err("Please select a file.")
        if not os.path.isfile(fp):
            _err("Selected file does not exist.")

        outd = out_var.get().strip() or _suggest_output_dir(fp)

        erode = _float_required(erode_var.get(), "ERODE_MULT")
        blobt = _float_required(blob_var.get(), "blob_log threshold")
        fps   = _int_required(fps_var.get(), "video FPS")

        margin  = _float_required(margin_var.get(), "MARGIN_UM (µm)")
        overlap = _float_required(overlap_var.get(), "OVERLAP_ALPHA (0..1)")
        vizmin  = _int_required(vizmin_var.get(), "VIZ_MIN_VOXELS (voxels)")

        if not (0.0 <= overlap <= 1.0):
            _err("OVERLAP_ALPHA must be between 0 and 1.")

        cfg.update({
            "ok": True,
            "file_path": fp,
            "output_dir": outd,

            "ERODE_MULT": float(erode),
            "BLOB_THRESHOLD": float(blobt),

            "DEFAULT_VX_VY_UM": float(default_vxy_um),
            "DEFAULT_VZ_UM": None if default_vz_um is None else float(default_vz_um),

            "MAX_REASONABLE_VXY_UM": float(default_max_reasonable_vxy_um),

            # Advanced only
            "MARGIN_UM": float(margin),
            "OVERLAP_ALPHA": float(overlap),
            "VIZ_MIN_VOXELS": int(vizmin),

            # Fixed (hidden)
            "NEIGHBOR_MAX_VOX": int(default_neighbor_max_vox),

            # Fixed defaults (hidden)
            "CH1_SMOOTH_SIGMA": float(default_ch1_smooth_sigma),
            "BLOB_MIN_SIGMA": float(default_blob_min_sigma),
            "BLOB_MAX_SIGMA": float(default_blob_max_sigma),
            "BLOB_NUM_SIGMA": int(default_blob_num_sigma),

            "RADIAL_MAX_RADIUS_NM": float(default_radial_max_radius_nm),
            "RADIAL_DR_NM": float(default_radial_dr_nm),
            "RADIAL_MIN_DROP_FRACTION": float(default_radial_min_drop_fraction),

            "CH2_SMOOTH_SIGMA": float(default_ch2_smooth_sigma),
            "THRESH_BLOCK_SIZE": int(default_thresh_block_size),
            "THRESH_OFFSET_STD_MULT": float(default_thresh_offset_std_mult),

            "VIDEO_FPS": int(fps),
            "LAUNCH_VIEWER": bool(launch_viewer_var.get()),
            "GENERATE_VIDEOS": bool(gen_videos_var.get()),
        })

        root.destroy()

    def cancel_clicked():
        root.destroy()

    root.protocol("WM_DELETE_WINDOW", cancel_clicked)

    pad = {"padx": 10, "pady": 6}
    frm = ttk.Frame(root)
    frm.grid(row=0, column=0, sticky="nsew", **pad)

    r = 0
    ttk.Label(frm, text="Image file:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=file_var, width=60).grid(row=r, column=1, sticky="we")
    ttk.Button(frm, text="Browse...", command=browse_file).grid(row=r, column=2, sticky="e")
    r += 1

    ttk.Label(frm, text="Output folder:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=out_var, width=60).grid(row=r, column=1, sticky="we")
    ttk.Button(frm, text="Browse...", command=browse_output_dir).grid(row=r, column=2, sticky="e")
    r += 1

    ttk.Separator(frm).grid(row=r, column=0, columnspan=3, sticky="we", pady=8)
    r += 1

    ttk.Label(frm, text="ERODE_MULT:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=erode_var, width=20).grid(row=r, column=1, sticky="w")
    ttk.Label(frm, text="unitless").grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Label(frm, text="blob_log threshold:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=blob_var, width=20).grid(row=r, column=1, sticky="w")
    ttk.Label(frm, text="unitless").grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Separator(frm).grid(row=r, column=0, columnspan=3, sticky="we", pady=8)
    r += 1

    ttk.Checkbutton(frm, text="Launch Napari viewer", variable=launch_viewer_var)\
        .grid(row=r, column=0, columnspan=2, sticky="w")
    ttk.Checkbutton(frm, text="Generate videos", variable=gen_videos_var)\
        .grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Label(frm, text="Video FPS:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=fps_var, width=20).grid(row=r, column=1, sticky="w")
    ttk.Label(frm, text="frames/sec").grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Separator(frm).grid(row=r, column=0, columnspan=3, sticky="we", pady=8)
    r += 1

    ttk.Checkbutton(frm, text="Show advanced settings", variable=show_adv, command=_toggle_adv)\
        .grid(row=r, column=0, columnspan=3, sticky="w")
    r += 1

    adv_frame = ttk.LabelFrame(frm, text="Advanced")
    adv_frame.grid(row=r, column=0, columnspan=3, sticky="we", pady=6)
    adv_frame.grid_remove()

    rr = 0
    def add_row(label_txt, var, hint):
        nonlocal rr
        ttk.Label(adv_frame, text=label_txt).grid(row=rr, column=0, sticky="w", padx=8, pady=3)
        ttk.Entry(adv_frame, textvariable=var, width=18).grid(row=rr, column=1, sticky="w", padx=8, pady=3)
        ttk.Label(adv_frame, text=hint).grid(row=rr, column=2, sticky="w", padx=8, pady=3)
        rr += 1

    add_row("MARGIN_UM:", margin_var, "µm (soft band around mask)")
    add_row("OVERLAP_ALPHA:", overlap_var, "0..1 (sphere overlap fraction)")
    add_row("VIZ_MIN_VOXELS:", vizmin_var, "voxels (hide small components)")

    r += 1
    ttk.Separator(frm).grid(row=r, column=0, columnspan=3, sticky="we", pady=8)
    r += 1

    btns = ttk.Frame(frm)
    btns.grid(row=r, column=0, columnspan=3, sticky="e")
    ttk.Button(btns, text="Cancel", command=cancel_clicked).grid(row=0, column=0, padx=6)
    ttk.Button(btns, text="Run", command=run_clicked).grid(row=0, column=1, padx=6)

    root.mainloop()

    if not cfg.get("ok"):
        raise SystemExit("Cancelled.")
    return cfg


# ===============================
# Metadata parsing
# ===============================
def _parse_ome_xml(xml_text):
    if not xml_text:
        return None, None, None

    def grab(attr):
        m = re.search(fr'PhysicalSize{attr}="([\d\.eE+-]+)"', xml_text)
        return float(m.group(1)) if m else None

    return grab("X"), grab("Y"), grab("Z")


def _parse_czi_scaling(czi_text):
    """
    Robust CZI scaling parser:
    - Works when values are nested (<Value>...</Value>) or attributes
    - Works with namespaces
    - Converts meters -> µm (typical CZI)
    """
    if not czi_text:
        return None, None, None

    if isinstance(czi_text, (bytes, bytearray)):
        czi_text = czi_text.decode("utf-8", errors="ignore")

    czi_text = czi_text.replace("\x00", "")

    def _to_float(s):
        if s is None:
            return None
        try:
            return float(str(s).strip().replace(",", "."))
        except Exception:
            return None

    def _to_um(val, unit_hint=None):
        if val is None:
            return None
        if unit_hint:
            u = str(unit_hint).strip().lower()
            if u in ("m", "meter", "metre", "meters", "metres"):
                return val * 1e6
            if u in ("µm", "um", "micron", "microns", "micrometer", "micrometre"):
                return val
            if u in ("nm", "nanometer", "nanometre", "nanometers", "nanometres"):
                return val / 1000.0

        # heuristic: CZI values often in meters (~1e-8 to 1e-6)
        if val < 1e-3:
            return val * 1e6
        # sometimes already in µm
        if val < 10:
            return val
        # sometimes in nm
        if val < 1e5:
            return val / 1000.0
        return None

    try:
        root = ET.fromstring(czi_text)
    except Exception:
        def _grab(axis):
            mm = re.search(
                rf'<Distance[^>]*Id="{axis}"[^>]*>.*?<Value>\s*([0-9eE\+\-\.]+)\s*</Value>',
                czi_text,
                flags=re.IGNORECASE | re.DOTALL,
            )
            return _to_float(mm.group(1)) if mm else None

        return _to_um(_grab("X")), _to_um(_grab("Y")), _to_um(_grab("Z"))

    sx = sy = sz = None
    for d in root.findall(".//{*}Distance"):
        axis = d.attrib.get("Id") or d.attrib.get("id") or d.attrib.get("Axis") or d.attrib.get("axis")
        if not axis:
            continue
        axis = axis.upper()
        unit = d.attrib.get("Unit") or d.attrib.get("unit")

        valf = _to_float(d.attrib.get("Value") or d.attrib.get("value"))

        if valf is None:
            v_el = d.find(".//{*}Value")
            if v_el is not None and v_el.text:
                valf = _to_float(v_el.text)

        if valf is None:
            for child in d.iter():
                if child is d:
                    continue
                if str(child.tag).lower().endswith("value"):
                    valf = _to_float(child.attrib.get("Value") or child.attrib.get("value")) or _to_float(child.text)
                    if valf is not None:
                        break

        val_um = _to_um(valf, unit_hint=unit)
        if val_um is None:
            continue

        if axis == "X":
            sx = val_um
        elif axis == "Y":
            sy = val_um
        elif axis == "Z":
            sz = val_um

    return sx, sy, sz


def load_any(file_path):
    ext = os.path.splitext(file_path)[1].lower()

    if ext in (".tif", ".tiff"):
        with tiff.TiffFile(file_path) as tf:
            arr = tf.asarray()
            try:
                ome_xml = tf.ome_metadata
            except Exception:
                ome_xml = None
            vx_um = vy_um = vz_um = None
            if ome_xml:
                vx_um, vy_um, vz_um = _parse_ome_xml(ome_xml)

        img = np.squeeze(arr)
        if img.ndim == 4:
            if img.shape[0] == 2:
                ch1, ch2 = img[0], img[1]
            elif img.shape[1] == 2:
                ch1, ch2 = img[:, 0], img[:, 1]
            elif img.shape[-1] == 2:
                ch1, ch2 = img[..., 0], img[..., 1]
            else:
                raise RuntimeError("Unexpected TIFF shape for 2 channels")
        else:
            raise RuntimeError("Unexpected TIFF shape")
        return ch1, ch2, (vx_um, vy_um, vz_um), {"type": "tiff"}

    if ext == ".czi":
        with czifile.CziFile(file_path) as cf:
            arr = cf.asarray()
            try:
                czi_xml = cf.metadata()
            except Exception:
                czi_xml = None

        vx_um = vy_um = vz_um = None
        if czi_xml:
            vx_um, vy_um, vz_um = _parse_czi_scaling(czi_xml)

        img = np.squeeze(arr)
        if img.ndim == 4:
            if img.shape[0] == 2:
                ch1, ch2 = img[0], img[1]
            elif img.shape[1] == 2:
                ch1, ch2 = img[:, 0], img[:, 1]
            elif img.shape[-1] == 2:
                ch1, ch2 = img[..., 0], img[..., 1]
            else:
                raise RuntimeError("Unexpected CZI shape for 2 channels")
        else:
            raise RuntimeError("Unexpected CZI shape")

        return ch1, ch2, (vx_um, vy_um, vz_um), {"type": "czi"}

    raise ValueError("Unsupported file format")


# ===============================
# Helper: refine radii by distance transform
# ===============================
def refine_radii_via_dt(img3d, blobs, win_px=40, bin_method="sauvola"):
    from skimage.filters import threshold_sauvola, threshold_local, threshold_otsu
    from skimage.morphology import remove_small_objects, disk, binary_opening
    from scipy.ndimage import distance_transform_edt as _edt
    from skimage.measure import label as _label

    if blobs is None or len(blobs) == 0:
        return blobs

    Z, H, W = img3d.shape
    out = blobs.copy().astype(np.float32)

    for i, (zc, yc, xc, _) in enumerate(out):
        z, y, x = int(round(zc)), int(round(yc)), int(round(xc))
        if not (0 <= z < Z and 0 <= y < H and 0 <= x < W):
            continue

        y1, y2 = max(0, y - win_px), min(H, y + win_px + 1)
        x1, x2 = max(0, x - win_px), min(W, x + win_px + 1)
        patch = img3d[z, y1:y2, x1:x2]

        if bin_method == "sauvola":
            ws = max(11, 2 * (win_px // 2) + 1)
            thr = threshold_sauvola(patch, window_size=ws, k=0.4)
            bw = patch > thr
        elif bin_method == "local":
            ws = max(11, 2 * (win_px // 2) + 1)
            thr = threshold_local(patch, block_size=ws, offset=-0.4 * np.std(patch))
            bw = patch > thr
        else:
            try:
                thr = threshold_otsu(patch)
                bw = patch > thr
            except ValueError:
                continue

        bw = binary_opening(bw, footprint=disk(1))
        bw = remove_small_objects(bw, min_size=3, connectivity=2)

        yy, xx = y - y1, x - x1
        if not (0 <= yy < bw.shape[0] and 0 <= xx < bw.shape[1]) or not bw[yy, xx]:
            continue

        lab = _label(bw)
        lbl = lab[yy, xx]
        if lbl == 0:
            continue
        bw_obj = (lab == lbl)
        dt = _edt(bw_obj)
        r_px = float(dt[yy, xx])
        if r_px <= 0:
            continue
        out[i, 3] = r_px

    return out


def refine_radii_via_radial_intensity(
    img3d,
    blobs,
    vx_um,
    vy_um,
    vz_um,
    max_radius_nm=300.0,
    dr_nm=10.0,
    min_drop_fraction=0.5,
):
    if blobs is None or len(blobs) == 0:
        return blobs

    img = img3d.astype(np.float32)
    Z, Y, X = img.shape
    blobs_out = blobs.copy().astype(np.float32)

    max_r_um = max_radius_nm / 1000.0
    dr_um = dr_nm / 1000.0

    r_edges = np.arange(0.0, max_r_um + dr_um, dr_um)
    if r_edges.size < 2:
        return blobs_out
    r_centers = 0.5 * (r_edges[:-1] + r_edges[1:])

    px_um_xy = float(np.sqrt(vx_um * vy_um))

    for i, (zc, yc, xc, r_px_init) in enumerate(blobs_out):
        z0 = int(round(zc))
        y0 = int(round(yc))
        x0 = int(round(xc))

        if not (0 <= z0 < Z and 0 <= y0 < Y and 0 <= x0 < X):
            continue

        rz = max(1, int(np.ceil(max_r_um / vz_um)))
        ry = max(1, int(np.ceil(max_r_um / vy_um)))
        rx = max(1, int(np.ceil(max_r_um / vx_um)))

        z1, z2 = max(0, z0 - rz), min(Z, z0 + rz + 1)
        y1, y2 = max(0, y0 - ry), min(Y, y0 + ry + 1)
        x1, x2 = max(0, x0 - rx), min(X, x0 + rx + 1)
        if z1 >= z2 or y1 >= y2 or x1 >= x2:
            continue

        patch = img[z1:z2, y1:y2, x1:x2]

        zz, yy, xx = np.mgrid[z1:z2, y1:y2, x1:x2]
        dz_um = (zz - z0) * vz_um
        dy_um = (yy - y0) * vy_um
        dx_um = (xx - x0) * vx_um
        r_um = np.sqrt(dz_um**2 + dy_um**2 + dx_um**2)

        mask = (r_um <= max_r_um)
        if not np.any(mask):
            continue

        r_vals = r_um[mask].ravel()
        I_vals = patch[mask].ravel()

        bin_idx = np.digitize(r_vals, r_edges) - 1
        valid = (bin_idx >= 0) & (bin_idx < r_centers.size)
        if not np.any(valid):
            continue

        bin_idx = bin_idx[valid]
        I_vals = I_vals[valid]

        sums = np.bincount(bin_idx, weights=I_vals, minlength=r_centers.size)
        counts = np.bincount(bin_idx, minlength=r_centers.size)
        with np.errstate(invalid="ignore", divide="ignore"):
            prof = sums / np.maximum(counts, 1)

        have = counts > 0
        if not np.any(have):
            continue

        r_prof = r_centers[have]
        I_prof = prof[have].astype(np.float32)

        I_smooth = gaussian_filter1d(I_prof, sigma=1.0)
        I_max = float(I_smooth.max())
        if I_max <= 0:
            continue
        I_min = float(I_smooth.min())

        if (I_max - I_min) / max(I_max, 1e-9) < min_drop_fraction:
            continue

        I_half = I_min + 0.5 * (I_max - I_min)

        peak_idx = int(np.argmax(I_smooth))
        n_bins = len(I_smooth)

        left_idx = peak_idx
        while left_idx > 0 and I_smooth[left_idx] >= I_half:
            left_idx -= 1
        if left_idx < peak_idx and I_smooth[left_idx] < I_half:
            left_idx += 1

        right_idx = peak_idx
        while right_idx < n_bins - 1 and I_smooth[right_idx] >= I_half:
            right_idx += 1
        if right_idx > peak_idx and I_smooth[right_idx] < I_half:
            right_idx -= 1

        if right_idx <= left_idx:
            continue

        radius_um = 0.5 * (float(r_prof[right_idx]) - float(r_prof[left_idx]))
        if radius_um <= 0:
            continue

        r_fwhm_px = radius_um / max(px_um_xy, 1e-9)
        blobs_out[i, 3] = max(float(r_px_init), float(r_fwhm_px))

    return blobs_out


# ===============================
# Distance/overlap-aware classification helpers
# ===============================
from scipy.ndimage import distance_transform_edt as _edt


def nearest_cell_label(cell_seg, z, y, x, max_r=12):
    Z, Y, X = cell_seg.shape
    for r in range(1, max_r + 1):
        z1, z2 = max(0, z - r), min(Z, z + r + 1)
        y1, y2 = max(0, y - r), min(Y, y + r + 1)
        x1, x2 = max(0, x - r), min(X, x + r + 1)
        patch = cell_seg[z1:z2, y1:y2, x1:x2]
        lab = patch[patch > 0]
        if lab.size:
            return int(np.bincount(lab.ravel()).argmax())
    return 0


def sphere_overlap_fraction(zc_um, yc_um, xc_um, r_um, mask_bool, vx_um, vy_um, vz_um):
    if r_um <= 0:
        return 0.0

    zc = int(round(zc_um / vz_um))
    yc = int(round(yc_um / vy_um))
    xc = int(round(xc_um / vx_um))

    rz = max(1, int(np.ceil(r_um / vz_um)))
    ry = max(1, int(np.ceil(r_um / vy_um)))
    rx = max(1, int(np.ceil(r_um / vx_um)))

    Z, Y, X = mask_bool.shape
    z1, z2 = max(0, zc - rz), min(Z, zc + rz + 1)
    y1, y2 = max(0, yc - ry), min(Y, yc + ry + 1)
    x1, x2 = max(0, xc - rx), min(X, xc + rx + 1)
    if z1 >= z2 or y1 >= y2 or x1 >= x2:
        return 0.0

    zz, yy, xx = np.mgrid[z1:z2, y1:y2, x1:x2]
    dz = (zz - zc) * vz_um
    dy = (yy - yc) * vy_um
    dx = (xx - xc) * vx_um
    sphere = (dz*dz + dy*dy + dx*dx) <= (r_um * r_um)

    if not np.any(sphere):
        return 0.0

    in_mask = mask_bool[z1:z2, y1:y2, x1:x2] & sphere
    return float(in_mask.sum()) / float(sphere.sum())


# ===============================
# Full-size overlay exporter (RGB stack + MP4)
# ===============================
def _norm_u8_stack(vol):
    vmin, vmax = float(vol.min()), float(vol.max())
    if vmax > vmin:
        out = (np.clip((vol - vmin) / (vmax - vmin), 0, 1) * 255.0).astype(np.uint8)
    else:
        out = np.zeros_like(vol, dtype=np.uint8)
    return out


def make_label_colormap(n_labels, seed_hue=0.13):
    colors = np.zeros((n_labels + 1, 3), dtype=np.uint8)
    if n_labels <= 0:
        return colors
    for i in range(1, n_labels + 1):
        h = (seed_hue + (i - 1) / max(n_labels, 1)) % 1.0
        s = 1.0
        v = 1.0
        r, g, b = colorsys.hsv_to_rgb(h, s, v)
        colors[i] = (int(255 * r), int(255 * g), int(255 * b))
    return colors


def export_fullsize_overlay_stack(
    img_ch1,
    img_ch2_raw,
    cell_seg_viz,
    df,
    vx_um, vy_um, vz_um,
    output_dir,
    alpha_labels=0.45,
    draw_only_inside=True,
    fps=8,
    basename="FULLSIZE_overlay_ID_Lysosomes_MAGENTA",
):
    os.makedirs(output_dir, exist_ok=True)

    ch1_u8 = _norm_u8_stack(img_ch1.astype(np.float32))
    ch2_u8 = _norm_u8_stack(img_ch2_raw.astype(np.float32))

    Z, H, W = ch2_u8.shape
    n_labels = int(cell_seg_viz.max()) if isinstance(cell_seg_viz, np.ndarray) else 0
    cmap = make_label_colormap(n_labels, seed_hue=0.13)

    use_df = None
    if isinstance(df, pd.DataFrame) and len(df) > 0 and {"z_um", "y_um", "x_um", "radius_um"}.issubset(df.columns):
        if draw_only_inside and "location_ch2" in df.columns:
            use_df = df[df["location_ch2"] == "cell"].copy()
        else:
            use_df = df.copy()
        use_df = use_df[
            np.isfinite(use_df["z_um"]) &
            np.isfinite(use_df["y_um"]) &
            np.isfinite(use_df["x_um"]) &
            np.isfinite(use_df["radius_um"])
        ].copy()

    px_um_xy = float(np.sqrt(vx_um * vy_um))
    frames = np.zeros((Z, H, W, 3), dtype=np.uint8)

    for z in range(Z):
        # Background (same style as your RAW composite): B=Ch1, G=Ch2, R=Ch1
        base = np.dstack([ch1_u8[z], ch2_u8[z], ch1_u8[z]]).astype(np.float32)

        # Colored IDs
        lab2d = cell_seg_viz[z].astype(np.int32)
        lab_rgb = cmap[lab2d].astype(np.float32)

        mask = (lab2d > 0)[..., None].astype(np.float32)
        out = base * (1.0 - alpha_labels * mask) + lab_rgb * (alpha_labels * mask)

        # Lysosomes (MAGENTA)
        if use_df is not None and len(use_df) > 0:
            zc = (use_df["z_um"].to_numpy() / vz_um).astype(float)
            yc = (use_df["y_um"].to_numpy() / vy_um).astype(float)
            xc = (use_df["x_um"].to_numpy() / vx_um).astype(float)
            r_um = use_df["radius_um"].to_numpy().astype(float)

            dz_um = np.abs(zc - z) * vz_um
            hits = dz_um <= r_um
            if np.any(hits):
                r_proj_um = np.sqrt(np.clip(r_um[hits]**2 - dz_um[hits]**2, 0.0, None))
                r_proj_px = r_proj_um / max(px_um_xy, 1e-12)
                ys = np.rint(yc[hits]).astype(int)
                xs = np.rint(xc[hits]).astype(int)

                out_u8 = np.clip(out, 0, 255).astype(np.uint8)
                for y, x, rp in zip(ys, xs, r_proj_px):
                    rr = int(max(3, round(rp)))
                    if 0 <= y < H and 0 <= x < W and rr > 0:
                        cv2.circle(out_u8, (x, y), rr, LYS_EDGE_BGR, 4, lineType=cv2.LINE_AA)
                        cv2.circle(out_u8, (x, y), rr, LYS_MAGENTA_BGR, 2, lineType=cv2.LINE_AA)
                        cv2.circle(out_u8, (x, y), 1,  LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)
                out = out_u8.astype(np.float32)

        frames[z] = np.clip(out, 0, 255).astype(np.uint8)

    # Save RGB TIFF stack
    tiff_path = os.path.join(output_dir, f"{basename}.tif")
    tiff.imwrite(tiff_path, frames, photometric="rgb")
    print("Saved full-size RGB TIFF stack:", tiff_path)

    # Save MP4 (fallback GIF)
    mp4_path = os.path.join(output_dir, f"{basename}.mp4")
    try:
        with imageio.get_writer(
            mp4_path,
            fps=int(fps),
            format="FFMPEG",
            codec="libx264",
            macro_block_size=None
        ) as w:
            for fr in frames:
                w.append_data(fr)
        print("Saved full-size MP4:", mp4_path)
    except Exception as e:
        gif_path = os.path.join(output_dir, f"{basename}.gif")
        imageio.mimsave(gif_path, list(frames), fps=int(fps))
        print("FFMPEG failed, saved GIF instead:", gif_path, "Error:", e)

    return tiff_path, mp4_path


# ===============================
# MAIN
# ===============================
DEFAULT_VX_VY_UM = 0.04
DEFAULT_VZ_UM = None  # if None and Z missing -> fallback to XY

cfg = get_user_config_gui(
    default_vxy_um=DEFAULT_VX_VY_UM,
    default_vz_um=DEFAULT_VZ_UM,
    default_erode_mult=1.0,
    default_blob_threshold=0.001,
)

file_path = cfg["file_path"]
output_dir = cfg["output_dir"]
os.makedirs(output_dir, exist_ok=True)
os.chdir(output_dir)

print("Selected file:", file_path)
print("Outputs will be saved to:", output_dir)

ERODE_MULT = cfg["ERODE_MULT"]
BLOB_THRESHOLD = cfg["BLOB_THRESHOLD"]

MAX_REASONABLE_VXY_UM = cfg["MAX_REASONABLE_VXY_UM"]
MARGIN_UM = cfg["MARGIN_UM"]
OVERLAP_ALPHA = cfg["OVERLAP_ALPHA"]
NEIGHBOR_MAX_VOX = cfg["NEIGHBOR_MAX_VOX"]
VIZ_MIN_VOXELS = cfg["VIZ_MIN_VOXELS"]

CH1_SMOOTH_SIGMA = cfg["CH1_SMOOTH_SIGMA"]
BLOB_MIN_SIGMA = cfg["BLOB_MIN_SIGMA"]
BLOB_MAX_SIGMA = cfg["BLOB_MAX_SIGMA"]
BLOB_NUM_SIGMA = cfg["BLOB_NUM_SIGMA"]

RADIAL_MAX_RADIUS_NM = cfg["RADIAL_MAX_RADIUS_NM"]
RADIAL_DR_NM = cfg["RADIAL_DR_NM"]
RADIAL_MIN_DROP_FRACTION = cfg["RADIAL_MIN_DROP_FRACTION"]

CH2_SMOOTH_SIGMA = cfg["CH2_SMOOTH_SIGMA"]
THRESH_BLOCK_SIZE = cfg["THRESH_BLOCK_SIZE"]
THRESH_OFFSET_STD_MULT = cfg["THRESH_OFFSET_STD_MULT"]

FPS = cfg["VIDEO_FPS"]
LAUNCH_VIEWER = cfg["LAUNCH_VIEWER"]
GENERATE_VIDEOS = cfg["GENERATE_VIDEOS"]

# Load data + metadata
img_ch1, img_ch2, (vx_um, vy_um, vz_um), meta = load_any(file_path)
print(f"[metadata] vx_um={vx_um}  vy_um={vy_um}  vz_um={vz_um}")

# Prompt only when metadata is missing
def _prompt_missing_size(title, prompt, default_value):
    root = tk.Tk()
    root.withdraw()
    try:
        root.attributes("-topmost", True)
    except Exception:
        pass

    use_default = messagebox.askyesno(
        title,
        f"{prompt}\n\nUse default value: {default_value} ?"
    )
    if use_default:
        root.destroy()
        return float(default_value)

    val = simpledialog.askfloat(
        title,
        "Enter a new value:",
        initialvalue=float(default_value),
        minvalue=1e-12
    )
    root.destroy()
    if val is None:
        raise SystemExit("Cancelled.")
    return float(val)

# XY
if vx_um is None or vy_um is None:
    vx_um = vy_um = _prompt_missing_size(
        title="Missing metadata",
        prompt="XY pixel size metadata is missing (µm/px).",
        default_value=float(cfg["DEFAULT_VX_VY_UM"]),
    )

# Z
if vz_um is None:
    z_default = float(cfg["DEFAULT_VZ_UM"]) if (cfg["DEFAULT_VZ_UM"] is not None) else float(vx_um)
    vz_um = _prompt_missing_size(
        title="Missing metadata",
        prompt="Z step metadata is missing (µm/slice).",
        default_value=z_default,
    )

if vx_um > MAX_REASONABLE_VXY_UM:
    raise ValueError(f"XY pixel size too large: {vx_um} µm/px")

px_um_xy = float(np.sqrt(vx_um * vy_um))
px_um = px_um_xy * 0.55
voxel_um3 = vx_um * vy_um * vz_um
print(f"Voxel size (µm): X={vx_um}, Y={vy_um}, Z={vz_um}")

# ===== Aliases =====
image = img_ch1
image_2 = img_ch2

# ==========================================
# Lysosome detection (Ch1)
# ==========================================
image_smooth = gaussian(image, sigma=CH1_SMOOTH_SIGMA)

blobs = blob_log(
    image_smooth,
    min_sigma=BLOB_MIN_SIGMA,
    max_sigma=BLOB_MAX_SIGMA,
    num_sigma=BLOB_NUM_SIGMA,
    threshold=BLOB_THRESHOLD
)

if len(blobs) > 0:
    blobs[:, 3] *= np.sqrt(3)  # LoG radius correction (pixels)
    blobs = refine_radii_via_dt(image_smooth, blobs)
    blobs = refine_radii_via_radial_intensity(
        image_smooth,
        blobs,
        vx_um, vy_um, vz_um,
        max_radius_nm=RADIAL_MAX_RADIUS_NM,
        dr_nm=RADIAL_DR_NM,
        min_drop_fraction=RADIAL_MIN_DROP_FRACTION
    )

# Peak intensity in raw 16-bit Ch1
peak_gray = np.zeros(len(blobs), dtype=np.uint16)
Z0, Y0, X0 = image.shape
rad = 1

for i, (zc, yc, xc, _) in enumerate(blobs):
    zc_i = int(round(zc))
    yc_i = int(round(yc))
    xc_i = int(round(xc))

    z1, z2 = max(0, zc_i - rad), min(Z0, zc_i + rad + 1)
    y1, y2 = max(0, yc_i - rad), min(Y0, yc_i + rad + 1)
    x1, x2 = max(0, xc_i - rad), min(X0, xc_i + rad + 1)

    peak_gray[i] = np.max(image[z1:z2, y1:y2, x1:x2]).astype(np.uint16)

if len(blobs) > 0:
    z_um = blobs[:, 0] * vz_um
    y_um = blobs[:, 1] * vy_um
    x_um = blobs[:, 2] * vx_um
    radius_um = blobs[:, 3] * px_um
    diameter_um = 2 * radius_um
    volume_um3 = (4/3) * np.pi * radius_um**3
    blob_ids = np.arange(1, len(blobs) + 1, dtype=int)
else:
    z_um = y_um = x_um = radius_um = diameter_um = volume_um3 = np.array([])
    blob_ids = np.array([], dtype=int)

df = pd.DataFrame({
    "id": blob_ids,
    "z_um": z_um,
    "y_um": y_um,
    "x_um": x_um,
    "radius_um": radius_um,
    "diameter_um": diameter_um,
    "volume_um3": volume_um3,
    "peak_gray": peak_gray,
})
df.to_csv("lysosome_blobs_regions.csv", index=False)
print("Saved: lysosome_blobs_regions.csv")

# ==========================================
# CH2 segmentation (neurites mask)
# ==========================================
vol = image_2.astype(np.float32)
vmin, vmax = float(vol.min()), float(vol.max())
if vmax > vmin:
    vol = (vol - vmin) / (vmax - vmin)
else:
    vol[:] = 0.0

ch2 = gaussian(vol, sigma=CH2_SMOOTH_SIGMA, preserve_range=True)

neuron_mask = np.zeros_like(ch2, dtype=bool)
for z in range(ch2.shape[0]):
    R = ch2[z]
    t = threshold_local(R, block_size=THRESH_BLOCK_SIZE, offset=-THRESH_OFFSET_STD_MULT * np.std(R))
    neuron_mask[z] = R > t

if NEURITE_MODE:
    # neurite-safe cleanup: connect gaps, remove small speckles; avoid erosion/opening
    neuron_mask = binary_closing(neuron_mask, ball(1))
    neuron_mask = remove_small_objects(neuron_mask, min_size=200, connectivity=3)
    # NOTE: do NOT fill holes globally for neurites (can fill loops)
else:
    # fallback to soma-style refinement if you ever switch back
    neuron_mask = binary_fill_holes(neuron_mask)

print("neurite voxels:", int(neuron_mask.sum()))

# ---- ID segmentation ----
if NEURITE_MODE:
    # label connected neurite networks (often 1)
    cell_seg = label(neuron_mask, connectivity=3).astype(np.int32)
    cell_mask = neuron_mask.copy()
    print("Detected components (neurite networks):", int(cell_seg.max()))
else:
    # Original soma logic (kept for compatibility)
    dist = edt(neuron_mask)
    cell_min_radius_vox = 1
    cell_mask = (dist >= cell_min_radius_vox)
    cell_mask &= neuron_mask
    cell_mask = binary_fill_holes(binary_closing(binary_opening(cell_mask, ball(1)), ball(2)))

    body_lab = label(cell_mask, connectivity=3)
    n_cells = int(body_lab.max())
    print(f"Detected {n_cells} cells (soma).")

    if n_cells > 0:
        dist_inside = edt(neuron_mask)
        cell_seg = watershed(-dist_inside, markers=body_lab, mask=neuron_mask)
    else:
        cell_seg = np.zeros_like(neuron_mask, dtype=np.int32)

print("ID voxels:", int((cell_seg > 0).sum()))

# ==========================================
# Visualization-only filtering (hide tiny components) + serial IDs
# ==========================================
cell_seg_viz = cell_seg.copy()
cell_id_map_viz = {}

if isinstance(cell_seg_viz, np.ndarray) and cell_seg_viz.max() > 0:
    counts_viz = np.bincount(cell_seg_viz.ravel().astype(np.int64))
    tiny_labels = np.where(counts_viz < VIZ_MIN_VOXELS)[0]
    tiny_labels = tiny_labels[tiny_labels > 0]
    if tiny_labels.size > 0:
        tiny_mask = np.isin(cell_seg_viz, tiny_labels)
        cell_seg_viz[tiny_mask] = 0

    unique_labels = np.unique(cell_seg_viz)
    unique_labels = unique_labels[unique_labels > 0]

    if unique_labels.size > 0:
        new_seg = np.zeros_like(cell_seg_viz, dtype=np.int32)
        for new_id, old_id in enumerate(unique_labels, start=1):
            new_seg[cell_seg_viz == old_id] = new_id
            cell_id_map_viz[old_id] = new_id
        cell_seg_viz = new_seg

cell_mask_viz = (cell_seg_viz > 0)

# ==========================================
# Classify lysosomes: inside vs outside neurite mask, and assign ID
# ==========================================
dist_out_um = _edt(~neuron_mask, sampling=(vz_um, vy_um, vx_um)).astype(np.float32)
soft_cell_mask = neuron_mask | (dist_out_um <= MARGIN_UM)

location_ch2 = []
cell_id_list = []

if len(df) > 0:
    Z, Y, X = neuron_mask.shape
    for (zc_um, yc_um, xc_um, r_um) in df[["z_um", "y_um", "x_um", "radius_um"]].to_numpy():
        zz = int(round(zc_um / vz_um))
        yy = int(round(yc_um / vy_um))
        xx = int(round(xc_um / vx_um))

        inside_hard = (0 <= zz < Z and 0 <= yy < Y and 0 <= xx < X and neuron_mask[zz, yy, xx])
        inside_soft = (0 <= zz < Z and 0 <= yy < Y and 0 <= xx < X and soft_cell_mask[zz, yy, xx])

        is_inside = bool(inside_hard)
        if not is_inside and inside_soft:
            is_inside = True

        if not is_inside:
            frac = sphere_overlap_fraction(zc_um, yc_um, xc_um, r_um, neuron_mask, vx_um, vy_um, vz_um)
            if frac >= OVERLAP_ALPHA:
                is_inside = True

        if is_inside:
            cid = 0
            if 0 <= zz < Z and 0 <= yy < Y and 0 <= xx < X:
                if cell_seg[zz, yy, xx] != 0:
                    cid = int(cell_seg[zz, yy, xx])
                else:
                    cid = nearest_cell_label(cell_seg, zz, yy, xx, max_r=NEIGHBOR_MAX_VOX)
            location_ch2.append("cell")
            cell_id_list.append(cid)
        else:
            location_ch2.append("outside")
            cell_id_list.append(0)

    df["location_ch2"] = location_ch2
    df["cell_id_ch2"] = cell_id_list
    df["cell_id_ch2_viz"] = df["cell_id_ch2"].map(cell_id_map_viz).fillna(0).astype(int) if isinstance(cell_id_map_viz, dict) else 0

    # Serial lysosome IDs within each ID component
    df["lys_id_in_cell"] = 0
    mask_in = (df["location_ch2"] == "cell") & (df["cell_id_ch2"] > 0)
    df_sorted = df.loc[mask_in].sort_values(["cell_id_ch2", "z_um", "y_um", "x_um"]).copy()
    df.loc[df_sorted.index, "lys_id_in_cell"] = (df_sorted.groupby("cell_id_ch2").cumcount().to_numpy() + 1).astype(int)

    df.to_csv("lysosomes_with_cell_vs_outside.csv", index=False)
    print("Saved: lysosomes_with_cell_vs_outside.csv")

    df.groupby("location_ch2").size().reset_index(name="count").to_csv("lysosome_counts_cell_vs_outside.csv", index=False)
    (df[df["location_ch2"] == "cell"]
        .groupby("cell_id_ch2").size()
        .reset_index(name="count")
        .to_csv("lysosome_counts_by_cell.csv", index=False))
    print("Saved: lysosome_counts_cell_vs_outside.csv, lysosome_counts_by_cell.csv")

# Per-ID volumes (µm^3)
cell_volume_df = pd.DataFrame(columns=["cell_id_ch2", "voxel_count", "volume_um3"])
if isinstance(cell_seg, np.ndarray) and cell_seg.max() > 0:
    counts = np.bincount(cell_seg.ravel().astype(np.int64))
    cell_ids = np.arange(1, counts.size, dtype=int)
    voxels = counts[1:].astype(np.int64)
    vol_um3 = voxels.astype(float) * voxel_um3

    cell_volume_df = pd.DataFrame({
        "cell_id_ch2": cell_ids,
        "voxel_count": voxels,
        "volume_um3": vol_um3
    })
    cell_volume_df.to_csv("cell_volumes_ch2.csv", index=False)
    print("Saved: cell_volumes_ch2.csv")

    try:
        if len(df) > 0 and "location_ch2" in df and "cell_id_ch2" in df:
            lys_counts = (
                df[df["location_ch2"] == "cell"]
                .groupby("cell_id_ch2").size()
                .reset_index(name="lysosome_count")
            )
        else:
            lys_counts = pd.DataFrame(columns=["cell_id_ch2", "lysosome_count"])

        merged = cell_volume_df.merge(lys_counts, on="cell_id_ch2", how="left").fillna({"lysosome_count": 0})
        merged.to_csv("cell_metrics_ch2.csv", index=False)
        print("Saved: cell_metrics_ch2.csv")
    except Exception as e:
        print("Merge with lysosome counts failed:", e)

# ==========================================
# Export full-size overlay (RAW + colored IDs + MAGENTA lysosomes)
# ==========================================
if EXPORT_FULLSIZE_OVERLAY:
    export_fullsize_overlay_stack(
        img_ch1=img_ch1,
        img_ch2_raw=img_ch2,
        cell_seg_viz=cell_seg_viz,
        df=df,
        vx_um=vx_um, vy_um=vy_um, vz_um=vz_um,
        output_dir=output_dir,
        alpha_labels=0.45,
        draw_only_inside=True,   # True = only inside neurites; set False for all
        fps=FPS,
        basename="FULLSIZE_overlay_ID_Lysosomes_MAGENTA"
    )

# ==========================================
# Videos (MAGENTA lysosomes)
# ==========================================
if GENERATE_VIDEOS:
    img_norm_2 = (ch2 * 255).astype(np.uint8)
    frames_fused = []
    Z = img_norm_2.shape[0]
    px_um_xy = float(np.sqrt(vx_um * vy_um))

    for z in range(Z):
        base = cv2.cvtColor(img_norm_2[z], cv2.COLOR_GRAY2BGR)

        mask_u8 = (cell_mask_viz[z].astype(np.uint8) * 255)
        overlay = base.copy()
        overlay[..., 1] = np.maximum(overlay[..., 1], mask_u8)
        overlay = cv2.addWeighted(base, 1.0, overlay, 0.35, 0.0)

        drew_any = False

        if isinstance(df, pd.DataFrame) and len(df) > 0:
            dfv = df[
                np.isfinite(df["z_um"]) &
                np.isfinite(df["y_um"]) &
                np.isfinite(df["x_um"]) &
                np.isfinite(df["radius_um"])
            ]
            if not dfv.empty:
                zc = (dfv["z_um"].to_numpy() / vz_um).astype(float)
                yc = (dfv["y_um"].to_numpy() / vy_um).astype(float)
                xc = (dfv["x_um"].to_numpy() / vx_um).astype(float)
                r_um = dfv["radius_um"].to_numpy().astype(float)

                dz_um = np.abs(zc - z) * vz_um
                hits = dz_um <= r_um

                if np.any(hits):
                    r_proj_um = np.sqrt(np.clip(r_um[hits]**2 - dz_um[hits]**2, 0.0, None))
                    r_proj_px = r_proj_um / max(px_um_xy, 1e-12)

                    ys = np.rint(yc[hits]).astype(int)
                    xs = np.rint(xc[hits]).astype(int)

                    H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
                    thickness = 2

                    for y, x, rpv in zip(ys, xs, r_proj_px):
                        rr = int(max(3, round(rpv)))
                        if 0 <= y < H and 0 <= x < W and rr > 0:
                            cv2.circle(overlay, (x, y), rr, LYS_EDGE_BGR, thickness + 2, lineType=cv2.LINE_AA)
                            cv2.circle(overlay, (x, y), rr, LYS_MAGENTA_BGR, thickness, lineType=cv2.LINE_AA)
                            cv2.circle(overlay, (x, y), 1,  LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)
                    drew_any = True

        # Fallback: draw raw blobs if df empty
        if (not drew_any) and (blobs is not None) and (len(blobs) > 0):
            z_blobs = blobs[np.abs(blobs[:, 0] - z) < 0.5]
            H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
            thickness = 2
            for b in z_blobs:
                y, x = int(round(b[1])), int(round(b[2]))
                r = int(max(3, round(b[3])))
                if 0 <= y < H and 0 <= x < W and r > 0:
                    cv2.circle(overlay, (x, y), r, LYS_EDGE_BGR, thickness + 2, lineType=cv2.LINE_AA)
                    cv2.circle(overlay, (x, y), r, LYS_MAGENTA_BGR, thickness, lineType=cv2.LINE_AA)
                    cv2.circle(overlay, (x, y), 1, LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)

        cv2.putText(overlay, "FUSED (mask + MAGENTA lysosomes)", (10, 22),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)

        frames_fused.append(overlay)

    try:
        imageio.mimsave("ch2_fused_cell_magenta.mp4", frames_fused, fps=int(FPS), format="FFMPEG")
        print("Saved: ch2_fused_cell_magenta.mp4")
    except TypeError:
        imageio.mimsave("ch2_fused_cell_magenta.gif", frames_fused, fps=int(FPS))
        print("Saved: ch2_fused_cell_magenta.gif")

    # RAW+FUSED side-by-side
    ch1_u8 = _norm_u8_stack(img_ch1.astype(np.float32))
    ch2_u8 = _norm_u8_stack(ch2.astype(np.float32) * 255.0 / max(1.0, ch2.max()))

    frames_raw = []
    frames_fused_all = []
    frames_side_by_side = []

    Z = ch2_u8.shape[0]
    px_um_xy = float(np.sqrt(vx_um * vy_um))

    for z in range(Z):
        b = ch1_u8[z]
        g = ch2_u8[z]
        r = ch1_u8[z]
        base = np.dstack([b, g, r])

        mask_u8 = (cell_mask_viz[z].astype(np.uint8) * 255)
        overlay = base.copy()
        overlay[..., 1] = np.maximum(overlay[..., 1], mask_u8)
        overlay = cv2.addWeighted(base, 1.0, overlay, 0.35, 0.0)

        drew_any = False

        if isinstance(df, pd.DataFrame) and len(df) > 0:
            dfv = df[
                np.isfinite(df["z_um"]) &
                np.isfinite(df["y_um"]) &
                np.isfinite(df["x_um"]) &
                np.isfinite(df["radius_um"])
            ]
            if not dfv.empty:
                zc = (dfv["z_um"].to_numpy() / vz_um).astype(float)
                yc = (dfv["y_um"].to_numpy() / vy_um).astype(float)
                xc = (dfv["x_um"].to_numpy() / vx_um).astype(float)
                r_um = dfv["radius_um"].to_numpy().astype(float)

                dz_um = np.abs(zc - z) * vz_um
                hits = dz_um <= r_um

                if np.any(hits):
                    r_proj_um = np.sqrt(np.clip(r_um[hits]**2 - dz_um[hits]**2, 0.0, None))
                    r_proj_px = r_proj_um / max(px_um_xy, 1e-12)

                    ys = np.rint(yc[hits]).astype(int)
                    xs = np.rint(xc[hits]).astype(int)

                    H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
                    thickness = 2

                    for y, x, rpv in zip(ys, xs, r_proj_px):
                        rr = int(max(3, round(rpv)))
                        if 0 <= y < H and 0 <= x < W and rr > 0:
                            cv2.circle(overlay, (x, y), rr, LYS_EDGE_BGR, thickness + 2, lineType=cv2.LINE_AA)
                            cv2.circle(overlay, (x, y), rr, LYS_MAGENTA_BGR, thickness, lineType=cv2.LINE_AA)
                            cv2.circle(overlay, (x, y), 1,  LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)
                    drew_any = True

        if (not drew_any) and (blobs is not None) and (len(blobs) > 0):
            z_blobs = blobs[np.abs(blobs[:, 0] - z) < 0.5]
            H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
            thickness = 2
            for b_ in z_blobs:
                y, x = int(round(b_[1])), int(round(b_[2]))
                rpx = int(max(3, round(b_[3])))
                if 0 <= y < H and 0 <= x < W and rpx > 0:
                    cv2.circle(overlay, (x, y), rpx, LYS_EDGE_BGR, thickness + 2, lineType=cv2.LINE_AA)
                    cv2.circle(overlay, (x, y), rpx, LYS_MAGENTA_BGR, thickness, lineType=cv2.LINE_AA)
                    cv2.circle(overlay, (x, y), 1, LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)

        cv2.putText(base, "RAW (Ch1+Ch2)", (10, 22),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(overlay, "FUSED (mask + MAGENTA lysosomes)", (10, 22),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)

        frames_raw.append(base)
        frames_fused_all.append(overlay)
        frames_side_by_side.append(cv2.hconcat([base, overlay]))

    def _save_video_sync(basename, frames):
        try:
            with imageio.get_writer(
                f"{basename}.mp4",
                fps=int(FPS),
                format="FFMPEG",
                codec="libx264",
                macro_block_size=None
            ) as w:
                for fr in frames:
                    w.append_data(fr)
            print(f"Saved: {basename}.mp4 @ {int(FPS)} fps")
        except Exception:
            imageio.mimsave(f"{basename}.gif", frames, duration=1.0 / max(int(FPS), 1))
            print(f"Saved: {basename}.gif @ {int(FPS)} fps equivalent")

    _save_video_sync("ch2_raw", frames_raw)
    _save_video_sync("ch2_fused_all_viz_magenta", frames_fused_all)
    _save_video_sync("ch2_raw_and_fused_all_viz_magenta", frames_side_by_side)

# ==========================================
# Napari visualization
# ==========================================
if LAUNCH_VIEWER:
    viewer = napari.Viewer()
    viewer.dims.ndisplay = 3

    viewer.add_image(img_ch2, name="Ch2 raw")
    viewer.add_image(img_ch1, name="Ch1 raw")

    mask_layer = viewer.add_labels(
        cell_mask_viz.astype(np.uint8),
        name="Neurite mask (viz)" if NEURITE_MODE else "Cell mask (viz)",
        opacity=0.35
    )
    mask_layer.blending = "translucent_no_depth"

    id_layer = viewer.add_labels(
        cell_seg_viz.astype(np.uint16),
        name="ID (viz)",
        opacity=0.25
    )
    id_layer.blending = "translucent_no_depth"

    if len(df) > 0 and "location_ch2" in df:
        in_mask = (df["location_ch2"].to_numpy() == "cell")
        if np.any(in_mask):
            pts_zyx = np.stack([
                (df.loc[in_mask, "z_um"].to_numpy() / vz_um),
                (df.loc[in_mask, "y_um"].to_numpy() / vy_um),
                (df.loc[in_mask, "x_um"].to_numpy() / vx_um),
            ], axis=1)

            radii_vox = df.loc[in_mask, "radius_um"].to_numpy() / (np.sqrt(vx_um * vy_um) + 1e-12)
            sizes = np.clip(radii_vox * 2, 2, None).astype(np.float32)

            pts = viewer.add_points(pts_zyx, size=sizes, name="Lysosomes (inside 3D)")
            pts.face_color = [1.0, 0.0, 1.0, 1.0]  # MAGENTA in napari
            pts.edge_color = "black"
            pts.edge_width = 0.3

            cell_ids = df.loc[in_mask, "cell_id_ch2_viz"].to_numpy().astype(int) if "cell_id_ch2_viz" in df.columns else np.zeros(int(np.sum(in_mask)), dtype=int)
            lys_ids  = df.loc[in_mask, "lys_id_in_cell"].to_numpy().astype(int) if "lys_id_in_cell" in df.columns else np.zeros(int(np.sum(in_mask)), dtype=int)
            diams    = df.loc[in_mask, "diameter_um"].to_numpy().astype(float)

            info = np.array(
                [f"ID:{c}  Ly:{l}  Diam:{d:.3f}µm" for c, l, d in zip(cell_ids, lys_ids, diams)],
                dtype=object
            )
            pts.properties = {"info": info}

    keep = {"Lysosomes (inside 3D)", "ID (viz)", "Neurite mask (viz)", "Cell mask (viz)", "Ch1 raw", "Ch2 raw"}
    for lyr in list(viewer.layers):
        if lyr.name not in keep:
            viewer.layers.remove(lyr)

    try:
        viewer.camera.zoom = 1.2
    except Exception:
        pass

    napari.run()

In [None]:
#############last version with skeletization###########

In [None]:
import os
import re
import numpy as np
import pandas as pd
import cv2
import imageio
import xml.etree.ElementTree as ET
from datetime import datetime

import czifile
import tifffile as tiff

from skimage.feature import blob_log
from skimage.filters import gaussian, threshold_local
from skimage.morphology import (
    remove_small_objects, binary_opening, binary_closing, ball, binary_erosion,
    skeletonize_3d, binary_dilation
)
from scipy.ndimage import distance_transform_edt as edt
from skimage.measure import label
from skimage.segmentation import watershed
from scipy.ndimage import binary_fill_holes
from skimage.draw import line_nd
from scipy.ndimage import convolve

import napari
from scipy.ndimage import gaussian_filter1d
import colorsys

# ===============================
# USER MODE
# ===============================
# For neurites-only datasets (no soma), set True (recommended for your new image).
# For soma datasets, set False to use the original soma/watershed logic.
NEURITE_MODE = True

# Export full-size overlay RGB stack + mp4 (recommended)
EXPORT_FULLSIZE_OVERLAY = True

# OpenCV colors are BGR (not RGB)
LYS_EDGE_BGR = (0, 0, 0)          # black outline
LYS_MAGENTA_BGR = (255, 0, 255)   # magenta

# ===============================
# GUI (single unified interface)
# Advanced settings ONLY:
#   - MARGIN_UM (µm)
#   - OVERLAP_ALPHA (0..1)
#   - VIZ_MIN_VOXELS (voxels)
# ===============================
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, simpledialog


def get_user_config_gui(
    # basic defaults
    default_vxy_um=0.04,
    default_vz_um=None,          # if None and Z missing -> fallback to XY
    default_erode_mult=1.0,
    default_blob_threshold=0.001,

    # advanced defaults (ONLY the 3 shown in GUI)
    default_margin_um=0.5,          # µm (neurite-friendly default)
    default_overlap_alpha=0.4,      # unitless (0..1)
    default_neighbor_max_vox=6,     # voxels (hidden; fixed default)
    default_viz_min_voxels=200,     # voxels (neurite-friendly default)

    # fixed defaults (not shown in GUI)
    default_max_reasonable_vxy_um=0.5,
    default_ch1_smooth_sigma=1.0,
    default_blob_min_sigma=0.8,
    default_blob_max_sigma=3.0,
    default_blob_num_sigma=10,
    default_radial_max_radius_nm=300.0,
    default_radial_dr_nm=10.0,
    default_radial_min_drop_fraction=0.5,

    # neurite-friendly threshold defaults (fixed, not shown)
    default_ch2_smooth_sigma=1.2,#0.9
    default_thresh_block_size=201,#151
    default_thresh_offset_std_mult=0.12,#0.25

    default_video_fps=8,
    default_launch_viewer=True,
    default_generate_videos=True,
):
    cfg = {"ok": False}

    root = tk.Tk()
    root.title("Lysosome + Neurite Segmentation (GUI)")
    root.resizable(False, False)

    file_var = tk.StringVar(value="")
    out_var = tk.StringVar(value="")

    erode_var = tk.StringVar(value=str(default_erode_mult))
    blob_var  = tk.StringVar(value=str(default_blob_threshold))

    show_adv = tk.BooleanVar(value=False)

    margin_var   = tk.StringVar(value=str(default_margin_um))
    overlap_var  = tk.StringVar(value=str(default_overlap_alpha))
    vizmin_var   = tk.StringVar(value=str(default_viz_min_voxels))

    fps_var = tk.StringVar(value=str(default_video_fps))

    launch_viewer_var = tk.BooleanVar(value=bool(default_launch_viewer))
    gen_videos_var    = tk.BooleanVar(value=bool(default_generate_videos))

    def _suggest_output_dir(fp):
        if not fp:
            return ""
        raw_dir = os.path.dirname(fp)
        raw_base = os.path.splitext(os.path.basename(fp))[0]
        stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        return os.path.join(raw_dir, f"{raw_base}_outputs_{stamp}")

    def browse_file():
        fp = filedialog.askopenfilename(
            title="Select image file",
            filetypes=[("Image files", "*.tif *.tiff *.czi"), ("All files", "*.*")],
        )
        if fp:
            file_var.set(fp)
            if not out_var.get().strip():
                out_var.set(_suggest_output_dir(fp))

    def browse_output_dir():
        d = filedialog.askdirectory(title="Select output folder")
        if d:
            fp = file_var.get().strip()
            if fp:
                raw_base = os.path.splitext(os.path.basename(fp))[0]
                stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                out_var.set(os.path.join(d, f"{raw_base}_outputs_{stamp}"))
            else:
                out_var.set(d)

    def _err(msg):
        messagebox.showerror("Invalid input", msg)
        raise ValueError(msg)

    def _float_required(s, name):
        s = (s or "").strip()
        if s == "":
            _err(f"{name} is required.")
        try:
            return float(s.replace(",", "."))
        except Exception:
            _err(f"{name} must be a number (got: {s})")

    def _int_required(s, name):
        s = (s or "").strip()
        if s == "":
            _err(f"{name} is required.")
        try:
            return int(float(s.replace(",", ".")))
        except Exception:
            _err(f"{name} must be an integer (got: {s})")

    def _toggle_adv():
        if show_adv.get():
            adv_frame.grid()
        else:
            adv_frame.grid_remove()

    def run_clicked():
        fp = file_var.get().strip()
        if not fp:
            _err("Please select a file.")
        if not os.path.isfile(fp):
            _err("Selected file does not exist.")

        outd = out_var.get().strip() or _suggest_output_dir(fp)

        erode = _float_required(erode_var.get(), "ERODE_MULT")
        blobt = _float_required(blob_var.get(), "blob_log threshold")
        fps   = _int_required(fps_var.get(), "video FPS")

        margin  = _float_required(margin_var.get(), "MARGIN_UM (µm)")
        overlap = _float_required(overlap_var.get(), "OVERLAP_ALPHA (0..1)")
        vizmin  = _int_required(vizmin_var.get(), "VIZ_MIN_VOXELS (voxels)")

        if not (0.0 <= overlap <= 1.0):
            _err("OVERLAP_ALPHA must be between 0 and 1.")

        cfg.update({
            "ok": True,
            "file_path": fp,
            "output_dir": outd,

            "ERODE_MULT": float(erode),
            "BLOB_THRESHOLD": float(blobt),

            "DEFAULT_VX_VY_UM": float(default_vxy_um),
            "DEFAULT_VZ_UM": None if default_vz_um is None else float(default_vz_um),

            "MAX_REASONABLE_VXY_UM": float(default_max_reasonable_vxy_um),

            # Advanced only
            "MARGIN_UM": float(margin),
            "OVERLAP_ALPHA": float(overlap),
            "VIZ_MIN_VOXELS": int(vizmin),

            # Fixed (hidden)
            "NEIGHBOR_MAX_VOX": int(default_neighbor_max_vox),

            # Fixed defaults (hidden)
            "CH1_SMOOTH_SIGMA": float(default_ch1_smooth_sigma),
            "BLOB_MIN_SIGMA": float(default_blob_min_sigma),
            "BLOB_MAX_SIGMA": float(default_blob_max_sigma),
            "BLOB_NUM_SIGMA": int(default_blob_num_sigma),

            "RADIAL_MAX_RADIUS_NM": float(default_radial_max_radius_nm),
            "RADIAL_DR_NM": float(default_radial_dr_nm),
            "RADIAL_MIN_DROP_FRACTION": float(default_radial_min_drop_fraction),

            "CH2_SMOOTH_SIGMA": float(default_ch2_smooth_sigma),
            "THRESH_BLOCK_SIZE": int(default_thresh_block_size),
            "THRESH_OFFSET_STD_MULT": float(default_thresh_offset_std_mult),

            "VIDEO_FPS": int(fps),
            "LAUNCH_VIEWER": bool(launch_viewer_var.get()),
            "GENERATE_VIDEOS": bool(gen_videos_var.get()),
        })

        root.destroy()

    def cancel_clicked():
        root.destroy()

    root.protocol("WM_DELETE_WINDOW", cancel_clicked)

    pad = {"padx": 10, "pady": 6}
    frm = ttk.Frame(root)
    frm.grid(row=0, column=0, sticky="nsew", **pad)

    r = 0
    ttk.Label(frm, text="Image file:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=file_var, width=60).grid(row=r, column=1, sticky="we")
    ttk.Button(frm, text="Browse...", command=browse_file).grid(row=r, column=2, sticky="e")
    r += 1

    ttk.Label(frm, text="Output folder:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=out_var, width=60).grid(row=r, column=1, sticky="we")
    ttk.Button(frm, text="Browse...", command=browse_output_dir).grid(row=r, column=2, sticky="e")
    r += 1

    ttk.Separator(frm).grid(row=r, column=0, columnspan=3, sticky="we", pady=8)
    r += 1

    ttk.Label(frm, text="ERODE_MULT:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=erode_var, width=20).grid(row=r, column=1, sticky="w")
    ttk.Label(frm, text="unitless").grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Label(frm, text="blob_log threshold:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=blob_var, width=20).grid(row=r, column=1, sticky="w")
    ttk.Label(frm, text="unitless").grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Separator(frm).grid(row=r, column=0, columnspan=3, sticky="we", pady=8)
    r += 1

    ttk.Checkbutton(frm, text="Launch Napari viewer", variable=launch_viewer_var)\
        .grid(row=r, column=0, columnspan=2, sticky="w")
    ttk.Checkbutton(frm, text="Generate videos", variable=gen_videos_var)\
        .grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Label(frm, text="Video FPS:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=fps_var, width=20).grid(row=r, column=1, sticky="w")
    ttk.Label(frm, text="frames/sec").grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Separator(frm).grid(row=r, column=0, columnspan=3, sticky="we", pady=8)
    r += 1

    ttk.Checkbutton(frm, text="Show advanced settings", variable=show_adv, command=_toggle_adv)\
        .grid(row=r, column=0, columnspan=3, sticky="w")
    r += 1

    adv_frame = ttk.LabelFrame(frm, text="Advanced")
    adv_frame.grid(row=r, column=0, columnspan=3, sticky="we", pady=6)
    adv_frame.grid_remove()

    rr = 0
    def add_row(label_txt, var, hint):
        nonlocal rr
        ttk.Label(adv_frame, text=label_txt).grid(row=rr, column=0, sticky="w", padx=8, pady=3)
        ttk.Entry(adv_frame, textvariable=var, width=18).grid(row=rr, column=1, sticky="w", padx=8, pady=3)
        ttk.Label(adv_frame, text=hint).grid(row=rr, column=2, sticky="w", padx=8, pady=3)
        rr += 1

    add_row("MARGIN_UM:", margin_var, "µm (soft band around mask)")
    add_row("OVERLAP_ALPHA:", overlap_var, "0..1 (sphere overlap fraction)")
    add_row("VIZ_MIN_VOXELS:", vizmin_var, "voxels (hide small components)")

    r += 1
    ttk.Separator(frm).grid(row=r, column=0, columnspan=3, sticky="we", pady=8)
    r += 1

    btns = ttk.Frame(frm)
    btns.grid(row=r, column=0, columnspan=3, sticky="e")
    ttk.Button(btns, text="Cancel", command=cancel_clicked).grid(row=0, column=0, padx=6)
    ttk.Button(btns, text="Run", command=run_clicked).grid(row=0, column=1, padx=6)

    root.mainloop()

    if not cfg.get("ok"):
        raise SystemExit("Cancelled.")
    return cfg


# ===============================
# Metadata parsing
# ===============================
def _parse_ome_xml(xml_text):
    if not xml_text:
        return None, None, None

    def grab(attr):
        m = re.search(fr'PhysicalSize{attr}="([\d\.eE+-]+)"', xml_text)
        return float(m.group(1)) if m else None

    return grab("X"), grab("Y"), grab("Z")


def _parse_czi_scaling(czi_text):
    """
    Robust CZI scaling parser:
    - Works when values are nested (<Value>...</Value>) or attributes
    - Works with namespaces
    - Converts meters -> µm (typical CZI)
    """
    if not czi_text:
        return None, None, None

    if isinstance(czi_text, (bytes, bytearray)):
        czi_text = czi_text.decode("utf-8", errors="ignore")

    czi_text = czi_text.replace("\x00", "")

    def _to_float(s):
        if s is None:
            return None
        try:
            return float(str(s).strip().replace(",", "."))
        except Exception:
            return None

    def _to_um(val, unit_hint=None):
        if val is None:
            return None
        if unit_hint:
            u = str(unit_hint).strip().lower()
            if u in ("m", "meter", "metre", "meters", "metres"):
                return val * 1e6
            if u in ("µm", "um", "micron", "microns", "micrometer", "micrometre"):
                return val
            if u in ("nm", "nanometer", "nanometre", "nanometers", "nanometres"):
                return val / 1000.0

        # heuristic: CZI values often in meters (~1e-8 to 1e-6)
        if val < 1e-3:
            return val * 1e6
        # sometimes already in µm
        if val < 10:
            return val
        # sometimes in nm
        if val < 1e5:
            return val / 1000.0
        return None

    try:
        root = ET.fromstring(czi_text)
    except Exception:
        def _grab(axis):
            mm = re.search(
                rf'<Distance[^>]*Id="{axis}"[^>]*>.*?<Value>\s*([0-9eE\+\-\.]+)\s*</Value>',
                czi_text,
                flags=re.IGNORECASE | re.DOTALL,
            )
            return _to_float(mm.group(1)) if mm else None

        return _to_um(_grab("X")), _to_um(_grab("Y")), _to_um(_grab("Z"))

    sx = sy = sz = None
    for d in root.findall(".//{*}Distance"):
        axis = d.attrib.get("Id") or d.attrib.get("id") or d.attrib.get("Axis") or d.attrib.get("axis")
        if not axis:
            continue
        axis = axis.upper()
        unit = d.attrib.get("Unit") or d.attrib.get("unit")

        valf = _to_float(d.attrib.get("Value") or d.attrib.get("value"))

        if valf is None:
            v_el = d.find(".//{*}Value")
            if v_el is not None and v_el.text:
                valf = _to_float(v_el.text)

        if valf is None:
            for child in d.iter():
                if child is d:
                    continue
                if str(child.tag).lower().endswith("value"):
                    valf = _to_float(child.attrib.get("Value") or child.attrib.get("value")) or _to_float(child.text)
                    if valf is not None:
                        break

        val_um = _to_um(valf, unit_hint=unit)
        if val_um is None:
            continue

        if axis == "X":
            sx = val_um
        elif axis == "Y":
            sy = val_um
        elif axis == "Z":
            sz = val_um

    return sx, sy, sz


def load_any(file_path):
    ext = os.path.splitext(file_path)[1].lower()

    if ext in (".tif", ".tiff"):
        with tiff.TiffFile(file_path) as tf:
            arr = tf.asarray()
            try:
                ome_xml = tf.ome_metadata
            except Exception:
                ome_xml = None
            vx_um = vy_um = vz_um = None
            if ome_xml:
                vx_um, vy_um, vz_um = _parse_ome_xml(ome_xml)

        img = np.squeeze(arr)
        if img.ndim == 4:
            if img.shape[0] == 2:
                ch1, ch2 = img[0], img[1]
            elif img.shape[1] == 2:
                ch1, ch2 = img[:, 0], img[:, 1]
            elif img.shape[-1] == 2:
                ch1, ch2 = img[..., 0], img[..., 1]
            else:
                raise RuntimeError("Unexpected TIFF shape for 2 channels")
        else:
            raise RuntimeError("Unexpected TIFF shape")
        return ch1, ch2, (vx_um, vy_um, vz_um), {"type": "tiff"}

    if ext == ".czi":
        with czifile.CziFile(file_path) as cf:
            arr = cf.asarray()
            try:
                czi_xml = cf.metadata()
            except Exception:
                czi_xml = None

        vx_um = vy_um = vz_um = None
        if czi_xml:
            vx_um, vy_um, vz_um = _parse_czi_scaling(czi_xml)

        img = np.squeeze(arr)
        if img.ndim == 4:
            if img.shape[0] == 2:
                ch1, ch2 = img[0], img[1]
            elif img.shape[1] == 2:
                ch1, ch2 = img[:, 0], img[:, 1]
            elif img.shape[-1] == 2:
                ch1, ch2 = img[..., 0], img[..., 1]
            else:
                raise RuntimeError("Unexpected CZI shape for 2 channels")
        else:
            raise RuntimeError("Unexpected CZI shape")

        return ch1, ch2, (vx_um, vy_um, vz_um), {"type": "czi"}

    raise ValueError("Unsupported file format")


# ===============================
# NEW (Option B): direction-aware stitching of broken neurites
# ===============================
def stitch_neurite_fragments_by_orientation(
    neuron_mask,
    vx_um, vy_um, vz_um,
    max_gap_um=1.0,      # max distance allowed to bridge (µm)
    cos_min=0.75,        # direction alignment threshold (0.75 ~ 41°)
    bridge_dilate_radius_vox=1,  # thickness of added bridge in vox
):
    """
    Connects broken neurite pieces by stitching skeleton endpoints that are
    (1) within max_gap_um and (2) oriented as a smooth continuation.
    """
    if neuron_mask is None or neuron_mask.sum() == 0:
        return neuron_mask

    skel = skeletonize_3d(neuron_mask).astype(bool)
    if skel.sum() == 0:
        return neuron_mask

    # neighbor count on skeleton (26-neighborhood)
    kernel = np.ones((3, 3, 3), dtype=np.uint8)
    kernel[1, 1, 1] = 0
    neigh = convolve(skel.astype(np.uint8), kernel, mode="constant", cval=0)

    endpoints = skel & (neigh == 1)
    pts = np.argwhere(endpoints)  # (N,3) zyx
    n = pts.shape[0]
    if n < 2:
        return neuron_mask

    spacing = np.array([vz_um, vy_um, vx_um], dtype=np.float32)

    # endpoint direction: vector from its only neighbor -> endpoint (points "outward")
    dirs = np.zeros((n, 3), dtype=np.float32)
    Z, Y, X = skel.shape

    for i, (z, y, x) in enumerate(pts):
        z1, z2 = max(0, z - 1), min(Z, z + 2)
        y1, y2 = max(0, y - 1), min(Y, y + 2)
        x1, x2 = max(0, x - 1), min(X, x + 2)

        patch = skel[z1:z2, y1:y2, x1:x2]
        nbrs = np.argwhere(patch)  # local coords
        nbrs = nbrs + np.array([z1, y1, x1], dtype=np.int32)

        nbrs = nbrs[~((nbrs[:, 0] == z) & (nbrs[:, 1] == y) & (nbrs[:, 2] == x))]
        if nbrs.shape[0] == 0:
            continue

        d2 = np.sum((nbrs - np.array([z, y, x], dtype=np.int32)) ** 2, axis=1)
        nb = nbrs[int(np.argmin(d2))]

        v_vox = (np.array([z, y, x], dtype=np.float32) - nb.astype(np.float32))
        v_um = v_vox * spacing
        norm = float(np.linalg.norm(v_um))
        if norm > 0:
            dirs[i] = v_um / norm

    # build candidate pairs (distance + alignment)
    candidates = []
    for i in range(n):
        if not np.isfinite(dirs[i]).all() or np.linalg.norm(dirs[i]) == 0:
            continue
        for j in range(i + 1, n):
            if not np.isfinite(dirs[j]).all() or np.linalg.norm(dirs[j]) == 0:
                continue

            d_vox = (pts[j] - pts[i]).astype(np.float32)
            d_um_vec = d_vox * spacing
            dist = float(np.linalg.norm(d_um_vec))
            if dist <= 0 or dist > max_gap_um:
                continue

            v = d_um_vec / dist  # unit vector i->j (in µm-space)

            # i should point toward j, and j should point toward i
            if float(np.dot(dirs[i], v)) < cos_min:
                continue
            if float(np.dot(dirs[j], -v)) < cos_min:
                continue

            candidates.append((dist, i, j))

    if not candidates:
        return neuron_mask

    candidates.sort(key=lambda t: t[0])

    used = np.zeros(n, dtype=bool)
    bridge = np.zeros_like(neuron_mask, dtype=bool)

    # greedy matching: connect closest valid endpoint pairs first
    for dist, i, j in candidates:
        if used[i] or used[j]:
            continue
        p1 = tuple(pts[i])
        p2 = tuple(pts[j])
        rr = line_nd(p1, p2, endpoint=True)  # returns a tuple of arrays for z,y,x
        bridge[rr] = True
        used[i] = True
        used[j] = True

    if not bridge.any():
        return neuron_mask

    bridged = neuron_mask | binary_dilation(bridge, ball(int(bridge_dilate_radius_vox)))
    bridged = binary_closing(bridged, ball(1))
    return bridged


# ===============================
# Helper: refine radii by distance transform
# ===============================
def refine_radii_via_dt(img3d, blobs, win_px=40, bin_method="sauvola"):
    from skimage.filters import threshold_sauvola, threshold_local, threshold_otsu
    from skimage.morphology import remove_small_objects, disk, binary_opening
    from scipy.ndimage import distance_transform_edt as _edt
    from skimage.measure import label as _label

    if blobs is None or len(blobs) == 0:
        return blobs

    Z, H, W = img3d.shape
    out = blobs.copy().astype(np.float32)

    for i, (zc, yc, xc, _) in enumerate(out):
        z, y, x = int(round(zc)), int(round(yc)), int(round(xc))
        if not (0 <= z < Z and 0 <= y < H and 0 <= x < W):
            continue

        y1, y2 = max(0, y - win_px), min(H, y + win_px + 1)
        x1, x2 = max(0, x - win_px), min(W, x + win_px + 1)
        patch = img3d[z, y1:y2, x1:x2]

        if bin_method == "sauvola":
            ws = max(11, 2 * (win_px // 2) + 1)
            thr = threshold_sauvola(patch, window_size=ws, k=0.4)
            bw = patch > thr
        elif bin_method == "local":
            ws = max(11, 2 * (win_px // 2) + 1)
            thr = threshold_local(patch, block_size=ws, offset=-0.4 * np.std(patch))
            bw = patch > thr
        else:
            try:
                thr = threshold_otsu(patch)
                bw = patch > thr
            except ValueError:
                continue

        bw = binary_opening(bw, footprint=disk(1))
        bw = remove_small_objects(bw, min_size=3, connectivity=2)

        yy, xx = y - y1, x - x1
        if not (0 <= yy < bw.shape[0] and 0 <= xx < bw.shape[1]) or not bw[yy, xx]:
            continue

        lab = _label(bw)
        lbl = lab[yy, xx]
        if lbl == 0:
            continue
        bw_obj = (lab == lbl)
        dt = _edt(bw_obj)
        r_px = float(dt[yy, xx])
        if r_px <= 0:
            continue
        out[i, 3] = r_px

    return out


def refine_radii_via_radial_intensity(
    img3d,
    blobs,
    vx_um,
    vy_um,
    vz_um,
    max_radius_nm=300.0,
    dr_nm=10.0,
    min_drop_fraction=0.5,
):
    if blobs is None or len(blobs) == 0:
        return blobs

    img = img3d.astype(np.float32)
    Z, Y, X = img.shape
    blobs_out = blobs.copy().astype(np.float32)

    max_r_um = max_radius_nm / 1000.0
    dr_um = dr_nm / 1000.0

    r_edges = np.arange(0.0, max_r_um + dr_um, dr_um)
    if r_edges.size < 2:
        return blobs_out
    r_centers = 0.5 * (r_edges[:-1] + r_edges[1:])

    px_um_xy = float(np.sqrt(vx_um * vy_um))

    for i, (zc, yc, xc, r_px_init) in enumerate(blobs_out):
        z0 = int(round(zc))
        y0 = int(round(yc))
        x0 = int(round(xc))

        if not (0 <= z0 < Z and 0 <= y0 < Y and 0 <= x0 < X):
            continue

        rz = max(1, int(np.ceil(max_r_um / vz_um)))
        ry = max(1, int(np.ceil(max_r_um / vy_um)))
        rx = max(1, int(np.ceil(max_r_um / vx_um)))

        z1, z2 = max(0, z0 - rz), min(Z, z0 + rz + 1)
        y1, y2 = max(0, y0 - ry), min(Y, y0 + ry + 1)
        x1, x2 = max(0, x0 - rx), min(X, x0 + rx + 1)
        if z1 >= z2 or y1 >= y2 or x1 >= x2:
            continue

        patch = img[z1:z2, y1:y2, x1:x2]

        zz, yy, xx = np.mgrid[z1:z2, y1:y2, x1:x2]
        dz_um = (zz - z0) * vz_um
        dy_um = (yy - y0) * vy_um
        dx_um = (xx - x0) * vx_um
        r_um = np.sqrt(dz_um**2 + dy_um**2 + dx_um**2)

        mask = (r_um <= max_r_um)
        if not np.any(mask):
            continue

        r_vals = r_um[mask].ravel()
        I_vals = patch[mask].ravel()

        bin_idx = np.digitize(r_vals, r_edges) - 1
        valid = (bin_idx >= 0) & (bin_idx < r_centers.size)
        if not np.any(valid):
            continue

        bin_idx = bin_idx[valid]
        I_vals = I_vals[valid]

        sums = np.bincount(bin_idx, weights=I_vals, minlength=r_centers.size)
        counts = np.bincount(bin_idx, minlength=r_centers.size)
        with np.errstate(invalid="ignore", divide="ignore"):
            prof = sums / np.maximum(counts, 1)

        have = counts > 0
        if not np.any(have):
            continue

        r_prof = r_centers[have]
        I_prof = prof[have].astype(np.float32)

        I_smooth = gaussian_filter1d(I_prof, sigma=1.0)
        I_max = float(I_smooth.max())
        if I_max <= 0:
            continue
        I_min = float(I_smooth.min())

        if (I_max - I_min) / max(I_max, 1e-9) < min_drop_fraction:
            continue

        I_half = I_min + 0.5 * (I_max - I_min)

        peak_idx = int(np.argmax(I_smooth))
        n_bins = len(I_smooth)

        left_idx = peak_idx
        while left_idx > 0 and I_smooth[left_idx] >= I_half:
            left_idx -= 1
        if left_idx < peak_idx and I_smooth[left_idx] < I_half:
            left_idx += 1

        right_idx = peak_idx
        while right_idx < n_bins - 1 and I_smooth[right_idx] >= I_half:
            right_idx += 1
        if right_idx > peak_idx and I_smooth[right_idx] < I_half:
            right_idx -= 1

        if right_idx <= left_idx:
            continue

        radius_um = 0.5 * (float(r_prof[right_idx]) - float(r_prof[left_idx]))
        if radius_um <= 0:
            continue

        r_fwhm_px = radius_um / max(px_um_xy, 1e-9)
        blobs_out[i, 3] = max(float(r_px_init), float(r_fwhm_px))

    return blobs_out


# ===============================
# Distance/overlap-aware classification helpers
# ===============================
from scipy.ndimage import distance_transform_edt as _edt


def nearest_cell_label(cell_seg, z, y, x, max_r=12):
    Z, Y, X = cell_seg.shape
    for r in range(1, max_r + 1):
        z1, z2 = max(0, z - r), min(Z, z + r + 1)
        y1, y2 = max(0, y - r), min(Y, y + r + 1)
        x1, x2 = max(0, x - r), min(X, x + r + 1)
        patch = cell_seg[z1:z2, y1:y2, x1:x2]
        lab = patch[patch > 0]
        if lab.size:
            return int(np.bincount(lab.ravel()).argmax())
    return 0


def sphere_overlap_fraction(zc_um, yc_um, xc_um, r_um, mask_bool, vx_um, vy_um, vz_um):
    if r_um <= 0:
        return 0.0

    zc = int(round(zc_um / vz_um))
    yc = int(round(yc_um / vy_um))
    xc = int(round(xc_um / vx_um))

    rz = max(1, int(np.ceil(r_um / vz_um)))
    ry = max(1, int(np.ceil(r_um / vy_um)))
    rx = max(1, int(np.ceil(r_um / vx_um)))

    Z, Y, X = mask_bool.shape
    z1, z2 = max(0, zc - rz), min(Z, zc + rz + 1)
    y1, y2 = max(0, yc - ry), min(Y, yc + ry + 1)
    x1, x2 = max(0, xc - rx), min(X, xc + rx + 1)
    if z1 >= z2 or y1 >= y2 or x1 >= x2:
        return 0.0

    zz, yy, xx = np.mgrid[z1:z2, y1:y2, x1:x2]
    dz = (zz - zc) * vz_um
    dy = (yy - yc) * vy_um
    dx = (xx - xc) * vx_um
    sphere = (dz*dz + dy*dy + dx*dx) <= (r_um * r_um)

    if not np.any(sphere):
        return 0.0

    in_mask = mask_bool[z1:z2, y1:y2, x1:x2] & sphere
    return float(in_mask.sum()) / float(sphere.sum())


# ===============================
# Full-size overlay exporter (RGB stack + MP4)
# ===============================
def _norm_u8_stack(vol):
    vmin, vmax = float(vol.min()), float(vol.max())
    if vmax > vmin:
        out = (np.clip((vol - vmin) / (vmax - vmin), 0, 1) * 255.0).astype(np.uint8)
    else:
        out = np.zeros_like(vol, dtype=np.uint8)
    return out


def make_label_colormap(n_labels, seed_hue=0.13):
    colors = np.zeros((n_labels + 1, 3), dtype=np.uint8)
    if n_labels <= 0:
        return colors
    for i in range(1, n_labels + 1):
        h = (seed_hue + (i - 1) / max(n_labels, 1)) % 1.0
        s = 1.0
        v = 1.0
        r, g, b = colorsys.hsv_to_rgb(h, s, v)
        colors[i] = (int(255 * r), int(255 * g), int(255 * b))
    return colors


def export_fullsize_overlay_stack(
    img_ch1,
    img_ch2_raw,
    cell_seg_viz,
    df,
    vx_um, vy_um, vz_um,
    output_dir,
    alpha_labels=0.45,
    draw_only_inside=True,
    fps=8,
    basename="FULLSIZE_overlay_ID_Lysosomes_MAGENTA",
):
    os.makedirs(output_dir, exist_ok=True)

    ch1_u8 = _norm_u8_stack(img_ch1.astype(np.float32))
    ch2_u8 = _norm_u8_stack(img_ch2_raw.astype(np.float32))

    Z, H, W = ch2_u8.shape
    n_labels = int(cell_seg_viz.max()) if isinstance(cell_seg_viz, np.ndarray) else 0
    cmap = make_label_colormap(n_labels, seed_hue=0.13)

    use_df = None
    if isinstance(df, pd.DataFrame) and len(df) > 0 and {"z_um", "y_um", "x_um", "radius_um"}.issubset(df.columns):
        if draw_only_inside and "location_ch2" in df.columns:
            use_df = df[df["location_ch2"] == "cell"].copy()
        else:
            use_df = df.copy()
        use_df = use_df[
            np.isfinite(use_df["z_um"]) &
            np.isfinite(use_df["y_um"]) &
            np.isfinite(use_df["x_um"]) &
            np.isfinite(use_df["radius_um"])
        ].copy()

    px_um_xy = float(np.sqrt(vx_um * vy_um))
    frames = np.zeros((Z, H, W, 3), dtype=np.uint8)

    for z in range(Z):
        # Background (same style as your RAW composite): B=Ch1, G=Ch2, R=Ch1
        base = np.dstack([ch1_u8[z], ch2_u8[z], ch1_u8[z]]).astype(np.float32)

        # Colored IDs
        lab2d = cell_seg_viz[z].astype(np.int32)
        lab_rgb = cmap[lab2d].astype(np.float32)

        mask = (lab2d > 0)[..., None].astype(np.float32)
        out = base * (1.0 - alpha_labels * mask) + lab_rgb * (alpha_labels * mask)

        # Lysosomes (MAGENTA)
        if use_df is not None and len(use_df) > 0:
            zc = (use_df["z_um"].to_numpy() / vz_um).astype(float)
            yc = (use_df["y_um"].to_numpy() / vy_um).astype(float)
            xc = (use_df["x_um"].to_numpy() / vx_um).astype(float)
            r_um = use_df["radius_um"].to_numpy().astype(float)

            dz_um = np.abs(zc - z) * vz_um
            hits = dz_um <= r_um
            if np.any(hits):
                r_proj_um = np.sqrt(np.clip(r_um[hits]**2 - dz_um[hits]**2, 0.0, None))
                r_proj_px = r_proj_um / max(px_um_xy, 1e-12)
                ys = np.rint(yc[hits]).astype(int)
                xs = np.rint(xc[hits]).astype(int)

                out_u8 = np.clip(out, 0, 255).astype(np.uint8)
                for y, x, rp in zip(ys, xs, r_proj_px):
                    rr = int(max(3, round(rp)))
                    if 0 <= y < H and 0 <= x < W and rr > 0:
                        cv2.circle(out_u8, (x, y), rr, LYS_EDGE_BGR, 4, lineType=cv2.LINE_AA)
                        cv2.circle(out_u8, (x, y), rr, LYS_MAGENTA_BGR, 2, lineType=cv2.LINE_AA)
                        cv2.circle(out_u8, (x, y), 1,  LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)
                out = out_u8.astype(np.float32)

        frames[z] = np.clip(out, 0, 255).astype(np.uint8)

    # Save RGB TIFF stack
    tiff_path = os.path.join(output_dir, f"{basename}.tif")
    tiff.imwrite(tiff_path, frames, photometric="rgb")
    print("Saved full-size RGB TIFF stack:", tiff_path)

    # Save MP4 (fallback GIF)
    mp4_path = os.path.join(output_dir, f"{basename}.mp4")
    try:
        with imageio.get_writer(
            mp4_path,
            fps=int(fps),
            format="FFMPEG",
            codec="libx264",
            macro_block_size=None
        ) as w:
            for fr in frames:
                w.append_data(fr)
        print("Saved full-size MP4:", mp4_path)
    except Exception as e:
        gif_path = os.path.join(output_dir, f"{basename}.gif")
        imageio.mimsave(gif_path, list(frames), fps=int(fps))
        print("FFMPEG failed, saved GIF instead:", gif_path, "Error:", e)

    return tiff_path, mp4_path


# ===============================
# MAIN
# ===============================
DEFAULT_VX_VY_UM = 0.04
DEFAULT_VZ_UM = None  # if None and Z missing -> fallback to XY

cfg = get_user_config_gui(
    default_vxy_um=DEFAULT_VX_VY_UM,
    default_vz_um=DEFAULT_VZ_UM,
    default_erode_mult=1.0,
    default_blob_threshold=0.001,
)

file_path = cfg["file_path"]
output_dir = cfg["output_dir"]
os.makedirs(output_dir, exist_ok=True)
os.chdir(output_dir)

print("Selected file:", file_path)
print("Outputs will be saved to:", output_dir)

ERODE_MULT = cfg["ERODE_MULT"]
BLOB_THRESHOLD = cfg["BLOB_THRESHOLD"]

MAX_REASONABLE_VXY_UM = cfg["MAX_REASONABLE_VXY_UM"]
MARGIN_UM = cfg["MARGIN_UM"]
OVERLAP_ALPHA = cfg["OVERLAP_ALPHA"]
NEIGHBOR_MAX_VOX = cfg["NEIGHBOR_MAX_VOX"]
VIZ_MIN_VOXELS = cfg["VIZ_MIN_VOXELS"]

CH1_SMOOTH_SIGMA = cfg["CH1_SMOOTH_SIGMA"]
BLOB_MIN_SIGMA = cfg["BLOB_MIN_SIGMA"]
BLOB_MAX_SIGMA = cfg["BLOB_MAX_SIGMA"]
BLOB_NUM_SIGMA = cfg["BLOB_NUM_SIGMA"]

RADIAL_MAX_RADIUS_NM = cfg["RADIAL_MAX_RADIUS_NM"]
RADIAL_DR_NM = cfg["RADIAL_DR_NM"]
RADIAL_MIN_DROP_FRACTION = cfg["RADIAL_MIN_DROP_FRACTION"]

CH2_SMOOTH_SIGMA = cfg["CH2_SMOOTH_SIGMA"]
THRESH_BLOCK_SIZE = cfg["THRESH_BLOCK_SIZE"]
THRESH_OFFSET_STD_MULT = cfg["THRESH_OFFSET_STD_MULT"]

FPS = cfg["VIDEO_FPS"]
LAUNCH_VIEWER = cfg["LAUNCH_VIEWER"]
GENERATE_VIDEOS = cfg["GENERATE_VIDEOS"]

# Load data + metadata
img_ch1, img_ch2, (vx_um, vy_um, vz_um), meta = load_any(file_path)
print(f"[metadata] vx_um={vx_um}  vy_um={vy_um}  vz_um={vz_um}")

# Prompt only when metadata is missing
def _prompt_missing_size(title, prompt, default_value):
    root = tk.Tk()
    root.withdraw()
    try:
        root.attributes("-topmost", True)
    except Exception:
        pass

    use_default = messagebox.askyesno(
        title,
        f"{prompt}\n\nUse default value: {default_value} ?"
    )
    if use_default:
        root.destroy()
        return float(default_value)

    val = simpledialog.askfloat(
        title,
        "Enter a new value:",
        initialvalue=float(default_value),
        minvalue=1e-12
    )
    root.destroy()
    if val is None:
        raise SystemExit("Cancelled.")
    return float(val)

# XY
if vx_um is None or vy_um is None:
    vx_um = vy_um = _prompt_missing_size(
        title="Missing metadata",
        prompt="XY pixel size metadata is missing (µm/px).",
        default_value=float(cfg["DEFAULT_VX_VY_UM"]),
    )

# Z
if vz_um is None:
    z_default = float(cfg["DEFAULT_VZ_UM"]) if (cfg["DEFAULT_VZ_UM"] is not None) else float(vx_um)
    vz_um = _prompt_missing_size(
        title="Missing metadata",
        prompt="Z step metadata is missing (µm/slice).",
        default_value=z_default,
    )

if vx_um > MAX_REASONABLE_VXY_UM:
    raise ValueError(f"XY pixel size too large: {vx_um} µm/px")

px_um_xy = float(np.sqrt(vx_um * vy_um))
px_um = px_um_xy * 0.55
voxel_um3 = vx_um * vy_um * vz_um
print(f"Voxel size (µm): X={vx_um}, Y={vy_um}, Z={vz_um}")

# ===== Aliases =====
image = img_ch1
image_2 = img_ch2

# ==========================================
# Lysosome detection (Ch1)
# ==========================================
image_smooth = gaussian(image, sigma=CH1_SMOOTH_SIGMA)

blobs = blob_log(
    image_smooth,
    min_sigma=BLOB_MIN_SIGMA,
    max_sigma=BLOB_MAX_SIGMA,
    num_sigma=BLOB_NUM_SIGMA,
    threshold=BLOB_THRESHOLD
)

if len(blobs) > 0:
    blobs[:, 3] *= np.sqrt(3)  # LoG radius correction (pixels)
    blobs = refine_radii_via_dt(image_smooth, blobs)
    blobs = refine_radii_via_radial_intensity(
        image_smooth,
        blobs,
        vx_um, vy_um, vz_um,
        max_radius_nm=RADIAL_MAX_RADIUS_NM,
        dr_nm=RADIAL_DR_NM,
        min_drop_fraction=RADIAL_MIN_DROP_FRACTION
    )

# Peak intensity in raw 16-bit Ch1
peak_gray = np.zeros(len(blobs), dtype=np.uint16)
Z0, Y0, X0 = image.shape
rad = 1

for i, (zc, yc, xc, _) in enumerate(blobs):
    zc_i = int(round(zc))
    yc_i = int(round(yc))
    xc_i = int(round(xc))

    z1, z2 = max(0, zc_i - rad), min(Z0, zc_i + rad + 1)
    y1, y2 = max(0, yc_i - rad), min(Y0, yc_i + rad + 1)
    x1, x2 = max(0, xc_i - rad), min(X0, xc_i + rad + 1)

    peak_gray[i] = np.max(image[z1:z2, y1:y2, x1:x2]).astype(np.uint16)

if len(blobs) > 0:
    z_um = blobs[:, 0] * vz_um
    y_um = blobs[:, 1] * vy_um
    x_um = blobs[:, 2] * vx_um
    radius_um = blobs[:, 3] * px_um
    diameter_um = 2 * radius_um
    volume_um3 = (4/3) * np.pi * radius_um**3
    blob_ids = np.arange(1, len(blobs) + 1, dtype=int)
else:
    z_um = y_um = x_um = radius_um = diameter_um = volume_um3 = np.array([])
    blob_ids = np.array([], dtype=int)

df = pd.DataFrame({
    "id": blob_ids,
    "z_um": z_um,
    "y_um": y_um,
    "x_um": x_um,
    "radius_um": radius_um,
    "diameter_um": diameter_um,
    "volume_um3": volume_um3,
    "peak_gray": peak_gray,
})
df.to_csv("lysosome_blobs_regions.csv", index=False)
print("Saved: lysosome_blobs_regions.csv")

# ==========================================
# CH2 segmentation (neurites mask)
# ==========================================
vol = image_2.astype(np.float32)
vmin, vmax = float(vol.min()), float(vol.max())
if vmax > vmin:
    vol = (vol - vmin) / (vmax - vmin)
else:
    vol[:] = 0.0

ch2 = gaussian(vol, sigma=CH2_SMOOTH_SIGMA, preserve_range=True)

neuron_mask = np.zeros_like(ch2, dtype=bool)
for z in range(ch2.shape[0]):
    R = ch2[z]
    t = threshold_local(R, block_size=THRESH_BLOCK_SIZE, offset=-THRESH_OFFSET_STD_MULT * np.std(R))
    neuron_mask[z] = R > t

if NEURITE_MODE:
    # neurite-safe cleanup: connect gaps, remove small speckles; avoid erosion/opening
    neuron_mask = binary_closing(neuron_mask, ball(2))#1
    neuron_mask = remove_small_objects(neuron_mask, min_size=30, connectivity=3)

    # NEW (Option B): stitch broken fragments into continuous branches (direction-aware)
    neuron_mask = stitch_neurite_fragments_by_orientation(
        neuron_mask,
        vx_um=vx_um, vy_um=vy_um, vz_um=vz_um,
        max_gap_um=2.5,#1
        cos_min=0.6,#0.75
        bridge_dilate_radius_vox=2,#1
    )

    # now remove leftover speckles#NEW
    neuron_mask = remove_small_objects(neuron_mask, min_size=200, connectivity=3)

    # NOTE: do NOT fill holes globally for neurites (can fill loops)
else:
    # fallback to soma-style refinement if you ever switch back
    neuron_mask = binary_fill_holes(neuron_mask)

print("neurite voxels:", int(neuron_mask.sum()))

# ---- ID segmentation ----
if NEURITE_MODE:
    # label connected neurite networks (often 1)
    cell_seg = label(neuron_mask, connectivity=3).astype(np.int32)
    cell_mask = neuron_mask.copy()
    print("Detected components (neurite networks):", int(cell_seg.max()))
else:
    # Original soma logic (kept for compatibility)
    dist = edt(neuron_mask)
    cell_min_radius_vox = 1
    cell_mask = (dist >= cell_min_radius_vox)
    cell_mask &= neuron_mask
    cell_mask = binary_fill_holes(binary_closing(binary_opening(cell_mask, ball(1)), ball(2)))

    body_lab = label(cell_mask, connectivity=3)
    n_cells = int(body_lab.max())
    print(f"Detected {n_cells} cells (soma).")

    if n_cells > 0:
        dist_inside = edt(neuron_mask)
        cell_seg = watershed(-dist_inside, markers=body_lab, mask=neuron_mask)
    else:
        cell_seg = np.zeros_like(neuron_mask, dtype=np.int32)

print("ID voxels:", int((cell_seg > 0).sum()))

# ==========================================
# Visualization-only filtering (hide tiny components) + serial IDs
# ==========================================
cell_seg_viz = cell_seg.copy()
cell_id_map_viz = {}

if isinstance(cell_seg_viz, np.ndarray) and cell_seg_viz.max() > 0:
    counts_viz = np.bincount(cell_seg_viz.ravel().astype(np.int64))
    tiny_labels = np.where(counts_viz < VIZ_MIN_VOXELS)[0]
    tiny_labels = tiny_labels[tiny_labels > 0]
    if tiny_labels.size > 0:
        tiny_mask = np.isin(cell_seg_viz, tiny_labels)
        cell_seg_viz[tiny_mask] = 0

    unique_labels = np.unique(cell_seg_viz)
    unique_labels = unique_labels[unique_labels > 0]

    if unique_labels.size > 0:
        new_seg = np.zeros_like(cell_seg_viz, dtype=np.int32)
        for new_id, old_id in enumerate(unique_labels, start=1):
            new_seg[cell_seg_viz == old_id] = new_id
            cell_id_map_viz[old_id] = new_id
        cell_seg_viz = new_seg

cell_mask_viz = (cell_seg_viz > 0)

# ==========================================
# Classify lysosomes: inside vs outside neurite mask, and assign ID
# ==========================================
dist_out_um = _edt(~neuron_mask, sampling=(vz_um, vy_um, vx_um)).astype(np.float32)
soft_cell_mask = neuron_mask | (dist_out_um <= MARGIN_UM)

location_ch2 = []
cell_id_list = []

if len(df) > 0:
    Z, Y, X = neuron_mask.shape
    for (zc_um, yc_um, xc_um, r_um) in df[["z_um", "y_um", "x_um", "radius_um"]].to_numpy():
        zz = int(round(zc_um / vz_um))
        yy = int(round(yc_um / vy_um))
        xx = int(round(xc_um / vx_um))

        inside_hard = (0 <= zz < Z and 0 <= yy < Y and 0 <= xx < X and neuron_mask[zz, yy, xx])
        inside_soft = (0 <= zz < Z and 0 <= yy < Y and 0 <= xx < X and soft_cell_mask[zz, yy, xx])

        is_inside = bool(inside_hard)
        if not is_inside and inside_soft:
            is_inside = True

        if not is_inside:
            frac = sphere_overlap_fraction(zc_um, yc_um, xc_um, r_um, neuron_mask, vx_um, vy_um, vz_um)
            if frac >= OVERLAP_ALPHA:
                is_inside = True

        if is_inside:
            cid = 0
            if 0 <= zz < Z and 0 <= yy < Y and 0 <= xx < X:
                if cell_seg[zz, yy, xx] != 0:
                    cid = int(cell_seg[zz, yy, xx])
                else:
                    cid = nearest_cell_label(cell_seg, zz, yy, xx, max_r=NEIGHBOR_MAX_VOX)
            location_ch2.append("cell")
            cell_id_list.append(cid)
        else:
            location_ch2.append("outside")
            cell_id_list.append(0)

    df["location_ch2"] = location_ch2
    df["cell_id_ch2"] = cell_id_list
    df["cell_id_ch2_viz"] = df["cell_id_ch2"].map(cell_id_map_viz).fillna(0).astype(int) if isinstance(cell_id_map_viz, dict) else 0

    # Serial lysosome IDs within each ID component
    df["lys_id_in_cell"] = 0
    mask_in = (df["location_ch2"] == "cell") & (df["cell_id_ch2"] > 0)
    df_sorted = df.loc[mask_in].sort_values(["cell_id_ch2", "z_um", "y_um", "x_um"]).copy()
    df.loc[df_sorted.index, "lys_id_in_cell"] = (df_sorted.groupby("cell_id_ch2").cumcount().to_numpy() + 1).astype(int)

    df.to_csv("lysosomes_with_cell_vs_outside.csv", index=False)
    print("Saved: lysosomes_with_cell_vs_outside.csv")

    df.groupby("location_ch2").size().reset_index(name="count").to_csv("lysosome_counts_cell_vs_outside.csv", index=False)
    (df[df["location_ch2"] == "cell"]
        .groupby("cell_id_ch2").size()
        .reset_index(name="count")
        .to_csv("lysosome_counts_by_cell.csv", index=False))
    print("Saved: lysosome_counts_cell_vs_outside.csv, lysosome_counts_by_cell.csv")

# Per-ID volumes (µm^3)
cell_volume_df = pd.DataFrame(columns=["cell_id_ch2", "voxel_count", "volume_um3"])
if isinstance(cell_seg, np.ndarray) and cell_seg.max() > 0:
    counts = np.bincount(cell_seg.ravel().astype(np.int64))
    cell_ids = np.arange(1, counts.size, dtype=int)
    voxels = counts[1:].astype(np.int64)
    vol_um3 = voxels.astype(float) * voxel_um3

    cell_volume_df = pd.DataFrame({
        "cell_id_ch2": cell_ids,
        "voxel_count": voxels,
        "volume_um3": vol_um3
    })
    cell_volume_df.to_csv("cell_volumes_ch2.csv", index=False)
    print("Saved: cell_volumes_ch2.csv")

    try:
        if len(df) > 0 and "location_ch2" in df and "cell_id_ch2" in df:
            lys_counts = (
                df[df["location_ch2"] == "cell"]
                .groupby("cell_id_ch2").size()
                .reset_index(name="lysosome_count")
            )
        else:
            lys_counts = pd.DataFrame(columns=["cell_id_ch2", "lysosome_count"])

        merged = cell_volume_df.merge(lys_counts, on="cell_id_ch2", how="left").fillna({"lysosome_count": 0})
        merged.to_csv("cell_metrics_ch2.csv", index=False)
        print("Saved: cell_metrics_ch2.csv")
    except Exception as e:
        print("Merge with lysosome counts failed:", e)

# ==========================================
# Export full-size overlay (RAW + colored IDs + MAGENTA lysosomes)
# ==========================================
if EXPORT_FULLSIZE_OVERLAY:
    export_fullsize_overlay_stack(
        img_ch1=img_ch1,
        img_ch2_raw=img_ch2,
        cell_seg_viz=cell_seg_viz,
        df=df,
        vx_um=vx_um, vy_um=vy_um, vz_um=vz_um,
        output_dir=output_dir,
        alpha_labels=0.45,
        draw_only_inside=True,   # True = only inside neurites; set False for all
        fps=FPS,
        basename="FULLSIZE_overlay_ID_Lysosomes_MAGENTA"
    )

# ==========================================
# Videos (MAGENTA lysosomes)
# ==========================================
if GENERATE_VIDEOS:
    img_norm_2 = (ch2 * 255).astype(np.uint8)
    frames_fused = []
    Z = img_norm_2.shape[0]
    px_um_xy = float(np.sqrt(vx_um * vy_um))

    for z in range(Z):
        base = cv2.cvtColor(img_norm_2[z], cv2.COLOR_GRAY2BGR)

        mask_u8 = (cell_mask_viz[z].astype(np.uint8) * 255)
        overlay = base.copy()
        overlay[..., 1] = np.maximum(overlay[..., 1], mask_u8)
        overlay = cv2.addWeighted(base, 1.0, overlay, 0.35, 0.0)

        drew_any = False

        if isinstance(df, pd.DataFrame) and len(df) > 0:
            dfv = df[
                np.isfinite(df["z_um"]) &
                np.isfinite(df["y_um"]) &
                np.isfinite(df["x_um"]) &
                np.isfinite(df["radius_um"])
            ]
            if not dfv.empty:
                zc = (dfv["z_um"].to_numpy() / vz_um).astype(float)
                yc = (dfv["y_um"].to_numpy() / vy_um).astype(float)
                xc = (dfv["x_um"].to_numpy() / vx_um).astype(float)
                r_um = dfv["radius_um"].to_numpy().astype(float)

                dz_um = np.abs(zc - z) * vz_um
                hits = dz_um <= r_um

                if np.any(hits):
                    r_proj_um = np.sqrt(np.clip(r_um[hits]**2 - dz_um[hits]**2, 0.0, None))
                    r_proj_px = r_proj_um / max(px_um_xy, 1e-12)

                    ys = np.rint(yc[hits]).astype(int)
                    xs = np.rint(xc[hits]).astype(int)

                    H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
                    thickness = 2

                    for y, x, rpv in zip(ys, xs, r_proj_px):
                        rr = int(max(3, round(rpv)))
                        if 0 <= y < H and 0 <= x < W and rr > 0:
                            cv2.circle(overlay, (x, y), rr, LYS_EDGE_BGR, thickness + 2, lineType=cv2.LINE_AA)
                            cv2.circle(overlay, (x, y), rr, LYS_MAGENTA_BGR, thickness, lineType=cv2.LINE_AA)
                            cv2.circle(overlay, (x, y), 1,  LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)
                    drew_any = True

        # Fallback: draw raw blobs if df empty
        if (not drew_any) and (blobs is not None) and (len(blobs) > 0):
            z_blobs = blobs[np.abs(blobs[:, 0] - z) < 0.5]
            H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
            thickness = 2
            for b in z_blobs:
                y, x = int(round(b[1])), int(round(b[2]))
                r = int(max(3, round(b[3])))
                if 0 <= y < H and 0 <= x < W and r > 0:
                    cv2.circle(overlay, (x, y), r, LYS_EDGE_BGR, thickness + 2, lineType=cv2.LINE_AA)
                    cv2.circle(overlay, (x, y), r, LYS_MAGENTA_BGR, thickness, lineType=cv2.LINE_AA)
                    cv2.circle(overlay, (x, y), 1, LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)

        cv2.putText(overlay, "FUSED (mask + MAGENTA lysosomes)", (10, 22),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)

        frames_fused.append(overlay)

    try:
        imageio.mimsave("ch2_fused_cell_magenta.mp4", frames_fused, fps=int(FPS), format="FFMPEG")
        print("Saved: ch2_fused_cell_magenta.mp4")
    except TypeError:
        imageio.mimsave("ch2_fused_cell_magenta.gif", frames_fused, fps=int(FPS))
        print("Saved: ch2_fused_cell_magenta.gif")

    # RAW+FUSED side-by-side
    ch1_u8 = _norm_u8_stack(img_ch1.astype(np.float32))
    ch2_u8 = _norm_u8_stack(ch2.astype(np.float32) * 255.0 / max(1.0, ch2.max()))

    frames_raw = []
    frames_fused_all = []
    frames_side_by_side = []

    Z = ch2_u8.shape[0]
    px_um_xy = float(np.sqrt(vx_um * vy_um))

    for z in range(Z):
        b = ch1_u8[z]
        g = ch2_u8[z]
        r = ch1_u8[z]
        base = np.dstack([b, g, r])

        mask_u8 = (cell_mask_viz[z].astype(np.uint8) * 255)
        overlay = base.copy()
        overlay[..., 1] = np.maximum(overlay[..., 1], mask_u8)
        overlay = cv2.addWeighted(base, 1.0, overlay, 0.35, 0.0)

        drew_any = False

        if isinstance(df, pd.DataFrame) and len(df) > 0:
            dfv = df[
                np.isfinite(df["z_um"]) &
                np.isfinite(df["y_um"]) &
                np.isfinite(df["x_um"]) &
                np.isfinite(df["radius_um"])
            ]
            if not dfv.empty:
                zc = (dfv["z_um"].to_numpy() / vz_um).astype(float)
                yc = (dfv["y_um"].to_numpy() / vy_um).astype(float)
                xc = (dfv["x_um"].to_numpy() / vx_um).astype(float)
                r_um = dfv["radius_um"].to_numpy().astype(float)

                dz_um = np.abs(zc - z) * vz_um
                hits = dz_um <= r_um

                if np.any(hits):
                    r_proj_um = np.sqrt(np.clip(r_um[hits]**2 - dz_um[hits]**2, 0.0, None))
                    r_proj_px = r_proj_um / max(px_um_xy, 1e-12)

                    ys = np.rint(yc[hits]).astype(int)
                    xs = np.rint(xc[hits]).astype(int)

                    H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
                    thickness = 2

                    for y, x, rpv in zip(ys, xs, r_proj_px):
                        rr = int(max(3, round(rpv)))
                        if 0 <= y < H and 0 <= x < W and rr > 0:
                            cv2.circle(overlay, (x, y), rr, LYS_EDGE_BGR, thickness + 2, lineType=cv2.LINE_AA)
                            cv2.circle(overlay, (x, y), rr, LYS_MAGENTA_BGR, thickness, lineType=cv2.LINE_AA)
                            cv2.circle(overlay, (x, y), 1,  LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)
                    drew_any = True

        if (not drew_any) and (blobs is not None) and (len(blobs) > 0):
            z_blobs = blobs[np.abs(blobs[:, 0] - z) < 0.5]
            H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
            thickness = 2
            for b_ in z_blobs:
                y, x = int(round(b_[1])), int(round(b_[2]))
                rpx = int(max(3, round(b_[3])))
                if 0 <= y < H and 0 <= x < W and rpx > 0:
                    cv2.circle(overlay, (x, y), rpx, LYS_EDGE_BGR, thickness + 2, lineType=cv2.LINE_AA)
                    cv2.circle(overlay, (x, y), rpx, LYS_MAGENTA_BGR, thickness, lineType=cv2.LINE_AA)
                    cv2.circle(overlay, (x, y), 1, LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)

        cv2.putText(base, "RAW (Ch1+Ch2)", (10, 22),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(overlay, "FUSED (mask + MAGENTA lysosomes)", (10, 22),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)

        frames_raw.append(base)
        frames_fused_all.append(overlay)
        frames_side_by_side.append(cv2.hconcat([base, overlay]))

    def _save_video_sync(basename, frames):
        try:
            with imageio.get_writer(
                f"{basename}.mp4",
                fps=int(FPS),
                format="FFMPEG",
                codec="libx264",
                macro_block_size=None
            ) as w:
                for fr in frames:
                    w.append_data(fr)
            print(f"Saved: {basename}.mp4 @ {int(FPS)} fps")
        except Exception:
            imageio.mimsave(f"{basename}.gif", frames, duration=1.0 / max(int(FPS), 1))
            print(f"Saved: {basename}.gif @ {int(FPS)} fps equivalent")

    _save_video_sync("ch2_raw", frames_raw)
    _save_video_sync("ch2_fused_all_viz_magenta", frames_fused_all)
    _save_video_sync("ch2_raw_and_fused_all_viz_magenta", frames_side_by_side)

# ==========================================
# Napari visualization
# ==========================================
if LAUNCH_VIEWER:
    viewer = napari.Viewer()
    viewer.dims.ndisplay = 3

    viewer.add_image(img_ch2, name="Ch2 raw")
    viewer.add_image(img_ch1, name="Ch1 raw")

    mask_layer = viewer.add_labels(
        cell_mask_viz.astype(np.uint8),
        name="Neurite mask (viz)" if NEURITE_MODE else "Cell mask (viz)",
        opacity=0.35
    )
    mask_layer.blending = "translucent_no_depth"

    id_layer = viewer.add_labels(
        cell_seg_viz.astype(np.uint16),
        name="ID (viz)",
        opacity=0.25
    )
    id_layer.blending = "translucent_no_depth"

    if len(df) > 0 and "location_ch2" in df:
        in_mask = (df["location_ch2"].to_numpy() == "cell")
        if np.any(in_mask):
            pts_zyx = np.stack([
                (df.loc[in_mask, "z_um"].to_numpy() / vz_um),
                (df.loc[in_mask, "y_um"].to_numpy() / vy_um),
                (df.loc[in_mask, "x_um"].to_numpy() / vx_um),
            ], axis=1)

            radii_vox = df.loc[in_mask, "radius_um"].to_numpy() / (np.sqrt(vx_um * vy_um) + 1e-12)
            sizes = np.clip(radii_vox * 2, 2, None).astype(np.float32)

            pts = viewer.add_points(pts_zyx, size=sizes, name="Lysosomes (inside 3D)")
            pts.face_color = [1.0, 0.0, 1.0, 1.0]  # MAGENTA in napari
            pts.edge_color = "black"
            pts.edge_width = 0.3

            cell_ids = df.loc[in_mask, "cell_id_ch2_viz"].to_numpy().astype(int) if "cell_id_ch2_viz" in df.columns else np.zeros(int(np.sum(in_mask)), dtype=int)
            lys_ids  = df.loc[in_mask, "lys_id_in_cell"].to_numpy().astype(int) if "lys_id_in_cell" in df.columns else np.zeros(int(np.sum(in_mask)), dtype=int)
            diams    = df.loc[in_mask, "diameter_um"].to_numpy().astype(float)

            info = np.array(
                [f"ID:{c}  Ly:{l}  Diam:{d:.3f}µm" for c, l, d in zip(cell_ids, lys_ids, diams)],
                dtype=object
            )
            pts.properties = {"info": info}

    keep = {"Lysosomes (inside 3D)", "ID (viz)", "Neurite mask (viz)", "Cell mask (viz)", "Ch1 raw", "Ch2 raw"}
    for lyr in list(viewer.layers):
        if lyr.name not in keep:
            viewer.layers.remove(lyr)

    try:
        viewer.camera.zoom = 1.2
    except Exception:
        pass

    napari.run()

In [None]:
#####last version with the interval######################3

In [None]:
import os
import re
import numpy as np
import pandas as pd
import cv2
import imageio
import xml.etree.ElementTree as ET
from datetime import datetime

import czifile
import tifffile as tiff

from skimage.feature import blob_log
from skimage.filters import gaussian, threshold_local
from skimage.morphology import (
    remove_small_objects, binary_opening, binary_closing, ball, binary_erosion,
    skeletonize_3d, binary_dilation
)
from scipy.ndimage import distance_transform_edt as edt
from skimage.measure import label
from skimage.segmentation import watershed
from scipy.ndimage import binary_fill_holes
from skimage.draw import line_nd
from scipy.ndimage import convolve

import napari
from scipy.ndimage import gaussian_filter1d
import colorsys

# >>> NEW (for the Napari dock widget / table)
from qtpy.QtWidgets import (
    QWidget, QVBoxLayout, QHBoxLayout, QLabel, QDoubleSpinBox,
    QCheckBox, QPushButton, QTableWidget, QTableWidgetItem, QHeaderView
)

# ===============================
# USER MODE
# ===============================
NEURITE_MODE = True
EXPORT_FULLSIZE_OVERLAY = True

LYS_EDGE_BGR = (0, 0, 0)          # black outline (OpenCV BGR)
LYS_MAGENTA_BGR = (255, 0, 255)   # magenta

# ===============================
# GUI (single unified interface)
# ===============================
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, simpledialog


def get_user_config_gui(
    default_vxy_um=0.04,
    default_vz_um=None,
    default_erode_mult=1.0,
    default_blob_threshold=0.001,
    default_margin_um=0.5,
    default_overlap_alpha=0.4,
    default_neighbor_max_vox=6,
    default_viz_min_voxels=200,
    default_max_reasonable_vxy_um=0.5,
    default_ch1_smooth_sigma=1.0,
    default_blob_min_sigma=0.8,
    default_blob_max_sigma=3.0,
    default_blob_num_sigma=10,
    default_radial_max_radius_nm=300.0,
    default_radial_dr_nm=10.0,
    default_radial_min_drop_fraction=0.5,
    default_ch2_smooth_sigma=1.2,
    default_thresh_block_size=201,
    default_thresh_offset_std_mult=0.12,
    default_video_fps=8,
    default_launch_viewer=True,
    default_generate_videos=True,
):
    cfg = {"ok": False}

    root = tk.Tk()
    root.title("Lysosome + Neurite Segmentation (GUI)")
    root.resizable(False, False)

    file_var = tk.StringVar(value="")
    out_var = tk.StringVar(value="")

    erode_var = tk.StringVar(value=str(default_erode_mult))
    blob_var  = tk.StringVar(value=str(default_blob_threshold))

    show_adv = tk.BooleanVar(value=False)

    margin_var   = tk.StringVar(value=str(default_margin_um))
    overlap_var  = tk.StringVar(value=str(default_overlap_alpha))
    vizmin_var   = tk.StringVar(value=str(default_viz_min_voxels))

    fps_var = tk.StringVar(value=str(default_video_fps))

    launch_viewer_var = tk.BooleanVar(value=bool(default_launch_viewer))
    gen_videos_var    = tk.BooleanVar(value=bool(default_generate_videos))

    def _suggest_output_dir(fp):
        if not fp:
            return ""
        raw_dir = os.path.dirname(fp)
        raw_base = os.path.splitext(os.path.basename(fp))[0]
        stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        return os.path.join(raw_dir, f"{raw_base}_outputs_{stamp}")

    def browse_file():
        fp = filedialog.askopenfilename(
            title="Select image file",
            filetypes=[("Image files", "*.tif *.tiff *.czi"), ("All files", "*.*")],
        )
        if fp:
            file_var.set(fp)
            if not out_var.get().strip():
                out_var.set(_suggest_output_dir(fp))

    def browse_output_dir():
        d = filedialog.askdirectory(title="Select output folder")
        if d:
            fp = file_var.get().strip()
            if fp:
                raw_base = os.path.splitext(os.path.basename(fp))[0]
                stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                out_var.set(os.path.join(d, f"{raw_base}_outputs_{stamp}"))
            else:
                out_var.set(d)

    def _err(msg):
        messagebox.showerror("Invalid input", msg)
        raise ValueError(msg)

    def _float_required(s, name):
        s = (s or "").strip()
        if s == "":
            _err(f"{name} is required.")
        try:
            return float(s.replace(",", "."))
        except Exception:
            _err(f"{name} must be a number (got: {s})")

    def _int_required(s, name):
        s = (s or "").strip()
        if s == "":
            _err(f"{name} is required.")
        try:
            return int(float(s.replace(",", ".")))
        except Exception:
            _err(f"{name} must be an integer (got: {s})")

    def _toggle_adv():
        if show_adv.get():
            adv_frame.grid()
        else:
            adv_frame.grid_remove()

    def run_clicked():
        fp = file_var.get().strip()
        if not fp:
            _err("Please select a file.")
        if not os.path.isfile(fp):
            _err("Selected file does not exist.")

        outd = out_var.get().strip() or _suggest_output_dir(fp)

        erode = _float_required(erode_var.get(), "ERODE_MULT")
        blobt = _float_required(blob_var.get(), "blob_log threshold")
        fps   = _int_required(fps_var.get(), "video FPS")

        margin  = _float_required(margin_var.get(), "MARGIN_UM (µm)")
        overlap = _float_required(overlap_var.get(), "OVERLAP_ALPHA (0..1)")
        vizmin  = _int_required(vizmin_var.get(), "VIZ_MIN_VOXELS (voxels)")

        if not (0.0 <= overlap <= 1.0):
            _err("OVERLAP_ALPHA must be between 0 and 1.")

        cfg.update({
            "ok": True,
            "file_path": fp,
            "output_dir": outd,

            "ERODE_MULT": float(erode),
            "BLOB_THRESHOLD": float(blobt),

            "DEFAULT_VX_VY_UM": float(default_vxy_um),
            "DEFAULT_VZ_UM": None if default_vz_um is None else float(default_vz_um),

            "MAX_REASONABLE_VXY_UM": float(default_max_reasonable_vxy_um),

            "MARGIN_UM": float(margin),
            "OVERLAP_ALPHA": float(overlap),
            "VIZ_MIN_VOXELS": int(vizmin),

            "NEIGHBOR_MAX_VOX": int(default_neighbor_max_vox),

            "CH1_SMOOTH_SIGMA": float(default_ch1_smooth_sigma),
            "BLOB_MIN_SIGMA": float(default_blob_min_sigma),
            "BLOB_MAX_SIGMA": float(default_blob_max_sigma),
            "BLOB_NUM_SIGMA": int(default_blob_num_sigma),

            "RADIAL_MAX_RADIUS_NM": float(default_radial_max_radius_nm),
            "RADIAL_DR_NM": float(default_radial_dr_nm),
            "RADIAL_MIN_DROP_FRACTION": float(default_radial_min_drop_fraction),

            "CH2_SMOOTH_SIGMA": float(default_ch2_smooth_sigma),
            "THRESH_BLOCK_SIZE": int(default_thresh_block_size),
            "THRESH_OFFSET_STD_MULT": float(default_thresh_offset_std_mult),

            "VIDEO_FPS": int(fps),
            "LAUNCH_VIEWER": bool(launch_viewer_var.get()),
            "GENERATE_VIDEOS": bool(gen_videos_var.get()),
        })

        root.destroy()

    def cancel_clicked():
        root.destroy()

    root.protocol("WM_DELETE_WINDOW", cancel_clicked)

    pad = {"padx": 10, "pady": 6}
    frm = ttk.Frame(root)
    frm.grid(row=0, column=0, sticky="nsew", **pad)

    r = 0
    ttk.Label(frm, text="Image file:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=file_var, width=60).grid(row=r, column=1, sticky="we")
    ttk.Button(frm, text="Browse...", command=browse_file).grid(row=r, column=2, sticky="e")
    r += 1

    ttk.Label(frm, text="Output folder:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=out_var, width=60).grid(row=r, column=1, sticky="we")
    ttk.Button(frm, text="Browse...", command=browse_output_dir).grid(row=r, column=2, sticky="e")
    r += 1

    ttk.Separator(frm).grid(row=r, column=0, columnspan=3, sticky="we", pady=8)
    r += 1

    ttk.Label(frm, text="ERODE_MULT:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=erode_var, width=20).grid(row=r, column=1, sticky="w")
    ttk.Label(frm, text="unitless").grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Label(frm, text="blob_log threshold:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=blob_var, width=20).grid(row=r, column=1, sticky="w")
    ttk.Label(frm, text="unitless").grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Separator(frm).grid(row=r, column=0, columnspan=3, sticky="we", pady=8)
    r += 1

    ttk.Checkbutton(frm, text="Launch Napari viewer", variable=launch_viewer_var)\
        .grid(row=r, column=0, columnspan=2, sticky="w")
    ttk.Checkbutton(frm, text="Generate videos", variable=gen_videos_var)\
        .grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Label(frm, text="Video FPS:").grid(row=r, column=0, sticky="w")
    ttk.Entry(frm, textvariable=fps_var, width=20).grid(row=r, column=1, sticky="w")
    ttk.Label(frm, text="frames/sec").grid(row=r, column=2, sticky="w")
    r += 1

    ttk.Separator(frm).grid(row=r, column=0, columnspan=3, sticky="we", pady=8)
    r += 1

    ttk.Checkbutton(frm, text="Show advanced settings", variable=show_adv, command=_toggle_adv)\
        .grid(row=r, column=0, columnspan=3, sticky="w")
    r += 1

    adv_frame = ttk.LabelFrame(frm, text="Advanced")
    adv_frame.grid(row=r, column=0, columnspan=3, sticky="we", pady=6)
    adv_frame.grid_remove()

    rr = 0
    def add_row(label_txt, var, hint):
        nonlocal rr
        ttk.Label(adv_frame, text=label_txt).grid(row=rr, column=0, sticky="w", padx=8, pady=3)
        ttk.Entry(adv_frame, textvariable=var, width=18).grid(row=rr, column=1, sticky="w", padx=8, pady=3)
        ttk.Label(adv_frame, text=hint).grid(row=rr, column=2, sticky="w", padx=8, pady=3)
        rr += 1

    add_row("MARGIN_UM:", margin_var, "µm (soft band around mask)")
    add_row("OVERLAP_ALPHA:", overlap_var, "0..1 (sphere overlap fraction)")
    add_row("VIZ_MIN_VOXELS:", vizmin_var, "voxels (hide small components)")

    r += 1
    ttk.Separator(frm).grid(row=r, column=0, columnspan=3, sticky="we", pady=8)
    r += 1

    btns = ttk.Frame(frm)
    btns.grid(row=r, column=0, columnspan=3, sticky="e")
    ttk.Button(btns, text="Cancel", command=cancel_clicked).grid(row=0, column=0, padx=6)
    ttk.Button(btns, text="Run", command=run_clicked).grid(row=0, column=1, padx=6)

    root.mainloop()

    if not cfg.get("ok"):
        raise SystemExit("Cancelled.")
    return cfg


# ===============================
# Metadata parsing
# ===============================
def _parse_ome_xml(xml_text):
    if not xml_text:
        return None, None, None

    def grab(attr):
        m = re.search(fr'PhysicalSize{attr}="([\d\.eE+-]+)"', xml_text)
        return float(m.group(1)) if m else None

    return grab("X"), grab("Y"), grab("Z")


def _parse_czi_scaling(czi_text):
    if not czi_text:
        return None, None, None

    if isinstance(czi_text, (bytes, bytearray)):
        czi_text = czi_text.decode("utf-8", errors="ignore")

    czi_text = czi_text.replace("\x00", "")

    def _to_float(s):
        if s is None:
            return None
        try:
            return float(str(s).strip().replace(",", "."))
        except Exception:
            return None

    def _to_um(val, unit_hint=None):
        if val is None:
            return None
        if unit_hint:
            u = str(unit_hint).strip().lower()
            if u in ("m", "meter", "metre", "meters", "metres"):
                return val * 1e6
            if u in ("µm", "um", "micron", "microns", "micrometer", "micrometre"):
                return val
            if u in ("nm", "nanometer", "nanometre", "nanometers", "nanometres"):
                return val / 1000.0

        if val < 1e-3:
            return val * 1e6
        if val < 10:
            return val
        if val < 1e5:
            return val / 1000.0
        return None

    try:
        root = ET.fromstring(czi_text)
    except Exception:
        def _grab(axis):
            mm = re.search(
                rf'<Distance[^>]*Id="{axis}"[^>]*>.*?<Value>\s*([0-9eE\+\-\.]+)\s*</Value>',
                czi_text,
                flags=re.IGNORECASE | re.DOTALL,
            )
            return _to_float(mm.group(1)) if mm else None

        return _to_um(_grab("X")), _to_um(_grab("Y")), _to_um(_grab("Z"))

    sx = sy = sz = None
    for d in root.findall(".//{*}Distance"):
        axis = d.attrib.get("Id") or d.attrib.get("id") or d.attrib.get("Axis") or d.attrib.get("axis")
        if not axis:
            continue
        axis = axis.upper()
        unit = d.attrib.get("Unit") or d.attrib.get("unit")

        valf = _to_float(d.attrib.get("Value") or d.attrib.get("value"))

        if valf is None:
            v_el = d.find(".//{*}Value")
            if v_el is not None and v_el.text:
                valf = _to_float(v_el.text)

        if valf is None:
            for child in d.iter():
                if child is d:
                    continue
                if str(child.tag).lower().endswith("value"):
                    valf = _to_float(child.attrib.get("Value") or child.attrib.get("value")) or _to_float(child.text)
                    if valf is not None:
                        break

        val_um = _to_um(valf, unit_hint=unit)
        if val_um is None:
            continue

        if axis == "X":
            sx = val_um
        elif axis == "Y":
            sy = val_um
        elif axis == "Z":
            sz = val_um

    return sx, sy, sz


def load_any(file_path):
    ext = os.path.splitext(file_path)[1].lower()

    if ext in (".tif", ".tiff"):
        with tiff.TiffFile(file_path) as tf:
            arr = tf.asarray()
            try:
                ome_xml = tf.ome_metadata
            except Exception:
                ome_xml = None
            vx_um = vy_um = vz_um = None
            if ome_xml:
                vx_um, vy_um, vz_um = _parse_ome_xml(ome_xml)

        img = np.squeeze(arr)
        if img.ndim == 4:
            if img.shape[0] == 2:
                ch1, ch2 = img[0], img[1]
            elif img.shape[1] == 2:
                ch1, ch2 = img[:, 0], img[:, 1]
            elif img.shape[-1] == 2:
                ch1, ch2 = img[..., 0], img[..., 1]
            else:
                raise RuntimeError("Unexpected TIFF shape for 2 channels")
        else:
            raise RuntimeError("Unexpected TIFF shape")
        return ch1, ch2, (vx_um, vy_um, vz_um), {"type": "tiff"}

    if ext == ".czi":
        with czifile.CziFile(file_path) as cf:
            arr = cf.asarray()
            try:
                czi_xml = cf.metadata()
            except Exception:
                czi_xml = None

        vx_um = vy_um = vz_um = None
        if czi_xml:
            vx_um, vy_um, vz_um = _parse_czi_scaling(czi_xml)

        img = np.squeeze(arr)
        if img.ndim == 4:
            if img.shape[0] == 2:
                ch1, ch2 = img[0], img[1]
            elif img.shape[1] == 2:
                ch1, ch2 = img[:, 0], img[:, 1]
            elif img.shape[-1] == 2:
                ch1, ch2 = img[..., 0], img[..., 1]
            else:
                raise RuntimeError("Unexpected CZI shape for 2 channels")
        else:
            raise RuntimeError("Unexpected CZI shape")

        return ch1, ch2, (vx_um, vy_um, vz_um), {"type": "czi"}

    raise ValueError("Unsupported file format")


# ===============================
# Neurite stitching (direction-aware)
# ===============================
def stitch_neurite_fragments_by_orientation(
    neuron_mask,
    vx_um, vy_um, vz_um,
    max_gap_um=1.0,
    cos_min=0.75,
    bridge_dilate_radius_vox=1,
):
    if neuron_mask is None or neuron_mask.sum() == 0:
        return neuron_mask

    skel = skeletonize_3d(neuron_mask).astype(bool)
    if skel.sum() == 0:
        return neuron_mask

    kernel = np.ones((3, 3, 3), dtype=np.uint8)
    kernel[1, 1, 1] = 0
    neigh = convolve(skel.astype(np.uint8), kernel, mode="constant", cval=0)

    endpoints = skel & (neigh == 1)
    pts = np.argwhere(endpoints)  # zyx
    n = pts.shape[0]
    if n < 2:
        return neuron_mask

    spacing = np.array([vz_um, vy_um, vx_um], dtype=np.float32)

    dirs = np.zeros((n, 3), dtype=np.float32)
    Z, Y, X = skel.shape

    for i, (z, y, x) in enumerate(pts):
        z1, z2 = max(0, z - 1), min(Z, z + 2)
        y1, y2 = max(0, y - 1), min(Y, y + 2)
        x1, x2 = max(0, x - 1), min(X, x + 2)

        patch = skel[z1:z2, y1:y2, x1:x2]
        nbrs = np.argwhere(patch)
        nbrs = nbrs + np.array([z1, y1, x1], dtype=np.int32)
        nbrs = nbrs[~((nbrs[:, 0] == z) & (nbrs[:, 1] == y) & (nbrs[:, 2] == x))]
        if nbrs.shape[0] == 0:
            continue

        d2 = np.sum((nbrs - np.array([z, y, x], dtype=np.int32)) ** 2, axis=1)
        nb = nbrs[int(np.argmin(d2))]

        v_vox = (np.array([z, y, x], dtype=np.float32) - nb.astype(np.float32))
        v_um = v_vox * spacing
        norm = float(np.linalg.norm(v_um))
        if norm > 0:
            dirs[i] = v_um / norm

    candidates = []
    for i in range(n):
        if not np.isfinite(dirs[i]).all() or np.linalg.norm(dirs[i]) == 0:
            continue
        for j in range(i + 1, n):
            if not np.isfinite(dirs[j]).all() or np.linalg.norm(dirs[j]) == 0:
                continue

            d_vox = (pts[j] - pts[i]).astype(np.float32)
            d_um_vec = d_vox * spacing
            dist = float(np.linalg.norm(d_um_vec))
            if dist <= 0 or dist > max_gap_um:
                continue

            v = d_um_vec / dist
            if float(np.dot(dirs[i], v)) < cos_min:
                continue
            if float(np.dot(dirs[j], -v)) < cos_min:
                continue

            candidates.append((dist, i, j))

    if not candidates:
        return neuron_mask

    candidates.sort(key=lambda t: t[0])

    used = np.zeros(n, dtype=bool)
    bridge = np.zeros_like(neuron_mask, dtype=bool)

    for dist, i, j in candidates:
        if used[i] or used[j]:
            continue
        p1 = tuple(pts[i])
        p2 = tuple(pts[j])
        rr = line_nd(p1, p2, endpoint=True)
        bridge[rr] = True
        used[i] = True
        used[j] = True

    if not bridge.any():
        return neuron_mask

    bridged = neuron_mask | binary_dilation(bridge, ball(int(bridge_dilate_radius_vox)))
    bridged = binary_closing(bridged, ball(1))
    return bridged


# ===============================
# Refine radii
# ===============================
def refine_radii_via_dt(img3d, blobs, win_px=40, bin_method="sauvola"):
    from skimage.filters import threshold_sauvola, threshold_local, threshold_otsu
    from skimage.morphology import remove_small_objects, disk, binary_opening
    from scipy.ndimage import distance_transform_edt as _edt
    from skimage.measure import label as _label

    if blobs is None or len(blobs) == 0:
        return blobs

    Z, H, W = img3d.shape
    out = blobs.copy().astype(np.float32)

    for i, (zc, yc, xc, _) in enumerate(out):
        z, y, x = int(round(zc)), int(round(yc)), int(round(xc))
        if not (0 <= z < Z and 0 <= y < H and 0 <= x < W):
            continue

        y1, y2 = max(0, y - win_px), min(H, y + win_px + 1)
        x1, x2 = max(0, x - win_px), min(W, x + win_px + 1)
        patch = img3d[z, y1:y2, x1:x2]

        if bin_method == "sauvola":
            ws = max(11, 2 * (win_px // 2) + 1)
            thr = threshold_sauvola(patch, window_size=ws, k=0.4)
            bw = patch > thr
        elif bin_method == "local":
            ws = max(11, 2 * (win_px // 2) + 1)
            thr = threshold_local(patch, block_size=ws, offset=-0.4 * np.std(patch))
            bw = patch > thr
        else:
            try:
                thr = threshold_otsu(patch)
                bw = patch > thr
            except ValueError:
                continue

        bw = binary_opening(bw, footprint=disk(1))
        bw = remove_small_objects(bw, min_size=3, connectivity=2)

        yy, xx = y - y1, x - x1
        if not (0 <= yy < bw.shape[0] and 0 <= xx < bw.shape[1]) or not bw[yy, xx]:
            continue

        lab = _label(bw)
        lbl = lab[yy, xx]
        if lbl == 0:
            continue
        bw_obj = (lab == lbl)
        dt = _edt(bw_obj)
        r_px = float(dt[yy, xx])
        if r_px <= 0:
            continue
        out[i, 3] = r_px

    return out


def refine_radii_via_radial_intensity(
    img3d,
    blobs,
    vx_um,
    vy_um,
    vz_um,
    max_radius_nm=300.0,
    dr_nm=10.0,
    min_drop_fraction=0.5,
):
    if blobs is None or len(blobs) == 0:
        return blobs

    img = img3d.astype(np.float32)
    Z, Y, X = img.shape
    blobs_out = blobs.copy().astype(np.float32)

    max_r_um = max_radius_nm / 1000.0
    dr_um = dr_nm / 1000.0

    r_edges = np.arange(0.0, max_r_um + dr_um, dr_um)
    if r_edges.size < 2:
        return blobs_out
    r_centers = 0.5 * (r_edges[:-1] + r_edges[1:])

    px_um_xy = float(np.sqrt(vx_um * vy_um))

    for i, (zc, yc, xc, r_px_init) in enumerate(blobs_out):
        z0 = int(round(zc))
        y0 = int(round(yc))
        x0 = int(round(xc))

        if not (0 <= z0 < Z and 0 <= y0 < Y and 0 <= x0 < X):
            continue

        rz = max(1, int(np.ceil(max_r_um / vz_um)))
        ry = max(1, int(np.ceil(max_r_um / vy_um)))
        rx = max(1, int(np.ceil(max_r_um / vx_um)))

        z1, z2 = max(0, z0 - rz), min(Z, z0 + rz + 1)
        y1, y2 = max(0, y0 - ry), min(Y, y0 + ry + 1)
        x1, x2 = max(0, x0 - rx), min(X, x0 + rx + 1)
        if z1 >= z2 or y1 >= y2 or x1 >= x2:
            continue

        patch = img[z1:z2, y1:y2, x1:x2]

        zz, yy, xx = np.mgrid[z1:z2, y1:y2, x1:x2]
        dz_um = (zz - z0) * vz_um
        dy_um = (yy - y0) * vy_um
        dx_um = (xx - x0) * vx_um
        r_um = np.sqrt(dz_um**2 + dy_um**2 + dx_um**2)

        mask = (r_um <= max_r_um)
        if not np.any(mask):
            continue

        r_vals = r_um[mask].ravel()
        I_vals = patch[mask].ravel()

        bin_idx = np.digitize(r_vals, r_edges) - 1
        valid = (bin_idx >= 0) & (bin_idx < r_centers.size)
        if not np.any(valid):
            continue

        bin_idx = bin_idx[valid]
        I_vals = I_vals[valid]

        sums = np.bincount(bin_idx, weights=I_vals, minlength=r_centers.size)
        counts = np.bincount(bin_idx, minlength=r_centers.size)
        with np.errstate(invalid="ignore", divide="ignore"):
            prof = sums / np.maximum(counts, 1)

        have = counts > 0
        if not np.any(have):
            continue

        r_prof = r_centers[have]
        I_prof = prof[have].astype(np.float32)

        I_smooth = gaussian_filter1d(I_prof, sigma=1.0)
        I_max = float(I_smooth.max())
        if I_max <= 0:
            continue
        I_min = float(I_smooth.min())

        if (I_max - I_min) / max(I_max, 1e-9) < min_drop_fraction:
            continue

        I_half = I_min + 0.5 * (I_max - I_min)

        peak_idx = int(np.argmax(I_smooth))
        n_bins = len(I_smooth)

        left_idx = peak_idx
        while left_idx > 0 and I_smooth[left_idx] >= I_half:
            left_idx -= 1
        if left_idx < peak_idx and I_smooth[left_idx] < I_half:
            left_idx += 1

        right_idx = peak_idx
        while right_idx < n_bins - 1 and I_smooth[right_idx] >= I_half:
            right_idx += 1
        if right_idx > peak_idx and I_smooth[right_idx] < I_half:
            right_idx -= 1

        if right_idx <= left_idx:
            continue

        radius_um = 0.5 * (float(r_prof[right_idx]) - float(r_prof[left_idx]))
        if radius_um <= 0:
            continue

        r_fwhm_px = radius_um / max(px_um_xy, 1e-9)
        blobs_out[i, 3] = max(float(r_px_init), float(r_fwhm_px))

    return blobs_out


# ===============================
# Classification helpers
# ===============================
from scipy.ndimage import distance_transform_edt as _edt


def nearest_cell_label(cell_seg, z, y, x, max_r=12):
    Z, Y, X = cell_seg.shape
    for r in range(1, max_r + 1):
        z1, z2 = max(0, z - r), min(Z, z + r + 1)
        y1, y2 = max(0, y - r), min(Y, y + r + 1)
        x1, x2 = max(0, x - r), min(X, x + r + 1)
        patch = cell_seg[z1:z2, y1:y2, x1:x2]
        lab = patch[patch > 0]
        if lab.size:
            return int(np.bincount(lab.ravel()).argmax())
    return 0


def sphere_overlap_fraction(zc_um, yc_um, xc_um, r_um, mask_bool, vx_um, vy_um, vz_um):
    if r_um <= 0:
        return 0.0

    zc = int(round(zc_um / vz_um))
    yc = int(round(yc_um / vy_um))
    xc = int(round(xc_um / vx_um))

    rz = max(1, int(np.ceil(r_um / vz_um)))
    ry = max(1, int(np.ceil(r_um / vy_um)))
    rx = max(1, int(np.ceil(r_um / vx_um)))

    Z, Y, X = mask_bool.shape
    z1, z2 = max(0, zc - rz), min(Z, zc + rz + 1)
    y1, y2 = max(0, yc - ry), min(Y, yc + ry + 1)
    x1, x2 = max(0, xc - rx), min(X, xc + rx + 1)
    if z1 >= z2 or y1 >= y2 or x1 >= x2:
        return 0.0

    zz, yy, xx = np.mgrid[z1:z2, y1:y2, x1:x2]
    dz = (zz - zc) * vz_um
    dy = (yy - yc) * vy_um
    dx = (xx - xc) * vx_um
    sphere = (dz*dz + dy*dy + dx*dx) <= (r_um * r_um)

    if not np.any(sphere):
        return 0.0

    in_mask = mask_bool[z1:z2, y1:y2, x1:x2] & sphere
    return float(in_mask.sum()) / float(sphere.sum())


# ===============================
# Full-size overlay exporter
# ===============================
def _norm_u8_stack(vol):
    vmin, vmax = float(vol.min()), float(vol.max())
    if vmax > vmin:
        out = (np.clip((vol - vmin) / (vmax - vmin), 0, 1) * 255.0).astype(np.uint8)
    else:
        out = np.zeros_like(vol, dtype=np.uint8)
    return out


def make_label_colormap(n_labels, seed_hue=0.13):
    colors = np.zeros((n_labels + 1, 3), dtype=np.uint8)
    if n_labels <= 0:
        return colors
    for i in range(1, n_labels + 1):
        h = (seed_hue + (i - 1) / max(n_labels, 1)) % 1.0
        s = 1.0
        v = 1.0
        r, g, b = colorsys.hsv_to_rgb(h, s, v)
        colors[i] = (int(255 * r), int(255 * g), int(255 * b))
    return colors


def export_fullsize_overlay_stack(
    img_ch1,
    img_ch2_raw,
    cell_seg_viz,
    df,
    vx_um, vy_um, vz_um,
    output_dir,
    alpha_labels=0.45,
    draw_only_inside=True,
    fps=8,
    basename="FULLSIZE_overlay_ID_Lysosomes_MAGENTA",
):
    os.makedirs(output_dir, exist_ok=True)

    ch1_u8 = _norm_u8_stack(img_ch1.astype(np.float32))
    ch2_u8 = _norm_u8_stack(img_ch2_raw.astype(np.float32))

    Z, H, W = ch2_u8.shape
    n_labels = int(cell_seg_viz.max()) if isinstance(cell_seg_viz, np.ndarray) else 0
    cmap = make_label_colormap(n_labels, seed_hue=0.13)

    use_df = None
    if isinstance(df, pd.DataFrame) and len(df) > 0 and {"z_um", "y_um", "x_um", "radius_um"}.issubset(df.columns):
        if draw_only_inside and "location_ch2" in df.columns:
            use_df = df[df["location_ch2"] == "cell"].copy()
        else:
            use_df = df.copy()
        use_df = use_df[
            np.isfinite(use_df["z_um"]) &
            np.isfinite(use_df["y_um"]) &
            np.isfinite(use_df["x_um"]) &
            np.isfinite(use_df["radius_um"])
        ].copy()

    px_um_xy = float(np.sqrt(vx_um * vy_um))
    frames = np.zeros((Z, H, W, 3), dtype=np.uint8)

    for z in range(Z):
        base = np.dstack([ch1_u8[z], ch2_u8[z], ch1_u8[z]]).astype(np.float32)

        lab2d = cell_seg_viz[z].astype(np.int32)
        lab_rgb = cmap[lab2d].astype(np.float32)

        mask = (lab2d > 0)[..., None].astype(np.float32)
        out = base * (1.0 - alpha_labels * mask) + lab_rgb * (alpha_labels * mask)

        if use_df is not None and len(use_df) > 0:
            zc = (use_df["z_um"].to_numpy() / vz_um).astype(float)
            yc = (use_df["y_um"].to_numpy() / vy_um).astype(float)
            xc = (use_df["x_um"].to_numpy() / vx_um).astype(float)
            r_um = use_df["radius_um"].to_numpy().astype(float)

            dz_um = np.abs(zc - z) * vz_um
            hits = dz_um <= r_um
            if np.any(hits):
                r_proj_um = np.sqrt(np.clip(r_um[hits]**2 - dz_um[hits]**2, 0.0, None))
                r_proj_px = r_proj_um / max(px_um_xy, 1e-12)
                ys = np.rint(yc[hits]).astype(int)
                xs = np.rint(xc[hits]).astype(int)

                out_u8 = np.clip(out, 0, 255).astype(np.uint8)
                for y, x, rp in zip(ys, xs, r_proj_px):
                    rr = int(max(3, round(rp)))
                    if 0 <= y < H and 0 <= x < W and rr > 0:
                        cv2.circle(out_u8, (x, y), rr, LYS_EDGE_BGR, 4, lineType=cv2.LINE_AA)
                        cv2.circle(out_u8, (x, y), rr, LYS_MAGENTA_BGR, 2, lineType=cv2.LINE_AA)
                        cv2.circle(out_u8, (x, y), 1,  LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)
                out = out_u8.astype(np.float32)

        frames[z] = np.clip(out, 0, 255).astype(np.uint8)

    tiff_path = os.path.join(output_dir, f"{basename}.tif")
    tiff.imwrite(tiff_path, frames, photometric="rgb")
    print("Saved full-size RGB TIFF stack:", tiff_path)

    mp4_path = os.path.join(output_dir, f"{basename}.mp4")
    try:
        with imageio.get_writer(
            mp4_path,
            fps=int(fps),
            format="FFMPEG",
            codec="libx264",
            macro_block_size=None
        ) as w:
            for fr in frames:
                w.append_data(fr)
        print("Saved full-size MP4:", mp4_path)
    except Exception as e:
        gif_path = os.path.join(output_dir, f"{basename}.gif")
        imageio.mimsave(gif_path, list(frames), fps=int(fps))
        print("FFMPEG failed, saved GIF instead:", gif_path, "Error:", e)

    return tiff_path, mp4_path


# ===============================
# MAIN
# ===============================
DEFAULT_VX_VY_UM = 0.04
DEFAULT_VZ_UM = None

cfg = get_user_config_gui(
    default_vxy_um=DEFAULT_VX_VY_UM,
    default_vz_um=DEFAULT_VZ_UM,
    default_erode_mult=1.0,
    default_blob_threshold=0.001,
)

file_path = cfg["file_path"]
output_dir = cfg["output_dir"]
os.makedirs(output_dir, exist_ok=True)
os.chdir(output_dir)

print("Selected file:", file_path)
print("Outputs will be saved to:", output_dir)

ERODE_MULT = cfg["ERODE_MULT"]
BLOB_THRESHOLD = cfg["BLOB_THRESHOLD"]

MAX_REASONABLE_VXY_UM = cfg["MAX_REASONABLE_VXY_UM"]
MARGIN_UM = cfg["MARGIN_UM"]
OVERLAP_ALPHA = cfg["OVERLAP_ALPHA"]
NEIGHBOR_MAX_VOX = cfg["NEIGHBOR_MAX_VOX"]
VIZ_MIN_VOXELS = cfg["VIZ_MIN_VOXELS"]

CH1_SMOOTH_SIGMA = cfg["CH1_SMOOTH_SIGMA"]
BLOB_MIN_SIGMA = cfg["BLOB_MIN_SIGMA"]
BLOB_MAX_SIGMA = cfg["BLOB_MAX_SIGMA"]
BLOB_NUM_SIGMA = cfg["BLOB_NUM_SIGMA"]

RADIAL_MAX_RADIUS_NM = cfg["RADIAL_MAX_RADIUS_NM"]
RADIAL_DR_NM = cfg["RADIAL_DR_NM"]
RADIAL_MIN_DROP_FRACTION = cfg["RADIAL_MIN_DROP_FRACTION"]

CH2_SMOOTH_SIGMA = cfg["CH2_SMOOTH_SIGMA"]
THRESH_BLOCK_SIZE = cfg["THRESH_BLOCK_SIZE"]
THRESH_OFFSET_STD_MULT = cfg["THRESH_OFFSET_STD_MULT"]

FPS = cfg["VIDEO_FPS"]
LAUNCH_VIEWER = cfg["LAUNCH_VIEWER"]
GENERATE_VIDEOS = cfg["GENERATE_VIDEOS"]

img_ch1, img_ch2, (vx_um, vy_um, vz_um), meta = load_any(file_path)
print(f"[metadata] vx_um={vx_um}  vy_um={vy_um}  vz_um={vz_um}")


def _prompt_missing_size(title, prompt, default_value):
    root = tk.Tk()
    root.withdraw()
    try:
        root.attributes("-topmost", True)
    except Exception:
        pass

    use_default = messagebox.askyesno(
        title,
        f"{prompt}\n\nUse default value: {default_value} ?"
    )
    if use_default:
        root.destroy()
        return float(default_value)

    val = simpledialog.askfloat(
        title,
        "Enter a new value:",
        initialvalue=float(default_value),
        minvalue=1e-12
    )
    root.destroy()
    if val is None:
        raise SystemExit("Cancelled.")
    return float(val)


if vx_um is None or vy_um is None:
    vx_um = vy_um = _prompt_missing_size(
        title="Missing metadata",
        prompt="XY pixel size metadata is missing (µm/px).",
        default_value=float(cfg["DEFAULT_VX_VY_UM"]),
    )

if vz_um is None:
    z_default = float(cfg["DEFAULT_VZ_UM"]) if (cfg["DEFAULT_VZ_UM"] is not None) else float(vx_um)
    vz_um = _prompt_missing_size(
        title="Missing metadata",
        prompt="Z step metadata is missing (µm/slice).",
        default_value=z_default,
    )

if vx_um > MAX_REASONABLE_VXY_UM:
    raise ValueError(f"XY pixel size too large: {vx_um} µm/px")

px_um_xy = float(np.sqrt(vx_um * vy_um))
px_um = px_um_xy * 0.55
voxel_um3 = vx_um * vy_um * vz_um
print(f"Voxel size (µm): X={vx_um}, Y={vy_um}, Z={vz_um}")

image = img_ch1
image_2 = img_ch2

# ==========================================
# Lysosome detection (Ch1)
# ==========================================
image_smooth = gaussian(image, sigma=CH1_SMOOTH_SIGMA)

blobs = blob_log(
    image_smooth,
    min_sigma=BLOB_MIN_SIGMA,
    max_sigma=BLOB_MAX_SIGMA,
    num_sigma=BLOB_NUM_SIGMA,
    threshold=BLOB_THRESHOLD
)

if len(blobs) > 0:
    blobs[:, 3] *= np.sqrt(3)
    blobs = refine_radii_via_dt(image_smooth, blobs)
    blobs = refine_radii_via_radial_intensity(
        image_smooth,
        blobs,
        vx_um, vy_um, vz_um,
        max_radius_nm=RADIAL_MAX_RADIUS_NM,
        dr_nm=RADIAL_DR_NM,
        min_drop_fraction=RADIAL_MIN_DROP_FRACTION
    )

peak_gray = np.zeros(len(blobs), dtype=np.uint16)
Z0, Y0, X0 = image.shape
rad = 1

for i, (zc, yc, xc, _) in enumerate(blobs):
    zc_i = int(round(zc))
    yc_i = int(round(yc))
    xc_i = int(round(xc))

    z1, z2 = max(0, zc_i - rad), min(Z0, zc_i + rad + 1)
    y1, y2 = max(0, yc_i - rad), min(Y0, yc_i + rad + 1)
    x1, x2 = max(0, xc_i - rad), min(X0, xc_i + rad + 1)

    peak_gray[i] = np.max(image[z1:z2, y1:y2, x1:x2]).astype(np.uint16)

if len(blobs) > 0:
    z_um = blobs[:, 0] * vz_um
    y_um = blobs[:, 1] * vy_um
    x_um = blobs[:, 2] * vx_um
    radius_um = blobs[:, 3] * px_um
    diameter_um = 2 * radius_um
    volume_um3 = (4/3) * np.pi * radius_um**3
    blob_ids = np.arange(1, len(blobs) + 1, dtype=int)
else:
    z_um = y_um = x_um = radius_um = diameter_um = volume_um3 = np.array([])
    blob_ids = np.array([], dtype=int)

df = pd.DataFrame({
    "id": blob_ids,
    "z_um": z_um,
    "y_um": y_um,
    "x_um": x_um,
    "radius_um": radius_um,
    "diameter_um": diameter_um,
    "volume_um3": volume_um3,
    "peak_gray": peak_gray,
})
df.to_csv("lysosome_blobs_regions.csv", index=False)
print("Saved: lysosome_blobs_regions.csv")

# ==========================================
# CH2 segmentation (neurites mask)
# ==========================================
vol = image_2.astype(np.float32)
vmin, vmax = float(vol.min()), float(vol.max())
if vmax > vmin:
    vol = (vol - vmin) / (vmax - vmin)
else:
    vol[:] = 0.0

ch2 = gaussian(vol, sigma=CH2_SMOOTH_SIGMA, preserve_range=True)

neuron_mask = np.zeros_like(ch2, dtype=bool)
for z in range(ch2.shape[0]):
    R = ch2[z]
    t = threshold_local(R, block_size=THRESH_BLOCK_SIZE, offset=-THRESH_OFFSET_STD_MULT * np.std(R))
    neuron_mask[z] = R > t

if NEURITE_MODE:
    #neuron_mask = binary_closing(neuron_mask, ball(2))
    neuron_mask = remove_small_objects(neuron_mask, min_size=30, connectivity=3)

    neuron_mask = stitch_neurite_fragments_by_orientation(
        neuron_mask,
        vx_um=vx_um, vy_um=vy_um, vz_um=vz_um,
        max_gap_um=2.5,
        cos_min=0.6,
        bridge_dilate_radius_vox=2,
    )

    neuron_mask = remove_small_objects(neuron_mask, min_size=200, connectivity=3)
else:
    neuron_mask = binary_fill_holes(neuron_mask)

print("neurite voxels:", int(neuron_mask.sum()))

# ---- ID segmentation ----
if NEURITE_MODE:
    cell_seg = label(neuron_mask, connectivity=3).astype(np.int32)
    cell_mask = neuron_mask.copy()
    print("Detected components (neurite networks):", int(cell_seg.max()))
else:
    dist = edt(neuron_mask)
    cell_min_radius_vox = 1
    cell_mask = (dist >= cell_min_radius_vox)
    cell_mask &= neuron_mask
    cell_mask = binary_fill_holes(binary_closing(binary_opening(cell_mask, ball(1)), ball(2)))

    body_lab = label(cell_mask, connectivity=3)
    n_cells = int(body_lab.max())
    print(f"Detected {n_cells} cells (soma).")

    if n_cells > 0:
        dist_inside = edt(neuron_mask)
        cell_seg = watershed(-dist_inside, markers=body_lab, mask=neuron_mask)
    else:
        cell_seg = np.zeros_like(neuron_mask, dtype=np.int32)

print("ID voxels:", int((cell_seg > 0).sum()))

# ==========================================
# Visualization-only filtering (hide tiny components) + serial IDs
# ==========================================
cell_seg_viz = cell_seg.copy()
cell_id_map_viz = {}

if isinstance(cell_seg_viz, np.ndarray) and cell_seg_viz.max() > 0:
    counts_viz = np.bincount(cell_seg_viz.ravel().astype(np.int64))
    tiny_labels = np.where(counts_viz < VIZ_MIN_VOXELS)[0]
    tiny_labels = tiny_labels[tiny_labels > 0]
    if tiny_labels.size > 0:
        tiny_mask = np.isin(cell_seg_viz, tiny_labels)
        cell_seg_viz[tiny_mask] = 0

    unique_labels = np.unique(cell_seg_viz)
    unique_labels = unique_labels[unique_labels > 0]

    if unique_labels.size > 0:
        new_seg = np.zeros_like(cell_seg_viz, dtype=np.int32)
        for new_id, old_id in enumerate(unique_labels, start=1):
            new_seg[cell_seg_viz == old_id] = new_id
            cell_id_map_viz[old_id] = new_id
        cell_seg_viz = new_seg

cell_mask_viz = (cell_seg_viz > 0)

# ==========================================
# Classify lysosomes: inside vs outside neurite mask, and assign ID
# ==========================================
dist_out_um = _edt(~neuron_mask, sampling=(vz_um, vy_um, vx_um)).astype(np.float32)
soft_cell_mask = neuron_mask | (dist_out_um <= MARGIN_UM)

location_ch2 = []
cell_id_list = []

if len(df) > 0:
    Z, Y, X = neuron_mask.shape
    for (zc_um, yc_um, xc_um, r_um) in df[["z_um", "y_um", "x_um", "radius_um"]].to_numpy():
        zz = int(round(zc_um / vz_um))
        yy = int(round(yc_um / vy_um))
        xx = int(round(xc_um / vx_um))

        inside_hard = (0 <= zz < Z and 0 <= yy < Y and 0 <= xx < X and neuron_mask[zz, yy, xx])
        inside_soft = (0 <= zz < Z and 0 <= yy < Y and 0 <= xx < X and soft_cell_mask[zz, yy, xx])

        is_inside = bool(inside_hard)
        if not is_inside and inside_soft:
            is_inside = True

        if not is_inside:
            frac = sphere_overlap_fraction(zc_um, yc_um, xc_um, r_um, neuron_mask, vx_um, vy_um, vz_um)
            if frac >= OVERLAP_ALPHA:
                is_inside = True

        if is_inside:
            cid = 0
            if 0 <= zz < Z and 0 <= yy < Y and 0 <= xx < X:
                if cell_seg[zz, yy, xx] != 0:
                    cid = int(cell_seg[zz, yy, xx])
                else:
                    cid = nearest_cell_label(cell_seg, zz, yy, xx, max_r=NEIGHBOR_MAX_VOX)
            location_ch2.append("cell")
            cell_id_list.append(cid)
        else:
            location_ch2.append("outside")
            cell_id_list.append(0)

    df["location_ch2"] = location_ch2
    df["cell_id_ch2"] = cell_id_list
    df["cell_id_ch2_viz"] = df["cell_id_ch2"].map(cell_id_map_viz).fillna(0).astype(int) if isinstance(cell_id_map_viz, dict) else 0

    df["lys_id_in_cell"] = 0
    mask_in = (df["location_ch2"] == "cell") & (df["cell_id_ch2"] > 0)
    df_sorted = df.loc[mask_in].sort_values(["cell_id_ch2", "z_um", "y_um", "x_um"]).copy()
    df.loc[df_sorted.index, "lys_id_in_cell"] = (df_sorted.groupby("cell_id_ch2").cumcount().to_numpy() + 1).astype(int)

    df.to_csv("lysosomes_with_cell_vs_outside.csv", index=False)
    print("Saved: lysosomes_with_cell_vs_outside.csv")

    df.groupby("location_ch2").size().reset_index(name="count").to_csv("lysosome_counts_cell_vs_outside.csv", index=False)
    (df[df["location_ch2"] == "cell"]
        .groupby("cell_id_ch2").size()
        .reset_index(name="count")
        .to_csv("lysosome_counts_by_cell.csv", index=False))
    print("Saved: lysosome_counts_cell_vs_outside.csv, lysosome_counts_by_cell.csv")

# Per-ID volumes (µm^3)
cell_volume_df = pd.DataFrame(columns=["cell_id_ch2", "voxel_count", "volume_um3"])
if isinstance(cell_seg, np.ndarray) and cell_seg.max() > 0:
    counts = np.bincount(cell_seg.ravel().astype(np.int64))
    cell_ids = np.arange(1, counts.size, dtype=int)
    voxels = counts[1:].astype(np.int64)
    vol_um3 = voxels.astype(float) * voxel_um3

    cell_volume_df = pd.DataFrame({
        "cell_id_ch2": cell_ids,
        "voxel_count": voxels,
        "volume_um3": vol_um3
    })
    cell_volume_df.to_csv("cell_volumes_ch2.csv", index=False)
    print("Saved: cell_volumes_ch2.csv")

    try:
        if len(df) > 0 and "location_ch2" in df and "cell_id_ch2" in df:
            lys_counts = (
                df[df["location_ch2"] == "cell"]
                .groupby("cell_id_ch2").size()
                .reset_index(name="lysosome_count")
            )
        else:
            lys_counts = pd.DataFrame(columns=["cell_id_ch2", "lysosome_count"])

        merged = cell_volume_df.merge(lys_counts, on="cell_id_ch2", how="left").fillna({"lysosome_count": 0})
        merged.to_csv("cell_metrics_ch2.csv", index=False)
        print("Saved: cell_metrics_ch2.csv")
    except Exception as e:
        print("Merge with lysosome counts failed:", e)

# ==========================================
# Export full-size overlay (RAW + colored IDs + MAGENTA lysosomes)
# ==========================================
if EXPORT_FULLSIZE_OVERLAY:
    export_fullsize_overlay_stack(
        img_ch1=img_ch1,
        img_ch2_raw=img_ch2,
        cell_seg_viz=cell_seg_viz,
        df=df,
        vx_um=vx_um, vy_um=vy_um, vz_um=vz_um,
        output_dir=output_dir,
        alpha_labels=0.45,
        draw_only_inside=True,
        fps=FPS,
        basename="FULLSIZE_overlay_ID_Lysosomes_MAGENTA"
    )

# ==========================================
# Videos (MAGENTA lysosomes)  [unchanged]
# ==========================================
if GENERATE_VIDEOS:
    img_norm_2 = (ch2 * 255).astype(np.uint8)
    frames_fused = []
    Z = img_norm_2.shape[0]
    px_um_xy = float(np.sqrt(vx_um * vy_um))

    for z in range(Z):
        base = cv2.cvtColor(img_norm_2[z], cv2.COLOR_GRAY2BGR)

        mask_u8 = (cell_mask_viz[z].astype(np.uint8) * 255)
        overlay = base.copy()
        overlay[..., 1] = np.maximum(overlay[..., 1], mask_u8)
        overlay = cv2.addWeighted(base, 1.0, overlay, 0.35, 0.0)

        drew_any = False

        if isinstance(df, pd.DataFrame) and len(df) > 0:
            dfv = df[
                np.isfinite(df["z_um"]) &
                np.isfinite(df["y_um"]) &
                np.isfinite(df["x_um"]) &
                np.isfinite(df["radius_um"])
            ]
            if not dfv.empty:
                zc = (dfv["z_um"].to_numpy() / vz_um).astype(float)
                yc = (dfv["y_um"].to_numpy() / vy_um).astype(float)
                xc = (dfv["x_um"].to_numpy() / vx_um).astype(float)
                r_um = dfv["radius_um"].to_numpy().astype(float)

                dz_um = np.abs(zc - z) * vz_um
                hits = dz_um <= r_um

                if np.any(hits):
                    r_proj_um = np.sqrt(np.clip(r_um[hits]**2 - dz_um[hits]**2, 0.0, None))
                    r_proj_px = r_proj_um / max(px_um_xy, 1e-12)

                    ys = np.rint(yc[hits]).astype(int)
                    xs = np.rint(xc[hits]).astype(int)

                    H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
                    thickness = 2

                    for y, x, rpv in zip(ys, xs, r_proj_px):
                        rr = int(max(3, round(rpv)))
                        if 0 <= y < H and 0 <= x < W and rr > 0:
                            cv2.circle(overlay, (x, y), rr, LYS_EDGE_BGR, thickness + 2, lineType=cv2.LINE_AA)
                            cv2.circle(overlay, (x, y), rr, LYS_MAGENTA_BGR, thickness, lineType=cv2.LINE_AA)
                            cv2.circle(overlay, (x, y), 1,  LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)
                    drew_any = True

        if (not drew_any) and (blobs is not None) and (len(blobs) > 0):
            z_blobs = blobs[np.abs(blobs[:, 0] - z) < 0.5]
            H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
            thickness = 2
            for b in z_blobs:
                y, x = int(round(b[1])), int(round(b[2]))
                r = int(max(3, round(b[3])))
                if 0 <= y < H and 0 <= x < W and r > 0:
                    cv2.circle(overlay, (x, y), r, LYS_EDGE_BGR, thickness + 2, lineType=cv2.LINE_AA)
                    cv2.circle(overlay, (x, y), r, LYS_MAGENTA_BGR, thickness, lineType=cv2.LINE_AA)
                    cv2.circle(overlay, (x, y), 1, LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)

        cv2.putText(overlay, "FUSED (mask + MAGENTA lysosomes)", (10, 22),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)

        frames_fused.append(overlay)

    try:
        imageio.mimsave("ch2_fused_cell_magenta.mp4", frames_fused, fps=int(FPS), format="FFMPEG")
        print("Saved: ch2_fused_cell_magenta.mp4")
    except TypeError:
        imageio.mimsave("ch2_fused_cell_magenta.gif", frames_fused, fps=int(FPS))
        print("Saved: ch2_fused_cell_magenta.gif")

    ch1_u8 = _norm_u8_stack(img_ch1.astype(np.float32))
    ch2_u8 = _norm_u8_stack(ch2.astype(np.float32) * 255.0 / max(1.0, ch2.max()))

    frames_raw = []
    frames_fused_all = []
    frames_side_by_side = []

    Z = ch2_u8.shape[0]
    px_um_xy = float(np.sqrt(vx_um * vy_um))

    for z in range(Z):
        b = ch1_u8[z]
        g = ch2_u8[z]
        r = ch1_u8[z]
        base = np.dstack([b, g, r])

        mask_u8 = (cell_mask_viz[z].astype(np.uint8) * 255)
        overlay = base.copy()
        overlay[..., 1] = np.maximum(overlay[..., 1], mask_u8)
        overlay = cv2.addWeighted(base, 1.0, overlay, 0.35, 0.0)

        drew_any = False

        if isinstance(df, pd.DataFrame) and len(df) > 0:
            dfv = df[
                np.isfinite(df["z_um"]) &
                np.isfinite(df["y_um"]) &
                np.isfinite(df["x_um"]) &
                np.isfinite(df["radius_um"])
            ]
            if not dfv.empty:
                zc = (dfv["z_um"].to_numpy() / vz_um).astype(float)
                yc = (dfv["y_um"].to_numpy() / vy_um).astype(float)
                xc = (dfv["x_um"].to_numpy() / vx_um).astype(float)
                r_um = dfv["radius_um"].to_numpy().astype(float)

                dz_um = np.abs(zc - z) * vz_um
                hits = dz_um <= r_um

                if np.any(hits):
                    r_proj_um = np.sqrt(np.clip(r_um[hits]**2 - dz_um[hits]**2, 0.0, None))
                    r_proj_px = r_proj_um / max(px_um_xy, 1e-12)

                    ys = np.rint(yc[hits]).astype(int)
                    xs = np.rint(xc[hits]).astype(int)

                    H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
                    thickness = 2

                    for y, x, rpv in zip(ys, xs, r_proj_px):
                        rr = int(max(3, round(rpv)))
                        if 0 <= y < H and 0 <= x < W and rr > 0:
                            cv2.circle(overlay, (x, y), rr, LYS_EDGE_BGR, thickness + 2, lineType=cv2.LINE_AA)
                            cv2.circle(overlay, (x, y), rr, LYS_MAGENTA_BGR, thickness, lineType=cv2.LINE_AA)
                            cv2.circle(overlay, (x, y), 1,  LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)
                    drew_any = True

        if (not drew_any) and (blobs is not None) and (len(blobs) > 0):
            z_blobs = blobs[np.abs(blobs[:, 0] - z) < 0.5]
            H, W = cell_mask_viz.shape[1], cell_mask_viz.shape[2]
            thickness = 2
            for b_ in z_blobs:
                y, x = int(round(b_[1])), int(round(b_[2]))
                rpx = int(max(3, round(b_[3])))
                if 0 <= y < H and 0 <= x < W and rpx > 0:
                    cv2.circle(overlay, (x, y), rpx, LYS_EDGE_BGR, thickness + 2, lineType=cv2.LINE_AA)
                    cv2.circle(overlay, (x, y), rpx, LYS_MAGENTA_BGR, thickness, lineType=cv2.LINE_AA)
                    cv2.circle(overlay, (x, y), 1, LYS_MAGENTA_BGR, -1, lineType=cv2.LINE_AA)

        cv2.putText(base, "RAW (Ch1+Ch2)", (10, 22),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(overlay, "FUSED (mask + MAGENTA lysosomes)", (10, 22),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)

        frames_raw.append(base)
        frames_fused_all.append(overlay)
        frames_side_by_side.append(cv2.hconcat([base, overlay]))

    def _save_video_sync(basename, frames):
        try:
            with imageio.get_writer(
                f"{basename}.mp4",
                fps=int(FPS),
                format="FFMPEG",
                codec="libx264",
                macro_block_size=None
            ) as w:
                for fr in frames:
                    w.append_data(fr)
            print(f"Saved: {basename}.mp4 @ {int(FPS)} fps")
        except Exception:
            imageio.mimsave(f"{basename}.gif", frames, duration=1.0 / max(int(FPS), 1))
            print(f"Saved: {basename}.gif @ {int(FPS)} fps equivalent")

    _save_video_sync("ch2_raw", frames_raw)
    _save_video_sync("ch2_fused_all_viz_magenta", frames_fused_all)
    _save_video_sync("ch2_raw_and_fused_all_viz_magenta", frames_side_by_side)

# ==========================================
# Napari visualization  (UPDATED with diameter filter widget)
# ==========================================
if LAUNCH_VIEWER:
    viewer = napari.Viewer()
    viewer.dims.ndisplay = 3

    viewer.add_image(img_ch2, name="Ch2 raw")
    viewer.add_image(img_ch1, name="Ch1 raw")

    mask_layer = viewer.add_labels(
        cell_mask_viz.astype(np.uint8),
        name="Neurite mask (viz)" if NEURITE_MODE else "Cell mask (viz)",
        opacity=0.35
    )
    mask_layer.blending = "translucent_no_depth"

    id_layer = viewer.add_labels(
        cell_seg_viz.astype(np.uint16),
        name="ID (viz)",
        opacity=0.25
    )
    id_layer.blending = "translucent_no_depth"

    # -------------------------------
    # NEW: Diameter filter dock widget
    # -------------------------------
    def _make_points_from_df(df_sel: pd.DataFrame):
        if df_sel is None or df_sel.empty:
            return np.zeros((0, 3), dtype=np.float32), np.zeros((0,), dtype=np.float32), np.array([], dtype=object)

        pts_zyx = np.stack([
            (df_sel["z_um"].to_numpy(dtype=float) / vz_um),
            (df_sel["y_um"].to_numpy(dtype=float) / vy_um),
            (df_sel["x_um"].to_numpy(dtype=float) / vx_um),
        ], axis=1).astype(np.float32)

        radii_vox = df_sel["radius_um"].to_numpy(dtype=float) / (np.sqrt(vx_um * vy_um) + 1e-12)
        sizes = np.clip(radii_vox * 2.0, 2.0, None).astype(np.float32)

        cell_ids = df_sel["cell_id_ch2_viz"].to_numpy(dtype=int) if "cell_id_ch2_viz" in df_sel.columns else np.zeros(len(df_sel), dtype=int)
        lys_ids  = df_sel["lys_id_in_cell"].to_numpy(dtype=int) if "lys_id_in_cell" in df_sel.columns else np.zeros(len(df_sel), dtype=int)
        diams    = df_sel["diameter_um"].to_numpy(dtype=float)

        info = np.array(
            [f"ID:{c}  Ly:{l}  Diam:{d:.3f}µm" for c, l, d in zip(cell_ids, lys_ids, diams)],
            dtype=object
        )
        return pts_zyx, sizes, info

    if not isinstance(df, pd.DataFrame) or df.empty or "diameter_um" not in df.columns:
        df_all = pd.DataFrame(columns=["z_um", "y_um", "x_um", "radius_um", "diameter_um"])
    else:
        df_all = df.copy()

    df_inside = df_all[df_all.get("location_ch2", "cell") == "cell"].copy() if "location_ch2" in df_all.columns else df_all.copy()

    if not df_all.empty:
        dmin0 = float(np.nanmin(df_all["diameter_um"].to_numpy(dtype=float)))
        dmax0 = float(np.nanmax(df_all["diameter_um"].to_numpy(dtype=float)))
        if not np.isfinite(dmin0): dmin0 = 0.0
        if not np.isfinite(dmax0): dmax0 = 1.0
        if dmax0 <= dmin0: dmax0 = dmin0 + 1e-6
    else:
        dmin0, dmax0 = 0.0, 1.0

    filtered_layer_name = "Lysosomes (filtered)"
    if filtered_layer_name in viewer.layers:
        pts_layer = viewer.layers[filtered_layer_name]
    else:
        pts0, sizes0, info0 = _make_points_from_df(df_inside)
        pts_layer = viewer.add_points(pts0, size=sizes0, name=filtered_layer_name)
        pts_layer.face_color = [1.0, 0.0, 1.0, 1.0]
        pts_layer.edge_color = "black"
        pts_layer.edge_width = 0.3
        pts_layer.properties = {"info": info0}

    dock = QWidget()
    vbox = QVBoxLayout(dock)

    title = QLabel("Diameter filter (µm)")
    vbox.addWidget(title)

    row = QHBoxLayout()
    vbox.addLayout(row)

    row.addWidget(QLabel("Min:"))
    spin_min = QDoubleSpinBox()
    spin_min.setDecimals(4)
    spin_min.setRange(dmin0, dmax0)
    spin_min.setSingleStep(max((dmax0 - dmin0) / 100.0, 0.001))
    spin_min.setValue(dmin0)
    row.addWidget(spin_min)

    row.addWidget(QLabel("Max:"))
    spin_max = QDoubleSpinBox()
    spin_max.setDecimals(4)
    spin_max.setRange(dmin0, dmax0)
    spin_max.setSingleStep(max((dmax0 - dmin0) / 100.0, 0.001))
    spin_max.setValue(dmax0)
    row.addWidget(spin_max)

    inside_only = QCheckBox("Inside neurites only")
    inside_only.setChecked(True)
    vbox.addWidget(inside_only)

    count_lbl = QLabel("0 lysosomes shown")
    vbox.addWidget(count_lbl)

    table = QTableWidget()
    vbox.addWidget(table)

    btn_row = QHBoxLayout()
    vbox.addLayout(btn_row)
    export_btn = QPushButton("Export filtered CSV")
    btn_row.addWidget(export_btn)

    state = {"df_sel": pd.DataFrame()}

    def _fill_table(df_show: pd.DataFrame, max_rows=500):
        cols = [c for c in [
            "id", "cell_id_ch2_viz", "lys_id_in_cell",
            "diameter_um", "radius_um", "volume_um3", "peak_gray",
            "z_um", "y_um", "x_um", "location_ch2"
        ] if c in df_show.columns]

        df_show2 = df_show[cols].copy() if cols else df_show.copy()
        if len(df_show2) > max_rows:
            df_show2 = df_show2.iloc[:max_rows].copy()

        table.clear()
        table.setColumnCount(len(df_show2.columns))
        table.setRowCount(len(df_show2))
        table.setHorizontalHeaderLabels(list(df_show2.columns))

        for r in range(len(df_show2)):
            for c, colname in enumerate(df_show2.columns):
                val = df_show2.iloc[r, c]
                if isinstance(val, float):
                    txt = f"{val:.6g}"
                else:
                    txt = str(val)
                table.setItem(r, c, QTableWidgetItem(txt))

        header = table.horizontalHeader()
        header.setSectionResizeMode(QHeaderView.ResizeToContents)

    def _apply_filter():
        dmin = float(spin_min.value())
        dmax = float(spin_max.value())
        if dmin > dmax:
            dmin, dmax = dmax, dmin
            spin_min.blockSignals(True); spin_max.blockSignals(True)
            spin_min.setValue(dmin); spin_max.setValue(dmax)
            spin_min.blockSignals(False); spin_max.blockSignals(False)

        src = df_inside if inside_only.isChecked() else df_all
        if src.empty:
            df_sel = src
        else:
            dd = src["diameter_um"].to_numpy(dtype=float)
            m = np.isfinite(dd) & (dd >= dmin) & (dd <= dmax)
            df_sel = src.loc[m].copy()

        state["df_sel"] = df_sel

        pts, sizes, info = _make_points_from_df(df_sel)
        pts_layer.data = pts
        pts_layer.size = sizes
        pts_layer.properties = {"info": info}

        count_lbl.setText(f"{len(df_sel)} lysosomes shown (diameter {dmin:.4g}–{dmax:.4g} µm)")
        _fill_table(df_sel, max_rows=500)

    def _export_filtered():
        df_sel = state.get("df_sel", pd.DataFrame())
        if df_sel is None or df_sel.empty:
            print("Nothing to export (filtered selection is empty).")
            return
        dmin = float(spin_min.value())
        dmax = float(spin_max.value())
        fn = f"lysosomes_filtered_diam_{dmin:.4g}_to_{dmax:.4g}.csv".replace(" ", "_").replace("µ", "u")
        df_sel.to_csv(fn, index=False)
        print("Saved:", fn)

    spin_min.valueChanged.connect(_apply_filter)
    spin_max.valueChanged.connect(_apply_filter)
    inside_only.stateChanged.connect(lambda _: _apply_filter())
    export_btn.clicked.connect(_export_filtered)

    viewer.window.add_dock_widget(dock, name="Lysosome diameter filter", area="right")
    _apply_filter()

    keep = {
        "Lysosomes (filtered)",
        "ID (viz)",
        "Neurite mask (viz)",
        "Cell mask (viz)",
        "Ch1 raw",
        "Ch2 raw"
    }
    for lyr in list(viewer.layers):
        if lyr.name not in keep:
            viewer.layers.remove(lyr)

    try:
        viewer.camera.zoom = 1.2
    except Exception:
        pass

    napari.run()

Selected file: C:/Users/nahue/Downloads/PROYECT OREN/axons/Airy scan_40A_UAS-TMEM-HA_axon_0h_again_3_051222.czi
Outputs will be saved to: C:/Users/nahue/Downloads/PROYECT OREN/axons\Airy scan_40A_UAS-TMEM-HA_axon_0h_again_3_051222_outputs_20260122_130603
[metadata] vx_um=0.04579717184801375  vy_um=0.04579717184801375  vz_um=0.25
Voxel size (µm): X=0.04579717184801375, Y=0.04579717184801375, Z=0.25
