Click to expand starting prompt
InterpolatorKNearestNeighbor variant: barycentric weights on top-3 nearest
A pure-JAX wildcard for replacing scipy.spatial.Delaunay in PyAutoArray's
source-plane interpolation. The idea is small and testable, the
potential payoff is large — if it holds the log-evidence to rtol=1e-4
against true Delaunay, it's a drop-in replacement that eliminates the
scipy callback bottleneck entirely.
Background from the nnls-vmap-speedup investigation
(Full findings: z_projects/profiling/FINDINGS_nnls_v2.md. Companion
research doc: PyAutoPrompt/autoarray/delaunay_research.md. Closed
issue: #307.)
At production batch=20 on A100, the Delaunay imaging likelihood costs
69.5 ms per element. Decomposition:
scipy.spatial.Delaunay via pure_callback = 16.87 ms (24%) <-- this prompt
other JAX-traced inversion setup = ~25 ms (36%)
PSF FFT convolution = ~9 ms (13%)
log_ev (slogdet + matmul) = 12 ms (17%)
NNLS reconstruction = 6.2 ms (9%)
misc = ~0.5 ms (1%)
pure_callback with vmap_method="sequential" invokes scipy serially
per batch element → 16.87 × 20 = 337 ms wall per batched likelihood
call is held on a single CPU running scipy/Qhull. The barycentric
weight computation (0.01 ms) and the sparse mapping matrix scatter
(0.15 ms) are essentially free — the 16.87 ms IS the scipy work.
Of that 16.87 ms, ~52% is find_simplex (point location for 15361
query points), called twice, and ~14% is the actual Delaunay
triangulation on 1231 mesh points. The bottleneck is point location
on CPU, not the triangulation per se.
The science context the user supplied
PyAutoArray's premier interpolator is InterpolatorDelaunay. It uses
scipy.spatial.Delaunay to triangulate the source-plane mesh vertices
and find which triangle each over-sampled image-plane data pixel falls
in. The barycentric coordinates of the data pixel inside its triangle
give the three interpolation weights to the three triangle vertices.
InterpolatorKNearestNeighbor was a previous JAX-friendly attempt at
the same problem (autoarray/inversion/mesh/interpolator/knn.py). It
uses brute-force k-nearest-neighbor search in JAX (no scipy callback)
and weighs the neighbors with a Wendland C4 kernel (smoothed
inverse-distance, compact support).
Per the user, kNN performed scientifically much worse than Delaunay
because:
- The Wendland kernel has knobs (
k_neighbors, radius_scale) that
are hard to set to work for all lenses.
- Delaunay adapts cleanly to local mesh density; kNN smears across
density gradients.
- At caustic crossings (folds in the source-plane mapping), Delaunay's
triangulation correctly spans the fold; kNN smears across it.
So Delaunay is scientifically the right method but the least JAX-friendly.
The wildcard idea
Use kNN to pick the top-3 nearest neighbors in source plane, then
compute exact barycentric coordinates on the triangle those 3 form.
This replaces Wendland kernel weights with locally-exact barycentric
weights — the same weights Delaunay uses, just on a triangle chosen
by Euclidean nearest-neighbor instead of by Delaunay triangulation.
When the 3 nearest mesh vertices happen to be the 3 vertices of the
containing Delaunay triangle (the common case for "interior" query
points), the result is bit-identical to Delaunay. When they aren't
(boundary points, density-gradient regions, caustic-crossing points),
the result is an approximation — but one that still respects local
mesh topology rather than smearing with a global kernel.
The potentially-killer subtlety: when the query point is OUTSIDE the
triangle formed by its 3 nearest neighbors, the barycentric coords
have a negative component. Two options:
- Clip and renormalize: max(bary, 0), then normalize so they sum
to 1. Gives a valid convex combination but slightly distorts.
- Fall back to kNN-Wendland for those query points. Hybrid.
- Take more neighbors (k=4, 5, 6) and pick the 3 with non-negative
barycentric coords. Most expressive but more complex.
Start with option 1. If log-evidence rtol fails, escalate to option 3.
Why this might work where pure kNN-Wendland didn't
- Barycentric weights on 3 nearest are locally exact when the 3 are
the correct triangle vertices. The Wendland kernel never is.
- The "knobs" go away: no
radius_scale, no kernel shape. Just k=3.
- At caustic crossings, the 3 nearest still bracket the local source-plane
structure better than a smoothly-decaying kernel.
Where it might still fail:
- Boundary points (mesh edge) where the 3 nearest don't form a
containing triangle, AND clip-renormalize doesn't approximate well.
- Highly anisotropic mesh density where the 3 nearest are nearly
collinear (degenerate triangle → numerical instability in barycentric).
Both failure modes are testable.
Performance expectations
InterpolatorKNearestNeighbor is already pure JAX with lax.fori_loop
over blocks of mesh points (brute-force kNN, no scipy callback). Adding
barycentric on top-3 is a small post-processing step on the kNN output.
Estimated runtime under vmap=20 on A100:
- kNN search (current code, k=3 instead of k=10): ~1-3 ms per element
(brute-force, 1231 mesh × 15361 queries × 3 = 5.7e7 ops, easy GPU work)
- Barycentric weight computation on the 3 picked vertices: ~0.1 ms
per element
- Total: ~2-4 ms per element
Replacing InterpolatorDelaunay's 16.87 ms per element callback with
~3 ms pure JAX gives ~13 ms saving per element, full pipeline drops
from 69.5 → 56 ms per element, batch_time 1.4 → 1.12 sec.
~1.25× speedup overall, just from this swap.
If it works, it's also vmap-parallel (no sequential callback), so
larger batch sizes will continue to amortize well — unlike the scipy
callback which is sequential-per-element forever.
Decision criteria up front
This is a science-validated speedup wildcard. Code is small (~50-100
lines new). The hard part is the validation: does kNN-barycentric give
the same scientific answer as Delaunay across the production lens
modeling matrix?
If yes — best possible outcome: drop scipy.spatial.Delaunay entirely,
get ~1.25× speedup on Delaunay production, zero external dep on scipy
in the inversion path.
If no — small sunk cost, fall back to the more conservative
split-the-callback approach in delaunay_research.md.
Either way the data tells us something useful.
Overview
A pure-JAX wildcard for replacing
scipy.spatial.Delaunayin PyAutoArray's source-plane interpolation. Pick the 3 nearest mesh vertices in source plane (existing brute-force kNN) and compute exact barycentric weights on that triangle — replacing the Wendland C4 kernel weights of the existingInterpolatorKNearestNeighbor. If validation passes (log-evidence rtol=1e-3 vs true Delaunay at the production HST fiducial), this is a drop-in replacement that eliminates the scipy callback bottleneck (~16.87 ms/element on A100, 24% of the production likelihood). Code is small (~50-100 lines); the hard part is the science validation.Plan
KNNBarycentric(Delaunay)mesh class so users opt in by swapping mesh — no pipeline-code changes.log_evidenceagainstInterpolatorDelaunayat the HST imaging fiducial (EXPECTED_LOG_EVIDENCE_HST = 26288.321397232066). Targets: rtol=1e-4 → drop-in default; rtol=1e-3 → opt-in alternative; rtol=1e-2 fail → abandon and pursue split-callback approach instead.z_projects/profiling/scripts/delaunay_vmap_probe.pyharness atPROBE_BATCH_SIZE=20. Expected: 69.5 → ~56 ms per element (~1.25× overall).Detailed implementation plan
Affected Repositories
Work Classification
Library (then workspace follow-up)
Branch Survey
Suggested branch:
feature/knn-barycentricWorktree root:
~/Code/PyAutoLabs-wt/knn-barycentric/(created later by/start_library)Coexistence note: Active task
jit-regression-driftclaimsautolens_workspace_developer+autolens_profilingfor drift triage of existingjit/scripts. New regression scripts underdelaunay_knn_barycentric.pyare additive and won't collide with the constants being triaged, but the merge order should sequence the drift PR first if both land in the same window.Implementation Steps
PyAutoArray (library — primary work)
autoarray/inversion/mesh/interpolator/knn.pybarycentric_weights_from_3_nearest(query_points, mesh_points, nearest_3_indices, xp):(Q, 3, 2))triangle_area_xpfromdelaunay.py)(Q, 3)InterpolatorKNNBarycentric(InterpolatorKNearestNeighbor):_mappings_sizes_weightsto callget_interpolation_weightswithk_neighbors=3and pipe the indices throughbarycentric_weights_from_3_nearestinstead of Wendland weights_mappings_sizes_weights_split: same swap on the split-points path (regularization requires it)autoarray/inversion/mesh/mesh/knn.pyKNNBarycentric(Delaunay)class mirroringKNearestNeighbor:k_neighbors/radius_scaleknobs (hard-coded k=3)areas_factorfor split regularizationinterpolator_clsreturnsInterpolatorKNNBarycentricautoarray/inversion/mesh/__init__.py— exportKNNBarycentric.Unit tests —
test_autoarray/inversion/mesh/interpolator/test_knn_barycentric.py:pixel_weights_delaunay_fromto fp64 precision when the 3 nearest are the containing Delaunay triangle (constructed by hand)autolens_workspace_developer (regression — follow-up)
jax_profiling/jit/imaging/delaunay_knn_barycentric.py— parallel ofdelaunay.py:KNNBarycentriclog_evidencewithinrtol=1e-3ofEXPECTED_LOG_EVIDENCE_HST = 26288.321397232066rtol=1e-4also passes, note it in a comment — that's the drop-in-default conditionautolens_workspace_test (smoke — follow-up)
scripts/jax_assertions/knn_barycentric.py— mirrornnls.pypattern:Speed measurement (sanity check, no PR)
z_projects/profiling/scripts/delaunay_vmap_probe.pyatPROBE_BATCH_SIZE=20with mesh swapped fromDelaunaytoKNNBarycentric. Compare per-call timings to existingoutput/imaging/delaunay/hpc_a100_fp64_vmap_probe.json. Target: ~13 ms saving per element, 69.5 → ~56 ms.Decision gate (validation outcome)
delaunay_research.mdKey Files
New code (PyAutoArray):
autoarray/inversion/mesh/interpolator/knn.py— addbarycentric_weights_from_3_nearest+InterpolatorKNNBarycentricautoarray/inversion/mesh/mesh/knn.py— addKNNBarycentricmesh classautoarray/inversion/mesh/__init__.py— exporttest_autoarray/inversion/mesh/interpolator/test_knn_barycentric.py— unit testsNew code (autolens_workspace_developer):
jax_profiling/jit/imaging/delaunay_knn_barycentric.py— regression script at rtol=1e-3New code (autolens_workspace_test):
scripts/jax_assertions/knn_barycentric.py— smoke / cross-conditioningRead-only reference:
autoarray/inversion/mesh/interpolator/delaunay.py—pixel_weights_delaunay_from,triangle_area_xp(reuse),InterpolatorDelaunayautoarray/inversion/mesh/interpolator/knn.py— currentget_interpolation_weights,InterpolatorKNearestNeighborautoarray/inversion/mesh/mesh/knn.py— existingKNearestNeighbor(Delaunay)patternautoarray/inversion/mappers/abstract.py:255—Mapper.mapping_matrixcall sitez_projects/profiling/scripts/delaunay_vmap_probe.py— speed harnessPyAutoPrompt/autoarray/delaunay_research.md— split-callback fallback strategyz_projects/profiling/FINDINGS_nnls_v2.md— full investigation backgroundOut of scope
delaunay_research.md, ~3-6 months)delaunay_research.md)Original Prompt
Click to expand starting prompt
InterpolatorKNearestNeighbor variant: barycentric weights on top-3 nearest
A pure-JAX wildcard for replacing scipy.spatial.Delaunay in PyAutoArray's
source-plane interpolation. The idea is small and testable, the
potential payoff is large — if it holds the log-evidence to rtol=1e-4
against true Delaunay, it's a drop-in replacement that eliminates the
scipy callback bottleneck entirely.
Background from the nnls-vmap-speedup investigation
(Full findings:
z_projects/profiling/FINDINGS_nnls_v2.md. Companionresearch doc:
PyAutoPrompt/autoarray/delaunay_research.md. Closedissue: #307.)
At production batch=20 on A100, the Delaunay imaging likelihood costs
69.5 ms per element. Decomposition:
pure_callbackwithvmap_method="sequential"invokes scipy seriallyper batch element → 16.87 × 20 = 337 ms wall per batched likelihood
call is held on a single CPU running scipy/Qhull. The barycentric
weight computation (0.01 ms) and the sparse mapping matrix scatter
(0.15 ms) are essentially free — the 16.87 ms IS the scipy work.
Of that 16.87 ms, ~52% is
find_simplex(point location for 15361query points), called twice, and ~14% is the actual Delaunay
triangulation on 1231 mesh points. The bottleneck is point location
on CPU, not the triangulation per se.
The science context the user supplied
PyAutoArray's premier interpolator is
InterpolatorDelaunay. It usesscipy.spatial.Delaunay to triangulate the source-plane mesh vertices
and find which triangle each over-sampled image-plane data pixel falls
in. The barycentric coordinates of the data pixel inside its triangle
give the three interpolation weights to the three triangle vertices.
InterpolatorKNearestNeighborwas a previous JAX-friendly attempt atthe same problem (
autoarray/inversion/mesh/interpolator/knn.py). Ituses brute-force k-nearest-neighbor search in JAX (no scipy callback)
and weighs the neighbors with a Wendland C4 kernel (smoothed
inverse-distance, compact support).
Per the user, kNN performed scientifically much worse than Delaunay
because:
k_neighbors,radius_scale) thatare hard to set to work for all lenses.
density gradients.
triangulation correctly spans the fold; kNN smears across it.
So Delaunay is scientifically the right method but the least JAX-friendly.
The wildcard idea
Use kNN to pick the top-3 nearest neighbors in source plane, then
compute exact barycentric coordinates on the triangle those 3 form.
This replaces Wendland kernel weights with locally-exact barycentric
weights — the same weights Delaunay uses, just on a triangle chosen
by Euclidean nearest-neighbor instead of by Delaunay triangulation.
When the 3 nearest mesh vertices happen to be the 3 vertices of the
containing Delaunay triangle (the common case for "interior" query
points), the result is bit-identical to Delaunay. When they aren't
(boundary points, density-gradient regions, caustic-crossing points),
the result is an approximation — but one that still respects local
mesh topology rather than smearing with a global kernel.
The potentially-killer subtlety: when the query point is OUTSIDE the
triangle formed by its 3 nearest neighbors, the barycentric coords
have a negative component. Two options:
to 1. Gives a valid convex combination but slightly distorts.
barycentric coords. Most expressive but more complex.
Start with option 1. If log-evidence rtol fails, escalate to option 3.
Why this might work where pure kNN-Wendland didn't
the correct triangle vertices. The Wendland kernel never is.
radius_scale, no kernel shape. Just k=3.structure better than a smoothly-decaying kernel.
Where it might still fail:
containing triangle, AND clip-renormalize doesn't approximate well.
collinear (degenerate triangle → numerical instability in barycentric).
Both failure modes are testable.
Performance expectations
InterpolatorKNearestNeighboris already pure JAX withlax.fori_loopover blocks of mesh points (brute-force kNN, no scipy callback). Adding
barycentric on top-3 is a small post-processing step on the kNN output.
Estimated runtime under vmap=20 on A100:
(brute-force, 1231 mesh × 15361 queries × 3 = 5.7e7 ops, easy GPU work)
per element
Replacing
InterpolatorDelaunay's 16.87 ms per element callback with~3 ms pure JAX gives ~13 ms saving per element, full pipeline drops
from 69.5 → 56 ms per element, batch_time 1.4 → 1.12 sec.
~1.25× speedup overall, just from this swap.
If it works, it's also vmap-parallel (no sequential callback), so
larger batch sizes will continue to amortize well — unlike the scipy
callback which is sequential-per-element forever.
Decision criteria up front
This is a science-validated speedup wildcard. Code is small (~50-100
lines new). The hard part is the validation: does kNN-barycentric give
the same scientific answer as Delaunay across the production lens
modeling matrix?
If yes — best possible outcome: drop scipy.spatial.Delaunay entirely,
get ~1.25× speedup on Delaunay production, zero external dep on scipy
in the inversion path.
If no — small sunk cost, fall back to the more conservative
split-the-callback approach in
delaunay_research.md.Either way the data tells us something useful.