Skip to content

feat: default use_jax_for_visualization to follow use_jax in Analysis.__init__ #1275

@Jammy2211

Description

@Jammy2211

Overview

Phase 2 of the JAX visualization roadmap (z_features/jax_visualization.md).
Today Analysis.__init__ takes use_jax and use_jax_for_visualization as
two independent boolean flags both defaulting to False. With Phases 0 and 1
shipped, every dataset type across PyAutoLens + PyAutoGalaxy now has a
working JIT visualization path. This task flips the default so
use_jax_for_visualization follows use_jax unless the user explicitly
opts out, making the two-flag model invisible to users.

Plan

  • Change use_jax_for_visualization default from False to None (sentinel
    meaning "follow use_jax") in Analysis.__init__.
  • Resolve the sentinel before the existing "requires use_jax=True" guard so
    explicit-True-without-use_jax still warns.
  • Preserve PYAUTO_DISABLE_JAX=1 env-var override and the JAX-not-installed
    fallback — both should remain hard-off paths.
  • Update the __init__ and fit_for_visualization docstrings.
  • Extend the existing test_use_jax_for_visualization.py with four new cases
    covering sentinel resolution, explicit opt-out, and env-var override under
    the new default.
  • Smoke-test every workspace + workspace_test repo to confirm no script's
    behaviour changes unexpectedly.
Detailed implementation plan

Affected Repositories

  • PyAutoFit (primary, library)

Work Classification

Library

Branch Survey

Repository Current Branch Dirty?
./PyAutoFit main clean

Suggested branch: feature/use-jax-for-vis-default
Worktree root: ~/Code/PyAutoLabs-wt/use-jax-for-vis-default/ (created later by /start_library)

Blocker — must serialise behind in-flight work

PyAutoFit is currently claimed by nss-checkpointing-and-visualization
(branch feature/nss-checkpointing-and-visualization, worktree at
~/Code/PyAutoLabs-wt/nss-checkpointing-and-visualization). This task is
parked in planned.md until the NSS PR merges.

Phase 0 / Phase 1 prerequisites — all confirmed shipped

Implementation Steps

  1. Signature changeautofit/non_linear/analysis/analysis.py:36-79:
    change the default of use_jax_for_visualization from False to None
    (typed Optional[bool]). Insert a sentinel-resolution block:

    if use_jax_for_visualization is None:
        use_jax_for_visualization = use_jax

    Placement: after the PYAUTO_DISABLE_JAX env-var check
    (lines 42-45) and the JAX-not-installed fallback (lines 50-69), and
    before the existing if use_jax_for_visualization and not use_jax
    warning guard at line 71. Ordering matters: the env-var path and the
    JAX-missing path already explicitly set both flags to False, so the
    sentinel resolution downstream picks up the already-resolved False
    correctly. The explicit-True-without-use_jax warning at lines 72-75 must
    still fire for the explicit-opt-in-without-use_jax case (this is still a
    user error worth a warning).

  2. Docstring updates

    • Analysis.__init__ — add a short docstring documenting the new sentinel
      semantics of use_jax_for_visualization=None.
    • fit_for_visualization docstring at lines 82-122: replace
      "use_jax_for_visualization=False (default) — plain self.fit_from..."
      with language reflecting that the default now mirrors use_jax.
  3. Tests — extend test_autofit/analysis/test_use_jax_for_visualization.py
    (existing file, six tests today). Add (skip-if-jax-installed where
    appropriate to match the existing pattern):

    • test_use_jax_true_implicit_visualization_on
      Analysis(use_jax=True)_use_jax_for_visualization is True
    • test_explicit_opt_out_when_use_jax_true
      Analysis(use_jax=True, use_jax_for_visualization=False)
      _use_jax_for_visualization is False
    • test_explicit_none_is_sentinel
      Analysis(use_jax=True, use_jax_for_visualization=None)
      _use_jax_for_visualization is True
    • test_pyauto_disable_jax_forces_both_off_with_sentinel
      set PYAUTO_DISABLE_JAX=1 then Analysis(use_jax=True, use_jax_for_visualization=None) still resolves to both False.
  4. Smoke verification/smoke_test across both workspace_test repos
    and the three production workspaces. Per audit, zero production scripts
    currently set use_jax_for_visualization= explicitly, so the only
    behaviour change is for scripts that already set use_jax=True — those
    will now also turn on JAX viz. Phase 1 coverage exists for every dataset
    type that gets touched.

Key Files

  • PyAutoFit/autofit/non_linear/analysis/analysis.py — signature + docstring
    changes (lines 36-122)
  • PyAutoFit/test_autofit/analysis/test_use_jax_for_visualization.py — extend
    with four new cases (note: this is the ACTUAL location; the prompt's
    reference to test_autofit/non_linear/analysis/test_analysis.py was stale)

Original Prompt

Click to expand starting prompt

Today Analysis.__init__ in PyAutoFit takes two independent flags:

# @PyAutoFit/autofit/non_linear/analysis/analysis.py:36
def __init__(
    self,
    use_jax: bool = False,
    use_jax_for_visualization: bool = False,
    **kwargs,
):

A user who wants the JAX-accelerated visualization path has to remember to
set both flags. Across all production workspaces (autolens_workspace,
autogalaxy_workspace, autofit_workspace) zero scripts set
use_jax_for_visualization=True today (audit, 2026-05-08), even where
use_jax=True would benefit. The flag is effectively dead weight — it
exists to gate the still-evolving JIT visualization path, but once the
underlying coverage is green there is no reason it should not follow
use_jax.

This task changes the default so that whenever use_jax=True, the
jit-cached visualization path is on by default. Users keep the explicit
opt-out via use_jax_for_visualization=False.

Why this matters

This is Phase 2 of z_features/jax_visualization.md. The user-facing
goal stated in the z_feature is:

"I want us to be at a point where all default runs do JAX visualization
and the notion of it being a separate thing is no longer relevant
(unless the user doesn't have JAX installed or has use_jax=False)."

This prompt delivers that.

Blockers — must land first

All Phase 1 coverage prompts must be merged before flipping the default,
otherwise dataset types with no JAX viz smoke coverage will silently
regress for any user who runs them with use_jax=True:

  • autolens_workspace_test/jax_viz_interferometer_coverage.md
  • autolens_workspace_test/jax_viz_point_source_coverage.md
  • autogalaxy_workspace_test/jax_viz_dataset_coverage.md

Phase 0 prerequisites (fit_imaging_pytree.md, the autogalaxy dispatch
swap, the autogalaxy other-datasets pytree registration) must also be
done. Verify all of these in complete.md before starting.

What to change

@PyAutoFit/autofit/non_linear/analysis/analysis.py:36-79 — change the
default of use_jax_for_visualization from False to a sentinel that
follows use_jax. Recommended signature:

def __init__(
    self,
    use_jax: bool = False,
    use_jax_for_visualization: Optional[bool] = None,
    **kwargs,
):
    ...
    if use_jax_for_visualization is None:
        use_jax_for_visualization = use_jax
    ...

None means "follow use_jax"; True/False are explicit opt-in/opt-out.
Resolution must happen before the existing
if use_jax_for_visualization and not use_jax guard at line 71 so the
guard's wording ("requires use_jax=True; disabling...") still applies
correctly when a user explicitly passes True without use_jax.

The PYAUTO_DISABLE_JAX=1 short-circuit at lines 42-45 already forces
both flags to False, so the env-var override path is unaffected — but
double-check that branch still resolves cleanly with the new sentinel.

The "JAX not installed" warning branch at lines 50-69 already sets
use_jax_for_visualization = False. That still works because the new
sentinel resolution happens after this branch — but be careful about
ordering: the warning branch sets use_jax = False, so the sentinel
resolution downstream picks up False correctly.

Update the docstring at lines 82-122 to reflect the new default behaviour.
The warning at line 71-76 must remain — passing
use_jax_for_visualization=True, use_jax=False explicitly is still a user
error (deserves a warning, then disabled).

What to verify

  1. Unit tests for the resolution logic. Add cases to
    @PyAutoFit/test_autofit/non_linear/analysis/test_analysis.py (or
    create the file if missing) covering:

    • Analysis() → both off
    • Analysis(use_jax=True)_use_jax_for_visualization=True
    • Analysis(use_jax=True, use_jax_for_visualization=False) → off (explicit opt-out works)
    • Analysis(use_jax=False, use_jax_for_visualization=True) → off + warning logged
    • Analysis(use_jax=True, use_jax_for_visualization=None) → on
    • PYAUTO_DISABLE_JAX=1 env-var override forces both off regardless
    • JAX-not-installed branch still sets both off
  2. Workspace smoke. Run /smoke_test on each workspace_test repo and
    on the production workspaces. The Phase 1 coverage protects against
    regressions; this is the moment of truth.

  3. No silent-warning regressions. Grep production workspaces and tests
    for use_jax_for_visualization= and confirm no script's behaviour
    changes meaningfully. Any script that previously ran with use_jax=True
    and no explicit use_jax_for_visualization was implicitly opting out;
    after this change it will opt in. That is the intended behaviour but
    should be verified case-by-case if any script's runtime changes
    significantly.

Out of scope

  • Phase 3 (production workspace tutorial adoption — opting tutorials into
    use_jax=True) is a separate prompt.
  • Phase 4 (subprocess visualization).
  • Phase 5 (live Jupyter / Colab cell).
  • Removing the use_jax_for_visualization parameter entirely. The
    parameter stays — only its default changes.

Reference

  • @PyAutoFit/autofit/non_linear/analysis/analysis.py:36-79 — site of the change
  • @PyAutoFit/autofit/non_linear/analysis/analysis.py:82-122fit_for_visualization docstring to update
  • complete.md entries jax-visualization, mge-jit-visualization (2026-04-19) — original Phase 0 ship notes
  • PyAutoPrompt/z_features/jax_visualization.md — Phase 2 in the sequenced roadmap

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions