In [None]:
# =========================================================
# Lysosome detection + neuron segmentation (CZI or TIFF)
# =========================================================

import os
import re
import numpy as np
import pandas as pd
import cv2
import imageio
import xml.etree.ElementTree as ET

import czifile
import tifffile as tiff

from skimage.feature import blob_log
from skimage.filters import gaussian, threshold_local
from skimage.morphology import (
    remove_small_objects, binary_opening, binary_closing, ball, binary_erosion
)
from scipy.ndimage import distance_transform_edt as edt
from skimage.measure import label, regionprops
from skimage.segmentation import watershed, find_boundaries

import napari

# ==========================================
# CONFIG: set a file path (.czi or .tif/.tiff)
# ==========================================
# Examples:
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/Airy scan_40A_UAS-TMEM-HA_CB_4h_1_051222.czi"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/Airy scan_40A_UAS-TMEM-HA_CB_4h_2_051222.czi"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/40A_UAS-TMEM1923x-HA x 71G10 40A MARCM_L3_1_Airy_010724.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/40A_UAS-TMEM1923x-HA x 71G10 40A MARCM_L3_2_Airy_010724.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/40A_UAS-TMEM192-3xHA x 40A 71G10 MARCM_around 12h - for quantification_4 Airy-CBs_300425.czi"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/40A_UAS-TMEM192-3xHA x 40A 71G10 MARCM_around 12h - for quantification_3 Airy-CBs_300425-1.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/Airy scan_40A_UAS-TMEM-HA_CB_0h_1_051222.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/Airy scan_40A_UAS-TMEM-HA_CB_0h_2_051222.tif"

#new folder
file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/ISP MAPATZ 2025/40A_UAS-TMEM1923x-HA x 71G10 40A MARCM_L3_2_Airy_010724.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/ISP MAPATZ 2025/40A_UAS-TMEM192-3xHA x 40A 71G10 MARCM_around 12h - for quantification_3 Airy-CBs_300425.tif"
#file_path = r"C:/Users/nahue/Downloads/PROYECT OREN/images/ISP MAPATZ 2025/TMEM-HA 6h CB airy_3_170722.tif"

# ==========================================
# Generic loader for CZI and TIFF (OME/ImageJ)
# Returns: ch1, ch2, (vx_um, vy_um, vz_um), meta
# ==========================================
def _try_float(x, default=None):
    try:
        return float(x)
    except Exception:
        return default

def _parse_ome_xml(xml_text: str):
    """Parse OME-XML PhysicalSizeX/Y/Z → µm (graceful fallback)."""
    if not xml_text:
        return None
    def grab(attr):
        m = re.search(fr'PhysicalSize{attr}="([\d\.eE+-]+)"(?:\s+PhysicalSize{attr}Unit="([^"]+)")?', xml_text)
        if not m:
            return None, None
        return m.group(1), (m.group(2) or "")

    x_val, x_unit = grab('X')
    y_val, y_unit = grab('Y')
    z_val, z_unit = grab('Z')

    def to_um(val_str, unit):
        v = _try_float(val_str)
        if v is None:
            return None
        u = unit.lower()
        if u in ("µm", "um", "micrometer", "micrometre", "microns", "micron"):
            return v
        if u in ("m", "meter", "metre"):
            return v * 1e6
        # If unit missing: treat tiny as meters, else µm
        return v * 1e6 if v < 1e-3 else v

    vx = to_um(x_val, x_unit) if x_val is not None else None
    vy = to_um(y_val, y_unit) if y_val is not None else None
    vz = to_um(z_val, z_unit) if z_val is not None else None
    return vx, vy, vz

def _voxel_from_tiff_tags(tf: tiff.TiffFile):
    """
    Try ImageJ-style X/Y resolution (inch/cm → µm/px) and Z spacing from ImageDescription.
    Returns (vx_um, vy_um, vz_um) with None if not found.
    """
    vx = vy = vz = None
    try:
        page0 = tf.pages[0]
        xres = getattr(page0, "tags", {}).get("XResolution", None)
        yres = getattr(page0, "tags", {}).get("YResolution", None)
        resunit = getattr(page0, "tags", {}).get("ResolutionUnit", None)

        def res_to_um(res_tag, unit_tag):
            if res_tag is None:
                return None
            val = res_tag.value
            if isinstance(val, tuple) and len(val) == 2:
                num, den = val
            else:
                try:
                    num, den = val.numerator, val.denominator
                except Exception:
                    return None
            if den == 0:
                return None
            ppu = num / den  # pixels per unit
            if ppu <= 0:
                return None
            unit = (unit_tag.value if unit_tag else 2)  # 2=inches, 3=cm
            if unit == 2:      # inch
                um_per_unit = 25400.0
            elif unit == 3:    # centimeter
                um_per_unit = 10000.0
            else:
                return None
            return um_per_unit / ppu  # µm per pixel

        vx = res_to_um(xres, resunit)
        vy = res_to_um(yres, resunit)

        # Try ImageDescription for Z spacing (often in µm)
        desc = page0.tags.get("ImageDescription", None)
        if desc is not None:
            txt = str(desc.value)
            m = re.search(r'(spacing|SliceSpacing)[=:]\s*([0-9.+-eE]+)', txt)
            if m:
                vz = _try_float(m.group(2))
    except Exception:
        pass
    return vx, vy, vz

def load_any(file_path):
    ext = os.path.splitext(file_path)[1].lower()
    vx_um = vy_um = vz_um = None

    if ext == ".czi":
        # -------- CZI --------
        with czifile.CziFile(file_path) as czi:
            img = czi.asarray()
            meta_xml = czi.metadata()
        img = np.squeeze(img)

        # Infer channels
        if img.ndim < 3:
            raise RuntimeError(f"CZI has unexpected ndim={img.ndim}")
        if img.shape[0] == 2:        # (C, Z, Y, X)
            ch1, ch2 = img[0], img[1]
        elif img.shape[1] == 2:      # (Z, C, Y, X)
            ch1, ch2 = img[:, 0], img[:, 1]
        else:
            raise RuntimeError(f"Can't auto-detect 2 channels in CZI shape {img.shape}")

        # Voxel size from CZI XML
        vx_um = vy_um = vz_um = 1.0
        try:
            r = ET.fromstring(meta_xml)
            def _get_um(axis: str) -> float:
                v = r.find(f".//{{*}}Scaling/{{*}}Items/{{*}}Distance[@Id='{axis}']/{{*}}Value")
                u = r.find(f".//{{*}}Scaling/{{*}}Items/{{*}}Distance[@Id='{axis}']/{{*}}DefaultUnit")
                val = float(v.text) if v is not None else 1.0
                unit = (u.text or "").lower() if u is not None else ""
                if unit in ("m", "meter", "metre") or (unit == "" and val < 1e-3):
                    val *= 1e6
                elif unit in ("µm", "um", "micrometer", "micrometre"):
                    pass
                else:
                    if val < 1e-3:
                        val *= 1e6
                return val
            vx_um, vy_um, vz_um = _get_um("X"), _get_um("Y"), _get_um("Z")
        except Exception:
            pass
        return ch1, ch2, (vx_um or 1.0, vy_um or 1.0, vz_um or 1.0), {"type": "czi"}

    elif ext in (".tif", ".tiff"):
        # -------- TIFF / OME-TIFF --------
        with tiff.TiffFile(file_path) as tf:
            arr = tf.asarray()
            ome_xml = None
            try:
                ome_xml = tf.ome_metadata
            except Exception:
                pass

            if ome_xml:
                vx_um, vy_um, vz_um = _parse_ome_xml(ome_xml) or (None, None, None)
            if vx_um is None or vy_um is None or vz_um is None:
                tx, ty, tz = _voxel_from_tiff_tags(tf)
                vx_um = vx_um if vx_um is not None else tx
                vy_um = vy_um if vy_um is not None else ty
                vz_um = vz_um if vz_um is not None else tz

        img = np.squeeze(arr)

        # Layouts we support:
        # (C,Z,Y,X) | (Z,C,Y,X) | (Z,Y,X,C) | (Y,X,C with Z=1)
        if img.ndim == 4:
            if img.shape[0] == 2:        # (C, Z, Y, X)
                ch1, ch2 = img[0], img[1]
            elif img.shape[1] == 2:      # (Z, C, Y, X)
                ch1, ch2 = img[:, 0], img[:, 1]
            elif img.shape[-1] == 2:     # (Z, Y, X, C)
                ch1, ch2 = img[..., 0], img[..., 1]
            else:
                raise RuntimeError(f"Cannot infer 2 channels from TIFF shape {img.shape}")
        elif img.ndim == 3:
            if img.shape[-1] == 2:       # (Y, X, C) → promote Z=1
                ch1 = img[..., 0][None, ...]
                ch2 = img[..., 1][None, ...]
            else:
                raise RuntimeError("TIFF is single-channel; expected 2 channels.")
        else:
            raise RuntimeError(f"Unsupported TIFF ndim: {img.ndim}")

        return ch1, ch2, (vx_um or 1.0, vy_um or 1.0, vz_um or 1.0), {"type": "tiff", "ome": bool(ome_xml)}

    else:
        raise ValueError(f"Unsupported file extension: {ext}")

