# code demo for both calculate_close_axon_dend_contacts.py and visualize_close_axon_dend_contacts.py

# Close Axon–Dendrite Contacts: Calculation + Visualization (Demo Notebook)

This notebook shows an end-to-end workflow to:
1. **Build/annotate two neurons** (axon/dendrite labels per skeleton vertex).
2. **Find close axon↔dendrite vertex pairs** within a distance threshold.
3. **Cluster matches into regions** using **skeleton connectivity on both neurons**.
4. **Visualize** results in Neuroglancer via **one line per region midpoint** (both directions on the same link).

> **Notes**
> - Keep your dataset + credentials ready for `caveclient`/`pcg_skel` usage.
> - Adjust the configuration cell to your own datastack + IDs.
> - The logic mirrors the scripts we prepared—just packaged for a sharable notebook.


## 1. Configuration & Imports


In [1]:
import numpy as np
import caveclient
import pcg_skel
import skeleton_plot as skelplot
from scipy.spatial import cKDTree
from nglui.statebuilder import ViewerState, ImageLayer, SegmentationLayer
from heapq import heappush, heappop
from collections import deque, defaultdict

# ---- Edit these values for your run ----
DATASTACK = "jchen_mouse_cortex"
ROOT_ID1 = 720575941071680793
SOMA1 = [168103, 158705, 2257]
ROOT_ID2 = 720575941057622500
SOMA2 = [170885, 158942, 3651]
ROOT_RESOLUTION = (7.5, 7.5, 50)  # nm/px
COLLAPSE_RADIUS = 7500
AXON_QUALITY_THRESH = 0.2
DIST_THRESHOLD_NM = 5000.0
MAX_PAIRS_PER_DIR = 5000

client = caveclient.CAVEclient(DATASTACK)
print('Connected to datastack:', DATASTACK)


Connected to datastack: jchen_mouse_cortex


## 2. Small Utilities


In [2]:
def nm_to_px(x_nm, res=ROOT_RESOLUTION):
    x_nm = np.asarray(x_nm, dtype=float)
    return x_nm / np.asarray(res, dtype=float)

def px_to_nm(x_px, res=ROOT_RESOLUTION):
    x_px = np.asarray(x_px, dtype=float)
    return x_px * np.asarray(res, dtype=float)


## 3. Project Mesh Annotations → Skeleton Compartments


In [3]:
def build_skel_compartments_from_mesh_annos(
    nrn,
    skel,
    axon_anno="is_axon",
    apical_anno=None,
    basal_anno=None,
    default_non_axon=3,
):
    """
    Project mesh-level annotations to a full per-skeleton-vertex label array.
    Labels: 2='axon', 3='basal', 4='apical', or default_non_axon elsewhere.
    Also writes skel.vertex_properties['compartment'].
    """
    def _mesh_bool_from_anno(name):
        if name is None or not hasattr(nrn.anno, name):
            return None
        df = getattr(nrn.anno, name).df  # expects 'mesh_index' or 'mesh_ind'
        if "mesh_index" not in df.columns:
            col = "mesh_ind" if "mesh_ind" in df.columns else None
            if col is None:
                raise ValueError(f"{name} has no mesh_index/mesh_ind column.")
            idxs = df[col].to_numpy()
        else:
            idxs = df["mesh_index"].to_numpy()
        return set(np.asarray(idxs, dtype=np.int64))

    axon_mesh = _mesh_bool_from_anno(axon_anno)
    apical_mesh = _mesh_bool_from_anno(apical_anno) if apical_anno else None
    basal_mesh = _mesh_bool_from_anno(basal_anno) if basal_anno else None

    if not hasattr(skel, "mesh_index") or skel.mesh_index is None:
        raise ValueError("skel.mesh_index is required.")
    sk_mi = np.asarray(skel.mesh_index, dtype=np.int64)

    comp = np.full(len(sk_mi), default_non_axon, dtype=int)
    if axon_mesh:
        comp[np.isin(sk_mi, list(axon_mesh))] = 2
    if apical_mesh:
        comp[np.isin(sk_mi, list(apical_mesh))] = 4
    if basal_mesh:
        comp[np.isin(sk_mi, list(basal_mesh))] = 3

    comp = np.where(comp == None, default_non_axon, comp)  # noqa: E711
    skel.vertex_properties["compartment"] = comp
    return comp


## 4. Build Two Neurons & Annotate Axon


