feat: default use_jax_for_visualization to follow use_jax in Analysis.__init__#1278
Merged
Conversation
Phase 2 of the JAX visualization roadmap. Change the default of use_jax_for_visualization from False to None (Optional[bool]) so that users who set use_jax=True automatically pick up the JIT visualization path. Explicit opt-out via use_jax_for_visualization=False still works. PYAUTO_DISABLE_JAX=1 and the JAX-not-installed fallback continue to force both flags off. Closes #1275
Collaborator
Author
|
Workspace PR: PyAutoLabs/autofit_workspace_test#28 |
4 tasks
Jammy2211
added a commit
that referenced
this pull request
May 17, 2026
revert: default use_jax_for_visualization to False (reverts #1278)
Jammy2211
pushed a commit
that referenced
this pull request
May 17, 2026
…inline) The Tests workflow has been red on `main` since 2026-05-16 09:48 UTC, the moment PR #1277 merged the `autofit[nss]` install extra. The 12 tests under `test_autofit/non_linear/search/nest/nss/` hit the script's own ImportError guard on `af.NSS()` because the CI install step only installs `[optional]`. A naïve fix (combine `[optional,nss]` into one pip install) hit a real dependency conflict — both extras pin `blackjax` but to different versions: - [optional] pins `blackjax>=1.2.0` (mainline, PyPI) - [nss] pins handley-lab fork @ ef45acd2 (~0.1.0b1.dev85+) The fork carries the `blackjax.ns.adaptive.init` entrypoint that mainline 1.2.x lacks, so it's not a "use the older one" merge — they're genuinely incompatible. pip rightly refuses `[optional,nss]` with `ResolutionImpossible`. Resolution: split into parallel jobs. - `unittest`: installs `[optional]` and runs the full test suite EXCLUDING `test_autofit/non_linear/search/nest/nss/`. This keeps `test_blackjax_nuts.py` (which requires mainline blackjax 1.2+) green. - `unittest_nss` (new): installs `[nss]` alone in a fresh env and runs ONLY the NSS test suite. Matrix is python 3.12 / 3.13 to match the main `unittest` job. Both jobs need to pass for the `Tests` workflow to be green. The existing `nss_install_smoke.yml` workflow stays as the weekly-cron upstream-drift canary on the [nss] git pins; this new job is the PR/push-time test gate. Fixes the failing Tests check observed on runs 25958838630 (#1277 merge), 25959492149 (#1278 merge), 25998822676 (#1280 revert merge), and 25999018708 (the first attempt at this fix that revealed the extras conflict). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Phase 2 of the JAX visualization roadmap (
z_features/jax_visualization.md).Change the default of
use_jax_for_visualizationonAnalysis.__init__fromFalsetoNone— a sentinel meaning "followuse_jax". Users who setuse_jax=Truenow automatically get the JIT-cached visualization path. Theexplicit opt-out (
use_jax_for_visualization=False) still works for callerswho want JAX likelihoods with the eager NumPy plotter. The
PYAUTO_DISABLE_JAX=1env override and the JAX-not-installed fallback bothcontinue to hard-disable both flags.
Closes #1275.
API Changes
Analysis.__init__(use_jax_for_visualization=...)default changes fromFalse→None. Type annotation widens frombooltoOptional[bool].use_jax=Trueand no explicit visualization flag is given, the JITvisualization path is now ON by default (previously had to be opted into).
TrueorFalseexplicitly areunaffected. Existing scripts that set only
use_jax=Truewill now alsohit the JIT viz path — that is the intended Phase 2 behaviour change.
See full details below.
Test Plan
pytest test_autofit/analysis/test_use_jax_for_visualization.py—existing tests still pass, 1 new numpy-only case added covering
PYAUTO_DISABLE_JAX=1over the new sentinel default.autofit_workspace_test/scripts/jax_assertions/fitness_dispatch.py(ships in the workspace follow-up PR) — all pass locally.
use_jax_for_visualization=explicitly (audited 2026-05-08 / re-verified2026-05-16 — zero hits across autolens_workspace, autogalaxy_workspace,
autofit_workspace).
Full API Changes (for automation & release notes)
Changed Signature
autofit.non_linear.analysis.analysis.Analysis.__init__:use_jax_for_visualization: bool = False→use_jax_for_visualization: Optional[bool] = NoneChanged Behaviour
Analysis(use_jax=True)— previously:_use_jax_for_visualizationresolvedto
False. Now: resolves toTrue(sentinelNonefollowsuse_jax).Analysis(use_jax=True, use_jax_for_visualization=False)— unchanged(explicit opt-out continues to disable JIT viz).
Analysis(use_jax=False, use_jax_for_visualization=True)— unchanged(still logs warning and forces both off).
PYAUTO_DISABLE_JAX=1env override — unchanged (still forces both off).Migration
keep their existing behaviour. Scripts that pass only
use_jax=Truewillpick up the JIT visualization path automatically — this is the desired
Phase 2 behaviour. To preserve old behaviour, pass
use_jax_for_visualization=Falseexplicitly.Roadmap context
PyAutoFit/PyAutoLens fit_imaging_pytree series).
(Maximum recursion depth exceeded when using EP + hierarchical modelling #500/Use Result object of project when using summed Analysis #506 PyAutoLens; Don't resume a MultiNest run if model.results exists #44/Attributes of hyper combined phase need to default to None #46/Raise error is 'result.model' or 'result.instance' is not called #48 autogalaxy_workspace_test; test_dataset.py failing #87/Feature/aggregator #91
autolens_workspace_test).
🤖 Generated with Claude Code