# Filtered Matching + Publish (confocal ↔ 2P)

Self-contained QC notebook to: (1) gather helpers in one cell, (2) run matching with distance + overlap gating, optional 1–1 on the 2P side, (3) compute QC metrics (IoU, splits/merges), and (4) publish filtered outputs under legacy variable names so your plotting cells pick them up.

Usage:
- Adjust paths and thresholds in the Config cell.
- Run cells top→bottom; plotting cells can live below and use the published variables (e.g., `pairs`, `match_df`, `summary_stats`).
- Inputs are read from disk; no dependency on the original notebook's kernel state.


In [None]:
# --- shared functions (auto) ---
from tools.paths import *
from tools.io_images import *
from tools.hcr_channels import *
from tools.image_ops import *
from tools.labels import *
from tools.transforms import *
from tools.qc_plots import *


In [None]:
# Default QC output directory
from pathlib import Path
QC_OUTPUT_DIR = Path('/Users/ddharmap/dataProcessing/2p_HCR/analysis/qc/output')
QC_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print('QC_OUTPUT_DIR =', QC_OUTPUT_DIR)


In [None]:

# --- Config: file paths, spacings, transforms, and thresholds ---
from pathlib import Path

# Try to auto-detect confocal labels in 2P grid from common filenames in CWD
CANDIDATE_CONF_LABELS = [
    'confocal_labels_in_2P_space_labels_uint16.tif',
    'stack_confocal_labels_in_2P_space_labels_uint16.tif'
]
CONF_LABELS_2P_PATH = next((str(Path(p)) for p in CANDIDATE_CONF_LABELS if Path(p).exists()), None)

# 2P segmentation path (Cellpose *_seg.npy). Set to your dataset.
TWOP_SEG_NPY = globals().get('TWOP_SEG_NPY', '/Users/ddharmap/dataProcessing/2p_HCR/data/L427_f02/L427_f02_anatomy_2P_cort_seg_anis1.npy')

# Optional: original-space confocal label TIFF and intensities (for warping)
# If set and ANTsPy is available, the notebook can warp the confocal mask to 2P and cache results.
CONF_LABEL_TIFF = "/Volumes/jlarsch/default/D2c/07_Data/Danin/Cellpose/trainingCort/hcr_test/L427_f02_round1_channel2_cort_gauss_cp_masks.tif"
CONF_NRRD_PATH = "/Volumes/jlarsch/default/D2c/07_Data/Matilde/Microscopy/L427_f02/02_reg/00_preprocessing/r1/L427_f02_round1_channel2_cort.nrrd"
TWOP_NRRD_PATH = "/Volumes/jlarsch/default/D2c/07_Data/Matilde/Microscopy/L427_f02/02_reg/00_preprocessing/2p_anatomy/L427_f02_anatomy_2P_cort.nrrd"
WARP_PATH = "/Volumes/jlarsch/default/D2c/07_Data/Matilde/Microscopy/L427_f02/02_reg/01_r1-2p/transMatrices/L427_f02_round1_GCaMP_to_ref1Warp.nii.gz"
AFFINE_PATH = "/Volumes/jlarsch/default/D2c/07_Data/Matilde/Microscopy/L427_f02/02_reg/01_r1-2p/transMatrices/L427_f02_round1_GCaMP_to_ref0GenericAffine.mat"
INVERT_AFFINE   = bool(globals().get('INVERT_AFFINE', False))

# Warp caching
FORCE_RECOMPUTE_WARP = bool(globals().get('FORCE_RECOMPUTE_WARP', False))
WARP_CACHE_BASENAME = QC_OUTPUT_DIR / 'confocal_labels_in_2P_space'

# 2P voxel spacings (µm)
VOX_2P = {"dz": 2.0, "dy": 0.6506220, "dx": 0.6506220}
VOX_CONF = {"dz": 1.0, "dy": 0.20756645602494875, "dx": 0.20756645602494875}
# Flip this to True to read spacings from the NRRDs via ANTs (slow on network paths).
USE_IMAGE_SPACING = False  # keeps notebook fast when False

# Matching + gating
MATCH_METHOD = globals().get('MATCH_METHOD', 'nn')     # 'nn' or 'hungarian'
MAX_DISTANCE_UM = float(globals().get('MAX_DISTANCE_UM', 10))
REQUIRE_OVERLAP = bool(globals().get('REQUIRE_OVERLAP', True))
MIN_OVERLAP_VOXELS = int(globals().get('MIN_OVERLAP_VOXELS', 1))
DEDUP_BY_TWOP = globals().get('DEDUP_BY_TWOP', 'max_overlap')  # {'none','closest','max_overlap'}

# What downstream plots should consume
PLOT_USE_FINAL_1TO1 = bool(globals().get('PLOT_USE_FINAL_1TO1', True))

print('CONF_LABELS_2P_PATH =', CONF_LABELS_2P_PATH)
print('TWOP_SEG_NPY         =', TWOP_SEG_NPY)
print('FORCE_RECOMPUTE_WARP =', FORCE_RECOMPUTE_WARP)
print('WARP_CACHE_BASENAME  =', WARP_CACHE_BASENAME)


In [None]:
# --- Evaluation datasets (pre-warped confocal in 2P space) ---
# Only external datasets here; the current ANTs run is added automatically from memory.
EVAL_DATASETS = {
    # Example placeholders — update to your files
    'bigwarp':  {'name': 'BigWarp',  'conf_labels_2p_path': '/Users/ddharmap/dataProcessing/cellpose/trainingCort/manualBW/L427_f02_round1_channel2_cort_gauss_cp_masks_BW_in_2P.tif'},
    'baseline': {'name': 'Baseline', 'conf_labels_2p_path': '/Users/ddharmap/dataProcessing/cellpose/trainingCort/manualBW/L427_f02_round1_channel2_cort_gauss_cp_masks_untransformed_in_2P.tif'},
}


In [None]:
def _detect_spacing_from_nrrd(path):
    try:
        import ants  # type: ignore
    except Exception:
        return None
    from pathlib import Path
    if path is None or not Path(path).exists():
        return None
    img = ants.image_read(path)
    sp = tuple(float(s) for s in img.spacing)  # (dx, dy, dz) or (dx, dy)
    if len(sp) == 3:
        return {"dx": sp[0], "dy": sp[1], "dz": sp[2]}
    if len(sp) == 2:
        return {"dx": sp[0], "dy": sp[1]}
    return None

def _merge_spacing(current: dict, detected: dict | None) -> tuple[dict, bool]:
    if detected is None:
        return current, False
    new = dict(current)
    for k in ("dz", "dy", "dx"):
        if k in detected:
            new[k] = detected[k]
    return new, True

if USE_IMAGE_SPACING:
    vox2p_det = _detect_spacing_from_nrrd(TWOP_NRRD_PATH)
    VOX_2P, ch2p = _merge_spacing(VOX_2P, vox2p_det)

    voxconf_det = _detect_spacing_from_nrrd(CONF_NRRD_PATH)
    VOX_CONF, chconf = _merge_spacing(VOX_CONF, voxconf_det)

    print("VOX_2P  =", VOX_2P,  "(detected from NRRD)" if ch2p else "(kept config)")
    print("VOX_CONF=", VOX_CONF, "(detected from NRRD)" if chconf else "(kept config)")
else:
    print("Using configured spacings (set USE_IMAGE_SPACING=True to detect):")
    print("VOX_2P  =", VOX_2P)
    print("VOX_CONF=", VOX_CONF)

In [None]:
# --- Helpers: imports, loaders, matching, overlap, QC ---
import json, math
import numpy as np
import pandas as pd
from scipy.spatial import cKDTree
from scipy.optimize import linear_sum_assignment
from skimage.measure import regionprops_table
from IPython.display import display, HTML
from pathlib import Path
import tifffile as tiff

def load_labels_any(path: str) -> np.ndarray:
    assert path is not None and Path(path).exists(), f'Label file not found: {path}'
    if path.endswith('.npy') or path.endswith('.npz'):
        obj = np.load(path, allow_pickle=True)
        if isinstance(obj, np.lib.npyio.NpzFile):
            # Try common keys
            for k in ('masks','labels','arr_0'):
                if k in obj: return np.asarray(obj[k])
            raise RuntimeError(f'Unsupported npz structure in {path}')
        else:
            arr = obj
            if isinstance(arr, np.ndarray):
                # Some Cellpose *_seg.npy are dicts; handle both
                if arr.dtype == object and arr.shape == () and isinstance(arr.item(), dict):
                    d = arr.item()
                    for k in ('masks','labels'):
                        if k in d: return np.asarray(d[k])
                    raise RuntimeError('Dict npy has no masks/labels key')
                return arr
            # Fallback if np.load returns a Python object (rare)
            try:
                d = arr.item()
                for k in ('masks','labels'):
                    if k in d: return np.asarray(d[k])
            except Exception:
                pass
            raise RuntimeError(f'Unsupported npy content in {path}')
    elif path.endswith('.tif') or path.endswith('.tiff'):
        return tiff.imread(path)
    else:
        raise RuntimeError(f'Unsupported label format: {path}')

def compute_centroids(mask: np.ndarray) -> pd.DataFrame:
    props = regionprops_table(mask, properties=('label','centroid'))
    df = pd.DataFrame(props)
    # regionprops_table returns centroid-0 (z), centroid-1 (y), centroid-2 (x)
    df = df.rename(columns={'centroid-0':'z','centroid-1':'y','centroid-2':'x'})
    df = df[df['label'] != 0].reset_index(drop=True)
    return df

def idx_to_um(df: pd.DataFrame, vox: dict) -> np.ndarray:
    return np.column_stack([df['z'].to_numpy()*vox['dz'],
                           df['y'].to_numpy()*vox['dy'],
                           df['x'].to_numpy()*vox['dx']])

def nearest_neighbor_match(P_src_um: np.ndarray, P_dst_um: np.ndarray):
    tree = cKDTree(P_dst_um)
    dists, nn = tree.query(P_src_um, k=1)
    return dists, nn

def hungarian_match(P_src_um: np.ndarray, P_dst_um: np.ndarray, max_cost=np.inf):
    # Compute cost matrix lazily in blocks if needed; for moderate sizes do dense
    from scipy.spatial.distance import cdist
    C = cdist(P_src_um, P_dst_um)
    if np.isfinite(max_cost):
        C[C > max_cost] = max_cost
    row_ind, col_ind = linear_sum_assignment(C)
    dists = C[row_ind, col_ind]
    return dists, col_ind, row_ind

def compute_label_overlap(conf_labels_2p: np.ndarray, twop_labels: np.ndarray, min_overlap_voxels=1) -> pd.DataFrame:
    assert conf_labels_2p.shape == twop_labels.shape, 'Label volumes must share shape'
    a = conf_labels_2p.ravel()
    b = twop_labels.ravel()
    # Exclude background early
    m = (a != 0) & (b != 0)
    if not m.any():
        return pd.DataFrame(columns=['conf_label','twoP_label','overlap_voxels'], dtype=int)
    a = a[m].astype(np.int64, copy=False)
    b = b[m].astype(np.int64, copy=False)
    # Combine pairs into a single 64-bit key (safe for uint32 labels)
    key = (a << 32) | b
    uniq, counts = np.unique(key, return_counts=True)
    conf = (uniq >> 32).astype(np.int64)
    twop = (uniq & ((1<<32)-1)).astype(np.int64)
    df = pd.DataFrame({'conf_label': conf, 'twoP_label': twop, 'overlap_voxels': counts.astype(int)})
    if min_overlap_voxels > 1:
        df = df[df['overlap_voxels'] >= int(min_overlap_voxels)].reset_index(drop=True)
    return df

def summarize_distances(dists: np.ndarray, valid_mask: np.ndarray) -> dict:
    dists = np.asarray(dists)
    valid_mask = np.asarray(valid_mask, dtype=bool)
    if dists.size == 0:
        return {
            'n': 0, 'mean': 0.0, 'median': 0.0, 'p90': 0.0, 'max': 0.0,
            'within_gate': 0, 'within_gate_frac': 0.0
        }
    return {
        'n': int(dists.size),
        'mean': float(np.mean(dists)),
        'median': float(np.median(dists)),
        'p90': float(np.percentile(dists, 90)),
        'max': float(np.max(dists)),
        'within_gate': int(valid_mask.sum()),
        'within_gate_frac': float(valid_mask.mean())
    }

def display_scrollable(df: pd.DataFrame, max_h=600):
    html = df.to_html(index=False).replace('<table', f'<table style="display:block; max-height:{max_h}px; overflow-y:auto; width:100%;"')
    display(HTML(html))

# --- ANTs (optional) and warp helpers ---
try:
    import ants  # type: ignore
    HAVE_ANTSPY = True
except Exception as e:  # pragma: no cover
    HAVE_ANTSPY = False
    print('ANTsPy not available; warp/points transform cells will be disabled. ', e)

def fs_info(path: str) -> dict:
    import os, time
    exists = os.path.exists(path) if path else False
    size_b = os.path.getsize(path) if exists else None
    mtime = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(os.path.getmtime(path))) if exists else None
    return {'exists': exists, 'size_bytes': size_b, 'size_MB': (size_b/1e6 if size_b else None), 'modified': mtime}

def _ants_clone_geometry(dst_img, like_img):
    dst_img.set_spacing(like_img.spacing)
    dst_img.set_origin(like_img.origin)
    dst_img.set_direction(like_img.direction)
    return dst_img

def warp_label_tiff_with_ants(tiff_path, mov_img_int, fix_img_int, warp_path, affine_path, *, vox_moving=None, save_basename=None):
    import numpy as np
    import tifffile as tiff
    from pathlib import Path
    tiff_path = Path(tiff_path)
    lab_zyx = tiff.imread(str(tiff_path))
    if lab_zyx.ndim != 3:
        raise ValueError(f'Expected 3D TIFF (Z,Y,X); got shape {lab_zyx.shape}')
    lab_xyz = np.transpose(lab_zyx, (2, 1, 0)).astype(np.int32, copy=False)
    mov_label_img = ants.from_numpy(lab_xyz)
    _ants_clone_geometry(mov_label_img, mov_img_int)
    if mov_label_img.shape != mov_img_int.shape:
        raise RuntimeError(f'Label XYZ shape {mov_label_img.shape} != moving intensity XYZ shape {mov_img_int.shape}.')
    transformlist = [str(warp_path), str(affine_path)]
    whichtoinvert = [False, bool(INVERT_AFFINE)]
    warped_xyz = ants.apply_transforms(
        fixed=fix_img_int,
        moving=mov_label_img,
        transformlist=transformlist,
        whichtoinvert=whichtoinvert,
        interpolator='nearestNeighbor',
    ).numpy().astype(np.int32, copy=False)
    warped_zyx = np.transpose(warped_xyz, (2, 1, 0))
    if save_basename is None:
        save_basename = str(tiff_path.with_suffix('')) + '_in_fixed'
    out = {}
    max_id = int(warped_zyx.max())
    nz = int((warped_zyx > 0).sum())
    print(f'[warp_label_tiff_with_ants] warped shape(ZYX)={warped_zyx.shape} maxID={max_id} nonzero={nz}')
    import numpy as _np
    if max_id <= 65535:
        out_tif = save_basename + '_labels_uint16.tif'
        tiff.imwrite(out_tif, warped_zyx.astype(_np.uint16))
        out['tif'] = out_tif
    else:
        out_npy = save_basename + '_labels_int32.npy'
        _np.save(out_npy, warped_zyx)
        out['npy'] = out_npy
        warped_img = ants.from_numpy(np.transpose(warped_zyx, (2, 1, 0)))
        _ants_clone_geometry(warped_img, fix_img_int)
        out_nii = save_basename + '_labels_int32.nii.gz'
        ants.image_write(warped_img, out_nii)
        out['nii'] = out_nii
    return warped_zyx, out

def warp_cache_candidates(save_basename: str):
    from pathlib import Path
    # Prefer *.npy (faster load) over TIFF if both exist
    return [Path(f'{save_basename}_labels_int32.npy'), Path(f'{save_basename}_labels_uint16.tif')]

def load_cached_warp(save_basename: str):
    import numpy as np
    import tifffile as tiff
    for c in warp_cache_candidates(save_basename):
        if c.exists():
            if c.suffix == '.npy':
                return np.load(c, mmap_mode='r'), c
            if c.suffix == '.tif':
                return tiff.imread(str(c)), c
    return None, None

def warp_metadata_path(save_basename: str) -> Path:
    return Path(f'{save_basename}_warp_meta.json')

def read_warp_metadata(save_basename: str):
    p = warp_metadata_path(save_basename)
    if not p.exists():
        return None
    return json.loads(p.read_text())

def write_warp_metadata(save_basename: str, metadata: dict) -> None:
    p = warp_metadata_path(save_basename)
    p.write_text(json.dumps(metadata, indent=2))

def warp_inputs_fingerprint() -> dict:
    def stamp(path: str) -> dict:
        info = fs_info(path)
        return {'path': path, 'exists': info['exists'], 'size_bytes': info['size_bytes'], 'modified': info['modified']}
    fp = {
        'CONF_LABEL_TIFF': stamp(CONF_LABEL_TIFF) if CONF_LABEL_TIFF else None,
        'CONF_NRRD_PATH':  stamp(CONF_NRRD_PATH)  if CONF_NRRD_PATH  else None,
        'TWOP_NRRD_PATH':  stamp(TWOP_NRRD_PATH)  if TWOP_NRRD_PATH  else None,
        'WARP_PATH':       stamp(WARP_PATH)       if WARP_PATH       else None,
        'AFFINE_PATH':     stamp(AFFINE_PATH)     if AFFINE_PATH     else None,
        'INVERT_AFFINE':   bool(INVERT_AFFINE),
    }
    if HAVE_ANTSPY:
        fp['ANTsPy_version'] = getattr(ants, '__version__', None)
    return fp

def apply_transforms_to_points_um(points_um: np.ndarray, mov_spacing_um: dict, fix_spacing_um: dict, *, warp_path: str, affine_path: str, invert_affine: bool = False):
    import numpy as np
    if not HAVE_ANTSPY:
        raise RuntimeError('ANTsPy required for apply_transforms_to_points.')
    # µm -> index in moving
    x = points_um[:, 2] / mov_spacing_um['dx']
    y = points_um[:, 1] / mov_spacing_um['dy']
    z = points_um[:, 0] / mov_spacing_um['dz']
    df_idx = pd.DataFrame({'x': x, 'y': y, 'z': z})
    transformlist = [str(warp_path), str(affine_path)]
    whichtoinvert = [False, bool(invert_affine)]
    fixed_idx = ants.apply_transforms_to_points(3, df_idx, transformlist, whichtoinvert=whichtoinvert)
    # index -> µm in fixed
    zu = fixed_idx['z'].to_numpy() * fix_spacing_um['dz']
    yu = fixed_idx['y'].to_numpy() * fix_spacing_um['dy']
    xu = fixed_idx['x'].to_numpy() * fix_spacing_um['dx']
    return np.column_stack([zu, yu, xu])


In [None]:

# --- Warp confocal labels → 2P (cache-aware) ---
# Fast path: use cached warp without reading large NRRDs
cache_fp = warp_inputs_fingerprint()
cached_labels, cached_source = load_cached_warp(WARP_CACHE_BASENAME)
use_cache = (cached_labels is not None) and (not FORCE_RECOMPUTE_WARP)
if use_cache:
    conf_labels_2p = cached_labels
    print(f'Loaded cached warp from {cached_source}')
elif HAVE_ANTSPY and all(v is not None for v in [CONF_LABEL_TIFF, CONF_NRRD_PATH, TWOP_NRRD_PATH, WARP_PATH, AFFINE_PATH]):
    # Slow path: need to run ANTs warp
    mov_img_int = ants.image_read(CONF_NRRD_PATH)
    fix_img_int = ants.image_read(TWOP_NRRD_PATH)
    conf_labels_2p, out_paths = warp_label_tiff_with_ants(
        CONF_LABEL_TIFF, mov_img_int, fix_img_int,
        warp_path=WARP_PATH, affine_path=AFFINE_PATH,
        save_basename=str(WARP_CACHE_BASENAME),
    )
    write_warp_metadata(WARP_CACHE_BASENAME, cache_fp)
    print('Saved warp artifacts:', out_paths)
else:
    if not HAVE_ANTSPY:
        print('ANTsPy not available; set CONF_LABELS_2P_PATH to pre-warped labels to proceed.')
    else:
        print('Warp inputs incomplete; using CONF_LABELS_2P_PATH if provided.')


In [None]:
# --- Load label volumes and compute centroids ---
# Prefer conf_labels_2p from the warp cell; otherwise load from disk if path is provided.
if 'conf_labels_2p' not in globals():
    # Prefer conf_labels_2p from the warp cell; otherwise load from disk if path is provided.
    assert CONF_LABELS_2P_PATH is not None, 'Could not auto-detect confocal labels in 2P grid; set CONF_LABELS_2P_PATH.'
    conf_labels_2p = load_labels_any(CONF_LABELS_2P_PATH)
    conf_labels_2p = load_labels_any(CONF_LABELS_2P_PATH)
    conf_labels_2p = load_labels_any(CONF_LABELS_2P_PATH)
masks_2p = load_labels_any(TWOP_SEG_NPY)
assert conf_labels_2p.shape == masks_2p.shape, f'Shape mismatch: {conf_labels_2p.shape} vs {masks_2p.shape}'

df_conf = compute_centroids(conf_labels_2p)
df_2p   = compute_centroids(masks_2p)

P_conf_in_2p_um = idx_to_um(df_conf, VOX_2P)
P_2p_um         = idx_to_um(df_2p, VOX_2P)

print(f'Loaded confocal labels (warped→2P): {conf_labels_2p.shape}, dtype={conf_labels_2p.dtype}')
print(f'Loaded 2P labels: {masks_2p.shape}, dtype={masks_2p.dtype}')
print(f'Centroids: conf={P_conf_in_2p_um.shape[0]} | 2P={P_2p_um.shape[0]}')


In [None]:
# --- Matching + distance gate + overlap + optional 1–1 by 2P ---
labels_conf = df_conf['label'].to_numpy()
labels_2p   = df_2p['label'].to_numpy()

if MATCH_METHOD == 'nn':
    dists, nn = nearest_neighbor_match(P_conf_in_2p_um, P_2p_um)
    matched_twoP_labels = labels_2p[nn]
    matched_conf_labels = labels_conf
elif MATCH_METHOD == 'hungarian':
    dists, col_ind, row_ind = hungarian_match(P_conf_in_2p_um, P_2p_um, max_cost=np.inf)
    matched_conf_labels = labels_conf[row_ind]
    matched_twoP_labels = labels_2p[col_ind]
else:
    raise ValueError('MATCH_METHOD must be "nn" or "hungarian"')

valid = dists <= float(MAX_DISTANCE_UM)
matches = pd.DataFrame({
    'conf_label': matched_conf_labels,
    'twoP_label': matched_twoP_labels,
    'distance_um': dists,
    'within_gate': valid
}).sort_values('distance_um', ascending=True).reset_index(drop=True)

# Resolve label arrays for overlap
conf_warped_labels = conf_labels_2p
twoP_labels = masks_2p

if REQUIRE_OVERLAP or DEDUP_BY_TWOP in {'closest','max_overlap'}:
    overlap_df = compute_label_overlap(conf_warped_labels, twoP_labels, min_overlap_voxels=int(MIN_OVERLAP_VOXELS))
    matches = matches.merge(overlap_df[['conf_label','twoP_label','overlap_voxels']],
                           on=['conf_label','twoP_label'], how='left')
    matches['overlap_voxels'] = matches['overlap_voxels'].fillna(0).astype(int)

    if REQUIRE_OVERLAP:
        matches['within_gate'] = matches['within_gate'] & (matches['overlap_voxels'] >= int(MIN_OVERLAP_VOXELS))

    if DEDUP_BY_TWOP in {'closest','max_overlap'}:
        m = matches['within_gate'].to_numpy()
        if m.any():
            sub = matches.loc[m].copy()
            if DEDUP_BY_TWOP == 'closest':
                sub = sub.sort_values(['twoP_label','distance_um'], ascending=[True, True])
            else:
                sub = sub.sort_values(['twoP_label','overlap_voxels','distance_um'], ascending=[True, False, True])
            keep_idx = sub.drop_duplicates(subset=['twoP_label'], keep='first').index
            drop_idx = sub.index.difference(keep_idx)
            if len(drop_idx) > 0:
                matches.loc[drop_idx, 'within_gate'] = False

matches = matches.sort_values('distance_um', ascending=True).reset_index(drop=True)

summary = summarize_distances(matches['distance_um'].to_numpy(), matches['within_gate'].to_numpy())
print('Summary:', json.dumps(summary, indent=2))
display_scrollable(matches)


In [None]:
# --- Post-matching QC (IoU-based) ---
# Defines: final_pairs (high-confidence 1–1), review (needs attention), and prints a QC summary.
# Notes:
# - IOU_MIN = 0.05 means intersection is at least 5% of the union (symmetric).
# - To require side-specific coverage, enable USE_FRAC_FILTERS and set MIN_OVERLAP_FRAC_*.

# Thresholds
IOU_MIN = 0.05            # intersection-over-union >= 5% of the union
USE_FRAC_FILTERS = False  # also require side-specific fractions?
MIN_OVERLAP_FRAC_CONF = 0.0
MIN_OVERLAP_FRAC_TWOP = 0.0

import numpy as np
import pandas as pd
try:
    from IPython.display import display
except Exception:
    display = None  # if not in a notebook

# Resolve label volumes in 2P space
conf_warped_labels = globals().get('conf_labels_2p', globals().get('conf_warped_labels'))
twoP_labels        = globals().get('masks_2p',      globals().get('twoP_labels'))
assert conf_warped_labels is not None and twoP_labels is not None, \
    "Need `conf_labels_2p` (or conf_warped_labels) and `masks_2p` (or twoP_labels)."
assert conf_warped_labels.shape == twoP_labels.shape, \
    f"Shape mismatch: {conf_warped_labels.shape} vs {twoP_labels.shape}"

# Fallback for compute_label_overlap if not defined earlier
if 'compute_label_overlap' not in globals():
    def compute_label_overlap(conf, twop, min_overlap_voxels=1):
        conf = np.asarray(conf); twop = np.asarray(twop)
        a = conf.ravel(); b = twop.ravel()
        m = (a != 0) & (b != 0)
        if not m.any():
            return pd.DataFrame(columns=['conf_label','twoP_label','overlap_voxels'], dtype=int)
        a = a[m].astype(np.int64, copy=False); b = b[m].astype(np.int64, copy=False)
        key = (a << 32) | b
        uniq, counts = np.unique(key, return_counts=True)
        return pd.DataFrame({
            'conf_label': (uniq >> 32).astype(np.int64),
            'twoP_label': (uniq & ((1<<32)-1)).astype(np.int64),
            'overlap_voxels': counts.astype(int)
        })

# Ensure matches table exists
assert 'matches' in globals(), "Expected `matches` DataFrame from matching step."
assert 'within_gate' in matches.columns, "Expected `within_gate` column in matches."

# Ensure overlap_voxels present on matches
if 'overlap_voxels' not in matches.columns:
    ov = compute_label_overlap(conf_warped_labels, twoP_labels, min_overlap_voxels=1)
    matches = matches.merge(
        ov[['conf_label','twoP_label','overlap_voxels']],
        on=['conf_label','twoP_label'], how='left'
    )
    matches['overlap_voxels'] = matches['overlap_voxels'].fillna(0).astype(int)

# Per-label volumes
def label_volumes(arr):
    labels, counts = np.unique(arr, return_counts=True)
    s = pd.Series(counts, index=labels)
    return s.drop(index=0, errors='ignore').astype(int)

conf_vol_s = label_volumes(conf_warped_labels)
twoP_vol_s = label_volumes(twoP_labels)

matches['conf_vol'] = matches['conf_label'].map(conf_vol_s).fillna(0).astype(int)
matches['twoP_vol'] = matches['twoP_label'].map(twoP_vol_s).fillna(0).astype(int)

# IoU and overlap fractions
den = matches['conf_vol'] + matches['twoP_vol'] - matches['overlap_voxels']
matches['iou'] = np.divide(
    matches['overlap_voxels'], den,
    out=np.zeros_like(den, dtype=float), where=(den > 0)
)
matches['overlap_frac_conf'] = np.divide(
    matches['overlap_voxels'], matches['conf_vol'],
    out=np.zeros_like(matches['conf_vol'], dtype=float), where=(matches['conf_vol'] > 0)
)
matches['overlap_frac_twoP'] = np.divide(
    matches['overlap_voxels'], matches['twoP_vol'],
    out=np.zeros_like(matches['twoP_vol'], dtype=float), where=(matches['twoP_vol'] > 0)
)

# Topology among accepted (within_gate = True)
acc = matches.loc[matches['within_gate']].copy()
conf_counts = acc['conf_label'].value_counts()
twop_counts = acc['twoP_label'].value_counts()

def _pair_type(row):
    if not row['within_gate']:
        return 'rejected'
    cm = int(conf_counts.get(row['conf_label'], 0))
    tm = int(twop_counts.get(row['twoP_label'], 0))
    if cm == 1 and tm == 1: return '1-1'
    if cm > 1 and tm == 1:  return 'merge'   # many conf -> one 2P
    if cm == 1 and tm > 1:  return 'split'   # one conf -> many 2P
    return 'complex'        # many-to-many

matches['pair_type'] = matches.apply(_pair_type, axis=1)

# Quality label (IoU; optional side-specific fractions)
if USE_FRAC_FILTERS:
    ok_frac = (
        (matches['overlap_frac_conf'] >= float(MIN_OVERLAP_FRAC_CONF)) &
        (matches['overlap_frac_twoP'] >= float(MIN_OVERLAP_FRAC_TWOP))
    )
else:
    ok_frac = True

matches['quality'] = np.where(
    matches['within_gate'] & (matches['iou'] >= float(IOU_MIN)) & ok_frac, 'good',
    np.where(matches['within_gate'], 'iffy', 'rejected')
)

# Final outputs
final_pairs = matches[
    (matches['pair_type'] == '1-1') & (matches['quality'] == 'good')
].copy().sort_values(['distance_um','twoP_label'])

review = matches[
    ((matches['pair_type'].isin(['split','merge','complex'])) & matches['within_gate']) |
    ((matches['pair_type'] == '1-1') & (matches['quality'] != 'good'))
].sort_values(['pair_type','iou','distance_um'], ascending=[True, False, True]).copy()

# Disambiguated QC summary
gate_mask = (matches['distance_um'] <= float(globals().get('MAX_DISTANCE_UM', 5.0)))
if bool(globals().get('REQUIRE_OVERLAP', False)):
    gate_mask = gate_mask & (matches['overlap_voxels'] >= int(globals().get('MIN_OVERLAP_VOXELS', 1)))

qc = pd.Series({
    'gate_only_pairs': int(gate_mask.sum()),
    'accepted_pairs_after_dedup': int(matches['within_gate'].sum()),
    'final_1to1_good': int(final_pairs.shape[0]),
    'splits_among_accepted': int((matches['pair_type'] == 'split').sum()),
    'merges_among_accepted': int((matches['pair_type'] == 'merge').sum()),
    'complex_among_accepted': int((matches['pair_type'] == 'complex').sum()),
}, name='QC Summary')

print(qc)
print("\nfinal_pairs (head):")
display(final_pairs.head(20)) if display else print(final_pairs.head(20).to_string(index=False))
print("\nreview (head):")
display(review.head(20)) if display else print(review.head(20).to_string(index=False))


In [None]:
# --- Freeze evaluation pairs (final 1–1 good) ---
assert 'final_pairs' in globals(), 'Run the QC cell first to create final_pairs.'
# Only the label mapping is needed to evaluate across datasets
eval_pairs = final_pairs[['conf_label','twoP_label']].copy()
print(f'Frozen evaluation pairs: {len(eval_pairs)}')


In [None]:
# --- Precompute centroids per dataset (confocal in 2P space) ---
import numpy as np
import pandas as pd
from pathlib import Path
import tifffile as tiff

# Helper: generic label loader
def _load_labels_any(path: str):
    p = str(path)
    if p.endswith(('.tif', '.tiff')):
        return tiff.imread(p)
    if p.endswith('.npy'):
        return np.load(p)
    if p.endswith('.npz'):
        obj = np.load(p, allow_pickle=True)
        for k in ('masks','labels','arr_0'):
            if k in obj: return obj[k]
    raise RuntimeError(f'Unsupported label format: {p}')

