feat(dynesty): support JAX-jitted likelihoods via use_jax_jit#1243
Merged
feat(dynesty): support JAX-jitted likelihoods via use_jax_jit#1243
Conversation
dynesty 2.1.5's NestedSampler has no `vectorized` parameter — it calls the likelihood one sample at a time, so Nautilus's vmap-batching approach doesn't apply. Use `jax.jit` on its own instead: JAX's compiled-function cache reuses the compiled likelihood across calls, giving a fast CPU/GPU evaluation path for nested sampling without requiring autodiff. Changes: - `Fitness.__init__` accepts `use_jax_jit: bool = False`. When set, `self._call = self._jit` (parallel to the existing `_vmap` dispatch). vmap takes precedence if both flags are somehow set. - `Fitness.call_wrap` casts the jit-path return value to a Python `float`. dynesty's `logz` accumulators and HDF5 savestate require numpy/Python scalars, not raw JAX `Array`s. The vmap path is untouched — Nautilus accepts JAX arrays at its `vectorized=True` interface. - `Fitness.__getstate__` / `__setstate__` re-enabled (previously commented out) and extended to strip `_call`, `_jit`, `_vmap`, `_grad` from the pickle. dynesty's `run_nested(checkpoint_file=...)` pickles the loglikelihood; JAX-compiled callables hold C++ XLA state that doesn't roundtrip through pickle. `__setstate__` re-derives the dispatch on resume so the cached_property recompiles lazily on the first call. - `AbstractDynesty.__init__` accepts `use_jax_jit: bool = True`. In `_fit`, the upfront `Fitness(...)` construction passes `use_jax_jit=(analysis._use_jax and self.use_jax_jit)`. Default-on when JAX is enabled; user can disable via the search-class flag. - The existing no-pool fallback (triggered when `force_x1_cpu` or `analysis._use_jax`) now branches the log message three ways: JAX path, force_x1_cpu, OS-multiprocessing fallback. The original message wrongly attributed JAX/force_x1_cpu fallbacks to "OS does not support multiprocessing". Tests: 5 new unit tests in `test_fitness_jax_dispatch.py` cover the dispatch logic and pickle round-trip. They do not import jax (per project policy: library unit tests stay numpy-only). `test_dict` fixture updated for the new `use_jax_jit` arg on `DynestyStatic`. Verification: companion script `Dynesty_jax.py` will land in `autofit_workspace_test`; runs end-to-end with `log_Z ≈ -54, dlogz < 0.5` on the standard 1D Gaussian dataset. `Nautilus_jax.py` remains green (vmap path unchanged). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
jax.jitlikelihood support toDynestyStatic/DynestyDynamic. Whenanalysis._use_jax=True, dynesty now runs against ajax.jit-compiled likelihood, mirroring (but not duplicating) the existing Nautilus JAX-vmap path.Fitness.__getstate__/__setstate__to strip JAX-compiled callables before pickle — required because dynesty'srun_nested(checkpoint_file=...)pickles the loglikelihood, and JAX-compiled functions hold C++ XLA state that doesn't roundtrip.Why
Follow-up to PyAutoFit#1240 (Nautilus_jax merged). The user wants the same fast-CPU/GPU likelihood path on dynesty for inference workloads where Nautilus isn't preferred. Explicitly no autodiff — just
jax.jit, nojax.grador HMC plumbing.dynesty 2.1.5 has no
vectorizedparameter (verified against the installed package), so Nautilus's vmap-batching approach doesn't apply.jax.jitalone wins here: JAX caches compiled functions by argument signature, so dynesty's one-sample-at-a-time call pattern still benefits from compilation.Hard blocker uncovered during planning
Fitness.__getstate__/__setstate__were commented out atfitness.py:385-393(with a TODO referencing_calland_grad). Onceself._call = self._jitis assigned, thecached_propertymaterializes ajax.jit-compiled function — not pickleable. Without the fix, dynesty's first checkpoint write would crash. Re-enabled and extended to also cover_jitand_vmap.__setstate__re-derives the dispatch on resume so the cached_property recompiles lazily.API
Stdout includes
JAX: Applying jit to likelihood function(analogous to Nautilus'sJAX: Applying vmap and jit...) andRunning Dynesty with JAX-jitted likelihood (single CPU, no pool).from the new log branch.Tests
test_fitness_jax_dispatch.pycover dispatch + pickle round-trip. Per project policy these don't import jax (library unit tests stay numpy-only).test_dictfixture updated for the newuse_jax_jitkwarg onDynestyStatic.test_autofit/non_linear/pass.Companion workspace test
autofit_workspace_testPR will follow withscripts/searches/Dynesty_jax.py(mirroringNautilus_jax.py). Verified locally:log_Z ≈ -54.0, dlogz ≈ 0.014on the 1D Gaussian dataset, ~3.5s runtime. Nautilus_jax.py also re-verified for vmap-path regression — still green (log_Z ≈ -54.13, N_eff = 366).Test plan
pytest test_autofit/non_linearpasses (244/244)Dynesty_jax.pyruns end-to-end and emits the JAX-jit log lineNautilus_jax.pystill passes (vmap regression check)autofit_workspace_testPR opened after this merges🤖 Generated with Claude Code