Skip to content

perf: unify JAX visualization with likelihood JIT path #1296

@Jammy2211

Description

@Jammy2211

Overview

Quick-update visualization currently compiles JAX separately from the search's likelihood function, paying a 20–30s penalty on the first quick update. cProfile shows 234 individual XLA compilations in the model_data access path — the profile methods dispatch to JAX individually via decorators rather than composing into one graph. Removing use_jax_for_visualization and having visualization reuse the search's cached JIT function should eliminate this cost entirely.

Plan

  • Remove use_jax_for_visualization flag — visualization follows use_jax. If the search uses JAX, visualization does too. One less knob for users.
  • Cache jax.jit(analysis.fit_from) on the Fitness instance alongside the existing _vmap. The search uses _vmap (vector → scalar) for likelihood; the quick update uses the cached _jit_fit_from (instance → FitImaging) for visualization. Both share XLA's compilation cache.
  • Investigate the 5s steady-state cost when use_jax_for_visualization=True — subsequent calls take 5–6s instead of sub-second, suggesting JIT cache misses (pytree structure changes, Python-float vs jax.Array leaves, or side-effect branching in fit_from).
  • Wire up manage_quick_update to pass a pre-computed FitImaging to a new visualize_quick_update(paths, fit) method, separating "compute the fit" from "render the fit" so the Fitness owns the JIT-cached computation.
  • Extend the profiling script (autolens_profiling/quick_update/imaging.py) to cover the unified JIT path and verify sub-second steady-state.

Profiling baselines (HST, 15k pixels, CPU)

Scenario First call Subsequent
use_jax_for_visualization=False (current) 22s (234 JIT compiles) 0.5s
use_jax_for_visualization=True (separate JIT) 31s (1 compile) 5–6s
Target: reuse search's JIT 0s (already compiled) <1s

The 35s matplotlib rendering cost (subplot_fit 12-panel figure) is separate and tracked independently.

Detailed implementation plan

Affected Repositories

  • PyAutoFit (primary — Fitness, Analysis, pytrees)
  • PyAutoGalaxy (perform_quick_update, AnalysisDataset)
  • PyAutoLens (AnalysisImaging, Visualizer, fit_imaging_plots)
  • autolens_workspace_test (update modeling_visualization_jit.py)
  • autolens_profiling (extend quick_update/imaging.py)

Branch Survey

Repository Current Branch Dirty?
PyAutoFit main live_viewer.py (uncommitted fix)
PyAutoGalaxy main CLAUDE.md only
PyAutoLens main CLAUDE.md, README

Suggested branch: feature/unify-jax-visualization

Implementation Steps

Phase 1 — Remove use_jax_for_visualization

  1. PyAutoFit/autofit/non_linear/analysis/analysis.py:

    • Remove use_jax_for_visualization from Analysis.__init__ signature and self._use_jax_for_visualization.
    • Simplify fit_for_visualization: when self._use_jax is True, use the JIT path; when False, use plain self.fit_from.
    • Remove the _jitted_fit_from lazy cache from fit_for_visualization (this moves to Fitness in Phase 2).
  2. PyAutoGalaxy/autogalaxy/analysis/analysis/analysis.py:

    • Remove use_jax_for_visualization from any __init__ signatures.
    • Update perform_quick_update if it references the flag.
  3. PyAutoLens/autolens/imaging/model/analysis.py:

    • Remove from AnalysisImaging.__init__ signature.
  4. All workspace scripts that pass use_jax_for_visualization=True (grep for it in autolens_workspace_test, autogalaxy_workspace_test).

Phase 2 — Cache jax.jit(fit_from) on Fitness

  1. PyAutoFit/autofit/non_linear/fitness.py:

    • Add a @cached_property _jit_fit_from that wraps jax.jit(self.analysis.fit_from), similar to the existing _vmap / _jit.
    • Ensure register_model(self.model) (from autofit.jax.pytrees) is called during Fitness.__init__ when use_jax_vmap or use_jax_jit is True — pytree registration must happen before the first JIT trace.
    • In manage_quick_update: when self._xp is JAX, build the instance via model.instance_from_vector(vector, xp=jnp), call self._jit_fit_from(instance) to get a FitImaging, then pass it to a new analysis.visualize_quick_update(paths, fit).
  2. PyAutoGalaxy/autogalaxy/analysis/analysis/analysis.py:

    • Add visualize_quick_update(self, paths, fit) — takes a pre-computed fit object and renders it (critical curves + subplot_fit). Replaces the current perform_quick_update flow where the analysis both computes and renders.
  3. PyAutoLens/autolens/imaging/model/visualizer.py:

    • Update Visualizer.visualize to accept a pre-computed fit instead of calling fit_for_visualization internally. The quick-update branch already receives a FitImaging.

Phase 3 — Investigate 5s steady-state

  1. Profile with jax.make_jaxpr(analysis.fit_from)(instance) to check if retracing occurs.
  2. Audit fit_fromtracer_via_instance_fromFitImaging.__init__ for Python branching on traced values.
  3. Check model.instance_from_vector(vector, xp=jnp) produces consistent pytree structure across calls.
  4. Check register_model timing — must happen before first _jit_fit_from call.

Phase 4 — Warmup and wiring

  1. In Fitness.__init__, after pytree registration, optionally warm up _jit_fit_from with a prior-medians instance so the first quick update pays zero compile cost.
  2. Add logging: "Warming up visualization JIT..." before the warmup call.
  3. Update autolens_profiling/quick_update/imaging.py to profile the unified path.

Key Files

File Changes
PyAutoFit/autofit/non_linear/fitness.py Add _jit_fit_from cached_property, update manage_quick_update
PyAutoFit/autofit/non_linear/analysis/analysis.py Remove use_jax_for_visualization, simplify fit_for_visualization
PyAutoFit/autofit/jax/pytrees.py Ensure register_model is called at the right time
PyAutoGalaxy/autogalaxy/analysis/analysis/analysis.py Add visualize_quick_update(paths, fit)
PyAutoLens/autolens/imaging/model/visualizer.py Accept pre-computed fit in quick-update path
PyAutoLens/autolens/imaging/model/analysis.py Remove flag from AnalysisImaging.__init__
autolens_profiling/quick_update/imaging.py Extend profiling for unified JIT path

Testing

  • pytest test_autofit/non_linear/ — existing tests must pass
  • autolens_workspace_test/scripts/imaging/modeling_visualization_jit.py — update to remove use_jax_for_visualization=True (now implicit)
  • autolens_profiling/quick_update/imaging.py — verify sub-second steady-state with the unified path
  • Smoke tests across all workspaces

Original Prompt

Click to expand starting prompt

Unify JAX visualization with likelihood JIT path

Remove use_jax_for_visualization and have quick-update visualization
reuse the same JIT-compiled function as the likelihood evaluation. The
current architecture compiles visualization separately, paying a 20–30s
penalty on the first quick update that the user perceives as "the quick
update is slow."

Problem statement

When use_jax=True and use_jax_for_visualization=False (the current
default), the quick-update visualization path calls fit_for_visualization
which runs analysis.fit_from(instance) through plain Python. But
the profile methods (Sersic, Isothermal, etc.) still dispatch to JAX
internally because analysis._use_jax=True. Each profile evaluation
triggers its own small jax.jit compilation via the
@aa.grid_dec.transform / @aa.grid_dec.to_array decorator chain.

cProfile of the first model_data access (15k masked pixels, HST):

234 calls to jax._src.pjit.cache_miss  → 12.97s
216 calls to backend_compile_and_load  →  9.59s

This is 234 individual XLA compilations instead of one composed graph.
Subsequent quick updates are fast (~0.5s) because the per-function JIT
cache is warm — but only for the same parameter shapes. The first
quick update in every process pays the full 20s+ cost.

Setting use_jax_for_visualization=True compiles a single
jax.jit(analysis.fit_from) — better in principle, but:

  1. It creates its own JIT function, separate from the search's
    Fitness._vmap = jax.vmap(jax.jit(self.call)). No cache sharing.
  2. First compile takes ~31s (one big graph vs 234 small ones).
  3. Subsequent calls are ~5s, not sub-second — suggests JIT cache misses
    (possibly pytree structure changes between calls).
  4. Requires register_model(model) from autofit.jax.pytrees before
    use — the ModelInstance / Galaxy types must be pytree-registered
    or jax.jit rejects them.

Profiling numbers (HST, 15k pixels, CPU, no GPU)

Scenario First call Subsequent
use_jax_for_visualization=False (current default) 22s (234 JIT compiles) 0.5s
use_jax_for_visualization=True (separate JIT) 31s (1 compile) 5–6s
Target: reuse search's JIT 0s (already compiled) <1s

The 35s matplotlib rendering cost (subplot_fit with 12 panels) is a
separate issue being tracked independently.

🤖 Generated with Claude Code

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions