In [1]:
import json
import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import tqdm
import os
import sys
from skimage import morphology, measure
import tifffile
from matplotlib.backends.backend_agg import FigureCanvasAgg


In [2]:
def orient_all(last_pt, spline_dict, constrain_frame=None, constrain_nose=None):
    out_dict = {}
    for i in range(len(spline_dict)):
        data = spline_dict[i]
        data_arr = np.array(data)
        # no spline sometimes if tracking failed and/or segmentation failed
        if len(data_arr) == 0:
            out_dict[i] = data
            continue
        # apply constraint if specified
        if constrain_frame is not None and i == constrain_frame:
            last_pt = np.array(constrain_nose)
        # determine orientation based on distance to last_pt
        dist_unflipped = np.linalg.norm(data_arr[0] - last_pt)
        dist_flip = np.linalg.norm(data_arr[-1] - last_pt)
        if dist_flip < dist_unflipped:
            data = data[::-1]
        # Truncate first to ensure consistency
        data = data[:350]
        # Update last_pt with the first point of the truncated, oriented spline
        last_pt = np.array(data[0])
        out_dict[i] = data
    return out_dict

def visualize_frame(seg, nodes, spline_dilation=4):
    out = np.zeros(seg.shape, np.bool)
    for node in nodes:
        out[node] = True
    # out = morphology.isotropic_dilation(out, spline_dilation)
    return np.logical_and(seg, np.logical_not(out))


In [3]:

pth = r'C:\Users\munib\POSTDOC\DATA\g5ht-free\20251028\date-20251028_time-1500_strain-ISg5HT_condition-starvedpatch_worm002'
label = tifffile.imread(os.path.join(pth,'label.tif'))
#reads spline
with open(os.path.join(pth,'spline.json'), 'r') as f:
    spline_dict = json.load(f)
spline_dict = {int(k): v for k, v in spline_dict.items()}

In [None]:
label.shape # FHW

In [None]:
# spline_dict keys are frame number (0, 1, ...)
# spline dict values are nested list, where spline_dict[0][0] is the 0th frames 0th point's x, and y coordinates
spline_dict[0][0]

In [None]:
n_points = []
for i in range(len(spline_dict)):
    n_points.append(len(spline_dict[i]))

In [None]:
max_n = max(n_points)

In [None]:
# for each frame, determine the width of the skeleton, defined by label, by using spline
%matplotlib qt
plt.close('all')

frame = 0
label_frame = label[frame] # HW
spline_frame = spline_dict[frame]

# plot the label and the spline for this frame
plt.imshow(label_frame, cmap='gray')
spline_frame_arr = np.array(spline_frame).reshape(-1, 2) # Nx2
# need to swap x and y for plotting since spline is in xy and label is in hw
plt.plot(spline_frame_arr[:, 1], spline_frame_arr[:, 0], 'r-')
plt.show()