# Build 2P coordinate LUT once
assert 'df_2p' in globals() and 'P_2p_um' in globals(), 'Need 2P centroids and coords.'
_twoP_lut = dict(zip(df_2p['label'].to_numpy(), P_2p_um))

# Cache per dataset
EVAL_CACHE = {}
EVAL_EXPORT_DIR = QC_OUTPUT_DIR / 'eval'
EVAL_EXPORT_DIR.mkdir(parents=True, exist_ok=True)
# Add current run as 'ANTs' from in-memory variables if available
if all(k in globals() for k in ('conf_labels_2p','df_conf','P_conf_in_2p_um')):
    EVAL_CACHE['ants'] = {
        'name': 'ANTs',
        'conf_arr': conf_labels_2p,
        'df_conf': df_conf,
        'P_conf_um': P_conf_in_2p_um,
        'conf_lut': dict(zip(df_conf['label'].to_numpy(), P_conf_in_2p_um)),
    }
for key, meta in EVAL_DATASETS.items():
    path = meta.get('conf_labels_2p_path')
    if not path or not Path(path).exists():
        print(f"[WARN] Skipping '{key}' — missing file: {path}")
        continue
    arr = _load_labels_any(path)
    df_conf_ds = compute_centroids(arr)
    P_conf_ds  = idx_to_um(df_conf_ds, VOX_2P)  # already in 2P grid
    conf_lut   = dict(zip(df_conf_ds['label'].to_numpy(), P_conf_ds))
    EVAL_CACHE[key] = {
        'name': meta.get('name', key),
        'conf_arr': arr,
        'df_conf': df_conf_ds,
        'P_conf_um': P_conf_ds,
        'conf_lut': conf_lut,
    }

print('Eval datasets cached:', [f"{k}({v['name']})" for k,v in EVAL_CACHE.items()])


In [None]:
# --- Comparative summary across datasets (frozen eval_pairs) ---
import numpy as np, pandas as pd

def _eval_distances_for_pairs(conf_lut: dict, twoP_lut: dict, pairs_df: pd.DataFrame, drop_missing=True):
    ds = []
    miss = 0
    for _, r in pairs_df.iterrows():
        a = conf_lut.get(int(r['conf_label']))
        b = twoP_lut.get(int(r['twoP_label']))
        if a is None or b is None:
            miss += 1
            if not drop_missing:
                ds.append(np.nan)
            continue
        dz, dy, dx = a[0]-b[0], a[1]-b[1], a[2]-b[2]
        ds.append(float(np.sqrt(dz*dz + dy*dy + dx*dx)))
    arr = np.asarray(ds, dtype=float)
    if drop_missing:
        arr = arr[np.isfinite(arr)]
    return arr, miss

assert 'EVAL_CACHE' in globals() and 'eval_pairs' in globals(), 'Run precompute + freeze cells first.'
_twoP_lut = dict(zip(df_2p['label'].to_numpy(), P_2p_um))

rows = []
_all_tidy = []
for key, payload in EVAL_CACHE.items():
    dists, dropped = _eval_distances_for_pairs(payload['conf_lut'], _twoP_lut, eval_pairs, drop_missing=True)
    stats = {
        'dataset': payload['name'],
        'n': int(dists.size),
        'median': float(np.median(dists)) if dists.size else 0.0,
        'p90': float(np.percentile(dists, 90)) if dists.size else 0.0,
        'mean': float(np.mean(dists)) if dists.size else 0.0,
        'max': float(np.max(dists)) if dists.size else 0.0,
        'dropped_pairs': int(dropped),
    }
    rows.append(stats)
    _all_tidy += [{'dataset': payload['name'], 'distance_um': float(x)} for x in dists]

compare_df = pd.DataFrame(rows).sort_values('dataset')
print('Comparative summary (final 1–1 good, frozen pairs):')
try:
    display(compare_df)
except Exception:
    print(compare_df.to_string(index=False))

dist_by_dataset = pd.DataFrame(_all_tidy)


In [None]:
# --- 3D viewer: switch datasets (confocal) + fixed 2P background ---
import numpy as np
import plotly.graph_objects as go
from skimage.measure import marching_cubes

assert 'EVAL_CACHE' in globals() and len(EVAL_CACHE) > 0, 'Run precompute datasets cell.'

# Colors
CONF_COLOR = '#f254a6'  # confocal (magenta)
TWO_P_COLOR = '#33a6ff' # 2P (azure)
PAIR_LINE_COLOR = 'red'
PAIR_LINE_WIDTH = 5
OPACITY = 0.10
STEP_SIZE = 1

# Voxel spacing (µm)
dz = float(VOX_2P.get('dz', 1.0))
dy = float(VOX_2P.get('dy', 1.0))
dx = float(VOX_2P.get('dx', 1.0))

# Build 2P background mesh once
mask_2p = (masks_2p > 0)
if np.any(mask_2p):
    vT, fT, _, _ = marching_cubes(mask_2p.astype(np.uint8), level=0.5, spacing=(dz, dy, dx), step_size=STEP_SIZE)
    iT, jT, kT = fT.T.astype(np.int32, copy=False)
    zT, yT, xT = vT[:, 0], vT[:, 1], vT[:, 2]
    t_twoP = go.Mesh3d(x=xT, y=yT, z=zT, i=iT, j=jT, k=kT, name='2P mask', color=TWO_P_COLOR, opacity=OPACITY, lighting=dict(ambient=0.5))
else:
    t_twoP = go.Mesh3d(x=[], y=[], z=[], i=[], j=[], k=[], name='2P mask', color=TWO_P_COLOR, opacity=OPACITY)

# 2P centroids
Z2, Y2, X2 = P_2p_um[:,0], P_2p_um[:,1], P_2p_um[:,2]
pts_twoP = go.Scatter3d(x=X2, y=Y2, z=Z2, mode='markers', name='2P centroids', marker=dict(size=2, color=TWO_P_COLOR), showlegend=True)

# Build per-dataset traces (conf mesh, conf centroids, pair lines)
all_traces = [t_twoP, pts_twoP]
trace_groups = {}

# LUTs
_twoP_lut = dict(zip(df_2p['label'].to_numpy(), P_2p_um))

for key, payload in EVAL_CACHE.items():
    name = payload['name']
    arr = payload['conf_arr']
    conf_mask = (arr > 0)
    if np.any(conf_mask):
        vC, fC, _, _ = marching_cubes(conf_mask.astype(np.uint8), level=0.5, spacing=(dz, dy, dx), step_size=STEP_SIZE)
        iC, jC, kC = fC.T.astype(np.int32, copy=False)
        zC, yC, xC = vC[:, 0], vC[:, 1], vC[:, 2]
        t_conf = go.Mesh3d(x=xC, y=yC, z=zC, i=iC, j=jC, k=kC, name=f'{name} conf mask', color=CONF_COLOR, opacity=OPACITY, lighting=dict(ambient=0.5))
    else:
        t_conf = go.Mesh3d(x=[], y=[], z=[], i=[], j=[], k=[], name=f'{name} conf mask', color=CONF_COLOR, opacity=OPACITY)

    P_conf = payload['P_conf_um']
    Zc, Yc, Xc = P_conf[:,0], P_conf[:,1], P_conf[:,2]
    pts_conf = go.Scatter3d(x=Xc, y=Yc, z=Zc, mode='markers', name=f'{name} conf centroids', marker=dict(size=2, color=CONF_COLOR), showlegend=True)

    # Pair lines using frozen eval_pairs
    xl, yl, zl = [], [], []
    conf_lut = payload['conf_lut']
    for _, r in eval_pairs.iterrows():
        a = conf_lut.get(int(r['conf_label']))
        b = _twoP_lut.get(int(r['twoP_label']))
        if a is None or b is None:
            continue
        x0, y0, z0 = a[2], a[1], a[0]
        x1, y1, z1 = b[2], b[1], b[0]
        xl += [x0, x1, None]; yl += [y0, y1, None]; zl += [z0, z1, None]
    pair_lines = go.Scatter3d(x=xl, y=yl, z=zl, mode='lines', name=f'{name} pairs', line=dict(color=PAIR_LINE_COLOR, width=PAIR_LINE_WIDTH), hoverinfo='skip', showlegend=True)

    idx0 = len(all_traces)
    all_traces += [t_conf, pts_conf, pair_lines]
    trace_groups[key] = [idx0, idx0+1, idx0+2]

# Initial visibility: 2P background + first dataset
visible = [True, True] + [False]*(len(all_traces)-2)
first_key = next(iter(trace_groups))
for i in trace_groups[first_key]:
    visible[i] = True

fig = go.Figure(data=all_traces)

# Dropdown to switch dataset
buttons = []
for key, idxs in trace_groups.items():
    vis = [True, True] + [False]*(len(all_traces)-2)
    for i in idxs:
        vis[i] = True
    buttons.append(dict(label=EVAL_CACHE[key]['name'], method='update', args=[{'visible': vis}, {'title': f"3D view — dataset: {EVAL_CACHE[key]['name']}"}]))

fig.update_layout(
    width=1400, height=900,
    title=f"3D view — dataset: {EVAL_CACHE[first_key]['name']}",
    scene=dict(xaxis_title='x (µm)', yaxis_title='y (µm)', zaxis_title='z (µm)', aspectmode='data'),
    legend=dict(x=0.02, y=0.95, font=dict(size=10)),
    updatemenus=[dict(type='dropdown', direction='down', x=1.05, y=0.95, showactive=True, xanchor='left', yanchor='top', buttons=buttons)],
)

fig.show()


In [None]:
# --- Violin: Baseline, BigWarp, ANTs (10 µm ticks, offset labels) ---
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from math import ceil

assert 'dist_by_dataset' in globals() and not dist_by_dataset.empty, 'Run comparative summary cell first.'

# Enforce order
preferred = ['Baseline', 'BigWarp', 'ANTs']
cats_all = list(dist_by_dataset['dataset'].unique())
cats = [c for c in preferred if c in cats_all] + [c for c in cats_all if c not in preferred]

# Gather series and stats
series, stats = [], []
for c in cats:
    vals = dist_by_dataset.loc[dist_by_dataset['dataset'] == c, 'distance_um'].to_numpy(float)
    series.append(vals)
    stats.append({'dataset': c, 'n': int(vals.size), 'median': float(np.median(vals)) if vals.size else 0.0})

# Y-axis in 10 µm increments
all_vals = np.concatenate(series) if any(s.size for s in series) else np.array([])
y_max_data = float(np.max(all_vals)) if all_vals.size else 10.0
y_max = max(10.0, 10.0 * ceil(y_max_data / 10.0))
yticks = np.arange(0.0, y_max + 0.1, 10.0)

plt.figure(figsize=(8, 4.5))
parts = plt.violinplot(series, showmeans=False, showmedians=False, showextrema=False)

# Style
for pc in parts['bodies']:
    pc.set_facecolor('#87bfff')
    pc.set_edgecolor('black')
    pc.set_alpha(0.7)

# Offset annotations for legibility
x_offset = 0.20
y_offset = 0.02 * y_max
bbox_style = dict(facecolor='white', alpha=0.7, edgecolor='none', pad=1.0)

# Median markers + labels
for i, vals in enumerate(series, start=1):
    if vals.size:
        med = float(np.median(vals))
        plt.scatter([i], [med], color='crimson', zorder=3, s=28)
        plt.text(i + x_offset, med + y_offset,
                 f"median={med:.2f} µm\nn={vals.size}",
                 va='bottom', ha='left', fontsize=8, bbox=bbox_style, clip_on=False, zorder=4)

plt.xticks(range(1, len(cats) + 1), cats)
plt.yticks(yticks)
plt.ylim(0, y_max)
plt.xlim(0.5, len(cats) + 0.8)  # right margin so labels don’t clip
plt.ylabel('distance (µm)')
plt.title('Final 1–1 good pairs — distance per dataset')
plt.grid(axis='y', alpha=0.2)
plt.tight_layout()
plt.show()

# Print a compact median table in the same order
median_table = pd.DataFrame(stats, columns=['dataset', 'n', 'median'])
try:
    display(median_table)
except Exception:
    print(median_table.to_string(index=False))


In [None]:
# --- Axis-wise centroid errors: XY vs Z (Baseline, BigWarp, ANTs) ---
# Computes per-pair XY planar error and Z error (in µm) between confocal and 2P centroids,
# using the frozen eval_pairs and the cached datasets. Plots side-by-side violins.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from math import ceil

assert 'eval_pairs' in globals(), "Run the freeze-eval-pairs cell first."
assert 'df_2p' in globals() and 'P_2p_um' in globals(), "Need 2P centroids/coords."
# Ensure ANTs (current run) is present in EVAL_CACHE
if 'EVAL_CACHE' not in globals():
    EVAL_CACHE = {}
EVAL_EXPORT_DIR = QC_OUTPUT_DIR / 'eval'
EVAL_EXPORT_DIR.mkdir(parents=True, exist_ok=True)
if 'ants' not in EVAL_CACHE and all(k in globals() for k in ('conf_labels_2p','df_conf','P_conf_in_2p_um')):
    EVAL_CACHE['ants'] = {
        'name': 'ANTs',
        'conf_arr': conf_labels_2p,
        'df_conf': df_conf,
        'P_conf_um': P_conf_in_2p_um,
        'conf_lut': dict(zip(df_conf['label'].to_numpy(), P_conf_in_2p_um)),
    }

# Build 2P LUT once
twoP_lut = dict(zip(df_2p['label'].to_numpy(), P_2p_um))

# Helper: compute XY and Z errors for a dataset (drop pairs missing on either side)
def axis_errors_for_dataset(conf_lut: dict, twoP_lut: dict, pairs_df: pd.DataFrame) -> tuple[np.ndarray, np.ndarray, int]:
    xy_list, z_list = [], []
    dropped = 0
    for _, r in pairs_df.iterrows():
        a = conf_lut.get(int(r['conf_label']))  # (z,y,x) in µm
        b = twoP_lut.get(int(r['twoP_label']))  # (z,y,x) in µm
        if a is None or b is None:
            dropped += 1
            continue
        dz = float(a[0] - b[0])
        dy = float(a[1] - b[1])
        dx = float(a[2] - b[2])
        xy = float(np.sqrt(dy*dy + dx*dx))
        z  = float(abs(dz))
        xy_list.append(xy)
        z_list.append(z)
    return np.asarray(xy_list, float), np.asarray(z_list, float), dropped

# Desired order (only include those present)
order_conf = ['Baseline', 'BigWarp', 'ANTs']
present_order = [nm for nm in order_conf if any(v.get('name', k) == nm for k, v in EVAL_CACHE.items())]

# Build tidy dataframe of errors
rows = []
for key, payload in EVAL_CACHE.items():
    name = payload.get('name', key)
    if name not in order_conf:
        continue
    xy, z, dropped = axis_errors_for_dataset(payload['conf_lut'], twoP_lut, eval_pairs)
    rows += [{'dataset': name, 'xy_um': float(v), 'z_um': np.nan} for v in xy]
    rows += [{'dataset': name, 'xy_um': np.nan,      'z_um': float(v)} for v in z]

dist_axis = pd.DataFrame(rows)
if dist_axis.empty:
    raise RuntimeError("No axis-wise errors computed. Check EVAL_CACHE and eval_pairs.")

# Split per-axis series in desired order
def collect_series(df: pd.DataFrame, col: str, cats: list[str]) -> list[np.ndarray]:
    return [df.loc[df['dataset'] == c, col].dropna().to_numpy(float) for c in cats]

series_xy = collect_series(dist_axis, 'xy_um', present_order)
series_z  = collect_series(dist_axis, 'z_um',  present_order)

# Axis titles
labels = present_order  # e.g., ['Baseline','BigWarp','ANTs']

# Y-axis ticks (make them readable; 2 µm steps up to a 10-multiple ceiling)
def y_ticks_from_series(series_list):
    all_vals = np.concatenate([s for s in series_list if s.size]) if any(s.size for s in series_list) else np.array([])
    ymax_data = float(np.max(all_vals)) if all_vals.size else 10.0
    ymax = max(10.0, 10.0 * ceil(ymax_data / 10.0))
    step = 2.0 if ymax <= 20 else 5.0
    yticks = np.arange(0.0, ymax + 0.1, step)
    return ymax, yticks

xy_ymax, xy_yticks = y_ticks_from_series(series_xy)
z_ymax,  z_yticks  = y_ticks_from_series(series_z)

# Plot side-by-side violins: XY error (left), Z error (right)
fig, axes = plt.subplots(1, 2, figsize=(12, 4.2), sharey=False)

for ax, ser, title, ymax, yticks in [
    (axes[0], series_xy, 'XY centroid error (µm)', xy_ymax, xy_yticks),
    (axes[1], series_z,  'Z centroid error (µm)',  z_ymax,  z_yticks),
]:
    parts = ax.violinplot(ser, showmeans=False, showmedians=False, showextrema=False)
    # Style
    for pc in parts['bodies']:
        pc.set_facecolor('#87bfff')
        pc.set_edgecolor('black')
        pc.set_alpha(0.7)
    # Median markers + offset labels
    x_offset = 0.18
    y_offset = 0.02 * ymax
    bbox_style = dict(facecolor='white', alpha=0.7, edgecolor='none', pad=1.0)
    for i, vals in enumerate(ser, start=1):
        if vals.size:
            med = float(np.median(vals))
            ax.scatter([i], [med], color='crimson', zorder=3, s=26)
            ax.text(i + x_offset, med + y_offset, f"median={med:.2f} µm\nn={vals.size}",
                    va='bottom', ha='left', fontsize=8, bbox=bbox_style, clip_on=False, zorder=4)
    ax.set_title(title)
    ax.set_xticks(range(1, len(labels) + 1))
    ax.set_xticklabels(labels)
    ax.set_ylim(0, ymax)
    ax.set_yticks(yticks)
    ax.grid(axis='y', alpha=0.2)

fig.suptitle('Axis-wise centroid errors by dataset')
plt.tight_layout()
plt.show()

# Print compact medians (same dataset order)
def med(col, ser_list, names):
    out = []
    for nm, vals in zip(names, ser_list):
        out.append({'dataset': nm, 'n': int(vals.size), 'median_' + col: float(np.median(vals)) if vals.size else np.nan})
    return pd.DataFrame(out)

xy_med = med('xy_um', series_xy, labels)
z_med  = med('z_um',  series_z,  labels)
summary_axis = pd.merge(xy_med, z_med, on=['dataset'], how='outer')
try:
    display(summary_axis)
except Exception:
    print(summary_axis.to_string(index=False))


In [None]:
# --- Mask diameters (µm): 2P vs confocal datasets (Baseline, BigWarp, ANTs) ---
# Computes per-label diameters along Z, Y, X (via bbox extents) and compares distributions.
# Uses eval_pairs to restrict to the same matched labels across datasets.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from math import ceil
from skimage.measure import regionprops_table

assert 'VOX_2P' in globals(), "Need VOX_2P with {'dz','dy','dx'}."
assert 'masks_2p' in globals() and 'df_2p' in globals() and 'P_2p_um' in globals(), "Need 2P labels and centroids."
assert 'eval_pairs' in globals(), "Run the freeze-eval-pairs cell first."
# Confocal datasets should be available in EVAL_CACHE; ensure ANTs(current) is included
if 'EVAL_CACHE' not in globals():
    EVAL_CACHE = {}
EVAL_EXPORT_DIR = QC_OUTPUT_DIR / 'eval'
EVAL_EXPORT_DIR.mkdir(parents=True, exist_ok=True)
if 'ants' not in EVAL_CACHE and all(k in globals() for k in ('conf_labels_2p','df_conf','P_conf_in_2p_um')):
    EVAL_CACHE['ants'] = {
        'name': 'ANTs',
        'conf_arr': conf_labels_2p,
        'df_conf': df_conf,
        'P_conf_um': P_conf_in_2p_um,
        'conf_lut': dict(zip(df_conf['label'].to_numpy(), P_conf_in_2p_um)),
    }

# Helper: diameters (Z,Y,X) in µm from a label array in 2P grid, optionally restricted to a set of labels
def diameters_um_from_array(arr, vox, restrict_labels=None):
    props = regionprops_table(arr, properties=('label','bbox'))
    df = pd.DataFrame(props)
    if df.empty:
        return pd.DataFrame(columns=['label','z_um','y_um','x_um'])
    # bbox indices for 3D come as: bbox-0:zmin, bbox-1:ymin, bbox-2:xmin, bbox-3:zmax, bbox-4:ymax, bbox-5:xmax
    df = df.rename(columns={
        'bbox-0':'zmin','bbox-1':'ymin','bbox-2':'xmin',
        'bbox-3':'zmax','bbox-4':'ymax','bbox-5':'xmax'
    })
    df['z_um'] = (df['zmax'] - df['zmin']) * float(vox['dz'])
    df['y_um'] = (df['ymax'] - df['ymin']) * float(vox['dy'])
    df['x_um'] = (df['xmax'] - df['xmin']) * float(vox['dx'])
    df = df[['label','z_um','y_um','x_um']].copy()
    df['label'] = df['label'].astype(int)
    if restrict_labels is not None:
        keep = np.array(list(set(restrict_labels)), dtype=int)
        df = df[df['label'].isin(keep)].reset_index(drop=True)
    return df

# Restrict to matched IDs for fair comparison
conf_eval_ids = np.unique(eval_pairs['conf_label'].to_numpy(int))
twop_eval_ids = np.unique(eval_pairs['twoP_label'].to_numpy(int))

# 2P diameters (single dataset)
df_2p_diam = diameters_um_from_array(masks_2p, VOX_2P, restrict_labels=twop_eval_ids)
df_2p_diam['dataset'] = '2P'