In [4]:
def build_two_neurons_with_is_axon(
    root_id1, soma_location1,
    root_id2, soma_location2,
    client,
    root_resolution=ROOT_RESOLUTION,
    collapse_radius=COLLAPSE_RADIUS,
    threshold_quality=AXON_QUALITY_THRESH,
):
    skel1, mesh1, (l2_to_skel1, skel_to_l21) = pcg_skel.pcg_skeleton(
        root_id1, client, return_mesh=True,
        root_point=soma_location1, root_point_resolution=root_resolution,
        collapse_soma=True, collapse_radius=collapse_radius,
        return_l2dict=True,
    )
    nrn1 = pcg_skel.pcg_meshwork(
        root_id=root_id1, client=client,
        root_point=soma_location1, root_point_resolution=root_resolution,
        collapse_soma=True, collapse_radius=collapse_radius,
        synapses=True,
    )
    pcg_skel.features.add_synapse_count(nrn1)
    pcg_skel.features.add_is_axon_annotation(
        nrn1, pre_anno="pre_syn", post_anno="post_syn",
        annotation_name="is_axon", return_quality=True,
        threshold_quality=threshold_quality,
    )

    skel2, mesh2, (l2_to_skel2, skel_to_l22) = pcg_skel.pcg_skeleton(
        root_id2, client, return_mesh=True,
        root_point=soma_location2, root_point_resolution=root_resolution,
        collapse_soma=True, collapse_radius=collapse_radius,
        return_l2dict=True,
    )
    nrn2 = pcg_skel.pcg_meshwork(
        root_id=root_id2, client=client,
        root_point=soma_location2, root_point_resolution=root_resolution,
        collapse_soma=True, collapse_radius=collapse_radius,
        synapses=True,
    )
    pcg_skel.features.add_synapse_count(nrn2)
    pcg_skel.features.add_is_axon_annotation(
        nrn2, pre_anno="pre_syn", post_anno="post_syn",
        annotation_name="is_axon", return_quality=True,
        threshold_quality=threshold_quality,
    )

    comp1 = build_skel_compartments_from_mesh_annos(nrn1, skel1, axon_anno="is_axon")
    comp2 = build_skel_compartments_from_mesh_annos(nrn2, skel2, axon_anno="is_axon")
    return (skel1, nrn1, comp1, skel2, nrn2, comp2)


## 5. Find Close Axon↔Dendrite Vertex Pairs


In [5]:
def _find_pairs_within_threshold(A_nm, B_nm, thr_nm, use_kdtree=True):
    """Return (pairs, dists) for points within <= thr_nm (nm)."""
    if A_nm.size == 0 or B_nm.size == 0:
        return np.empty((0,2), dtype=int), np.empty((0,), dtype=float)

    if use_kdtree:
        try:
            tree = cKDTree(B_nm)
            neigh = tree.query_ball_point(A_nm, r=thr_nm)
            pairs, dists = [], []
            for i, js in enumerate(neigh):
                if not js: continue
                a = A_nm[i]
                bsel = B_nm[np.asarray(js, dtype=int)]
                ds = np.linalg.norm(bsel - a, axis=1)
                dists.extend(ds.tolist())
                gi = np.full(len(js), i, dtype=int)
                gj = np.asarray(js, dtype=int)
                pairs.append(np.stack([gi, gj], axis=1))
            if pairs:
                pairs = np.concatenate(pairs, axis=0)
                dists = np.asarray(dists, dtype=float)
            else:
                pairs = np.empty((0,2), dtype=int)
                dists = np.empty((0,), dtype=float)
            return pairs, dists
        except Exception:
            pass

    thr2 = float(thr_nm) ** 2
    A2 = np.sum(A_nm*A_nm, axis=1, keepdims=True)
    B2 = np.sum(B_nm*B_nm, axis=1, keepdims=True).T
    AB = A_nm @ B_nm.T
    dist2 = A2 + B2 - 2.0 * AB
    np.maximum(dist2, 0.0, out=dist2)
    ai, bj = np.where(dist2 <= thr2)
    if ai.size == 0:
        return np.empty((0,2), dtype=int), np.empty((0,), dtype=float)
    pairs = np.stack([ai, bj], axis=1)
    dists = np.sqrt(dist2[ai, bj], dtype=float)
    return pairs, dists


def compute_close_pairs_and_meta(
    root_id1, verts1_nm, comp1,
    root_id2, verts2_nm, comp2,
    threshold_nm=DIST_THRESHOLD_NM,
    root_resolution=ROOT_RESOLUTION,
    max_pairs_per_dir=MAX_PAIRS_PER_DIR,
    dendrite_labels=(3, 4),
):
    axon_label = 2
    verts1_nm = np.asarray(verts1_nm, dtype=float)
    verts2_nm = np.asarray(verts2_nm, dtype=float)
    comp1 = np.asarray(comp1); comp2 = np.asarray(comp2)

    mask1_ax = (comp1 == axon_label); mask1_dn = np.isin(comp1, dendrite_labels)
    mask2_ax = (comp2 == axon_label); mask2_dn = np.isin(comp2, dendrite_labels)

    v1_ax_nm = verts1_nm[mask1_ax]; v1_dn_nm = verts1_nm[mask1_dn]
    v2_ax_nm = verts2_nm[mask2_ax]; v2_dn_nm = verts2_nm[mask2_dn]

    pairs_1ax_2dn, dists_12 = _find_pairs_within_threshold(v1_ax_nm, v2_dn_nm, threshold_nm)
    pairs_2ax_1dn, dists_21 = _find_pairs_within_threshold(v2_ax_nm, v1_dn_nm, threshold_nm)

    def _cap(p, d):
        if p.shape[0] <= max_pairs_per_dir: return p, d
        idx = np.linspace(0, p.shape[0]-1, max_pairs_per_dir, dtype=int)
        return p[idx], (d[idx] if d is not None and d.size else d)

    pairs_1ax_2dn, dists_12 = _cap(pairs_1ax_2dn, dists_12)
    pairs_2ax_1dn, dists_21 = _cap(pairs_2ax_1dn, dists_21)

    idx1_all = np.arange(verts1_nm.shape[0]); idx2_all = np.arange(verts2_nm.shape[0])
    idx1_ax = idx1_all[mask1_ax]; idx1_dn = idx1_all[mask1_dn]
    idx2_ax = idx2_all[mask2_ax]; idx2_dn = idx2_all[mask2_dn]

    global_pairs_12 = np.column_stack([idx1_ax[pairs_1ax_2dn[:,0]], idx2_dn[pairs_1ax_2dn[:,1]]]).astype(int)
    global_pairs_21 = np.column_stack([idx2_ax[pairs_2ax_1dn[:,0]], idx1_dn[pairs_2ax_1dn[:,1]]]).astype(int)

    meta = {
        "n_pairs_1axon_2dend": int(global_pairs_12.shape[0]),
        "n_pairs_2axon_1dend": int(global_pairs_21.shape[0]),
        "threshold_nm": float(threshold_nm),
        "resolution_nm_per_pixel": tuple(float(x) for x in root_resolution),
    }
    return meta, global_pairs_12, global_pairs_21, dists_12, dists_21


## 6. Cluster Pairs into Regions Using Skeleton Connectivity (Both Sides)


In [6]:
def build_adj_from_edges(edges):
    adj = defaultdict(list)
    for u, v in np.asarray(edges, dtype=int):
        adj[u].append(v)
        adj[v].append(u)
    return adj

def k_hop_neighborhoods(adj, seeds, k=1):
    seeds = list(map(int, set(seeds)))
    out = {s: set([s]) for s in seeds}
    for s in seeds:
        if k <= 0: continue
        seen = {s}
        q = deque([(s, 0)])
        while q:
            u, d = q.popleft()
            if d == k: continue
            for w in adj.get(u, []):
                if w not in seen:
                    seen.add(w); out[s].add(w); q.append((w, d+1))
    return out

def dijkstra_subset_weighted(adj, coords_nm, start, allowed):
    allowed = set(allowed)
    if start not in allowed: return {}
    dist = {start: 0.0}; pq = [(0.0, start)]
    while pq:
        du, u = heappop(pq)
        if du > dist[u]: continue
        for v in adj.get(u, []):
            if v not in allowed: continue
            w = float(np.linalg.norm(coords_nm[u] - coords_nm[v]))
            alt = du + w
            if v not in dist or alt < dist[v]:
                dist[v] = alt; heappush(pq, (alt, v))
    return dist

def induced_geodesic_diameter(adj, coords_nm, nodes_subset):
    nodes = list(set(map(int, nodes_subset)))
    if len(nodes) <= 1:
        return 0.0, (nodes[0] if nodes else None, nodes[0] if nodes else None)
    best_len = 0.0; best_pair = (nodes[0], nodes[0]); allowed = set(nodes)
    for s in nodes:
        dist = dijkstra_subset_weighted(adj, coords_nm, s, allowed)
        if not dist: continue
        v = max(dist, key=dist.get)
        if dist[v] > best_len:
            best_len = dist[v]; best_pair = (s, v)
    return float(best_len), best_pair