# for each spline point, determine the width of the skeleton by:
# 1. drawing a line perpendicular to the spline at that point
# 2. finding the intersection of that line with the label
# 3. calculating the distance between the two intersection points, which gives the width of the skeleton at that point
# 4. plotting labels and the perpendicular line for visualization
widths = []
for i in range(len(spline_frame)):
# for i in range(500,450, -1):
    point = spline_frame[i]
    if i == 0:
        next_point = spline_frame[i + 1]
        tangent = np.array(next_point) - np.array(point)
    elif i == len(spline_frame) - 1:
        prev_point = spline_frame[i - 1]
        tangent = np.array(point) - np.array(prev_point)
    else:
        next_point = spline_frame[i + 1]
        prev_point = spline_frame[i - 1]
        tangent = np.array(next_point) - np.array(prev_point)
    # get perpendicular vector
    perp_vector = np.array([-tangent[1], tangent[0]])
    perp_vector = perp_vector / np.linalg.norm(perp_vector) # normalize
    # define line endpoints for visualization
    line_length = 50 # length of the perpendicular line for visualization
    line_start = np.array(point) - perp_vector * line_length
    line_end = np.array(point) + perp_vector * line_length

    # determine width of perpendicular line by finding intersection with label
    # sample points along the perpendicular line and check if they intersect with the label
    num_samples = 100
    # extend line in both directions until it reaches the label boundary or image boundary
    line_points = np.linspace(line_start, line_end, num_samples)
    line_points_extended = np.concatenate([line_points, np.linspace(line_start - perp_vector * line_length, line_start, num_samples)])
    line_points_extended = np.concatenate([line_points_extended, np.linspace(line_end, line_end + perp_vector * line_length, num_samples)])
    line_points_hw = line_points_extended[:, ::-1] # convert xy to hw for indexing
    line_points_hw = line_points_hw.astype(int)
    # filter out points that are outside the image boundaries    
    line_points_hw = line_points_hw[(line_points_hw[:, 0] >= 0) & (line_points_hw[:, 0] < label_frame.shape[0]) & (line_points_hw[:, 1] >= 0) & (line_points_hw[:, 1] < label_frame.shape[1])]
    # check which points intersect with the label
    intersecting_points = line_points_hw[label_frame[line_points_hw[:, 0], line_points_hw[:, 1]] > 0]
    if len(intersecting_points) >= 2:
        # calculate width as distance between first and last intersecting points
        width = np.linalg.norm(intersecting_points[0] - intersecting_points[-1])
    else:
        width = 0
    widths.append(width)
    
    

    # plot the label and the perpendicular line for visualization
    plt.imshow(label_frame, cmap='gray')
    # plot every 10th point's perpendicular line for visualization
    if i % 20 == 0:
        # plt.plot([line_start[1], line_end[1]], [line_start[0], line_end[0]], 'g-', linewidth=1)
        # plot line_points_hw for visualization
        plt.plot(line_points_hw[:, 0], line_points_hw[:, 1], 'g-', linewidth=1)
    if i==0:
        plt.plot(spline_frame_arr[:, 1], spline_frame_arr[:, 0], 'r-')
    plt.show()



plt.figure()
plt.plot(widths)
plt.title('Skeleton Width Along Spline')
plt.show()




In [None]:
widths = np.zeros((max_n, len(spline_dict)))

for frame in range(len(spline_dict)):
    label_frame = label[frame] # HW
    spline_frame = spline_dict[frame]

    for i in range(len(spline_frame)):
    # for i in range(500,450, -1):
        point = spline_frame[i]
        if i == 0:
            next_point = spline_frame[i + 1]
            tangent = np.array(next_point) - np.array(point)
        elif i == len(spline_frame) - 1:
            prev_point = spline_frame[i - 1]
            tangent = np.array(point) - np.array(prev_point)
        else:
            next_point = spline_frame[i + 1]
            prev_point = spline_frame[i - 1]
            tangent = np.array(next_point) - np.array(prev_point)
        # get perpendicular vector
        perp_vector = np.array([-tangent[1], tangent[0]])
        perp_vector = perp_vector / np.linalg.norm(perp_vector) # normalize
        # define line endpoints for visualization
        line_length = 50 # length of the perpendicular line for visualization
        line_start = np.array(point) - perp_vector * line_length
        line_end = np.array(point) + perp_vector * line_length

        # determine width of perpendicular line by finding intersection with label
        # sample points along the perpendicular line and check if they intersect with the label
        num_samples = 100
        # extend line in both directions until it reaches the label boundary or image boundary
        line_points = np.linspace(line_start, line_end, num_samples)
        line_points_extended = np.concatenate([line_points, np.linspace(line_start - perp_vector * line_length, line_start, num_samples)])
        line_points_extended = np.concatenate([line_points_extended, np.linspace(line_end, line_end + perp_vector * line_length, num_samples)])
        line_points_hw = line_points_extended[:, ::-1] # convert xy to hw for indexing
        line_points_hw = line_points_hw.astype(int)
        # filter out points that are outside the image boundaries    
        line_points_hw = line_points_hw[(line_points_hw[:, 0] >= 0) & (line_points_hw[:, 0] < label_frame.shape[0]) & (line_points_hw[:, 1] >= 0) & (line_points_hw[:, 1] < label_frame.shape[1])]
        # check which points intersect with the label
        intersecting_points = line_points_hw[label_frame[line_points_hw[:, 0], line_points_hw[:, 1]] > 0]
        if len(intersecting_points) >= 2:
            # calculate width as distance between first and last intersecting points
            width = np.linalg.norm(intersecting_points[0] - intersecting_points[-1])
        else:
            width = 0
        widths[i, frame] = width

In [None]:
%matplotlib qt
plt.figure()
plt.imshow(widths, aspect='auto', cmap='viridis')
plt.colorbar(label='Skeleton Width')
plt.title('Skeleton Width Along Spline Across Frames')
plt.xlabel('Frame')
plt.ylabel('Spline Point Index')
plt.show()

In [None]:
# applying a moving average filter to smooth the widths across frames for each spline point
window_size = 20
widths_smoothed = np.copy(widths)
for i in range(widths.shape[0]):
    widths_smoothed[i, :] = np.convolve(widths[i, :], np.ones(window_size)/window_size, mode='same')

plt.figure()
plt.imshow(widths_smoothed, aspect='auto', cmap='viridis')
plt.colorbar(label='Skeleton Width')
plt.title('Skeleton Width Along Spline Across Frames')
plt.xlabel('Frame')
plt.ylabel('Spline Point Index')
plt.show()

plt.figure()
plt.plot(np.mean(widths_smoothed, axis=1))
plt.title('Average Skeleton Width Along Spline')
plt.xlabel('Spline Point Index')
plt.ylabel('Average Width')
plt.show()

In [None]:
plt.figure()
plt.plot(widths_smoothed)
plt.title('Smoothed Skeleton Width Along Spline Across Frames')
plt.xlabel('Frame')
plt.ylabel('Skeleton Width')
plt.show()

In [None]:
widths_frame.shape

In [None]:
# color the spline points by their width for an example frame
frame = 0
label_frame = label[frame] # HW
spline_frame = spline_dict[frame]
spline_frame_arr = np.array(spline_frame).reshape(-1, 2) # Nx2
widths_frame = widths_smoothed[:, frame]
# plot the label and the spline for this frame, coloring the spline points by their width
plt.figure()
plt.imshow(label_frame, cmap='gray')
plt.scatter(spline_frame_arr[:, 1], spline_frame_arr[:, 0], c=widths_frame[0:len(spline_frame_arr[:,1])], cmap='viridis')
plt.colorbar(label='Skeleton Width')
plt.title(f'Skeleton Width Along Spline for Frame {frame}')
plt.show()

## Automated Nose Detection

**Goal**: Automatically determine which end of the spline is the nose vs the body-exit point, without manual annotation.

**Key constraint**: Only the head of the worm is in the frame. The body extends out of the field of view. So:
- The **nose** endpoint is interior to the image, and the mask terminates naturally around it.
- The **body-exit** endpoint is near the image border, and the mask is abruptly truncated by the edge.

### Approach — three complementary features

1. **Distance to image border**: The body-exit endpoint is closer to the edge of the image. The nose is more interior. Simple and strong.

2. **Mask border contact**: Count how many mask pixels touch the image border (row 0, row H-1, col 0, col W-1) near each endpoint. The body-exit end has many; the nose has few or none.

3. **Angular coverage**: Sample rays in all directions from each endpoint and measure what fraction have mask pixels nearby. The nose has high angular coverage (~360° surroundings are mask), while the body-exit endpoint only has mask on the body side, so coverage is lower.

Combined into a signed score per frame, then blended with temporal continuity in a two-pass algorithm.

In [4]:
from scipy.ndimage import gaussian_filter1d
from skimage.measure import find_contours


def border_distance_score(endpoint_yx, img_shape, margin=20):
    """
    How far the endpoint is from the nearest image border, normalized.
    Returns a value in [0, 1].  1 = deep interior, 0 = at the edge.
    """
    y, x = endpoint_yx[0], endpoint_yx[1]
    H, W = img_shape
    dist = min(y, x, H - 1 - y, W - 1 - x)
    max_dist = min(H, W) / 2.0
    return float(np.clip(dist / max_dist, 0, 1))


def mask_border_contact(endpoint_yx, mask, radius=40):
    """
    Count the fraction of mask pixels within `radius` of the endpoint that
    lie on the image border (row 0, row H-1, col 0, col W-1).
    Body-exit endpoint -> high contact.  Nose -> low/zero contact.
    Returns a value in [0, 1].
    """
    r, c = int(round(endpoint_yx[0])), int(round(endpoint_yx[1]))
    H, W = mask.shape
    r0, r1 = max(0, r - radius), min(H, r + radius + 1)
    c0, c1 = max(0, c - radius), min(W, c + radius + 1)
    patch = mask[r0:r1, c0:c1].astype(bool)
    rr, cc = np.mgrid[r0:r1, c0:c1]
    in_circle = ((rr - r)**2 + (cc - c)**2) <= radius**2
    mask_in_circle = patch & in_circle
    n_mask = mask_in_circle.sum()
    if n_mask == 0:
        return 0.0
    # which of those mask pixels are on the image border?
    on_border = ((rr == 0) | (rr == H - 1) | (cc == 0) | (cc == W - 1))
    n_border = (mask_in_circle & on_border).sum()
    return float(n_border) / float(n_mask)


def angular_coverage(endpoint_yx, mask, radius=20, n_rays=36):
    """
    Fraction of angular directions from the endpoint that hit mask within `radius`.
    Nose (interior, mask wraps around) -> high coverage.
    Body-exit (truncated by border) -> low coverage.
    Returns a value in [0, 1].
    """
    y, x = float(endpoint_yx[0]), float(endpoint_yx[1])
    H, W = mask.shape
    angles = np.linspace(0, 2 * np.pi, n_rays, endpoint=False)
    hits = 0
    for theta in angles:
        # sample along the ray
        for t in np.linspace(1, radius, int(radius)):
            ry = int(round(y + t * np.cos(theta)))
            rx = int(round(x + t * np.sin(theta)))
            if 0 <= ry < H and 0 <= rx < W:
                if mask[ry, rx]:
                    hits += 1
                    break
            else:
                # ray went out of image — don't count as hit
                break
    return float(hits) / float(n_rays)


def detect_nose_endpoint(spline_pts, mask, border_radius=40, ang_radius=20):
    """
    Determine which end of the spline is the nose (interior, natural termination)
    vs the body-exit (near border, truncated).

    Features (all designed for partial-body-in-frame scenario):
      1. Border distance: nose is farther from image edge
      2. Mask border contact: body-exit has mask touching image border
      3. Angular coverage: nose has mask wrapping around it

    Returns
    -------
    nose_idx : int
        0 if first point is nose, -1 if last point is nose.
    score : float
        Signed score (positive -> first point is nose).
    features : dict
        Per-feature values for diagnostics.
    """
    pts = np.asarray(spline_pts)
    if len(pts) < 3:
        return 0, 0.0, {}

    p0, p1 = pts[0], pts[-1]
    H, W = mask.shape

    # Feature 1: distance to image border (higher = more interior = nose-like)
    bd0 = border_distance_score(p0, (H, W))
    bd1 = border_distance_score(p1, (H, W))
    bd_denom = bd0 + bd1
    bd_score = (bd0 - bd1) / bd_denom if bd_denom > 1e-8 else 0.0

    # Feature 2: mask border contact (higher = more border contact = body-exit)
    bc0 = mask_border_contact(p0, mask, radius=border_radius)
    bc1 = mask_border_contact(p1, mask, radius=border_radius)
    bc_denom = bc0 + bc1
    # negative sign: more border contact = NOT nose
    bc_score = -(bc0 - bc1) / bc_denom if bc_denom > 1e-8 else 0.0

    # Feature 3: angular coverage (higher = mask wraps around = nose)
    ac0 = angular_coverage(p0, mask, radius=ang_radius)
    ac1 = angular_coverage(p1, mask, radius=ang_radius)
    ac_denom = ac0 + ac1
    ac_score = (ac0 - ac1) / ac_denom if ac_denom > 1e-8 else 0.0

    # Weighted combination (border contact is most discriminative)
    combined = 0.25 * bd_score + 0.45 * bc_score + 0.30 * ac_score

    feats = dict(
        bd_start=bd0, bd_end=bd1, bd_score=bd_score,
        bc_start=bc0, bc_end=bc1, bc_score=bc_score,
        ac_start=ac0, ac_end=ac1, ac_score=ac_score,
        combined=combined,
    )
    nose_idx = 0 if combined >= 0 else -1
    return nose_idx, combined, feats

In [14]:
# Test nose detection on a single frame
%matplotlib qt

frame = 0
mask_frame = label[frame].astype(bool)
spline_frame = spline_dict[frame]
pts = np.array(spline_frame)

nose_idx, score, feats = detect_nose_endpoint(pts, mask_frame)

print(f"Nose detection for frame {frame}:")
print(f"  Border dist  — start: {feats['bd_start']:.3f}, end: {feats['bd_end']:.3f}, score: {feats['bd_score']:.3f}")
print(f"  Border cont. — start: {feats['bc_start']:.3f}, end: {feats['bc_end']:.3f}, score: {feats['bc_score']:.3f}")
print(f"  Angular cov. — start: {feats['ac_start']:.3f}, end: {feats['ac_end']:.3f}, score: {feats['ac_score']:.3f}")
print(f"  Combined score: {score:.3f}")
print(f"  Nose is at: {'start (pts[0])' if nose_idx == 0 else 'end (pts[-1])'}")

fig, axes = plt.subplots(1, 3, figsize=(18, 6))
H, W = mask_frame.shape

# Panel 1: mask with spline endpoints and detected nose
axes[0].imshow(mask_frame, cmap='gray')
axes[0].plot(pts[:, 1], pts[:, 0], 'c-', linewidth=1, alpha=0.5)
axes[0].scatter(pts[0, 1], pts[0, 0], c='lime', s=100, zorder=5,
                label=f'Start (BD={feats["bd_start"]:.2f})')
axes[0].scatter(pts[-1, 1], pts[-1, 0], c='red', s=100, zorder=5,
                label=f'End (BD={feats["bd_end"]:.2f})')
nose_pt = pts[0] if nose_idx == 0 else pts[-1]
axes[0].scatter(nose_pt[1], nose_pt[0], c='yellow', s=200, marker='*',
                zorder=6, label='Detected nose')
axes[0].legend(fontsize=9)
axes[0].set_title(f'Spline endpoints & detected nose\nScore={score:.3f}')

# Panel 2: border contact visualization
axes[1].imshow(mask_frame, cmap='gray')
# highlight image border pixels
border_mask = np.zeros_like(mask_frame, dtype=bool)
border_mask[0, :] = True; border_mask[-1, :] = True
border_mask[:, 0] = True; border_mask[:, -1] = True
border_and_mask = mask_frame.astype(bool) & border_mask
axes[1].imshow(border_and_mask, cmap='Reds', alpha=0.6)
for pt_i, color, lbl in [(pts[0], 'lime', f'Start (BC={feats["bc_start"]:.2f})'),
                          (pts[-1], 'red', f'End (BC={feats["bc_end"]:.2f})')]:
    circle = plt.Circle((pt_i[1], pt_i[0]), 40, fill=False,
                         color=color, linewidth=2, linestyle='--')
    axes[1].add_patch(circle)
    axes[1].scatter(pt_i[1], pt_i[0], c=color, s=60, zorder=5, label=lbl)
axes[1].legend(fontsize=9)
axes[1].set_title('Mask border contact\n(red overlay = mask touching border)')

# Panel 3: angular coverage visualization
axes[2].imshow(mask_frame, cmap='gray', alpha=0.5)
ang_radius_vis = 20
for pt_i, color, lbl in [(pts[0], 'lime', f'Start (AC={feats["ac_start"]:.2f})'),
                          (pts[-1], 'red', f'End (AC={feats["ac_end"]:.2f})')]:
    y_c, x_c = float(pt_i[0]), float(pt_i[1])
    angles = np.linspace(0, 2 * np.pi, 36, endpoint=False)
    for theta in angles:
        ey = y_c + ang_radius_vis * np.cos(theta)
        ex = x_c + ang_radius_vis * np.sin(theta)
        # check if ray hits mask
        hit = False
        for t in np.linspace(1, ang_radius_vis, ang_radius_vis):
            ry = int(round(y_c + t * np.cos(theta)))
            rx = int(round(x_c + t * np.sin(theta)))
            if 0 <= ry < H and 0 <= rx < W and mask_frame[ry, rx]:
                hit = True; break
            elif not (0 <= ry < H and 0 <= rx < W):
                break
        ray_color = color if hit else 'gray'
        ray_alpha = 0.8 if hit else 0.2
        axes[2].plot([x_c, ex], [y_c, ey], color=ray_color, alpha=ray_alpha, linewidth=1)
    axes[2].scatter(x_c, y_c, c=color, s=60, zorder=5, label=lbl)
axes[2].legend(fontsize=9)
axes[2].set_title('Angular coverage\n(colored rays = hit mask)')

plt.suptitle(f'Frame {frame} — Score: {score:.3f}  '
             f'({"start" if nose_idx == 0 else "end"} is nose)',
             fontsize=14)
plt.tight_layout()
plt.show()

Nose detection for frame 0:
  Border dist  — start: 0.000, end: 0.035, score: -1.000
  Border cont. — start: 0.000, end: 0.004, score: 1.000
  Angular cov. — start: 0.472, end: 1.000, score: -0.358
  Combined score: 0.092
  Nose is at: start (pts[0])


In [9]:
# Compute per-frame nose scores across ALL frames
n_frames_total = len(spline_dict)
all_scores = np.zeros(n_frames_total)
all_bd_start = np.zeros(n_frames_total)
all_bd_end = np.zeros(n_frames_total)
all_bc_start = np.zeros(n_frames_total)
all_bc_end = np.zeros(n_frames_total)
all_ac_start = np.zeros(n_frames_total)
all_ac_end = np.zeros(n_frames_total)

for i in tqdm.tqdm(range(n_frames_total), desc="Detecting nose per frame"):
    pts_i = np.array(spline_dict[i])
    if len(pts_i) < 3:
        continue
    mask_i = label[i].astype(bool)
    _, sc, ft = detect_nose_endpoint(pts_i, mask_i)
    all_scores[i] = sc
    if ft:
        all_bd_start[i] = ft['bd_start']
        all_bd_end[i] = ft['bd_end']
        all_bc_start[i] = ft['bc_start']
        all_bc_end[i] = ft['bc_end']
        all_ac_start[i] = ft['ac_start']
        all_ac_end[i] = ft['ac_end']

# Visualize score distribution and time series
fig, axes = plt.subplots(2, 2, figsize=(14, 8))

axes[0, 0].plot(all_scores, linewidth=0.5)
axes[0, 0].axhline(0, color='k', linestyle='--', alpha=0.5)
axes[0, 0].set_xlabel('Frame')
axes[0, 0].set_ylabel('Combined score')
axes[0, 0].set_title('Nose score over time (>0 = start is nose)')

nonzero = all_scores[all_scores != 0]
axes[0, 1].hist(nonzero, bins=50, color='steelblue', edgecolor='k')
axes[0, 1].axvline(0, color='k', linestyle='--')
axes[0, 1].set_xlabel('Combined score')
axes[0, 1].set_ylabel('Count')
axes[0, 1].set_title('Score distribution')

axes[1, 0].plot(all_bd_start, label='Start (border dist)', alpha=0.7, linewidth=0.5)
axes[1, 0].plot(all_bd_end, label='End (border dist)', alpha=0.7, linewidth=0.5)
axes[1, 0].plot(all_bc_start, label='Start (border contact)', alpha=0.5, linewidth=0.5, linestyle='--')
axes[1, 0].plot(all_bc_end, label='End (border contact)', alpha=0.5, linewidth=0.5, linestyle='--')
axes[1, 0].legend(fontsize=8)
axes[1, 0].set_xlabel('Frame')
axes[1, 0].set_ylabel('Value')
axes[1, 0].set_title('Border distance & contact over time')

axes[1, 1].plot(all_ac_start, label='Start (angular cov.)', alpha=0.7, linewidth=0.5)
axes[1, 1].plot(all_ac_end, label='End (angular cov.)', alpha=0.7, linewidth=0.5)
axes[1, 1].legend(fontsize=8)
axes[1, 1].set_xlabel('Frame')
axes[1, 1].set_ylabel('Angular coverage')
axes[1, 1].set_title('Angular coverage over time')

pct_start = np.mean(all_scores > 0) * 100
pct_end = np.mean(all_scores < 0) * 100
plt.suptitle(f'Per-frame nose detection — {pct_start:.0f}% start is nose, '
             f'{pct_end:.0f}% end is nose', fontsize=14)
plt.tight_layout()
plt.show()

Detecting nose per frame: 100%|██████████| 751/751 [00:00<00:00, 1365.82it/s]


In [12]:
def orient_all_auto(spline_dict, label_stack, crop_n=350,
                    border_radius=40, ang_radius=20,
                    continuity_weight=0.4):
    """
    Fully automated spline orientation (no manual nose annotation).

    Designed for the case where only the HEAD of the worm is in frame:
    - The nose endpoint is interior to the image (mask wraps around it).
    - The body-exit endpoint is near the image border (mask truncated).

    Two-pass algorithm:
      Pass 1 — score every frame independently using border/shape features.
      Pass 2 — orient each frame using a blend of shape score and
               temporal-continuity signal.

    Parameters
    ----------
    spline_dict : dict
        Frame index -> list of [y, x] spline points.
    label_stack : ndarray (F, H, W)
        Binary segmentation masks.
    crop_n : int
        Truncate oriented spline to this many points.
    border_radius : int
        Radius for mask border contact computation.
    ang_radius : int
        Radius for angular coverage rays.
    continuity_weight : float in [0, 1]
        How much to weight temporal continuity vs shape features.
        0 = pure shape, 1 = pure continuity.

    Returns
    -------
    out_dict : dict
        Oriented (and cropped) splines.
    shape_scores : ndarray
        Per-frame shape-only scores (positive = first point is nose).
    """
    n = len(spline_dict)

    # ---- Pass 1: per-frame shape scores (independent, no temporal info) ----
    shape_scores = np.zeros(n)
    for i in tqdm.tqdm(range(n), desc="Pass 1: shape features"):
        pts = np.asarray(spline_dict[i])
        if len(pts) < 3:
            continue
        mask = label_stack[i].astype(bool)
        _, sc, _ = detect_nose_endpoint(pts, mask,
                                        border_radius=border_radius,
                                        ang_radius=ang_radius)
        shape_scores[i] = sc

    # Global orientation from majority of confident frames
    confident = shape_scores[np.abs(shape_scores) > 0.05]
    global_sign = float(np.sign(np.median(confident))) if len(confident) > 0 else 1.0
    print(f"Global orientation sign: {global_sign:+.0f}  "
          f"({np.sum(confident > 0)}/{len(confident)} confident frames agree start=nose)")

    # ---- Pass 2: orient with blended shape + continuity ----
    out_dict = {}
    last_nose = None

    for i in range(n):
        data = list(spline_dict[i])  # copy
        pts = np.asarray(data)

        if len(pts) < 3:
            out_dict[i] = data
            continue

        score = shape_scores[i]

        if last_nose is not None:
            # continuity signal: positive -> start is closer to last nose
            d0 = float(np.linalg.norm(pts[0] - last_nose))
            d1 = float(np.linalg.norm(pts[-1] - last_nose))
            dt = d0 + d1
            cont_score = (d1 - d0) / dt if dt > 1e-8 else 0.0

            blended = ((1 - continuity_weight) * score
                       + continuity_weight * cont_score)
        else:
            # first valid frame: use shape score, fall back to global sign
            blended = score if abs(score) > 0.05 else global_sign * 0.1

        if blended < 0:
            data = data[::-1]

        data = data[:crop_n]
        last_nose = np.array(data[0])
        out_dict[i] = data

    return out_dict, shape_scores


# ---- Run automated orientation ----
oriented_auto, scores_auto = orient_all_auto(spline_dict, label, crop_n=350, continuity_weight=1.0)
print(f"\nOriented {len(oriented_auto)} frames automatically.")

Pass 1: shape features: 100%|██████████| 751/751 [00:00<00:00, 1639.32it/s]


Global orientation sign: +1  (360/687 confident frames agree start=nose)

Oriented 751 frames automatically.


In [13]:
# Visualize the automated orientation result
%matplotlib qt

fig, axes = plt.subplots(1, 2, figsize=(16, 7))
cmap_vis = plt.get_cmap('viridis')
n_total = len(oriented_auto)

# Left: all nose positions overlaid
axes[0].imshow(np.ones((512, 512)), cmap='gray', vmin=0, vmax=1)
for i in range(n_total):
    pts_i = np.array(oriented_auto[i])
    if len(pts_i) < 2:
        continue
    color = cmap_vis(i / max(n_total - 1, 1))
    axes[0].plot(pts_i[:, 1], pts_i[:, 0], color=color, alpha=0.15, linewidth=0.5)
    axes[0].scatter(pts_i[0, 1], pts_i[0, 0], color=color, s=8)
axes[0].set_title('Automated orientation — all frames\n(dots = nose)')

# Right: nose trajectory over time
nose_y = [np.array(oriented_auto[i])[0, 0] if len(oriented_auto[i]) >= 2 else np.nan
           for i in range(n_total)]
nose_x = [np.array(oriented_auto[i])[0, 1] if len(oriented_auto[i]) >= 2 else np.nan
           for i in range(n_total)]

axes[1].plot(nose_x, label='Nose X', alpha=0.7)
axes[1].plot(nose_y, label='Nose Y', alpha=0.7)
axes[1].set_xlabel('Frame')
axes[1].set_ylabel('Position (px)')
axes[1].set_title('Nose position over time\n(smooth = good, jumps = orientation flip errors)')
axes[1].legend()

plt.tight_layout()
plt.show()

# Check for sudden jumps in nose position (potential orientation errors)
nose_coords = np.column_stack([nose_y, nose_x])
nose_jumps = np.linalg.norm(np.diff(nose_coords, axis=0), axis=1)
large_jumps = np.where(nose_jumps > 50)[0]
if len(large_jumps) > 0:
    print(f"\nWarning: {len(large_jumps)} frames with nose jumps > 50px.")
    print(f"Frames: {large_jumps[:20]}{'...' if len(large_jumps) > 20 else ''}")
    print("These may be orientation flip errors — inspect visually.")
else:
    print("\nNo large nose jumps detected — orientation looks consistent.")


Frames: [ 1  2  4  5  6  7  8  9 10 15 16 17 18 23 24 25 28 29 30 32]...
These may be orientation flip errors — inspect visually.


### Tuning parameters

| Parameter | Default | Effect |
|-----------|---------|--------|
| `border_radius` | 40 | Radius (px) for mask border contact search. Increase if body exits far from spline endpoint. |
| `ang_radius` | 20 | Ray length for angular coverage. Increase for thick worms; decrease for thin. |
| `continuity_weight` | 0.4 | Blend of shape vs continuity. 0 = pure shape, 1 = pure continuity. Increase if shape features are noisy. |

### How the features work for partial-body-in-frame

- **Border distance**: The body-exit endpoint is typically near the image edge (distance ~0), while the nose is interior (distance ~0.1–0.5). Strong signal unless the worm is very centered.
- **Mask border contact**: Near the body-exit point, the mask runs right up to the image boundary. Near the nose, it doesn't. This is the most discriminative feature (weighted 0.45).
- **Angular coverage**: At the nose, you can shoot rays in all directions and they hit mask (body wraps around). At the body-exit, rays toward the border miss (mask is truncated). Works even when the nose itself isn't fully in frame.

### Pipeline integration

Replace step 7 (orient) in `run_pipeline_flvc_transfer.ipynb` with `orient_all_auto`. It takes `spline_dict` and `label` stack and outputs `oriented.json` directly, with no manual nose annotation needed.

In [None]:
# Save oriented.json (same format as orient.py / orient_v2.py output)
out_path = os.path.join(pth, 'oriented_auto.json')

with open(out_path, 'w') as f:
    json.dump(oriented_auto, f, indent=4)

print(f"Saved oriented splines to: {out_path}")