# Confocal datasets (Baseline, BigWarp, ANTs)
order_conf = ['Baseline','BigWarp','ANTs']  # desired order
conf_dfs = []
present = []
for key, payload in EVAL_CACHE.items():
    name = payload.get('name', key)
    if name not in order_conf:
        continue
    arr = payload['conf_arr']
    df_d = diameters_um_from_array(arr, VOX_2P, restrict_labels=conf_eval_ids)
    df_d['dataset'] = name
    conf_dfs.append(df_d)
    present.append(name)

if not conf_dfs:
    raise RuntimeError("No confocal datasets found in EVAL_CACHE matching Baseline/BigWarp/ANTs.")

# Tidy table for plotting
df_conf_all = pd.concat(conf_dfs, ignore_index=True)
# Column order: Baseline, BigWarp, ANTs (only those present)
cats_conf = [c for c in order_conf if c in present]

# Build series per axis for plotting: [2P, Baseline, BigWarp, ANTs]
def series_for_axis(df_2p, df_conf, axis):
    s = []
    s.append(df_2p[f'{axis}_um'].to_numpy(float))  # 2P
    for c in cats_conf:
        s.append(df_conf.loc[df_conf['dataset']==c, f'{axis}_um'].to_numpy(float))
    return s

series_x = series_for_axis(df_2p_diam, df_conf_all, 'x')
series_y = series_for_axis(df_2p_diam, df_conf_all, 'y')
series_z = series_for_axis(df_2p_diam, df_conf_all, 'z')

# Common y-axis (10 µm ticks)
all_vals = np.concatenate([a for a in series_x + series_y + series_z if a.size]) if any(
    (a.size for a in (series_x + series_y + series_z))) else np.array([])
y_max_data = float(np.max(all_vals)) if all_vals.size else 10.0
from math import ceil
y_max = max(10.0, 10.0 * ceil(y_max_data / 10.0))
yticks = np.arange(0.0, y_max + 0.1, 10.0)

# Labels and colors
cats_plot = ['2P'] + cats_conf  # e.g., ['2P','Baseline','BigWarp','ANTs']
colors = ['#bbbbbb'] + ['#87bfff'] * len(cats_conf)  # grey for 2P, blue for conf

# Plot 3 subplots: X, Y, Z
fig, axes = plt.subplots(1, 3, figsize=(20, 4.2), sharey=True)
for ax, ser, title in zip(axes, [series_x, series_y, series_z], ['X diameter (µm)', 'Y diameter (µm)', 'Z diameter (µm)']):
    parts = ax.violinplot(ser, showmeans=False, showmedians=False, showextrema=False)
    # Color violins
    for i, pc in enumerate(parts['bodies']):
        pc.set_facecolor(colors[i])
        pc.set_edgecolor('black')
        pc.set_alpha(0.7)
    # Medians + n (offset for clarity)
    x_offset = 0.18
    y_offset = 0.02 * y_max
    bbox_style = dict(facecolor='white', alpha=0.7, edgecolor='none', pad=1.0)
    for i, vals in enumerate(ser, start=1):
        if vals.size:
            med = float(np.median(vals))
            ax.scatter([i], [med], color='crimson', zorder=3, s=26)
            ax.text(i + x_offset, med + y_offset, f"median={med:.2f} µm\nn={vals.size}",
                    va='bottom', ha='left', fontsize=8, bbox=bbox_style, clip_on=False, zorder=4)
    ax.set_title(title)
    ax.set_xticks(range(1, len(cats_plot)+1))
    ax.set_xticklabels(cats_plot, rotation=0)
    ax.set_ylim(0, y_max)
    ax.set_yticks(yticks)
    ax.grid(axis='y', alpha=0.2)

fig.suptitle('Mask diameters by axis — 2P vs confocal datasets')
plt.tight_layout()
plt.show()

# Print compact medians (in Baseline, BigWarp, ANTs order for conf; plus 2P)
def med_row(name, df, prefix=''):
    if df.empty:
        return {'dataset': name, 'x_median': np.nan, 'y_median': np.nan, 'z_median': np.nan, 'n': 0}
    return {
        'dataset': name,
        'x_median': float(np.median(df['x_um'])) if len(df['x_um']) else np.nan,
        'y_median': float(np.median(df['y_um'])) if len(df['y_um']) else np.nan,
        'z_median': float(np.median(df['z_um'])) if len(df['z_um']) else np.nan,
        'n': int(len(df))
    }

rows = [med_row('2P', df_2p_diam)]
for c in cats_conf:
    rows.append(med_row(c, df_conf_all.loc[df_conf_all['dataset']==c]))
med_table = pd.DataFrame(rows)
try:
    display(med_table)
except Exception:
    print(med_table.to_string(index=False))


In [None]:
# --- Summary for accepted_pairs_after_dedup (within_gate == True) ---
import numpy as np, json

assert 'matches' in globals() and 'summarize_distances' in globals(), "Need `matches` and `summarize_distances`."

accepted_mask = matches['within_gate'].to_numpy()
accepted_dists = matches.loc[accepted_mask, 'distance_um'].to_numpy(float)

# Compute summary on accepted-only distances (n = number of accepted pairs)
summary_accepted_after_dedup = summarize_distances(
    accepted_dists,
    np.ones_like(accepted_dists, dtype=bool)
)

print(f"accepted_pairs_after_dedup: {accepted_mask.sum()}")
print("Summary (accepted-only):")
print(json.dumps(summary_accepted_after_dedup, indent=2))

# Optional aliases if you want to reuse elsewhere
accepted_summary = summary_accepted_after_dedup
summary_within_gate_only = summary_accepted_after_dedup


In [None]:
# --- Publish to downstream plots under legacy names ---
scope = 'final_1to1' if PLOT_USE_FINAL_1TO1 and ('final_pairs' in globals()) else 'within_gate'
if scope == 'final_1to1':
    plot_df = final_pairs.copy()
    key_final = set(map(tuple, plot_df[['conf_label','twoP_label']].to_numpy()))
    accepted_mask = matches[['conf_label','twoP_label']].apply(tuple, axis=1).isin(key_final).to_numpy()
else:
    plot_df = matches.loc[matches['within_gate']].copy()
    accepted_mask = matches['within_gate'].to_numpy()

matches_for_plots = plot_df
dists_all = matches['distance_um'].to_numpy()
dists_filtered = matches_for_plots['distance_um'].to_numpy()
summary_stats = summarize_distances(dists_all, accepted_mask)
distance_summary = summary_stats

# aliases
pair_table = matches_for_plots
pairings = matches_for_plots
pairs_df = matches_for_plots
pairs = matches_for_plots
match_df = matches_for_plots

dists = dists_all
distances = dists_all
accepted_dists = dists_filtered
distances_filtered = dists_filtered

valid_mask = accepted_mask
within_gate_mask = accepted_mask

print(f'Published scope for plots: {scope}')
print(f'- pairs: {len(matches_for_plots)} rows')
print(f'- dists_all: {dists_all.size} | dists_filtered: {dists_filtered.size}')


In [None]:

# --- Export masks/labels/overlay for accepted pairs ---
import numpy as np
import tifffile as tiff
from pathlib import Path

# Validate prerequisites
needed = ['conf_labels_2p','masks_2p','VOX_2P']
missing = [n for n in needed if n not in globals()]
if missing:
    raise RuntimeError(f"Missing required variables in notebook scope: {missing}")
if conf_labels_2p.shape != masks_2p.shape:
    raise ValueError(f"Shape mismatch: conf_labels_2p {conf_labels_2p.shape} vs masks_2p {masks_2p.shape}")

# Choose scope from published outputs (matches_for_plots preferred)
if 'matches_for_plots' in globals():
    scope_df = matches_for_plots.copy()
elif 'matches' in globals():
    scope_df = matches[matches['within_gate']].copy()
else:
    raise RuntimeError('No matches or matches_for_plots available for export scope.')

if scope_df.empty:
    print('No accepted matches to export. Adjust gates or recompute matches.')
else:
    conf_ids = np.unique(scope_df['conf_label'].to_numpy(dtype=int))
    twoP_ids = np.unique(scope_df['twoP_label'].to_numpy(dtype=int))
    conf_ids = conf_ids[conf_ids != 0]
    twoP_ids = twoP_ids[twoP_ids != 0]

    # Boolean masks (ZYX)
    conf_within_mask = np.isin(conf_labels_2p, conf_ids)
    twoP_within_mask = np.isin(masks_2p, twoP_ids)

    out_dir = QC_OUTPUT_DIR / 'tiff_exports'
    out_dir.mkdir(parents=True, exist_ok=True)

    dz_um = float(VOX_2P.get('dz', 1.0))

    # 1) Binary masks (8-bit)
    tiff.imwrite(
        out_dir / 'conf_within_mask.tif',
        (conf_within_mask.astype(np.uint8) * 255),
        imagej=True,
        metadata={'axes': 'ZYX', 'spacing': dz_um, 'unit': 'um'},
        compression='deflate',
    )
    tiff.imwrite(
        out_dir / 'twoP_within_mask.tif',
        (twoP_within_mask.astype(np.uint8) * 255),
        imagej=True,
        metadata={'axes': 'ZYX', 'spacing': dz_um, 'unit': 'um'},
        compression='deflate',
    )

    # 2) Label stacks (preserve IDs; background=0)
    conf_within_labels = np.where(conf_within_mask, conf_labels_2p, 0)
    twoP_within_labels = np.where(twoP_within_mask, masks_2p, 0)

    def min_unsigned_dtype(max_val: int):
        import numpy as _np
        if max_val <= _np.iinfo(_np.uint16).max:
            return _np.uint16
        elif max_val <= _np.iinfo(_np.uint32).max:
            return _np.uint32
        else:
            return _np.uint64

    conf_dtype = min_unsigned_dtype(int(conf_within_labels.max()))
    twoP_dtype = min_unsigned_dtype(int(twoP_within_labels.max()))
    conf_within_labels = conf_within_labels.astype(conf_dtype, copy=False)
    twoP_within_labels = twoP_within_labels.astype(twoP_dtype, copy=False)

    def imwrite_with_meta(path, arr):
        import numpy as _np
        if arr.dtype in (_np.uint8, _np.uint16):
            tiff.imwrite(
                path, arr, imagej=True,
                metadata={'axes': 'ZYX', 'spacing': dz_um, 'unit': 'um'},
                compression='deflate',
            )
        else:
            tiff.imwrite(path, arr, compression='deflate')

    imwrite_with_meta(out_dir / 'conf_within_labels.tif', conf_within_labels)
    imwrite_with_meta(out_dir / 'twoP_within_labels.tif', twoP_within_labels)

    # 3) RGB overlay (conf=magenta, 2P=azure)
    WRITE_OVERLAY = True
    if WRITE_OVERLAY:
        rgb = np.zeros(conf_within_mask.shape + (3,), dtype=np.uint8)
        rgb[..., 0] = np.where(conf_within_mask, 242, 0)
        rgb[..., 1] = np.where(conf_within_mask,  84, 0)
        rgb[..., 2] = np.where(conf_within_mask, 166, 0)
        rgb[..., 0] = np.clip(rgb[..., 0] + np.where(twoP_within_mask,  51, 0), 0, 255)
        rgb[..., 1] = np.clip(rgb[..., 1] + np.where(twoP_within_mask, 166, 0), 0, 255)
        rgb[..., 2] = np.clip(rgb[..., 2] + np.where(twoP_within_mask, 255, 0), 0, 255)
        tiff.imwrite(out_dir / 'within_gate_overlay_rgb.tif', rgb, photometric='rgb', compression='deflate')

    print(f'Saved exports to: {out_dir.resolve()}')