def cluster_pairs_by_skeleton_connectivity(
    pairs, edgesA, edgesB, coordsA_nm, coordsB_nm, hop_tol_A=1, hop_tol_B=1
):
    pairs = np.asarray(pairs, dtype=int)
    if pairs.size == 0: return []

    adjA = build_adj_from_edges(edgesA); adjB = build_adj_from_edges(edgesB)
    A_verts = pairs[:,0]; B_verts = pairs[:,1]
    A_khop = k_hop_neighborhoods(adjA, A_verts, k=hop_tol_A)
    B_khop = k_hop_neighborhoods(adjB, B_verts, k=hop_tol_B)

    M = pairs.shape[0]; pair_adj = [[] for _ in range(M)]
    for i in range(M):
        a_i, b_i = int(pairs[i,0]), int(pairs[i,1])
        A_near = A_khop.get(a_i, {a_i}); B_near = B_khop.get(b_i, {b_i})
        for j in range(i+1, M):
            a_j, b_j = int(pairs[j,0]), int(pairs[j,1])
            if (a_j in A_near) and (b_j in B_near):
                pair_adj[i].append(j); pair_adj[j].append(i)

    seen = np.zeros(M, dtype=bool); comps = []
    for i in range(M):
        if seen[i]: continue
        stack=[i]; seen[i]=True; comp=[i]
        while stack:
            u=stack.pop()
            for v in pair_adj[u]:
                if not seen[v]:
                    seen[v]=True; stack.append(v); comp.append(v)
        comps.append(np.array(comp, dtype=int))

    regions = []
    for rows in comps:
        A_nodes = np.unique(pairs[rows,0])
        B_nodes = np.unique(pairs[rows,1])
        A_diam_nm, (Au, Av) = induced_geodesic_diameter(adjA, coordsA_nm, A_nodes)
        B_diam_nm, (Bu, Bv) = induced_geodesic_diameter(adjB, coordsB_nm, B_nodes)
        regions.append({
            "pair_rows": rows,
            "A_vertices": A_nodes,
            "B_vertices": B_nodes,
            "A_diameter_nm": float(A_diam_nm),
            "B_diameter_nm": float(B_diam_nm),
            "overlap_distance_nm": float(min(A_diam_nm, B_diam_nm)),
            "A_endpoints": (int(Au) if Au is not None else None,
                            int(Av) if Av is not None else None),
            "B_endpoints": (int(Bu) if Bu is not None else None,
                            int(Bv) if Bv is not None else None),
            "n_pairs": int(len(rows)),
            "n_A_vertices": int(len(A_nodes)),
            "n_B_vertices": int(len(B_nodes)),
        })
    regions.sort(key=lambda r: r["overlap_distance_nm"], reverse=True)
    return regions


## 7. Visualize Regions as Midpoint Lines (Single Neuroglancer Link)


In [7]:
def regions_midpoint_lines(regions, coordsA_nm, coordsB_nm, root_resolution, method='centroid'):
    A_px, B_px = [], []
    for r in regions:
        if method == 'centroid':
            A_mid_nm = coordsA_nm[np.asarray(r["A_vertices"], dtype=int)].mean(axis=0)
            B_mid_nm = coordsB_nm[np.asarray(r["B_vertices"], dtype=int)].mean(axis=0)
        elif method == 'endpoints':
            Au, Av = r.get("A_endpoints", (None, None))
            Bu, Bv = r.get("B_endpoints", (None, None))
            if None in (Au, Av, Bu, Bv):
                A_mid_nm = coordsA_nm[np.asarray(r["A_vertices"], dtype=int)].mean(axis=0)
                B_mid_nm = coordsB_nm[np.asarray(r["B_vertices"], dtype=int)].mean(axis=0)
            else:
                A_mid_nm = 0.5*(coordsA_nm[int(Au)] + coordsA_nm[int(Av)])
                B_mid_nm = 0.5*(coordsB_nm[int(Bu)] + coordsB_nm[int(Bv)])
        else:
            raise ValueError("method must be 'centroid' or 'endpoints'")
        A_px.append(tuple(nm_to_px(A_mid_nm, root_resolution).tolist()))
        B_px.append(tuple(nm_to_px(B_mid_nm, root_resolution).tolist()))
    return A_px, B_px

