Skip to content

update bessel algo#180

Closed
EiffL wants to merge 3 commits intomainfrom
besssel_improv
Closed

update bessel algo#180
EiffL wants to merge 3 commits intomainfrom
besssel_improv

Conversation

@EiffL
Copy link
Copy Markdown
Member

@EiffL EiffL commented Feb 8, 2026

rewrite of bessel kv wth fixed length iterations, and as proper jax primitives.

this implementation was protiyped here: https://gist.github.com/EiffL/95e00c160dc42a58cb44b0bf8aab83ea

@EiffL
Copy link
Copy Markdown
Member Author

EiffL commented Feb 8, 2026

How this PR addresses #161

Issue #161 identifies that Moffat profile drawing is slow, with gradients being 40x more expensive than the forward pass. The detailed analysis by @EiffL pinpoints several contributing factors in the kv() Bessel function implementation (listed as "Fix 4" in the issue). This PR rewrites kv() as a proper JAX primitive, which addresses several of those root causes:

1. No more unnecessary forward-pass work for gradients

The old implementation used @jax.custom_vjp, which unconditionally computes 3 Bessel evaluations (K_v, K_{v-1}, K_{v+1}) in every forward pass and stores them as residuals for the backward pass — even when no gradient is needed (e.g., when evaluating the PSF in a convolution). From the issue:

The kv() Bessel function (bessel.py:550-554) always computes 3 Bessel evaluations in the forward pass to store residuals for the backward.

The new implementation uses ad.defjvp on the primitive instead. The JVP rule is only invoked during differentiation — the forward pass computes only the single K_v evaluation it actually needs. The K_{v-1} and K_{v+1} terms are computed lazily, only when JAX's AD system actually requests the derivative.

2. Fixed-size fori_loop replaces while_loop → smaller HLO

The old code used jax.lax.while_loop for convergence-based iteration in the Temme series and Steed's continued fraction algorithms. while_loop generates larger, harder-to-optimize HLO because XLA cannot predict the iteration count at compile time.

The new code uses jax.lax.fori_loop with empirically-determined fixed iteration counts (15 for Temme series, 80 for Steed CF), using jnp.where(converged, old_val, new_val) to no-op after convergence. This lets XLA:

  • Know the exact loop trip count at compile time
  • Potentially unroll or pipeline the loops
  • Generate more compact HLO graphs

This is relevant to the observation in the issue that "taking gradients of the moffat and exponential convolution creates a very big HLO."

3. Native element-wise operation replaces vmap over scalar core

The old implementation defined a scalar _kv_scalar function and used jax.vmap to vectorize it over array inputs. The new implementation operates element-wise natively using _up_and_broadcast (the same pattern used by igamma, betainc, and other special functions in jax._src.lax.special). This removes one layer of vmap tracing overhead from the evaluation stack — relevant to @beckermr's point that "XLA is not smart enough to do all of its optimizations with vmap."

4. Proper JAX primitive integration

Registering kv as a JAX primitive via standard_naryop + mlir.register_lowering + ad.defjvp means it integrates with JAX's compiler infrastructure the same way built-in special functions do, rather than being an opaque Python-level overlay. This gives XLA maximum visibility into the operation for optimization.

What this doesn't fix

This PR addresses Fix 4 from the issue analysis. The other identified improvements — stopping gradient flow through non-parameter PSF objects (Fix 2) and vectorizing k-space evaluations in draw_by_kValue (Fix 3) — are orthogonal and would need separate PRs.

Accuracy & performance

  • Max relative error vs scipy.special.kv: ~1.2e-13 across a wide (v, x) grid
  • All 28 existing gradient tests pass
  • Gradient computation uses the analytical recurrence relation: ∂K_ν/∂x = -½(K_{ν-1}(x) + K_{ν+1}(x))

@codspeed-hq
Copy link
Copy Markdown

codspeed-hq bot commented Feb 8, 2026

CodSpeed Performance Report

Merging this PR will degrade performance by 81.89%

Comparing besssel_improv (d95c078) with main (51df286)

⚠️ Unknown Walltime execution environment detected

Using the Walltime instrument on standard Hosted Runners will lead to inconsistent data.

For the most accurate results, we recommend using CodSpeed Macro Runners: bare-metal machines fine-tuned for performance measurement consistency.

Summary

⚡ 2 improved benchmarks
❌ 5 regressed benchmarks
✅ 29 untouched benchmarks

⚠️ Please fix the performance issues or acknowledge them on CodSpeed.

Performance Changes

Mode Benchmark BASE HEAD Efficiency
Simulation test_benchmark_moffat_conv_grad[run] 13.6 s 60.6 s -77.61%
Simulation test_benchmark_moffat_conv[run] 10.4 s 57.4 s -81.89%
WallTime test_benchmark_moffat_conv[run] 681.8 ms 2,771.1 ms -75.4%
WallTime test_benchmark_moffat_init[run] 135.9 µs 92.1 µs +47.65%
WallTime test_benchmark_spergel_calcfluxrad[run] 192.4 µs 303.5 µs -36.62%
WallTime test_benchmarks_lanczos_interp[xval-no_conserve_dc-run] 113.2 µs 80.3 µs +41.02%
WallTime test_benchmark_moffat_conv_grad[run] 1.3 s 3.3 s -60.85%

@EiffL EiffL linked an issue Feb 8, 2026 that may be closed by this pull request
@EiffL
Copy link
Copy Markdown
Member Author

EiffL commented Feb 9, 2026

I haven't yet confirmed this PR is helping the original problem of inefficient moffats and gradients, so I'll keep it as a draft for now. Should be able to run that test tonight

@EiffL
Copy link
Copy Markdown
Member Author

EiffL commented Feb 9, 2026

Baseline Comparison

Metric main besssel_improv Change
Fwd HLO 5,472 5,169 -5.5%
Grad HLO 6,349 9,384 +47.8% (worse)
Fwd time 0.647ms 1.161ms 1.8x slower
Grad time 23.0ms 135.9ms 5.9x slower
Grad/Fwd ratio 35.6x 117.1x 3.3x worse

@EiffL
Copy link
Copy Markdown
Member Author

EiffL commented Feb 9, 2026

sigh, not helping at all, actually making things worse

@EiffL EiffL closed this Feb 9, 2026
@beckermr beckermr reopened this Feb 9, 2026
@beckermr
Copy link
Copy Markdown
Collaborator

beckermr commented Feb 9, 2026

I am reopening this to use the new benchmarks I merged in #187 and #186 to see what this PR does.

@beckermr
Copy link
Copy Markdown
Collaborator

beckermr commented Feb 9, 2026

benchmarks appear to reliably indicate this PR is worse - that should help with debugging!

@beckermr beckermr closed this Feb 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Moffat profile drawing returns NaN (demo2.py)

2 participants