In [None]:
# PLOTLY_3D_MATCH_VIEW
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from skimage.measure import marching_cubes

# Tunables
OPACITY = globals().get('OPACITY', 0.10)
STEP_SIZE = globals().get('STEP_SIZE', 1)
PAIR_LINE_COLOR = globals().get('PAIR_LINE_COLOR', 'red')
PAIR_LINE_WIDTH = globals().get('PAIR_LINE_WIDTH', 5)

# Validate deps
need = ['conf_labels_2p','masks_2p','VOX_2P','df_conf','df_2p','P_conf_in_2p_um','P_2p_um']
missing = [n for n in need if n not in globals()]
if missing:
    raise RuntimeError(f'Missing prerequisites: {missing}')

# Choose which pairs to visualize
df_pairs = None
if 'matches_for_plots' in globals():
    df_pairs = matches_for_plots.copy()
elif 'matches' in globals():
    df_pairs = matches.copy()
else:
    raise RuntimeError('Need matches or matches_for_plots for plotting.')

dz, dy, dx = float(VOX_2P['dz']), float(VOX_2P['dy']), float(VOX_2P['dx'])

def map_labels_to_values(label_vol, mapping, default_value, dtype):
    max_label = int(label_vol.max())
    lut = np.full(max_label + 1, default_value, dtype=dtype)
    for k, v in mapping.items():
        k = int(k)
        if 0 < k <= max_label:
            lut[k] = v
    return lut[label_vol]

# Fallback if within_gate column is absent
wg_col = 'within_gate' if 'within_gate' in df_pairs.columns else None
if wg_col is None:
    df_pairs = df_pairs.copy()
    df_pairs['_within_view'] = True
    wg_col = '_within_view'

# Lookup per-label best pair
conf_lookup = (
    df_pairs.sort_values('distance_um')
            .drop_duplicates('conf_label', keep='first')
            .set_index('conf_label')
)
twoP_lookup = (
    df_pairs.sort_values('distance_um')
            .drop_duplicates('twoP_label', keep='first')
            .set_index('twoP_label')
)

conf_within_map = conf_lookup[wg_col].to_dict()
twoP_within_map = twoP_lookup[wg_col].to_dict()
conf_dist_map = conf_lookup['distance_um'].to_dict()
twoP_dist_map = twoP_lookup['distance_um'].to_dict()

conf_mask_all = (conf_labels_2p > 0)
twoP_mask_all = (masks_2p > 0)

# Distances; cap for coloring
all_d = df_pairs['distance_um'].to_numpy(dtype=float)
d_cap = float(np.nanmax([10.0, np.nanpercentile(all_d, 95)])) if np.isfinite(all_d).any() else 10.0

conf_within_vol = map_labels_to_values(conf_labels_2p, conf_within_map, False, np.bool_)
twoP_within_vol = map_labels_to_values(masks_2p,      twoP_within_map, False, np.bool_)
conf_dist_vol   = np.clip(map_labels_to_values(conf_labels_2p, conf_dist_map, 0.0, np.float32), 0.0, d_cap)
twoP_dist_vol   = np.clip(map_labels_to_values(masks_2p,      twoP_dist_map, 0.0, np.float32), 0.0, d_cap)

# Marching cubes surfaces (in µm coords)
def build_surface(mask):
    if not np.any(mask):
        return np.array([]), np.array([]), np.array([]), np.array([]), np.array([]), np.array([])
    verts, faces, _, _ = marching_cubes(mask.astype(np.uint8), level=0.5, spacing=(dz, dy, dx), step_size=STEP_SIZE)
    i, j, k = faces.T.astype(np.int32, copy=False)
    zc, yc, xc = verts[:, 0], verts[:, 1], verts[:, 2]
    return xc, yc, zc, i, j, k

xC, yC, zC, iC, jC, kC = build_surface(conf_mask_all)
xT, yT, zT, iT, jT, kT = build_surface(twoP_mask_all)

def sample_nearest(vol, z_um, y_um, x_um):
    if vol.size == 0 or len(z_um) == 0:
        return np.array([], dtype=vol.dtype)
    zi = np.clip(np.round(z_um / dz).astype(int), 0, vol.shape[0] - 1)
    yi = np.clip(np.round(y_um / dy).astype(int), 0, vol.shape[1] - 1)
    xi = np.clip(np.round(x_um / dx).astype(int), 0, vol.shape[2] - 1)
    return vol[zi, yi, xi]

# Intensities per metric
conf_int_within = sample_nearest(conf_within_vol, zC, yC, xC).astype(float)
twoP_int_within = sample_nearest(twoP_within_vol, zT, yT, xT).astype(float)
conf_int_dist   = sample_nearest(conf_dist_vol,   zC, yC, xC).astype(float)
twoP_int_dist   = sample_nearest(twoP_dist_vol,   zT, yT, xT).astype(float)
conf_int_z      = zC.astype(float)
twoP_int_z      = zT.astype(float)

WITHIN_CS = [[0.0, '#e74c3c'], [1.0, '#2ecc71']]  # red/green
DIST_CS   = 'Viridis'
Z_CS      = 'Plasma'
CONF_COLOR = '#f254a6'
TWO_P_COLOR = '#33a6ff'

def mesh_metric(x, y, z, i, j, k, inten, name, colorscale, cmin, cmax, showscale=False):
    if len(x) == 0:
        return go.Mesh3d(x=[], y=[], z=[], i=[], j=[], k=[], name=name, opacity=OPACITY, showlegend=True)
    return go.Mesh3d(
        x=x, y=y, z=z, i=i, j=j, k=k,
        name=name,
        intensity=inten,
        colorscale=colorscale,
        cmin=cmin, cmax=cmax,
        showscale=showscale,
        opacity=OPACITY,
        lighting=dict(ambient=0.5),
    )

def mesh_solid(x, y, z, i, j, k, color, name):
    if len(x) == 0:
        return go.Mesh3d(x=[], y=[], z=[], i=[], j=[], k=[], name=name, opacity=OPACITY, showlegend=True)
    return go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, name=name, color=color, opacity=OPACITY, lighting=dict(ambient=0.5))

# Centroid traces per metric
conf_z, conf_y, conf_x = P_conf_in_2p_um[:,0], P_conf_in_2p_um[:,1], P_conf_in_2p_um[:,2]
twoP_z, twoP_y, twoP_x = P_2p_um[:,0], P_2p_um[:,1], P_2p_um[:,2]

def centroid_metric_traces(metric):
    if metric == 'within_gate':
        conf_vals = df_conf['label'].map(conf_within_map).fillna(False).to_numpy().astype(float)
        twoP_vals = df_2p['label'].map(twoP_within_map).fillna(False).to_numpy().astype(float)
        cs, cmin, cmax = WITHIN_CS, 0.0, 1.0
    elif metric == 'distance_um':
        conf_vals = df_conf['label'].map(conf_dist_map).fillna(0.0).to_numpy(dtype=float)
        twoP_vals = df_2p['label'].map(twoP_dist_map).fillna(0.0).to_numpy(dtype=float)
        cs, cmin, cmax = DIST_CS, 0.0, d_cap
    elif metric == 'z_um':
        conf_vals = conf_z
        twoP_vals = twoP_z
        zmin = float(min(conf_z.min(), twoP_z.min())) if P_conf_in_2p_um.size and P_2p_um.size else 0.0
        zmax = float(max(conf_z.max(), twoP_z.max())) if P_conf_in_2p_um.size and P_2p_um.size else 1.0
        cs, cmin, cmax = Z_CS, zmin, zmax
    else:  # source
        cs, cmin, cmax = None, None, None

    if metric == 'source':
        c_conf = go.Scatter3d(x=conf_x, y=conf_y, z=conf_z, mode='markers', name='Conf centroids', marker=dict(size=2, color=CONF_COLOR), showlegend=True)
        c_twoP = go.Scatter3d(x=twoP_x, y=twoP_y, z=twoP_z, mode='markers', name='2P centroids',   marker=dict(size=2, color=TWO_P_COLOR), showlegend=True)
    else:
        c_conf = go.Scatter3d(x=conf_x, y=conf_y, z=conf_z, mode='markers', name='Conf centroids', marker=dict(size=2, color=conf_vals, colorscale=cs, cmin=cmin, cmax=cmax, colorbar=dict(title=metric, len=0.5)), showlegend=True)
        c_twoP = go.Scatter3d(x=twoP_x, y=twoP_y, z=twoP_z, mode='markers', name='2P centroids',   marker=dict(size=2, color=twoP_vals, colorscale=cs, cmin=cmin, cmax=cmax, showscale=False), showlegend=True)
    return c_conf, c_twoP

all_traces = []
trace_index = {}

# within_gate metric
t_conf_w = mesh_metric(xC, yC, zC, iC, jC, kC, conf_int_within, 'Confocal mask', WITHIN_CS, 0, 1, showscale=True)
t_twoP_w = mesh_metric(xT, yT, zT, iT, jT, kT, twoP_int_within, '2P mask',      WITHIN_CS, 0, 1, showscale=False)
idx0 = len(all_traces); all_traces += [t_conf_w, t_twoP_w]
c_conf_w, c_twoP_w = centroid_metric_traces('within_gate')
idx1 = len(all_traces); all_traces += [c_conf_w, c_twoP_w]
trace_index['within_gate'] = [idx0, idx0 + 1, idx1, idx1 + 1]

# distance metric
t_conf_d = mesh_metric(xC, yC, zC, iC, jC, kC, conf_int_dist, 'Confocal mask', DIST_CS, 0, d_cap, showscale=True)
t_twoP_d = mesh_metric(xT, yT, zT, iT, jT, kT, twoP_int_dist, '2P mask',      DIST_CS, 0, d_cap, showscale=False)
idx0 = len(all_traces); all_traces += [t_conf_d, t_twoP_d]
c_conf_d, c_twoP_d = centroid_metric_traces('distance_um')
idx1 = len(all_traces); all_traces += [c_conf_d, c_twoP_d]
trace_index['distance_um'] = [idx0, idx0 + 1, idx1, idx1 + 1]

# z metric
zmin = float(min(conf_int_z.min(), twoP_int_z.min())) if (len(conf_int_z) and len(twoP_int_z)) else 0.0
zmax = float(max(conf_int_z.max(), twoP_int_z.max())) if (len(conf_int_z) and len(twoP_int_z)) else 1.0
t_conf_z = mesh_metric(xC, yC, zC, iC, jC, kC, conf_int_z, 'Confocal mask', Z_CS, zmin, zmax, showscale=True)
t_twoP_z = mesh_metric(xT, yT, zT, iT, jT, kT, twoP_int_z, '2P mask',      Z_CS, zmin, zmax, showscale=False)
idx0 = len(all_traces); all_traces += [t_conf_z, t_twoP_z]
c_conf_z, c_twoP_z = centroid_metric_traces('z_um')
idx1 = len(all_traces); all_traces += [c_conf_z, c_twoP_z]
trace_index['z_um'] = [idx0, idx0 + 1, idx1, idx1 + 1]

# source metric (solid colors)
t_conf_s = mesh_solid(xC, yC, zC, iC, jC, kC, CONF_COLOR, 'Confocal mask')
t_twoP_s = mesh_solid(xT, yT, zT, iT, jT, kT, TWO_P_COLOR, '2P mask')
idx0 = len(all_traces); all_traces += [t_conf_s, t_twoP_s]
c_conf_s, c_twoP_s = centroid_metric_traces('source')
idx1 = len(all_traces); all_traces += [c_conf_s, c_twoP_s]
trace_index['source'] = [idx0, idx0 + 1, idx1, idx1 + 1]

# Pair lines based on df_pairs mapping (conf_label -> matched twoP)
conf_pts = df_conf[['label']].copy()
conf_pts['z_um'] = conf_z; conf_pts['y_um'] = conf_y; conf_pts['x_um'] = conf_x
conf_pts['matched_twoP_label'] = conf_pts['label'].map(conf_lookup['twoP_label']) if not conf_lookup.empty else np.nan

twoP_pts = df_2p[['label']].copy()
twoP_pts['z_um'] = twoP_z; twoP_pts['y_um'] = twoP_y; twoP_pts['x_um'] = twoP_x
twoP_map = twoP_pts.set_index('label').to_dict('index')

xl, yl, zl = [], [], []
for _, r in conf_pts.iterrows():
    tgt = r['matched_twoP_label']
    if pd.isna(tgt):
        continue
    tgt = int(tgt)
    if tgt not in twoP_map:
        continue
    x0, y0, z0 = float(r['x_um']), float(r['y_um']), float(r['z_um'])
    p1 = twoP_map[tgt]
    x1, y1, z1 = float(p1['x_um']), float(p1['y_um']), float(p1['z_um'])
    xl += [x0, x1, None]; yl += [y0, y1, None]; zl += [z0, z1, None]

pair_lines = go.Scatter3d(x=xl, y=yl, z=zl, mode='lines', name='pairs (centroid links)', line=dict(color=PAIR_LINE_COLOR, width=PAIR_LINE_WIDTH), hoverinfo='skip', showlegend=True)
pair_idx = len(all_traces)
all_traces.append(pair_lines)

# Initial visibility: 'within_gate' + pair lines
visible = [False] * len(all_traces)
for i in trace_index['within_gate']:
    visible[i] = True
visible[pair_idx] = True
for i, tr in enumerate(all_traces):
    tr.visible = visible[i]

fig = go.Figure(data=all_traces)

# Dropdown to switch metric (keep pair-line state)
metrics = ['within_gate','distance_um','z_um','source']
buttons = []
for metric in metrics:
    vis = [False] * len(all_traces)
    for i in trace_index[metric]:
        vis[i] = True
    vis[pair_idx] = visible[pair_idx]
    buttons.append(dict(label=f'Color by: {metric}', method='update', args=[{'visible': vis}, {'title': f'Mask+Centroids — color by {metric}'}]))

fig.update_layout(
    width=1500, height=1000,
    title='Mask+Centroids — color by within_gate',
    scene=dict(xaxis_title='x (µm)', yaxis_title='y (µm)', zaxis_title='z (µm)', aspectmode='data'),
    legend=dict(x=0.02, y=0.70, font=dict(size=10)),
    updatemenus=[dict(type='dropdown', direction='down', x=1.05, y=0.95, showactive=True, xanchor='left', yanchor='top', buttons=buttons)],
)

fig.show()


In [None]:
# SLAB_VIEWER_2D
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, FloatSlider, Checkbox

# convenience projections
Z2, Y2, X2 = P_2p_um[:,0], P_2p_um[:,1], P_2p_um[:,2]
Zc, Yc, Xc = P_conf_in_2p_um[:,0], P_conf_in_2p_um[:,1], P_conf_in_2p_um[:,2]

zmin = float(min(Z2.min(), Zc.min())) if P_2p_um.size and P_conf_in_2p_um.size else 0.0
zmax = float(max(Z2.max(), Zc.max())) if P_2p_um.size and P_conf_in_2p_um.size else 1.0
default_thick = 4.0

def _plot_slice(z_um=0.0, thickness_um=default_thick, show_conf=True, show_2p=True, show_pairs=False):
    plt.figure()
    if show_2p:
        m2 = np.abs(Z2 - z_um) <= thickness_um
        plt.scatter(X2[m2], Y2[m2], s=8, label='2P', alpha=0.9)
    if show_conf:
        mc = np.abs(Zc - z_um) <= thickness_um
        plt.scatter(Xc[mc], Yc[mc], s=8, alpha=0.6, label='Conf→2P')
    if show_pairs and 'matches_for_plots' in globals():
        # build label→coord maps
        coord_2p = dict(zip(df_2p['label'].to_numpy(), P_2p_um))
        coord_conf = dict(zip(df_conf['label'].to_numpy(), P_conf_in_2p_um))
        for _, row in matches_for_plots.iterrows():
            a = coord_conf.get(int(row['conf_label']))
            b = coord_2p.get(int(row['twoP_label']))
            if a is None or b is None:
                continue
            if (abs(a[0]-z_um) <= thickness_um) and (abs(b[0]-z_um) <= thickness_um):
                plt.plot([a[2], b[2]], [a[1], b[1]], linewidth=0.5)
    plt.gca().invert_yaxis()
    plt.xlabel('x (µm)'); plt.ylabel('y (µm)')
    plt.title(f'Centroids near z = {z_um:.2f} µm (±{thickness_um:.2f})')
    plt.legend(loc='upper right')
    plt.tight_layout(); plt.show()

interact(
    _plot_slice,
    z_um=FloatSlider(min=zmin, max=zmax, step=0.5, value=(zmin+zmax)/2, description='z (µm)'),
    thickness_um=FloatSlider(min=0.5, max=20.0, step=0.5, value=default_thick, description='slab ±µm'),
    show_conf=Checkbox(value=True, description='show Conf→2P'),
    show_2p=Checkbox(value=True, description='show 2P'),
    show_pairs=Checkbox(value=False, description='show pair lines'),
)


In [None]:
# --- Interactive Histogram: final 1–1 good pairs (fixed bins + fixed Y) ---
# Uses only `final_pairs` (1–1 + IoU “good”), and shows how many are ≤ threshold.

import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, FloatSlider, Checkbox

# Resolve final 1–1 good pairs
if 'final_pairs' in globals():
    fp = final_pairs.copy()
else:
    # Fallback if not defined: try to build from matches
    assert 'matches' in globals(), "Need `final_pairs` or `matches`."
    assert {'pair_type','quality'}.issubset(matches.columns), \
        "If `final_pairs` is missing, `matches` must have 'pair_type' and 'quality'."
    fp = matches[(matches['pair_type'] == '1-1') & (matches['quality'] == 'good')].copy()
    assert not fp.empty, "No 1–1 good pairs found. Run the QC cell to define `final_pairs`."

assert 'distance_um' in fp.columns, "`final_pairs` must include 'distance_um'."

# Distances of final 1–1 good pairs
d_final = fp['distance_um'].to_numpy(dtype=float)
n_final = d_final.size
if n_final == 0:
    raise RuntimeError("No final 1–1 good pairs to plot.")

# Fixed bin edges (computed once)
bins_count = max(20, min(200, int(np.sqrt(max(1, n_final)))))
edges = np.histogram_bin_edges(d_final, bins=bins_count)

# Fixed Y-axis (based on all final pairs)
counts_all, _ = np.histogram(d_final, bins=edges)
y_max = max(int(counts_all.max() * 1.10), 1)

# Slider defaults (ensure max ≥ 10)
thr_suggest = float(np.nanpercentile(d_final, 75)) if n_final else float(globals().get('MAX_DISTANCE_UM', 5.0))
pct99 = float(np.nanpercentile(d_final, 99)) if n_final else float(globals().get('MAX_DISTANCE_UM', 10.0))
slider_max = max(10.0, pct99)
slider_value = min(thr_suggest, slider_max)

def _plot(threshold_um=slider_value, update_globals=False):
    # Threshold subset
    accepted_mask = d_final <= float(threshold_um)
    accepted_dists = d_final[accepted_mask]
    accepted = int(accepted_mask.sum())
    frac = (accepted / n_final) if n_final else 0.0

    # Plot with fixed edges and fixed Y
    plt.figure(figsize=(6,4))
    plt.hist(d_final, bins=edges, color='lightgray', alpha=0.7, label=f'final 1–1 good (n={n_final})')
    if accepted > 0:
        plt.hist(accepted_dists, bins=edges, color='steelblue', alpha=0.85,
                 label=f'≤ {threshold_um:.2f} µm (n={accepted}, {frac:.1%})')
    plt.axvline(threshold_um, color='crimson', linestyle='--', linewidth=1.5,
                label=f'Threshold = {threshold_um:.2f} µm')
    plt.ylim(0, y_max)
    plt.xlabel('distance (µm)'); plt.ylabel('count'); plt.legend()
    plt.title('Final 1–1 good pairs vs distance threshold')
    plt.tight_layout(); plt.show()

    print(f"Accepted final pairs: {accepted} / {n_final}  ({frac:.1%})")
    if accepted > 0:
        print(f"Accepted stats — median: {np.median(accepted_dists):.3f} µm | "
              f"p90: {np.percentile(accepted_dists, 90):.3f} µm | max: {np.max(accepted_dists):.3f} µm")

    # Optional: publish mask back to original `matches` for reuse
    if update_globals and 'matches' in globals():
        key_all = set(map(tuple, fp[['conf_label','twoP_label']].to_numpy()))
        key_acc = set(map(tuple, fp.loc[accepted_mask, ['conf_label','twoP_label']].to_numpy()))
        full_mask = matches[['conf_label','twoP_label']].apply(tuple, axis=1).isin(key_acc).to_numpy()

        globals()['MAX_DISTANCE_UM'] = float(threshold_um)
        globals()['accepted_dists'] = accepted_dists
        globals()['distances_filtered'] = accepted_dists
        globals()['within_gate_mask'] = full_mask
        globals()['valid_mask'] = full_mask
        print("Published: MAX_DISTANCE_UM, accepted_dists, distances_filtered, within_gate_mask, valid_mask")

interact(
    _plot,
    threshold_um=FloatSlider(min=0.0, max=slider_max, step=0.1, value=slider_value,
                             description='MAX_DISTANCE_UM'),
    update_globals=Checkbox(value=False, description='publish vars')
)