def build_combined_midpoint_link(
    root_id1, root_id2,
    regions_12, regions_21,
    coords1_nm, coords2_nm,
    root_resolution,
    method='centroid',
    IMAGE_SOURCE_URL='precomputed://gs://zetta_jchen_mouse_cortex_001_alignment/img',
    SEG_SOURCE_URL='graphene://middleauth+https://cave.fanc-fly.com/segmentation/table/jchen_mouse_cortex/',
    client=None,
    color_12='gold',
    color_21='deepskyblue'
):
    A12_mid, B12_mid = regions_midpoint_lines(regions_12, coords1_nm, coords2_nm, root_resolution, method)
    A21_mid, B21_mid = regions_midpoint_lines(regions_21, coords2_nm, coords1_nm, root_resolution, method)

    v = (
        ViewerState()
        .add_layer(ImageLayer(source=IMAGE_SOURCE_URL))
        .add_layer(
            SegmentationLayer()
            .add_source(SEG_SOURCE_URL)
            .add_segments([int(root_id1), int(root_id2)])
        )
    )
    v = v.add_lines(
        name=f"{root_id1} → {root_id2} region midpoints ({method})",
        point_a_column=A12_mid, point_b_column=B12_mid, color=color_12
    )
    v = v.add_lines(
        name=f"{root_id2} → {root_id1} region midpoints ({method})",
        point_a_column=A21_mid, point_b_column=B21_mid, color=color_21
    )
    return v.to_link_shortener(client=client) if client is not None else v.to_url()


## 8. Run End-to-End


In [8]:
# Build neurons + compartments
skel1, nrn1, comp1, skel2, nrn2, comp2 = build_two_neurons_with_is_axon(
    ROOT_ID1, SOMA1, ROOT_ID2, SOMA2, client,
    root_resolution=ROOT_RESOLUTION,
    collapse_radius=COLLAPSE_RADIUS,
    threshold_quality=AXON_QUALITY_THRESH,
)

# Find close pairs
meta, gp12, gp21, d12, d21 = compute_close_pairs_and_meta(
    ROOT_ID1, skel1.vertices, comp1,
    ROOT_ID2, skel2.vertices, comp2,
    threshold_nm=DIST_THRESHOLD_NM,
    root_resolution=ROOT_RESOLUTION,
    max_pairs_per_dir=MAX_PAIRS_PER_DIR
)
print('Pairs 1→2:', meta['n_pairs_1axon_2dend'], '| Pairs 2→1:', meta['n_pairs_2axon_1dend'])

# Cluster into regions using skeleton connectivity on both sides
regions_12 = cluster_pairs_by_skeleton_connectivity(
    gp12, skel1.edges, skel2.edges, skel1.vertices, skel2.vertices, hop_tol_A=1, hop_tol_B=1
)
regions_21 = cluster_pairs_by_skeleton_connectivity(
    gp21, skel2.edges, skel1.edges, skel2.vertices, skel1.vertices, hop_tol_A=1, hop_tol_B=1
)
print(f"Regions 1→2: {len(regions_12)} | Regions 2→1: {len(regions_21)}")




Pairs 1→2: 28 | Pairs 2→1: 13
Regions 1→2: 8 | Regions 2→1: 2


## 9. Build a Single Neuroglancer Link for Both Directions


In [9]:
# Optional: short URLs if you pass a CAVEclient (already available as `client`)
link_both = build_combined_midpoint_link(
    ROOT_ID1, ROOT_ID2,
    regions_12, regions_21,
    coords1_nm=skel1.vertices, coords2_nm=skel2.vertices,
    root_resolution=ROOT_RESOLUTION,
    method='centroid',
    client=client
)
print('Neuroglancer link (both directions):')
print(link_both)


0
Neuroglancer link (both directions):
https://spelunker.cave-explorer.org/#!middleauth+https://global.daf-apis.com/nglstate/api/v1/5399438355857408


## 10. (Optional) Save a Small Artifact Bundle


In [None]:
OUT_PATH = "close_contacts_artifacts_demo.npz"
np.savez(
    OUT_PATH,
    root_id1=np.array([ROOT_ID1], dtype=np.int64),
    root_id2=np.array([ROOT_ID2], dtype=np.int64),
    root_resolution=np.array(ROOT_RESOLUTION, dtype=float),
    skel1_vertices=skel1.vertices, skel2_vertices=skel2.vertices,
    skel1_edges=skel1.edges, skel2_edges=skel2.edges,
    global_pairs_1axon_2dend=gp12, global_pairs_2axon_1dend=gp21,
    dists_1axon_2dend_nm=d12, dists_2axon_1dend_nm=d21
)
print(f"[saved] {OUT_PATH}")
