# Prospective Synapses

This notebook analyzes all the contact points (already found by a script such as find-contacts.py)
within a small volume, and for each one, constructs a line segment representing a possible synapse.
These are written in NG state format, so that in Neuroglancer, a domain expert can review each one
and mark it as Synapse, Not Synapse, or Unclear (using the Description field).

See the ContactPathAnalysis notebook for a more detailed breakdown of the steps in constructing a single line bisecting a given contact.

In [None]:
# Imports
from zetta_utils.layer.volumetric.cloudvol import build_cv_layer
from zetta_utils.geometry import Vec3D
import cc3d
from collections import deque
import numpy as np
import zetta_utils.tensor_ops.convert as convert
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from skimage.morphology import skeletonize
import time

In [None]:
# Constants
CONTACT_SEG_PATH = "gs://tmp_2w/joe/concact-20240816"
RESOLUTION = (16, 16, 42)  # (nm)
BOUNDS_START = (13995, 10633, 3063)  # (voxels)
BOUNDS_END = (14123, 10761, 3103)
CELL_SEG_PATH = "gs://dkronauer-ant-001-kisuk/test/240507-finetune-z1400-3400/seg"
PRESYN_CELL_ID = 74170031222807106

In [None]:
# Load and extract the 3D volume of the contact segmentation layer
cvl = build_cv_layer(path = CONTACT_SEG_PATH)
data = cvl[RESOLUTION, BOUNDS_START[0]:BOUNDS_END[0], BOUNDS_START[1]:BOUNDS_END[1], BOUNDS_START[2]:BOUNDS_END[2]]
data = data[0] # (ignore dimension 0, channel)
print(f'Contacts: Loaded {data.shape} image of type {data.dtype}, with {len(np.unique(data))} unique values')

In [None]:
# And load the cell segmentation layer.
cell_vol = build_cv_layer(path = CELL_SEG_PATH)
cell_data = cell_vol[RESOLUTION, BOUNDS_START[0]:BOUNDS_END[0], BOUNDS_START[1]:BOUNDS_END[1], BOUNDS_START[2]:BOUNDS_END[2]]
cell_data = cell_data[0] # (ignore dimension 0, channel)
print(f'Cells: Loaded {cell_data.shape} image of type {data.dtype}, with {len(np.unique(cell_data))} unique values')

In [None]:
# Define a function to find the Z value that contains the most of a given contact ID.
def z_of_max_count(array: np.ndarray, value: int) -> int:
    counts = np.sum(array == value, axis=(0, 1))
    return np.argmax(counts)

In [None]:
def find_endpoints(skeleton):
    endpoints = []
    # Define the 8-connectivity structure
    struct = np.array([[1,1,1],
                       [1,0,1],
                       [1,1,1]])
    
    for i in range(1, skeleton.shape[0] - 1):
        for j in range(1, skeleton.shape[1] - 1):
            if skeleton[i, j] == 1:
                # Count neighbors
                neighbors = np.sum(skeleton[i-1:i+2, j-1:j+2] * struct)
                if neighbors == 1:
                    endpoints.append((i, j))
    if not endpoints:
        # No endpoints?  Must be a loop!  Return an arbitrary point
        skel_points = np.argwhere(skeleton)
        if len(skel_points) < 3:
            return None
        i = len(skel_points) // 2
        endpoints = [tuple(skel_points[i]), tuple(skel_points[i+1])]
    return endpoints

def trace_path(skeleton, start):
    path = []
    queue = deque([start])
    visited = set()
    
    while queue:
        current = queue.popleft()
        if current in visited:
            continue
        visited.add(current)
        path.append(current)
        
        i, j = current
        # Look at all 8 neighbors
        for ni in range(i-1, i+2):
            for nj in range(j-1, j+2):
                if (ni, nj) != (i, j) and skeleton[ni, nj] == 1 and (ni, nj) not in visited:
                    queue.append((ni, nj))
                    break # Found the next step in the path
    return path

