Skip to content

feat(dynesty): support JAX-jitted likelihoods via use_jax_jit#1243

Merged
Jammy2211 merged 1 commit intomainfrom
feature/dynesty-jax
Apr 30, 2026
Merged

feat(dynesty): support JAX-jitted likelihoods via use_jax_jit#1243
Jammy2211 merged 1 commit intomainfrom
feature/dynesty-jax

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

  • Adds jax.jit likelihood support to DynestyStatic / DynestyDynamic. When analysis._use_jax=True, dynesty now runs against a jax.jit-compiled likelihood, mirroring (but not duplicating) the existing Nautilus JAX-vmap path.
  • Re-enables Fitness.__getstate__ / __setstate__ to strip JAX-compiled callables before pickle — required because dynesty's run_nested(checkpoint_file=...) pickles the loglikelihood, and JAX-compiled functions hold C++ XLA state that doesn't roundtrip.
  • Three-way branch in dynesty's no-pool fallback log message (JAX / force_x1_cpu / OS multiprocessing) — the previous single message was misleading.

Why

Follow-up to PyAutoFit#1240 (Nautilus_jax merged). The user wants the same fast-CPU/GPU likelihood path on dynesty for inference workloads where Nautilus isn't preferred. Explicitly no autodiff — just jax.jit, no jax.grad or HMC plumbing.

dynesty 2.1.5 has no vectorized parameter (verified against the installed package), so Nautilus's vmap-batching approach doesn't apply. jax.jit alone wins here: JAX caches compiled functions by argument signature, so dynesty's one-sample-at-a-time call pattern still benefits from compilation.

Hard blocker uncovered during planning

Fitness.__getstate__ / __setstate__ were commented out at fitness.py:385-393 (with a TODO referencing _call and _grad). Once self._call = self._jit is assigned, the cached_property materializes a jax.jit-compiled function — not pickleable. Without the fix, dynesty's first checkpoint write would crash. Re-enabled and extended to also cover _jit and _vmap. __setstate__ re-derives the dispatch on resume so the cached_property recompiles lazily.

API

from autofit.jax.pytrees import enable_pytrees, register_model

enable_pytrees()
model = af.Model(af.ex.Gaussian)
register_model(model)
analysis = af.ex.Analysis(data=data, noise_map=noise_map, use_jax=True)

# Default: jit auto-on when analysis._use_jax. Disable with use_jax_jit=False.
search = af.DynestyStatic(nlive=30)
result = search.fit(model=model, analysis=analysis)

Stdout includes JAX: Applying jit to likelihood function (analogous to Nautilus's JAX: Applying vmap and jit...) and Running Dynesty with JAX-jitted likelihood (single CPU, no pool). from the new log branch.

Tests

  • 5 new unit tests in test_fitness_jax_dispatch.py cover dispatch + pickle round-trip. Per project policy these don't import jax (library unit tests stay numpy-only).
  • test_dict fixture updated for the new use_jax_jit kwarg on DynestyStatic.
  • All 244 tests in test_autofit/non_linear/ pass.

Companion workspace test

autofit_workspace_test PR will follow with scripts/searches/Dynesty_jax.py (mirroring Nautilus_jax.py). Verified locally: log_Z ≈ -54.0, dlogz ≈ 0.014 on the 1D Gaussian dataset, ~3.5s runtime. Nautilus_jax.py also re-verified for vmap-path regression — still green (log_Z ≈ -54.13, N_eff = 366).

Test plan

  • pytest test_autofit/non_linear passes (244/244)
  • Dynesty_jax.py runs end-to-end and emits the JAX-jit log line
  • Nautilus_jax.py still passes (vmap regression check)
  • CI green
  • Companion autofit_workspace_test PR opened after this merges

🤖 Generated with Claude Code

dynesty 2.1.5's NestedSampler has no `vectorized` parameter — it calls
the likelihood one sample at a time, so Nautilus's vmap-batching
approach doesn't apply. Use `jax.jit` on its own instead: JAX's
compiled-function cache reuses the compiled likelihood across calls,
giving a fast CPU/GPU evaluation path for nested sampling without
requiring autodiff.

Changes:

- `Fitness.__init__` accepts `use_jax_jit: bool = False`. When set,
  `self._call = self._jit` (parallel to the existing `_vmap` dispatch).
  vmap takes precedence if both flags are somehow set.

- `Fitness.call_wrap` casts the jit-path return value to a Python
  `float`. dynesty's `logz` accumulators and HDF5 savestate require
  numpy/Python scalars, not raw JAX `Array`s. The vmap path is
  untouched — Nautilus accepts JAX arrays at its `vectorized=True`
  interface.

- `Fitness.__getstate__` / `__setstate__` re-enabled (previously
  commented out) and extended to strip `_call`, `_jit`, `_vmap`,
  `_grad` from the pickle. dynesty's `run_nested(checkpoint_file=...)`
  pickles the loglikelihood; JAX-compiled callables hold C++ XLA state
  that doesn't roundtrip through pickle. `__setstate__` re-derives the
  dispatch on resume so the cached_property recompiles lazily on the
  first call.

- `AbstractDynesty.__init__` accepts `use_jax_jit: bool = True`. In
  `_fit`, the upfront `Fitness(...)` construction passes
  `use_jax_jit=(analysis._use_jax and self.use_jax_jit)`. Default-on
  when JAX is enabled; user can disable via the search-class flag.

- The existing no-pool fallback (triggered when `force_x1_cpu` or
  `analysis._use_jax`) now branches the log message three ways: JAX
  path, force_x1_cpu, OS-multiprocessing fallback. The original
  message wrongly attributed JAX/force_x1_cpu fallbacks to "OS does
  not support multiprocessing".

Tests: 5 new unit tests in `test_fitness_jax_dispatch.py` cover the
dispatch logic and pickle round-trip. They do not import jax (per
project policy: library unit tests stay numpy-only). `test_dict`
fixture updated for the new `use_jax_jit` arg on `DynestyStatic`.

Verification: companion script `Dynesty_jax.py` will land in
`autofit_workspace_test`; runs end-to-end with
`log_Z ≈ -54, dlogz < 0.5` on the standard 1D Gaussian dataset.
`Nautilus_jax.py` remains green (vmap path unchanged).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 merged commit 48831e4 into main Apr 30, 2026
3 checks passed
@Jammy2211 Jammy2211 deleted the feature/dynesty-jax branch April 30, 2026 18:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant