Skip to content

feat: default use_jax_for_visualization to follow use_jax in Analysis.__init__#1278

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/use-jax-for-vis-default
May 16, 2026
Merged

feat: default use_jax_for_visualization to follow use_jax in Analysis.__init__#1278
Jammy2211 merged 1 commit into
mainfrom
feature/use-jax-for-vis-default

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Phase 2 of the JAX visualization roadmap (z_features/jax_visualization.md).
Change the default of use_jax_for_visualization on Analysis.__init__ from
False to None — a sentinel meaning "follow use_jax". Users who set
use_jax=True now automatically get the JIT-cached visualization path. The
explicit opt-out (use_jax_for_visualization=False) still works for callers
who want JAX likelihoods with the eager NumPy plotter. The
PYAUTO_DISABLE_JAX=1 env override and the JAX-not-installed fallback both
continue to hard-disable both flags.

Closes #1275.

API Changes

  • Analysis.__init__(use_jax_for_visualization=...) default changes from
    FalseNone. Type annotation widens from bool to Optional[bool].
  • When use_jax=True and no explicit visualization flag is given, the JIT
    visualization path is now ON by default (previously had to be opted into).
  • All existing call sites that pass True or False explicitly are
    unaffected. Existing scripts that set only use_jax=True will now also
    hit the JIT viz path — that is the intended Phase 2 behaviour change.

See full details below.

Test Plan

  • pytest test_autofit/analysis/test_use_jax_for_visualization.py
    existing tests still pass, 1 new numpy-only case added covering
    PYAUTO_DISABLE_JAX=1 over the new sentinel default.
  • Three JAX-conditional sentinel-resolution assertions added to
    autofit_workspace_test/scripts/jax_assertions/fitness_dispatch.py
    (ships in the workspace follow-up PR) — all pass locally.
  • Reviewer: confirm no production workspace script today sets
    use_jax_for_visualization= explicitly (audited 2026-05-08 / re-verified
    2026-05-16 — zero hits across autolens_workspace, autogalaxy_workspace,
    autofit_workspace).
Full API Changes (for automation & release notes)

Changed Signature

  • autofit.non_linear.analysis.analysis.Analysis.__init__:
    • use_jax_for_visualization: bool = Falseuse_jax_for_visualization: Optional[bool] = None

Changed Behaviour

  • Analysis(use_jax=True) — previously: _use_jax_for_visualization resolved
    to False. Now: resolves to True (sentinel None follows use_jax).
  • Analysis(use_jax=True, use_jax_for_visualization=False) — unchanged
    (explicit opt-out continues to disable JIT viz).
  • Analysis(use_jax=False, use_jax_for_visualization=True) — unchanged
    (still logs warning and forces both off).
  • PYAUTO_DISABLE_JAX=1 env override — unchanged (still forces both off).
  • JAX-not-installed fallback — unchanged (still forces both off with warning).

Migration

  • No required migration. Scripts that currently set both flags explicitly
    keep their existing behaviour. Scripts that pass only use_jax=True will
    pick up the JIT visualization path automatically — this is the desired
    Phase 2 behaviour. To preserve old behaviour, pass
    use_jax_for_visualization=False explicitly.

Roadmap context

🤖 Generated with Claude Code

Phase 2 of the JAX visualization roadmap. Change the default of
use_jax_for_visualization from False to None (Optional[bool]) so that
users who set use_jax=True automatically pick up the JIT visualization
path. Explicit opt-out via use_jax_for_visualization=False still works.
PYAUTO_DISABLE_JAX=1 and the JAX-not-installed fallback continue to
force both flags off.

Closes #1275
@Jammy2211
Copy link
Copy Markdown
Collaborator Author

Workspace PR: PyAutoLabs/autofit_workspace_test#28

@Jammy2211 Jammy2211 merged commit 035f93a into main May 16, 2026
3 of 5 checks passed
@Jammy2211 Jammy2211 deleted the feature/use-jax-for-vis-default branch May 16, 2026 10:22
Jammy2211 added a commit that referenced this pull request May 17, 2026
revert: default use_jax_for_visualization to False (reverts #1278)
Jammy2211 pushed a commit that referenced this pull request May 17, 2026
…inline)

The Tests workflow has been red on `main` since 2026-05-16 09:48 UTC,
the moment PR #1277 merged the `autofit[nss]` install extra. The 12
tests under `test_autofit/non_linear/search/nest/nss/` hit the script's
own ImportError guard on `af.NSS()` because the CI install step only
installs `[optional]`.

A naïve fix (combine `[optional,nss]` into one pip install) hit a real
dependency conflict — both extras pin `blackjax` but to different
versions:

  - [optional] pins `blackjax>=1.2.0` (mainline, PyPI)
  - [nss]      pins handley-lab fork @ ef45acd2 (~0.1.0b1.dev85+)

The fork carries the `blackjax.ns.adaptive.init` entrypoint that
mainline 1.2.x lacks, so it's not a "use the older one" merge — they're
genuinely incompatible. pip rightly refuses `[optional,nss]` with
`ResolutionImpossible`.

Resolution: split into parallel jobs.

- `unittest`: installs `[optional]` and runs the full test suite
  EXCLUDING `test_autofit/non_linear/search/nest/nss/`. This keeps
  `test_blackjax_nuts.py` (which requires mainline blackjax 1.2+) green.
- `unittest_nss` (new): installs `[nss]` alone in a fresh env and runs
  ONLY the NSS test suite. Matrix is python 3.12 / 3.13 to match
  the main `unittest` job.

Both jobs need to pass for the `Tests` workflow to be green. The
existing `nss_install_smoke.yml` workflow stays as the weekly-cron
upstream-drift canary on the [nss] git pins; this new job is the
PR/push-time test gate.

Fixes the failing Tests check observed on runs 25958838630 (#1277
merge), 25959492149 (#1278 merge), 25998822676 (#1280 revert merge),
and 25999018708 (the first attempt at this fix that revealed the
extras conflict).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat: default use_jax_for_visualization to follow use_jax in Analysis.__init__

1 participant