# ==========================================
# 1) Load image (CZI or TIFF) + voxel size (µm)
# ==========================================
img_ch1, img_ch2, (vx_um, vy_um, vz_um), meta = load_any(file_path)

# Choose which channel is lysosomes vs neuron channel
image   = img_ch1      # Ch1: lysosome channel
image_2 = img_ch2      # Ch2: neuron (CELL vs OUTSIDE)

# Per-voxel metrics
voxel_um3    = vz_um * vy_um * vx_um                 # µm^3 per voxel
lin_equiv_um = voxel_um3 ** (1.0 / 3.0)              # linear scale preserving volume
print(f"[{meta.get('type','?').upper()}] Voxel size (µm): X={vx_um:.4g}, Y={vy_um:.4g}, Z={vz_um:.4g} | equiv linear µm={lin_equiv_um:.4g}")
print("Ch1 shape:", image.shape, "| Ch2 shape:", image_2.shape)

# ==========================================
# 2) CH1: blobs (lysosomes) — metrics in µm
# ==========================================
image_smooth = gaussian(image, sigma=1)
blobs = blob_log(
    image_smooth,
    min_sigma=1,#1
    max_sigma=4,#10
    num_sigma=3,#8
    threshold=0.0025,   # tune (0.003 was stricter)
    overlap=0.8#0.5
)

# LoG returns sigma; convert to a scale radius in index coords; for 3D use sqrt(3)
if len(blobs) > 0:
    # skimage returns columns [z, y, x, sigma] for 3D
    blobs[:, 3] = blobs[:, 3] * np.sqrt(3)  # radius in index units (pixels)
print(f"Detected {len(blobs)} lysosomes.")

# Convert per-lysosome metrics to physical units (µm)
if len(blobs) > 0:
    z_um = blobs[:, 0] * vz_um
    y_um = blobs[:, 1] * vy_um
    x_um = blobs[:, 2] * vx_um

    # equivalent-sphere
    radius_um   = (blobs[:, 3] * lin_equiv_um) * 0.5
    diameter_um = 2.0 * radius_um
    volume_um3  = (4.0/3.0) * np.pi * (radius_um ** 3)
else:
    z_um = y_um = x_um = np.array([])
    radius_um = diameter_um = volume_um3 = np.array([])

# Optional region bins (commented)
num_bins = (4, 4, 4)
z_bins = np.linspace(0, image.shape[0], num_bins[0] + 1, dtype=int)
y_bins = np.linspace(0, image.shape[1], num_bins[1] + 1, dtype=int)
x_bins = np.linspace(0, image.shape[2], num_bins[2] + 1, dtype=int)

if len(blobs) > 0:
    z_idx = np.clip(np.digitize(blobs[:, 0], z_bins) - 1, 0, num_bins[0]-1)
    y_idx = np.clip(np.digitize(blobs[:, 1], y_bins) - 1, 0, num_bins[1]-1)
    x_idx = np.clip(np.digitize(blobs[:, 2], x_bins) - 1, 0, num_bins[2]-1)
    region_labels = z_idx * (num_bins[1] * num_bins[2]) + y_idx * num_bins[2] + x_idx
else:
    region_labels = np.array([], dtype=int)

# Per-blob DF (µm only)
blob_ids = np.arange(1, len(blobs) + 1)
df = pd.DataFrame({
    "id": blob_ids,
    "z_um": z_um,
    "y_um": y_um,
    "x_um": x_um,
    "diameter_um": diameter_um,
    "radius_um": radius_um,
    "volume_um3": volume_um3,
    # "region_id": region_labels
})
df.to_csv("lysosome_blobs_regions.csv", index=False)
print("Saved: lysosome_blobs_regions.csv (µm-only)")

# ==========================================
# 3) Viewer base
# ==========================================
viewer = napari.Viewer()
viewer.add_image(image_2, name='Ch2 raw', blending='additive')
viewer.add_image(image,  name='Ch1 raw', blending='additive')

# ==========================================
# 4) CH2: segmentation (CELL vs OUTSIDE)
# ==========================================
# Normalize & denoise
vol = image_2.astype(np.float32)
vmin, vmax = float(vol.min()), float(vol.max())
if vmax > vmin:
    vol = (vol - vmin) / (vmax - vmin)
