Skip to content

revert: default use_jax_for_visualization to False (reverts #1278)#1280

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/jax-viz-default-broken
May 17, 2026
Merged

revert: default use_jax_for_visualization to False (reverts #1278)#1280
Jammy2211 merged 1 commit into
mainfrom
feature/jax-viz-default-broken

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Reverts PR #1278's change to Analysis.__init__'s use_jax_for_visualization default. The new default (Optional[bool] = None → follow use_jax) caused every Nautilus quick_update under use_jax=True to call jax.jit(self.fit_from)(instance=instance) where instance is a ModelInstance — a Python object that is not pytree-registered. JAX raises a TypeError trying to abstract it.

On real pipeline runs (e.g. z_projects/euclid/scripts/initial_lens_model.py) the exception was swallowed by the visualizer's outer guards; the visible symptom was source-plane FITS written all-zero and Einstein-radius posteriors collapsing to the full prior across every Euclid tile after the May 16 update. Reverting restores the previous behaviour: JIT visualization is opt-in only.

Closes the regression introduced in #1278.

API Changes

Analysis.__init__(use_jax_for_visualization) default flips back from Optional[bool] = None (follow use_jax) to bool = False. The sentinel-resolution block is dropped. Users wanting JIT visualization must pass use_jax_for_visualization=True explicitly. The existing use_jax_for_visualization=True and not use_jax warning remains. No public API surface was added — this is purely a default-value revert. See full details below.

Test Plan

  • pytest test_autofit/analysis/test_use_jax_for_visualization.py passes (6 passed, 1 skipped — verified locally)
  • Analysis(use_jax=True)._use_jax_for_visualization is False (default off)
  • Analysis(use_jax=True, use_jax_for_visualization=True)._use_jax_for_visualization is True (explicit opt-in still works)
  • Companion autofit_workspace_test PR ships next (workspace assertions updated to match new default — assert_use_jax_true_defaults_visualization_off replaces the three sentinel assertions)
Full API Changes (for automation & release notes)

Changed Signature

  • autofit.non_linear.analysis.analysis.Analysis.__init__(use_jax_for_visualization) — default flipped from Optional[bool] = None to bool = False.

Removed

  • Sentinel-resolution branch inside Analysis.__init__:
    if use_jax_for_visualization is None:
        use_jax_for_visualization = use_jax
    No longer needed once the literal default is False.

Migration

🤖 Generated with Claude Code

PR #1278 made `use_jax_for_visualization` default to follow `use_jax`
(`Optional[bool] = None` resolving to `use_jax`). That caused every
Nautilus quick-update under `use_jax=True` to evaluate
`jax.jit(self.fit_from)(instance=instance)` where `instance` is a
`ModelInstance` — a plain Python object that is not pytree-registered.
JAX raises a `TypeError` trying to abstract it.

On real pipeline runs (e.g. `z_projects/euclid/scripts/initial_lens_model.py`)
the exception was swallowed deeper in the visualizer's outer guards;
visible symptom was source-plane FITS files written all-zero and
posteriors collapsing to the full prior on Einstein radius across
every Euclid tile.

Reverts the default to `bool = False`. Drops the sentinel-resolution
block. Explicit opt-in (`use_jax_for_visualization=True`) and the
existing `use_jax_for_visualization=True and not use_jax` warning
both remain. No code in the PyAuto ecosystem relied on the implicit-
on behaviour introduced by #1278.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211
Copy link
Copy Markdown
Collaborator Author

Workspace PR: PyAutoLabs/autofit_workspace_test#29

@Jammy2211 Jammy2211 merged commit 2ddfa62 into main May 17, 2026
3 of 5 checks passed
@Jammy2211 Jammy2211 deleted the feature/jax-viz-default-broken branch May 17, 2026 18:16
Jammy2211 pushed a commit that referenced this pull request May 17, 2026
…inline)

The Tests workflow has been red on `main` since 2026-05-16 09:48 UTC,
the moment PR #1277 merged the `autofit[nss]` install extra. The 12
tests under `test_autofit/non_linear/search/nest/nss/` hit the script's
own ImportError guard on `af.NSS()` because the CI install step only
installs `[optional]`.

A naïve fix (combine `[optional,nss]` into one pip install) hit a real
dependency conflict — both extras pin `blackjax` but to different
versions:

  - [optional] pins `blackjax>=1.2.0` (mainline, PyPI)
  - [nss]      pins handley-lab fork @ ef45acd2 (~0.1.0b1.dev85+)

The fork carries the `blackjax.ns.adaptive.init` entrypoint that
mainline 1.2.x lacks, so it's not a "use the older one" merge — they're
genuinely incompatible. pip rightly refuses `[optional,nss]` with
`ResolutionImpossible`.

Resolution: split into parallel jobs.

- `unittest`: installs `[optional]` and runs the full test suite
  EXCLUDING `test_autofit/non_linear/search/nest/nss/`. This keeps
  `test_blackjax_nuts.py` (which requires mainline blackjax 1.2+) green.
- `unittest_nss` (new): installs `[nss]` alone in a fresh env and runs
  ONLY the NSS test suite. Matrix is python 3.12 / 3.13 to match
  the main `unittest` job.

Both jobs need to pass for the `Tests` workflow to be green. The
existing `nss_install_smoke.yml` workflow stays as the weekly-cron
upstream-drift canary on the [nss] git pins; this new job is the
PR/push-time test gate.

Fixes the failing Tests check observed on runs 25958838630 (#1277
merge), 25959492149 (#1278 merge), 25998822676 (#1280 revert merge),
and 25999018708 (the first attempt at this fix that revealed the
extras conflict).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant