Skip to content

feat: KNN-barycentric interpolator — drop-in JAX-native Delaunay alternative #317

@Jammy2211

Description

@Jammy2211

Overview

A pure-JAX wildcard for replacing scipy.spatial.Delaunay in PyAutoArray's source-plane interpolation. Pick the 3 nearest mesh vertices in source plane (existing brute-force kNN) and compute exact barycentric weights on that triangle — replacing the Wendland C4 kernel weights of the existing InterpolatorKNearestNeighbor. If validation passes (log-evidence rtol=1e-3 vs true Delaunay at the production HST fiducial), this is a drop-in replacement that eliminates the scipy callback bottleneck (~16.87 ms/element on A100, 24% of the production likelihood). Code is small (~50-100 lines); the hard part is the science validation.

Plan

  • Implement a pure-JAX interpolator that picks the 3 nearest mesh vertices in source plane (k=3) and computes exact barycentric weights on that triangle.
  • For queries outside their nearest-3 triangle (negative barycentric component): start with clip + renormalize; escalate to "take more neighbors, pick 3 with non-negative coords" only if validation fails.
  • Register as a new KNNBarycentric(Delaunay) mesh class so users opt in by swapping mesh — no pipeline-code changes.
  • Validate by comparing log_evidence against InterpolatorDelaunay at the HST imaging fiducial (EXPECTED_LOG_EVIDENCE_HST = 26288.321397232066). Targets: rtol=1e-4 → drop-in default; rtol=1e-3 → opt-in alternative; rtol=1e-2 fail → abandon and pursue split-callback approach instead.
  • Measure speedup via the existing z_projects/profiling/scripts/delaunay_vmap_probe.py harness at PROBE_BATCH_SIZE=20. Expected: 69.5 → ~56 ms per element (~1.25× overall).
Detailed implementation plan

Affected Repositories

  • PyAutoArray (primary, library)
  • autolens_workspace_developer (regression follow-up)
  • autolens_workspace_test (smoke follow-up)

Work Classification

Library (then workspace follow-up)

Branch Survey

Repository Current Branch Dirty?
./PyAutoArray main clean
./autolens_workspace_developer main dirty (pre-existing unrelated work; worktree gives clean origin/main copy)
./autolens_workspace_test main clean

Suggested branch: feature/knn-barycentric
Worktree root: ~/Code/PyAutoLabs-wt/knn-barycentric/ (created later by /start_library)

Coexistence note: Active task jit-regression-drift claims autolens_workspace_developer + autolens_profiling for drift triage of existing jit/ scripts. New regression scripts under delaunay_knn_barycentric.py are additive and won't collide with the constants being triaged, but the merge order should sequence the drift PR first if both land in the same window.

Implementation Steps

PyAutoArray (library — primary work)

  1. autoarray/inversion/mesh/interpolator/knn.py

    • Add barycentric_weights_from_3_nearest(query_points, mesh_points, nearest_3_indices, xp):
      • Gathers the 3 vertices for each query (shape (Q, 3, 2))
      • Computes signed barycentric coordinates (reuse triangle_area_xp from delaunay.py)
      • Clips negatives to 0, renormalizes so the row sums to 1
      • Zeros weights on degenerate triangles (matches existing Delaunay code path)
      • Returns weights shape (Q, 3)
    • Add InterpolatorKNNBarycentric(InterpolatorKNearestNeighbor):
      • Overrides _mappings_sizes_weights to call get_interpolation_weights with k_neighbors=3 and pipe the indices through barycentric_weights_from_3_nearest instead of Wendland weights
      • _mappings_sizes_weights_split: same swap on the split-points path (regularization requires it)
  2. autoarray/inversion/mesh/mesh/knn.py

    • Add KNNBarycentric(Delaunay) class mirroring KNearestNeighbor:
      • No k_neighbors / radius_scale knobs (hard-coded k=3)
      • Keeps areas_factor for split regularization
      • interpolator_cls returns InterpolatorKNNBarycentric
  3. autoarray/inversion/mesh/__init__.py — export KNNBarycentric.

  4. Unit teststest_autoarray/inversion/mesh/interpolator/test_knn_barycentric.py:

    • Triangle interior: weights match pixel_weights_delaunay_from to fp64 precision when the 3 nearest are the containing Delaunay triangle (constructed by hand)
    • Boundary (query outside triangle of 3 nearest): clipped weights ≥ 0, sum to 1
    • Degenerate (collinear 3 nearest): no NaN, weights finite
    • Stays numpy-only — no JAX in library unit tests

autolens_workspace_developer (regression — follow-up)

  1. jax_profiling/jit/imaging/delaunay_knn_barycentric.py — parallel of delaunay.py:
    • Same dataset, same fiducial, same model setup
    • Mesh swapped to KNNBarycentric
    • Assertion: log_evidence within rtol=1e-3 of EXPECTED_LOG_EVIDENCE_HST = 26288.321397232066
    • If rtol=1e-4 also passes, note it in a comment — that's the drop-in-default condition

autolens_workspace_test (smoke — follow-up)

  1. scripts/jax_assertions/knn_barycentric.py — mirror nnls.py pattern:
    • Well-conditioned case: finite output, weights sum to 1 per row
    • Ill-conditioned (caustic-crossing lens model from existing test suite): assert tolerance vs Delaunay's source reconstruction

Speed measurement (sanity check, no PR)

  1. Re-run z_projects/profiling/scripts/delaunay_vmap_probe.py at PROBE_BATCH_SIZE=20 with mesh swapped from Delaunay to KNNBarycentric. Compare per-call timings to existing output/imaging/delaunay/hpc_a100_fp64_vmap_probe.json. Target: ~13 ms saving per element, 69.5 → ~56 ms.

Decision gate (validation outcome)

Outcome Ship as
rtol=1e-4 passes Default Delaunay-equivalent mesh; can deprecate scipy path
rtol=1e-3 passes, rtol=1e-4 fails Opt-in alternative via config flag, default off
rtol=1e-2 fails Abandon; fall back to split-callback approach in delaunay_research.md

Key Files

New code (PyAutoArray):

  • autoarray/inversion/mesh/interpolator/knn.py — add barycentric_weights_from_3_nearest + InterpolatorKNNBarycentric
  • autoarray/inversion/mesh/mesh/knn.py — add KNNBarycentric mesh class
  • autoarray/inversion/mesh/__init__.py — export
  • test_autoarray/inversion/mesh/interpolator/test_knn_barycentric.py — unit tests

New code (autolens_workspace_developer):

  • jax_profiling/jit/imaging/delaunay_knn_barycentric.py — regression script at rtol=1e-3

New code (autolens_workspace_test):

  • scripts/jax_assertions/knn_barycentric.py — smoke / cross-conditioning

Read-only reference:

  • autoarray/inversion/mesh/interpolator/delaunay.pypixel_weights_delaunay_from, triangle_area_xp (reuse), InterpolatorDelaunay
  • autoarray/inversion/mesh/interpolator/knn.py — current get_interpolation_weights, InterpolatorKNearestNeighbor
  • autoarray/inversion/mesh/mesh/knn.py — existing KNearestNeighbor(Delaunay) pattern
  • autoarray/inversion/mappers/abstract.py:255Mapper.mapping_matrix call site
  • z_projects/profiling/scripts/delaunay_vmap_probe.py — speed harness
  • PyAutoPrompt/autoarray/delaunay_research.md — split-callback fallback strategy
  • z_projects/profiling/FINDINGS_nnls_v2.md — full investigation background

Out of scope

  • Pure-JAX Delaunay (separate research project — delaunay_research.md, ~3-6 months)
  • Multiprocessing the scipy callback (separate optimization in delaunay_research.md)
  • Changes to NNLS, log_ev, PSF FFT (not bottlenecks at production batch=20)
  • Sampler-level Delaunay caching (PyAutoFit concern)

Original Prompt

Click to expand starting prompt

InterpolatorKNearestNeighbor variant: barycentric weights on top-3 nearest

A pure-JAX wildcard for replacing scipy.spatial.Delaunay in PyAutoArray's
source-plane interpolation. The idea is small and testable, the
potential payoff is large — if it holds the log-evidence to rtol=1e-4
against true Delaunay, it's a drop-in replacement that eliminates the
scipy callback bottleneck entirely.

Background from the nnls-vmap-speedup investigation