def estimate_slope(path, midpoint, window=3):
    mid_index = path.index(midpoint)
    
    # Define indices for a small segment around the midpoint
    start_index = max(0, mid_index - window)
    end_index = min(len(path) - 1, mid_index + window)
    
    # Coordinates of the start and end points of the segment
    start_point = path[start_index]
    end_point = path[end_index]
    
    # Calculate the vector (dx, dy)
    dx = end_point[0] - start_point[0]
    dy = end_point[1] - start_point[1]
    
    # Normalize the vector to avoid scaling issues
    length = np.hypot(dx, dy)
    if length != 0:
        dx /= length
        dy /= length
    
    return dx, dy

def estimate_avg_slope(path, midpoint):
    dx_sum, dy_sum = 0, 0
    count = 0
    for window in range(1,5):
        dx, dy = estimate_slope(path, midpoint, window)
        dx_sum += dx
        dy_sum += dy
        count += 1
    return dx_sum / count, dy_sum / count

def calc_bisection(data2D, contact_id):
    contact_data = (data2D == contact_id).astype(bool)
    skeleton = skeletonize(contact_data)
    endpoints = find_endpoints(skeleton)
    if not endpoints:
        return None
    path = trace_path(skeleton, endpoints[0])
    midpoint = path[len(path) // 2]
    dx, dy = estimate_avg_slope(path, midpoint)
    hl = 3  # line half-length
    line = ((round(midpoint[0]-dy*hl), round(midpoint[1]+dx*hl)), 
            (round(midpoint[0]+dy*hl), round(midpoint[1]-dx*hl)), '')
    return line

In [None]:
# Define a function to return the 3D line for a given contact ID.
def compute_line(data3D: np.ndarray, cell_data: np.ndarray, contact_id: int):
    """
    Examine the 2D image at the given z (relative to our bounds), find the segment in the given
    array identified by contact_id, and calculate a short line segment that bisects that segment.
    Return this in the form ((x0,y0), (x1,y1), description).
    """
    best_z = z_of_max_count(data3D, contact_id)
    image = data3D[:, :, best_z]
    line2D = calc_bisection(image, contact_id)
    if not line2D:
        return None
    # Look up the cell ID on either end of the line as well, then reorder the points
    # so that the first point is the postsynaptic cell, and the second point is presynaptic.
    cell_img = cell_data[:, :, best_z]
    cell_A = cell_img[line2D[0]]
    cell_B = cell_img[line2D[1]]
    #print(f'Found putative synapse between {cell_A} and {cell_B}')
    if cell_A == PRESYN_CELL_ID:
        line2D = (line2D[1], line2D[0], line2D[2])
    line3D = (line2D[0] + (best_z,), line2D[1] + (best_z,), line2D[2])
    return line3D

In [None]:
# Now let's find the 3D lines for all our IDs.  Store them in a list,
# where each element is (id, start, end, description)
lines = []
for id in np.unique(data):
    if id == 0: continue
    line = compute_line(data, cell_data, id)
    if line:
        lines.append ((id,) + line)
    else:
        print(f"Couldn't find valid line for contact {id}")
print(f'Found {len(lines)} putative synapses, such as:')
print(lines[0])

In [None]:
# Print them in JSON format, suitable for pasting into Neuroglancer.
z_offset = 0.5  # (needed because NG actually displays 0.5 units off of where it claims)
print('"annotations": [')
for i, line in enumerate(lines):
    id, pos_A, pos_B, desc = line
    pos_A = [pos_A[0] + BOUNDS_START[0], pos_A[1] + BOUNDS_START[1],  pos_A[2] + BOUNDS_START[2] + z_offset]
    pos_B = [pos_B[0] + BOUNDS_START[0], pos_B[1] + BOUNDS_START[1],  pos_B[2] + BOUNDS_START[2] + z_offset]
    print('{')
    print(f'"pointA": {pos_A},')
    print(f'"pointB": {pos_B},')
    print('"type": "line",')
    if desc:
        print(f'"description": "{desc}",')
    print(f'"id": "{id}"')
    print('},' if i < len(lines)-1 else '}')
print("],")