You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Issue #161 identifies that Moffat profile drawing is slow, with gradients being 40x more expensive than the forward pass. The detailed analysis by @EiffL pinpoints several contributing factors in the kv() Bessel function implementation (listed as "Fix 4" in the issue). This PR rewrites kv() as a proper JAX primitive, which addresses several of those root causes:
1. No more unnecessary forward-pass work for gradients
The old implementation used @jax.custom_vjp, which unconditionally computes 3 Bessel evaluations (K_v, K_{v-1}, K_{v+1}) in every forward pass and stores them as residuals for the backward pass — even when no gradient is needed (e.g., when evaluating the PSF in a convolution). From the issue:
The kv() Bessel function (bessel.py:550-554) always computes 3 Bessel evaluations in the forward pass to store residuals for the backward.
The new implementation uses ad.defjvp on the primitive instead. The JVP rule is only invoked during differentiation — the forward pass computes only the single K_v evaluation it actually needs. The K_{v-1} and K_{v+1} terms are computed lazily, only when JAX's AD system actually requests the derivative.
The old code used jax.lax.while_loop for convergence-based iteration in the Temme series and Steed's continued fraction algorithms. while_loop generates larger, harder-to-optimize HLO because XLA cannot predict the iteration count at compile time.
The new code uses jax.lax.fori_loop with empirically-determined fixed iteration counts (15 for Temme series, 80 for Steed CF), using jnp.where(converged, old_val, new_val) to no-op after convergence. This lets XLA:
Know the exact loop trip count at compile time
Potentially unroll or pipeline the loops
Generate more compact HLO graphs
This is relevant to the observation in the issue that "taking gradients of the moffat and exponential convolution creates a very big HLO."
3. Native element-wise operation replaces vmap over scalar core
The old implementation defined a scalar _kv_scalar function and used jax.vmap to vectorize it over array inputs. The new implementation operates element-wise natively using _up_and_broadcast (the same pattern used by igamma, betainc, and other special functions in jax._src.lax.special). This removes one layer of vmap tracing overhead from the evaluation stack — relevant to @beckermr's point that "XLA is not smart enough to do all of its optimizations with vmap."
4. Proper JAX primitive integration
Registering kv as a JAX primitive via standard_naryop + mlir.register_lowering + ad.defjvp means it integrates with JAX's compiler infrastructure the same way built-in special functions do, rather than being an opaque Python-level overlay. This gives XLA maximum visibility into the operation for optimization.
What this doesn't fix
This PR addresses Fix 4 from the issue analysis. The other identified improvements — stopping gradient flow through non-parameter PSF objects (Fix 2) and vectorizing k-space evaluations in draw_by_kValue (Fix 3) — are orthogonal and would need separate PRs.
Accuracy & performance
Max relative error vs scipy.special.kv: ~1.2e-13 across a wide (v, x) grid
I haven't yet confirmed this PR is helping the original problem of inefficient moffats and gradients, so I'll keep it as a draft for now. Should be able to run that test tonight
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
rewrite of bessel kv wth fixed length iterations, and as proper jax primitives.
this implementation was protiyped here: https://gist.github.com/EiffL/95e00c160dc42a58cb44b0bf8aab83ea