else:
    vol[:] = 0.0
ch2 = gaussian(vol, sigma=0.6, preserve_range=True)

# Base neuron "foreground" via local threshold per z
neuron_mask = np.zeros_like(ch2, dtype=bool)
for z in range(ch2.shape[0]):
    R = ch2[z]
    t = threshold_local(R, block_size=91, offset=-0.3*np.std(R))#91, 0.3
    neuron_mask[z] = R > t
neuron_mask = remove_small_objects(neuron_mask, min_size=20000, connectivity=3)#2000
neuron_mask = binary_closing(neuron_mask, ball(8))#7
neuron_mask = binary_erosion(neuron_mask, ball(2))#1#2

# CELL (soma) via distance transform
dist = edt(neuron_mask)
cell_min_radius_vox = 1   # tune (e.g., 4–6) for thicker soma
cell_mask = (dist >= cell_min_radius_vox)
cell_mask &= neuron_mask
cell_mask = binary_opening(cell_mask, ball(3))#3
cell_mask = binary_closing(cell_mask, ball(8))#4

# Label all cells and partition territories
body_lab = label(cell_mask, connectivity=3)
n_cells = int(body_lab.max())
print(f"Detected {n_cells} cells (soma).")

if n_cells > 0:
    dist_inside = edt(neuron_mask)
    cell_seg = watershed(-dist_inside, markers=body_lab, mask=neuron_mask)
else:
    cell_seg = np.zeros_like(neuron_mask, dtype=np.int32)

print("neuron voxels:", int(neuron_mask.sum()))
print("cell voxels:", int(cell_mask.sum()))

# ==========================================
# ADD-ON: filter out oversized cell bodies from segmentation & visualization
# (Remove any cell with voxel_count > 20000)
# ==========================================
MAX_BODY_VOXELS = 0  # threshold in voxels #20000 NOT MAYBE <
try:
    if isinstance(cell_seg, np.ndarray) and cell_seg.max() > 0:
        #counts = np.bincount(cell_seg.ravel().astype(np.int64))
        drop_labels = np.where(counts = MAX_BODY_VOXELS)[0]
        #drop_labels = drop_labels[drop_labels != 0]  # ignore background

        #counts = np.bincount(cell_seg.ravel().astype(np.int64))
        
        # select labels with more than 560000 voxels OR between 44000 and 270000
        #drop_labels = np.where((counts > 560000) | ((counts > 44000) & (counts < 270000)))[0]
        #drop_labels = np.where(counts < 270000)[0]
        drop_labels = drop_labels[drop_labels != 0]  # remove background (label 0)

        if drop_labels.size > 0:
            print(f"Filtered {len(drop_labels)} oversized/intermediate cells. IDs: {drop_labels.tolist()}")
        else:
            print("No cells matched the voxel filter conditions.")

        if drop_labels.size > 0:
            removed = int(drop_labels.size)
            removed_voxels = int(counts[drop_labels].sum())
            to_remove = np.isin(cell_seg, drop_labels)

            # remove from segmentation and masks
            cell_seg[to_remove] = 0
            if 'cell_mask' in globals() and isinstance(cell_mask, np.ndarray):
                cell_mask[to_remove] = False
            # Optional: also exclude from neuron_mask (uncomment if desired)
            # if 'neuron_mask' in globals() and isinstance(neuron_mask, np.ndarray):
            #     neuron_mask[to_remove] = False

            print(f"Filtered {removed} oversized cells (> {MAX_BODY_VOXELS} voxels). "
                  f"Removed voxels: {removed_voxels}. IDs: {drop_labels.tolist()}")

            pd.DataFrame({
                "filtered_cell_id_ch2": drop_labels,
                "voxel_count": counts[drop_labels]
            }).to_csv("filtered_cells_gt_threshold.csv", index=False)
        else:
            print(f"No cells exceeded {MAX_BODY_VOXELS} voxels; nothing filtered.")
except Exception as e:
    print("Cell size filtering failed:", e)

# ==========================================
# 5) Map lysosomes to (cell/outside) with per-cell IDs
# ==========================================
location_ch2 = []
cell_id_list = []

if len(blobs) > 0:
    Z, Y, X = neuron_mask.shape
    for zc, yc, xc in blobs[:, :3]:
        zz, yy, xx = int(round(zc)), int(round(yc)), int(round(xc))
        if not (0 <= zz < Z and 0 <= yy < Y and 0 <= xx < X):
            location_ch2.append("outside"); cell_id_list.append(0); continue
        if cell_mask[zz, yy, xx]:
            location_ch2.append("cell")
            cid = int(cell_seg[zz, yy, xx]) if n_cells > 0 else 0
            cell_id_list.append(cid)
        else:
            location_ch2.append("outside")
            cell_id_list.append(0)

if len(df) > 0:
    df["location_ch2"] = location_ch2
    df["cell_id_ch2"]  = cell_id_list

    # Per-class counts
    df.groupby("location_ch2").size().reset_index(name="count") \
      .to_csv("lysosome_counts_cell_vs_outside.csv", index=False)

    # Per-cell counts (for those inside 'cell')
    (df[df["location_ch2"] == "cell"]
        .groupby("cell_id_ch2").size()
        .reset_index(name="count")
        .to_csv("lysosome_counts_by_cell.csv", index=False))

    # Full table (µm-only)
    df.to_csv("lysosomes_with_cell_vs_outside.csv", index=False)

print("Saved: lysosome_counts_cell_vs_outside.csv, lysosome_counts_by_cell.csv, lysosomes_with_cell_vs_outside.csv (µm-only metrics)")

# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# ADD-ON: include zero-count cells in lysosome_counts_by_cell (append-only)
# This overwrites lysosome_counts_by_cell.csv with an all-cells-inclusive table.
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
try:
    if isinstance(cell_seg, np.ndarray) and cell_seg.max() > 0:
        all_cells = pd.DataFrame({
            "cell_id_ch2": np.arange(1, int(cell_seg.max()) + 1, dtype=int)
        })

        if len(df) > 0 and "location_ch2" in df and "cell_id_ch2" in df:
            lys_counts_nonzero = (
                df[df["location_ch2"] == "cell"]
                .groupby("cell_id_ch2")
                .size()
                .reset_index(name="count")
            )
        else:
            lys_counts_nonzero = pd.DataFrame(columns=["cell_id_ch2", "count"])

        lys_counts_all = (
            all_cells.merge(lys_counts_nonzero, on="cell_id_ch2", how="left")
                     .fillna({"count": 0})
        )
        lys_counts_all["count"] = lys_counts_all["count"].astype(int)

        lys_counts_all.to_csv("lysosome_counts_by_cell.csv", index=False)
        print("Updated: lysosome_counts_by_cell.csv now includes cells with 0 lysosomes.")
except Exception as e:
    print("Could not expand lysosome_counts_by_cell with zero-count cells:", e)

# Define cells that contain lysosomes (for territory stamps)
ids_with_lyso = set()
if len(df) > 0 and "location_ch2" in df and "cell_id_ch2" in df:
    hits = df[(df["location_ch2"] == "cell") & (df["cell_id_ch2"] > 0)]
    ids_with_lyso = set(hits["cell_id_ch2"].astype(int).unique())

# ==========================================
# 5b) Per-cell (Ch2) volumes (µm^3)
# ==========================================
cell_volume_df = pd.DataFrame(columns=["cell_id_ch2", "voxel_count", "volume_um3"])
if isinstance(cell_seg, np.ndarray) and cell_seg.max() > 0:
    counts = np.bincount(cell_seg.ravel().astype(np.int64))
    cell_ids = np.arange(1, counts.size, dtype=int)
    voxels = counts[1:].astype(np.int64)
    vol_um3 = voxels.astype(float) * voxel_um3

    cell_volume_df = pd.DataFrame({
        "cell_id_ch2": cell_ids,
        "voxel_count": voxels,
        "volume_um3": vol_um3
    })
    cell_volume_df.to_csv("cell_volumes_ch2.csv", index=False)
    print("Saved: cell_volumes_ch2.csv")

    # Merge volume with lysosome counts per cell
    try:
        if len(df) > 0 and "location_ch2" in df and "cell_id_ch2" in df:
            lys_counts = (df[df["location_ch2"] == "cell"]
                          .groupby("cell_id_ch2")
                          .size()
                          .reset_index(name="lysosome_count"))
            merged = (cell_volume_df
                      .merge(lys_counts, on="cell_id_ch2", how="left")
                      .fillna({"lysosome_count": 0}))
            merged.to_csv("cell_metrics_ch2.csv", index=False)
            print("Saved: cell_metrics_ch2.csv")
    except Exception as e:
        print("Merge with lysosome counts failed:", e)

# ==========================================
# 6) Visualization (CELL vs OUTSIDE only)
# ==========================================
# Cell (green)
cell_layer = viewer.add_labels(cell_mask.astype(np.uint8), name='Cell (Ch2)', opacity=0.35)
try:
    cell_layer.color = {1: (0.0, 1.0, 0.0, 1.0)}  # green
except Exception:
    pass
cell_layer.blending = 'translucent_no_depth'

# Per-cell territories (labels)
try:
    cellid_layer = viewer.add_labels(
        cell_seg.astype(np.uint16),
        name='Cell ID (Ch2)',
        opacity=0.25
    )
    cellid_layer.blending = 'translucent_no_depth'

    # Optional boundaries overlay (magenta)
    boundaries = find_boundaries(cell_seg, connectivity=1, mode='outer')
    viewer.add_image(
        boundaries.astype(np.uint8),
        name='Cell ID boundaries',
        blending='additive',
        contrast_limits=(0, 1),
        colormap='magenta',
        opacity=0.6
    )
except Exception:
    pass

# Lysosomes colored (cyan=cell, white=outside)
if len(blobs) > 0:
    loc = np.array(df["location_ch2"].tolist()) if "location_ch2" in df else np.array([])
    colors = np.zeros((len(loc), 4), dtype=float)
    if loc.size > 0:
        colors[loc == "cell"]    = [0.0, 1.0, 1.0, 1.0]  # cyan
        colors[loc == "outside"] = [1.0, 1.0, 1.0, 1.0]  # white

    pts = viewer.add_points(
        blobs[:, :3] if len(blobs) > 0 else np.empty((0, 3)),
        size=np.clip(blobs[:, 3] * 2, 2, None) if len(blobs) > 0 else 2,
        name='Lysosomes (cell vs outside)'
    )
    try:
        if len(blobs) > 0 and loc.size > 0:
            pts.face_color = colors
            pts.edge_color = 'black'
            pts.edge_width = 0.3
            pts.properties = {
                'lys_id': df['id'].to_numpy(),
                'where':  df['location_ch2'].to_numpy(),
                'cell':   df['cell_id_ch2'].to_numpy(),
                'diameter_um': df['diameter_um'].to_numpy(),
                'volume_um3': df['volume_um3'].to_numpy()
            }
            pts.text = {'text': 'ID:{lys_id}  C:{cell}', 'size': 10, 'color': 'yellow', 'anchor': 'upper left'}
    except Exception:
        pass

# Title + stamps
try:
    viewer.title = "Neuron segmentation"
except Exception:
    pass

try:
    if isinstance(cell_seg, np.ndarray) and cell_seg.max() > 0 and isinstance(body_lab, np.ndarray) and body_lab.max() > 0:
        STAMP_TERRITORIES = True
        MAX_POINTS_PER_CELL = 60
        if STAMP_TERRITORIES:
            rng = np.random.default_rng(42)
            coords_all, texts_all, colors_all = [], [], []
            col_yes = np.array([1.0, 1.0, 0.0, 0.9])    # yellow
            col_no  = np.array([0.8, 0.8, 0.8, 0.85])   # gray

            for cid in range(1, int(cell_seg.max()) + 1):
                zz, yy, xx = np.where(cell_seg == cid)
                if zz.size == 0:
                    continue
                k = min(MAX_POINTS_PER_CELL, zz.size)
                idx = rng.choice(zz.size, size=k, replace=False)
                sample = np.stack([zz[idx], yy[idx], xx[idx]], axis=1)

                coords_all.append(sample)
                texts_all.extend([f"{cid}"] * k)
                colors_all.append(np.tile(col_yes if cid in ids_with_lyso else col_no, (k, 1)))

            if coords_all:
                coords_all = np.concatenate(coords_all, axis=0)
                colors_all = np.concatenate(colors_all, axis=0)
                terr = viewer.add_points(
                    coords_all.astype(float),
                    name="Cell ID territory stamps",
                    size=0.1,
                    face_color=[0, 0, 0, 0],
                    edge_color=[0, 0, 0, 0],
                    edge_width=0
                )
                terr.text = {"text": texts_all, "size": 10, "color": colors_all, "anchor": "center"}
                terr.blending = "translucent_no_depth"
except Exception as e:
    print("Label overlay error:", e)

# ==========================================
# 7) Quick fused 2D video (optional)
# ==========================================
img_norm_2 = (ch2 * 255).astype(np.uint8)
frames_fused = []
Z = img_norm_2.shape[0]
for z in range(Z):
    base = cv2.cvtColor(img_norm_2[z], cv2.COLOR_GRAY2BGR)
    cell = (cell_mask[z].astype(np.uint8) * 255)

    overlay = base.copy()
    overlay[..., 1] = np.maximum(overlay[..., 1], cell)  # green for cell
    overlay = cv2.addWeighted(base, 1.0, overlay, 0.35, 0.0)

    if len(blobs) > 0:
        z_blobs = blobs[np.abs(blobs[:, 0] - z) < 0.5]
        for b in z_blobs:
            y, x = int(round(b[1])), int(round(b[2]))
            r = max(2, int(round(b[3])))
            if 0 <= y < cell_mask.shape[1] and 0 <= x < cell_mask.shape[2] and cell_mask[z, y, x]:
                color = (255, 255, 0)  # yellow (cell)
            else:
                color = (255, 255, 255)  # white
            cv2.circle(overlay, (x, y), r, color, 2)

    frames_fused.append(overlay)

try:
    imageio.mimsave('ch2_fused_cell.mp4', frames_fused, fps=8, format='FFMPEG')
    print("Saved: ch2_fused_cell.mp4")
except TypeError:
    imageio.mimsave('ch2_fused_cell.gif', frames_fused, fps=8)
    print("Saved: ch2_fused_cell.gif")

# ==========================================
# 8) Save 3D screenshots (XY, YZ, XZ)
# ==========================================
def save_xy_3d_screenshot(viewer, path='ch2_segmentation_XY_3d.png'):
    viewer.dims.ndisplay = 3
    try:
        viewer.camera.angles = (90, 0, 0)
    except Exception:
        try:
            viewer.camera.elevation = 90
            viewer.camera.azimuth = 0
        except Exception:
            pass
    img_xy = viewer.screenshot(canvas_only=True)
    try:
        imageio.imwrite(path, img_xy)
    except Exception:
        cv2.imwrite(path, cv2.cvtColor(img_xy, cv2.COLOR_RGBA2BGRA))
    print(f"Saved 3D XY screenshot: {path}")

def save_yz_3d_screenshot(viewer, path='ch2_segmentation_YZ_3d.png'):
    viewer.dims.ndisplay = 3
    try:
        viewer.camera.angles = (0, 90, 0)
    except Exception:
        try:
            viewer.camera.elevation = 0
            viewer.camera.azimuth = 90
        except Exception:
            pass
    img_yz = viewer.screenshot(canvas_only=True)
    try:
        imageio.imwrite(path, img_yz)
    except Exception:
        cv2.imwrite(path, cv2.cvtColor(img_yz, cv2.COLOR_RGBA2BGRA))
    print(f"Saved 3D YZ screenshot: {path}")

def save_xz_3d_screenshot(viewer, path='ch2_segmentation_XZ_3d.png'):
    viewer.dims.ndisplay = 3
    try:
        viewer.camera.angles = (0, 0, 0)
    except Exception:
        try:
            viewer.camera.elevation = 0
            viewer.camera.azimuth = 0
        except Exception:
            pass
    img_xz = viewer.screenshot(canvas_only=True)
    try:
        imageio.imwrite(path, img_xz)
    except Exception:
        cv2.imwrite(path, cv2.cvtColor(img_xz, cv2.COLOR_RGBA2BGRA))
    print(f"Saved 3D XZ screenshot: {path}")

viewer.dims.ndisplay = 3
try:
    viewer.title = "CELL SEGMENTATION WITH LYSOSOMES"
except Exception:
    pass

save_xy_3d_screenshot(viewer, path='cells_segmentation_lysosomes_XY_3d.png')
save_yz_3d_screenshot(viewer, path='cells_segmentation_lysosomes_YZ_3d.png')
save_xz_3d_screenshot(viewer, path='cells_segmentation_lysosomes_XZ_3d.png')

# ==========================================
# 10) EXTRA: Export Original stack (Ch1) as MP4
# ==========================================
try:
    raw_stack = np.array(image, dtype=np.float32)
    raw_norm = (255 * (raw_stack - raw_stack.min()) / (raw_stack.ptp() + 1e-8)).astype(np.uint8)

    frames_raw = []
    for z in range(raw_norm.shape[0]):
        frame_gray = raw_norm[z]
        frame_bgr = cv2.cvtColor(frame_gray, cv2.COLOR_GRAY2BGR)
        frames_raw.append(frame_bgr)

    mp4_name = "original_raw_ch1.mp4"
    imageio.mimsave(mp4_name, frames_raw, fps=8, format='FFMPEG')
    print(f"Saved: {mp4_name}")
except Exception as e:
    print("MP4 export of original stack failed:", e)

# ==========================================
# 11) EXTRA: Side-by-side (RAW | SEGMENTED) MP4
# ==========================================
try:
    def to_uint8_grayscale(vol):
        vol = vol.astype(np.float32)
        vmin, vmax = float(vol.min()), float(vol.max())
        if vmax <= vmin:
            return (np.zeros_like(vol, dtype=np.uint8))
        return (255.0 * (vol - vmin) / (vmax - vmin)).astype(np.uint8)

    raw_stack_u8 = to_uint8_grayscale(image)      # left panel (raw Ch1)
    seg_base_u8  = to_uint8_grayscale(image_2)    # right base (Ch2)

    Z, H, W = raw_stack_u8.shape
    fps = 8

    out_name = "raw_vs_segmented_side_by_side.mp4"
    writer = imageio.get_writer(out_name, fps=fps, format="FFMPEG")

    for z in range(Z):
        left_bgr = cv2.cvtColor(raw_stack_u8[z], cv2.COLOR_GRAY2BGR)
        base_bgr = cv2.cvtColor(seg_base_u8[z], cv2.COLOR_GRAY2BGR)

        if 'cell_mask' in globals():
            cell = (cell_mask[z].astype(np.uint8) * 255)
            overlay = base_bgr.copy()
            overlay[..., 1] = np.maximum(overlay[..., 1], cell)
            right_bgr = cv2.addWeighted(base_bgr, 1.0, overlay, 0.35, 0.0)
        else:
            right_bgr = base_bgr

        if 'blobs' in globals() and len(blobs) > 0:
            z_blobs = blobs[np.abs(blobs[:, 0] - z) < 0.5]
            for b in z_blobs:
                y, x = int(round(b[1])), int(round(b[2]))
                r = max(2, int(round(b[3])))
                if 0 <= y < H and 0 <= x < W:
                    color = (255, 255, 0) if ('cell_mask' in globals() and cell_mask[z, y, x]) else (255, 255, 255)
                    cv2.circle(right_bgr, (x, y), r, color, 2)

        if left_bgr.shape != right_bgr.shape:
            right_bgr = cv2.resize(right_bgr, (left_bgr.shape[1], left_bgr.shape[0]),
                                   interpolation=cv2.INTER_NEAREST)
        divider = np.full((left_bgr.shape[0], 4, 3), 32, dtype=np.uint8)
        frame = cv2.hconcat([left_bgr, divider, right_bgr])

        writer.append_data(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

    writer.close()
    print(f"Saved: {out_name}")

except Exception as e:
    try:
        print("FFMPEG writer failed; attempting GIF fallback. Error:", e)
        frames = []
        for z in range(raw_stack_u8.shape[0]):
            left_bgr = cv2.cvtColor(raw_stack_u8[z], cv2.COLOR_GRAY2BGR)
            base_bgr = cv2.cvtColor(seg_base_u8[z], cv2.COLOR_GRAY2BGR)
            if 'cell_mask' in globals():
                cell = (cell_mask[z].astype(np.uint8) * 255)
                overlay = base_bgr.copy()
                overlay[..., 1] = np.maximum(overlay[..., 1], cell)
                right_bgr = cv2.addWeighted(base_bgr, 1.0, overlay, 0.35, 0.0)
            else:
                right_bgr = base_bgr

            if 'blobs' in globals() and len(blobs) > 0:
                z_blobs = blobs[np.abs(blobs[:, 0] - z) < 0.5]
                for b in z_blobs:
                    y, x = int(round(b[1])), int(round(b[2]))
                    r = max(2, int(round(b[3])))
                    if 0 <= y < H and 0 <= x < W:
                        color = (255, 255, 0) if ('cell_mask' in globals() and cell_mask[z, y, x]) else (255, 255, 255)
                        cv2.circle(right_bgr, (x, y), r, color, 2)

            if left_bgr.shape != right_bgr.shape:
                right_bgr = cv2.resize(right_bgr, (left_bgr.shape[1], left_bgr.shape[0]),
                                       interpolation=cv2.INTER_NEAREST)
            divider = np.full((left_bgr.shape[0], 4, 3), 32, dtype=np.uint8)
            frame = cv2.hconcat([left_bgr, divider, right_bgr])
            frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

        imageio.mimsave("raw_vs_segmented_side_by_side.gif", frames, fps=8)
        print("Saved: raw_vs_segmented_side_by_side.gif")
    except Exception as e2:
        print("Side-by-side export failed:", e2)

# ==========================================
# 12) Run viewer
# ==========================================
napari.run()