(Full findings: z_projects/profiling/FINDINGS_nnls_v2.md. Companion
research doc: PyAutoPrompt/autoarray/delaunay_research.md. Closed
issue: #307.)

At production batch=20 on A100, the Delaunay imaging likelihood costs
69.5 ms per element. Decomposition:

scipy.spatial.Delaunay via pure_callback = 16.87 ms (24%)  <-- this prompt
other JAX-traced inversion setup         = ~25 ms (36%)
PSF FFT convolution                      = ~9 ms (13%)
log_ev (slogdet + matmul)                = 12 ms (17%)
NNLS reconstruction                      = 6.2 ms (9%)
misc                                     = ~0.5 ms (1%)

pure_callback with vmap_method="sequential" invokes scipy serially
per batch element → 16.87 × 20 = 337 ms wall per batched likelihood
call
is held on a single CPU running scipy/Qhull. The barycentric
weight computation (0.01 ms) and the sparse mapping matrix scatter
(0.15 ms) are essentially free — the 16.87 ms IS the scipy work.

Of that 16.87 ms, ~52% is find_simplex (point location for 15361
query points), called twice, and ~14% is the actual Delaunay
triangulation on 1231 mesh points. The bottleneck is point location
on CPU, not the triangulation per se.

The science context the user supplied

PyAutoArray's premier interpolator is InterpolatorDelaunay. It uses
scipy.spatial.Delaunay to triangulate the source-plane mesh vertices
and find which triangle each over-sampled image-plane data pixel falls
in. The barycentric coordinates of the data pixel inside its triangle
give the three interpolation weights to the three triangle vertices.

InterpolatorKNearestNeighbor was a previous JAX-friendly attempt at
the same problem (autoarray/inversion/mesh/interpolator/knn.py). It
uses brute-force k-nearest-neighbor search in JAX (no scipy callback)
and weighs the neighbors with a Wendland C4 kernel (smoothed
inverse-distance, compact support).

Per the user, kNN performed scientifically much worse than Delaunay
because:

  • The Wendland kernel has knobs (k_neighbors, radius_scale) that
    are hard to set to work for all lenses.
  • Delaunay adapts cleanly to local mesh density; kNN smears across
    density gradients.
  • At caustic crossings (folds in the source-plane mapping), Delaunay's
    triangulation correctly spans the fold; kNN smears across it.

So Delaunay is scientifically the right method but the least JAX-friendly.

The wildcard idea

Use kNN to pick the top-3 nearest neighbors in source plane, then
compute exact barycentric coordinates on the triangle those 3 form.

This replaces Wendland kernel weights with locally-exact barycentric
weights
— the same weights Delaunay uses, just on a triangle chosen
by Euclidean nearest-neighbor instead of by Delaunay triangulation.

When the 3 nearest mesh vertices happen to be the 3 vertices of the
containing Delaunay triangle (the common case for "interior" query
points), the result is bit-identical to Delaunay. When they aren't
(boundary points, density-gradient regions, caustic-crossing points),
the result is an approximation — but one that still respects local
mesh topology rather than smearing with a global kernel.

The potentially-killer subtlety: when the query point is OUTSIDE the
triangle formed by its 3 nearest neighbors, the barycentric coords
have a negative component. Two options:

  1. Clip and renormalize: max(bary, 0), then normalize so they sum
    to 1. Gives a valid convex combination but slightly distorts.
  2. Fall back to kNN-Wendland for those query points. Hybrid.
  3. Take more neighbors (k=4, 5, 6) and pick the 3 with non-negative
    barycentric coords
    . Most expressive but more complex.

Start with option 1. If log-evidence rtol fails, escalate to option 3.

Why this might work where pure kNN-Wendland didn't

  • Barycentric weights on 3 nearest are locally exact when the 3 are
    the correct triangle vertices. The Wendland kernel never is.
  • The "knobs" go away: no radius_scale, no kernel shape. Just k=3.
  • At caustic crossings, the 3 nearest still bracket the local source-plane
    structure better than a smoothly-decaying kernel.

Where it might still fail:

  • Boundary points (mesh edge) where the 3 nearest don't form a
    containing triangle, AND clip-renormalize doesn't approximate well.
  • Highly anisotropic mesh density where the 3 nearest are nearly
    collinear (degenerate triangle → numerical instability in barycentric).

Both failure modes are testable.

Performance expectations

InterpolatorKNearestNeighbor is already pure JAX with lax.fori_loop
over blocks of mesh points (brute-force kNN, no scipy callback). Adding
barycentric on top-3 is a small post-processing step on the kNN output.

Estimated runtime under vmap=20 on A100:

  • kNN search (current code, k=3 instead of k=10): ~1-3 ms per element
    (brute-force, 1231 mesh × 15361 queries × 3 = 5.7e7 ops, easy GPU work)
  • Barycentric weight computation on the 3 picked vertices: ~0.1 ms
    per element
  • Total: ~2-4 ms per element

Replacing InterpolatorDelaunay's 16.87 ms per element callback with
~3 ms pure JAX gives ~13 ms saving per element, full pipeline drops
from 69.5 → 56 ms per element, batch_time 1.4 → 1.12 sec.
~1.25× speedup overall, just from this swap.

If it works, it's also vmap-parallel (no sequential callback), so
larger batch sizes will continue to amortize well — unlike the scipy
callback which is sequential-per-element forever.

Decision criteria up front

This is a science-validated speedup wildcard. Code is small (~50-100
lines new). The hard part is the validation: does kNN-barycentric give
the same scientific answer as Delaunay across the production lens
modeling matrix?

If yes — best possible outcome: drop scipy.spatial.Delaunay entirely,
get ~1.25× speedup on Delaunay production, zero external dep on scipy
in the inversion path.

If no — small sunk cost, fall back to the more conservative
split-the-callback approach in delaunay_research.md.

Either way the data tells us something useful.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions