## Computing the Worm Spline (Julia)

The spline computation runs in **Julia** using the Flavell lab's `BehaviorDataNIR.jl` package. It:
1. Segments the worm from each NIR frame using a UNet2D model
2. Computes the medial axis (skeleton) of the worm
3. Fits a smooth spline along the medial axis
4. Detects omega turns and self-intersections
5. Recomputes a cleaned spline

julia 1.8.1 (might be under jupyter kernels)

In [None]:
# using PyCall
# println(PyCall.python)

In [None]:
# import Pkg; Pkg.add("Conda")

In [None]:
# using Conda
# Conda.add("scikit-image", channel="conda-forge")
# Conda.add("scikit-learn", channel="conda-forge")
# Conda.add("networkx")

In [None]:
# using PyCall
# # Try to import them
# ski = pyimport("skimage.morphology")
# skl = pyimport("sklearn.neighbors")
# nx = pyimport("networkx")
# println("All packages imported successfully!")

# SETUP

In [10]:
# Expand the path to include all Conda binary locations
conda_env = "C:\\Users\\munib\\miniconda3\\envs\\nir-pipeline"
new_paths = [
    joinpath(conda_env, "bin"),
    joinpath(conda_env, "Library", "bin"),
    joinpath(conda_env, "Scripts")
]
ENV["PATH"] = join(new_paths, ";") * ";" * ENV["PATH"]

using PyCall
# Re-verify PyCall is definitely looking at the right python
println("PyCall using: ", PyCall.python) 

torch = pyimport("torch")

PyCall using: C:\Users\munib\miniconda3\envs\nir-pipeline\python.exe


PyObject <module 'torch' from 'C:\\Users\\munib\\miniconda3\\envs\\nir-pipeline\\lib\\site-packages\\torch\\__init__.py'>

In [11]:
using Pkg
ENV["PYTHON"] = "C:\\Users\\munib\\miniconda3\\envs\\nir-pipeline\\python.exe"
Pkg.build("PyCall")

[32m[1m    Building[22m[39m Conda ─→ `C:\Users\munib\.julia-flv\scratchspaces\44cfe95a-1eb2-52ea-b672-e2afdf69b78f\8f06b0cfa4c514c7b9546756dbae91fcfbc92dc9\build.log`
[32m[1m    Building[22m[39m PyCall → `C:\Users\munib\.julia-flv\scratchspaces\44cfe95a-1eb2-52ea-b672-e2afdf69b78f\9816a3826b0ebf49ab4926e2b18842ad8b5c8f04\build.log`


In [12]:
using PyCall
PyCall.python

"C:\\Users\\munib\\miniconda3\\envs\\nir-pipeline\\python.exe"

In [13]:
torch = pyimport("torch")

println("Is CUDA available? ", torch.cuda.is_available())
println("GPU Name: ", torch.cuda.get_device_name(0))

Is CUDA available? true
GPU Name: NVIDIA GeForce RTX 5080


# CONFIG

In [14]:
# --- GPU setup (set to "" to disable GPU) ---
ENV["CUDA_VISIBLE_DEVICES"] = "0"   # Use GPU 0; set to "" for CPU
ENV["JULIA_IO_BUFFER"] = "0"

# --- Load packages ---
using ImageDataIO, BehaviorDataNIR, UNet2D, H5Zblosc
using HDF5, PyPlot, FileIO

In [7]:
# Check if the globals were actually initialized
println("py_ski_morphology: ", BehaviorDataNIR.py_ski_morphology)
println("py_skl_neighbors: ", BehaviorDataNIR.py_skl_neighbors)
println("py_nx: ", BehaviorDataNIR.py_nx)

py_ski_morphology: PyObject <module 'skimage.morphology' from 'C:\\Users\\munib\\miniconda3\\envs\\nir-pipeline\\lib\\site-packages\\skimage\\morphology\\__init__.py'>
py_skl_neighbors: PyObject <module 'sklearn.neighbors' from 'C:\\Users\\munib\\miniconda3\\envs\\nir-pipeline\\lib\\site-packages\\sklearn\\neighbors\\__init__.py'>
py_nx: PyObject <module 'networkx' from 'C:\\Users\\munib\\miniconda3\\envs\\nir-pipeline\\lib\\site-packages\\networkx\\__init__.py'>


In [15]:
# ============================================================
# USER CONFIG — EDIT THESE
# ============================================================
NAME = "date-20251028_time-1500_strain-ISg5HT_condition-starvedpatch_worm002"  

# Windows paths:
PATH_H5   = "C:\\Users\\munib\\POSTDOC\\DATA\\g5ht-free\\20251028\\$(NAME).h5"
PATH_JLD2 = "C:\\Users\\munib\\POSTDOC\\DATA\\g5ht-free\\20251028\\$(NAME)_data_dict.jld2"

# Linux paths (if running on server):
# PATH_H5   = "/data3/albert/2024/$(NAME).h5"
# PATH_JLD2 = "/home/user/data/$(NAME)_data_dict.jld2"

# Path to segmentation model weights
# (copy worm_segmentation_best_weights_0310.pt to a known location)
path_weight = "C:\\Users\\munib\\POSTDOC\\CODE\\g5ht-pipeline\\nir\\worm_segmentation_best_weights_0310.pt"

"C:\\Users\\munib\\POSTDOC\\CODE\\g5ht-pipeline\\nir\\worm_segmentation_best_weights_0310.pt"

In [16]:
# ============================================================
# GET FRAME COUNT
# ============================================================
# Simple approach:
MAX_T_NIR = size(h5open(PATH_H5)["img_nir"])[3]

# Safer approach — finds last readable frame (use if h5 file is truncated):
# function last_good_frame(PATH_H5)
#     out = 0
#     for i = 1:size(h5open(PATH_H5)["img_nir"])[3]
#         try
#             img = h5open(PATH_H5)["img_nir"][:, :, i]
#         catch
#             break
#         end
#         out = i
#     end
#     return out
# end
# MAX_T_NIR = last_good_frame(PATH_H5)

println("Total NIR frames: $MAX_T_NIR")

Total NIR frames: 7974


In [17]:
# ============================================================
# CREATE `param` DICTIONARY
# ============================================================
# These are the algorithm parameters — usually no need to change
param = Dict()

# --- For 1st compute_worm_spline!() ---
param["num_center_pts"]           = 1000    # Number of points along the spline
param["img_label_size"]           = (480, 360)  # Expected NIR image size
param["nose_confidence_threshold"] = 0.99   # UNet confidence for nose detection
param["nose_crop_threshold"]      = 20      # Crop threshold for nose region

# --- For compute_worm_thickness() ---
param["min_len_percent"]          = 90      # Min worm length percentile
param["max_len_percent"]          = 98      # Max worm length percentile

# --- For 2nd compute_worm_spline!() (after thickness detection) ---
param["worm_thickness_pad"]       = 3       # Padding around worm thickness
param["boundary_thickness"]       = 5       # Boundary thickness for detection
param["close_pts_threshold"]      = 30      # Threshold for close points
param["trim_head_tail"]           = 15      # Trim head/tail pixels
param["max_med_axis_delta"]       = Inf     # Max medial axis displacement

Inf

In [18]:
# ============================================================
# CREATE `data_dict` DICTIONARY
# ============================================================
# This is the main output dictionary. Here are ALL the keys
# and what they store:

data_dict = Dict()

# --- Medial axis results (Dict of per-frame data) ---
data_dict["med_axis_dict"] = Dict()       # Medial axis points per frame
data_dict["med_axis_dict"][0] = nothing    # Initialize with sentinel

# --- Point ordering (Dict of per-frame data) ---
data_dict["pts_order_dict"] = Dict()      # Ordered skeleton points per frame
data_dict["pts_order_dict"][0] = nothing   # Initialize with sentinel

# --- Omega turn detection (Dict of Bool per frame) ---
data_dict["is_omega"] = Dict()            # Whether frame has omega turn

# --- Spline coordinates: (n_frames, num_center_pts+1) matrices ---
data_dict["x_array"] = zeros(MAX_T_NIR, param["num_center_pts"] + 1)
data_dict["y_array"] = zeros(MAX_T_NIR, param["num_center_pts"] + 1)

# --- Per-frame scalar metrics ---
data_dict["nir_worm_angle"] = zeros(MAX_T_NIR)  # Worm body angle per frame
data_dict["eccentricity"]   = zeros(MAX_T_NIR)  # Worm eccentricity per frame

# --- Error tracking ---
error_dict = Dict()

Dict{Any, Any}()

# MAIN

In [20]:
# ============================================================
# RUN SPLINE COMPUTATION (3 stages)
# ============================================================

# Stage 1: Load the UNet2D segmentation model
println("Loading model...")
using PyCall
# Define the CPU device using PyCall to talk to torch
# device = pyimport("torch").device("cpu")
# worm_seg_model = create_model(1, 1, 16, path_weight, device=device)
worm_seg_model = create_model(1, 1, 16, path_weight)

#  expliclty move worm_seg_model to CPU (should already be on CPU, but just to be sure)
# worm_seg_model.to(device)

Loading model...


PyObject UNet2D(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(32, eps=1e-05, 

In [None]:
# # find the path to the compute_worm_spline!() function
# @which compute_worm_spline!(
#     param, PATH_H5, worm_seg_model, nothing,
#     data_dict["med_axis_dict"], data_dict["pts_order_dict"],
#     data_dict["is_omega"], data_dict["x_array"], data_dict["y_array"],
#     data_dict["nir_worm_angle"], data_dict["eccentricity"],
#     timepts=1:MAX_T_NIR
# )

In [None]:
# Stage 2: Initial spline computation, around X minutes
println("Computing initial spline (pass 1)...")
error_dict["worm_spline_errors_1"] = compute_worm_spline!(
    param, PATH_H5, worm_seg_model, nothing,
    data_dict["med_axis_dict"], data_dict["pts_order_dict"],
    data_dict["is_omega"], data_dict["x_array"], data_dict["y_array"],
    data_dict["nir_worm_angle"], data_dict["eccentricity"]
)

In [None]:
error_dict["worm_spline_errors_1"][4986]

### Diagnostic: run a single failing frame WITHOUT try/catch to get the full stack trace

In [22]:
# ============================================================
# DIAGNOSTIC: Run one failing frame WITHOUT try/catch
# Change `test_idx` to any frame that appeared in the error dict
# ============================================================
test_idx = 1  # start with frame 1 (uses sentinel nothing values from key 0)

# --- Reproduce the exact setup compute_worm_spline! does ---
spline_interval = 1/param["num_center_pts"]
img_label_test = zeros(Int32, param["img_label_size"][1], param["img_label_size"][2])

f = h5open(PATH_H5)
pos_feature, pos_feature_unet = BehaviorDataNIR.read_pos_feature(f)
close(f)

println("pos_feature_unet size: ", size(pos_feature_unet))

pts = pos_feature_unet[:,:,test_idx]
pts_n = pts[1, :]
pts_n[3] = minimum(pts[:,3])
println("pts_n: ", pts_n)

f = h5open(PATH_H5)
img_raw = f["img_nir"][:,:,test_idx]
close(f)
println("img_raw size: ", size(img_raw), " type: ", typeof(img_raw))

# Step 1: Segment
println("\n--- Step 1: segment_worm! ---")
img_raw_ds, img_bin = BehaviorDataNIR.segment_worm!(worm_seg_model, img_raw, img_label_test)
println("segment_worm! succeeded. img_bin type: ", typeof(img_bin), " size: ", size(img_bin))

img_bin = Int32.(img_bin)

# Step 2: Medial axis (this is likely where it fails)
println("\n--- Step 2: medial_axis ---")
med_xs, med_ys, pts_order, is_omega = BehaviorDataNIR.medial_axis(
    param, img_bin, pts_n,
    prev_med_axis=data_dict["med_axis_dict"][test_idx-1],
    prev_pts_order=data_dict["pts_order_dict"][test_idx-1],
    worm_thickness=nothing
)
println("medial_axis succeeded! is_omega=$is_omega, n_pts=$(length(med_xs))")

# Step 3: Fit spline
println("\n--- Step 3: fit_spline ---")
spl_data, spl = BehaviorDataNIR.fit_spline(param, med_xs, med_ys, pts_n, n_subsample=15)
spl_pts = spl(0:spline_interval:1, 1:2)
println("fit_spline succeeded! spl_pts size: ", size(spl_pts))

println("\n=== All steps passed for frame $test_idx ===")

pos_feature_unet size: (3, 3, 7974)
pts_n: Float32[242.0, 144.0, 0.9983859]
img_raw size: (968, 732) type: Matrix{UInt8}

--- Step 1: segment_worm! ---
segment_worm! succeeded. img_bin type: BitMatrix size: (480, 360)

--- Step 2: medial_axis ---
medial_axis succeeded! is_omega=false, n_pts=363

--- Step 3: fit_spline ---
fit_spline succeeded! spl_pts size: (1001, 2)

=== All steps passed for frame 1 ===


In [None]:
# Stage 3: Detect self-intersections using worm thickness
println("Detecting self-intersection...")
data_dict["worm_thickness"], count = compute_worm_thickness(
    param, PATH_H5, worm_seg_model,
    data_dict["med_axis_dict"], data_dict["is_omega"]
)

In [None]:
# Stage 4: Recompute spline with thickness correction
println("Recomputing spline (pass 2)...")
error_dict["worm_spline_errors_2"] = compute_worm_spline!(
    param, PATH_H5, worm_seg_model, data_dict["worm_thickness"],
    data_dict["med_axis_dict"], data_dict["pts_order_dict"],
    data_dict["is_omega"], data_dict["x_array"], data_dict["y_array"],
    data_dict["nir_worm_angle"], data_dict["eccentricity"],
    timepts=1:MAX_T_NIR
)

# ============================================================
# SAVE RESULTS
# ============================================================
println("Saving data_dict to: $PATH_JLD2")
save(PATH_JLD2, "data_dict", data_dict)
println("Done!")

In [None]:
# Quick sanity check — plot every 100th frame's spline
using PyPlot
figure(figsize=(10, 8))
step = 100
for i = 1:step:MAX_T_NIR
    plot(data_dict["x_array"][i, :], data_dict["y_array"][i, :])
end
title("Worm splines (every $(step)th frame)")
xlabel("x (pixels)")
ylabel("y (pixels)")
savefig("spline_check.png", dpi=150)

### `data_dict` Keys After Spline Computation

After running the spline script, the saved `.jld2` file contains:

| Key | Type | Shape | Description |
|-----|------|-------|-------------|
| `x_array` | `Matrix{Float64}` | `(MAX_T_NIR, 1001)` | X coords of spline points per frame |
| `y_array` | `Matrix{Float64}` | `(MAX_T_NIR, 1001)` | Y coords of spline points per frame |
| `nir_worm_angle` | `Vector{Float64}` | `(MAX_T_NIR,)` | Overall worm body angle per frame |
| `eccentricity` | `Vector{Float64}` | `(MAX_T_NIR,)` | Worm shape eccentricity per frame |
| `med_axis_dict` | `Dict` | per-frame | Medial axis points (raw skeleton) |
| `pts_order_dict` | `Dict` | per-frame | Ordered skeleton points |
| `is_omega` | `Dict{Int,Bool}` | per-frame | Whether frame shows omega turn |
| `worm_thickness` | value | scalar | Estimated worm thickness in pixels |