Overview
Quick-update visualization currently compiles JAX separately from the search's likelihood function, paying a 20–30s penalty on the first quick update. cProfile shows 234 individual XLA compilations in the model_data access path — the profile methods dispatch to JAX individually via decorators rather than composing into one graph. Removing use_jax_for_visualization and having visualization reuse the search's cached JIT function should eliminate this cost entirely.
Plan
- Remove
use_jax_for_visualization flag — visualization follows use_jax. If the search uses JAX, visualization does too. One less knob for users.
- Cache
jax.jit(analysis.fit_from) on the Fitness instance alongside the existing _vmap. The search uses _vmap (vector → scalar) for likelihood; the quick update uses the cached _jit_fit_from (instance → FitImaging) for visualization. Both share XLA's compilation cache.
- Investigate the 5s steady-state cost when
use_jax_for_visualization=True — subsequent calls take 5–6s instead of sub-second, suggesting JIT cache misses (pytree structure changes, Python-float vs jax.Array leaves, or side-effect branching in fit_from).
- Wire up
manage_quick_update to pass a pre-computed FitImaging to a new visualize_quick_update(paths, fit) method, separating "compute the fit" from "render the fit" so the Fitness owns the JIT-cached computation.
- Extend the profiling script (
autolens_profiling/quick_update/imaging.py) to cover the unified JIT path and verify sub-second steady-state.
Profiling baselines (HST, 15k pixels, CPU)
| Scenario |
First call |
Subsequent |
use_jax_for_visualization=False (current) |
22s (234 JIT compiles) |
0.5s |
use_jax_for_visualization=True (separate JIT) |
31s (1 compile) |
5–6s |
| Target: reuse search's JIT |
0s (already compiled) |
<1s |
The 35s matplotlib rendering cost (subplot_fit 12-panel figure) is separate and tracked independently.
Detailed implementation plan
Affected Repositories
- PyAutoFit (primary —
Fitness, Analysis, pytrees)
- PyAutoGalaxy (
perform_quick_update, AnalysisDataset)
- PyAutoLens (
AnalysisImaging, Visualizer, fit_imaging_plots)
- autolens_workspace_test (update
modeling_visualization_jit.py)
- autolens_profiling (extend
quick_update/imaging.py)
Branch Survey
| Repository |
Current Branch |
Dirty? |
| PyAutoFit |
main |
live_viewer.py (uncommitted fix) |
| PyAutoGalaxy |
main |
CLAUDE.md only |
| PyAutoLens |
main |
CLAUDE.md, README |
Suggested branch: feature/unify-jax-visualization
Implementation Steps
Phase 1 — Remove use_jax_for_visualization
-
PyAutoFit/autofit/non_linear/analysis/analysis.py:
- Remove
use_jax_for_visualization from Analysis.__init__ signature and self._use_jax_for_visualization.
- Simplify
fit_for_visualization: when self._use_jax is True, use the JIT path; when False, use plain self.fit_from.
- Remove the
_jitted_fit_from lazy cache from fit_for_visualization (this moves to Fitness in Phase 2).
-
PyAutoGalaxy/autogalaxy/analysis/analysis/analysis.py:
- Remove
use_jax_for_visualization from any __init__ signatures.
- Update
perform_quick_update if it references the flag.
-
PyAutoLens/autolens/imaging/model/analysis.py:
- Remove from
AnalysisImaging.__init__ signature.
-
All workspace scripts that pass use_jax_for_visualization=True (grep for it in autolens_workspace_test, autogalaxy_workspace_test).
Phase 2 — Cache jax.jit(fit_from) on Fitness
-
PyAutoFit/autofit/non_linear/fitness.py:
- Add a
@cached_property _jit_fit_from that wraps jax.jit(self.analysis.fit_from), similar to the existing _vmap / _jit.
- Ensure
register_model(self.model) (from autofit.jax.pytrees) is called during Fitness.__init__ when use_jax_vmap or use_jax_jit is True — pytree registration must happen before the first JIT trace.
- In
manage_quick_update: when self._xp is JAX, build the instance via model.instance_from_vector(vector, xp=jnp), call self._jit_fit_from(instance) to get a FitImaging, then pass it to a new analysis.visualize_quick_update(paths, fit).
-
PyAutoGalaxy/autogalaxy/analysis/analysis/analysis.py:
- Add
visualize_quick_update(self, paths, fit) — takes a pre-computed fit object and renders it (critical curves + subplot_fit). Replaces the current perform_quick_update flow where the analysis both computes and renders.
-
PyAutoLens/autolens/imaging/model/visualizer.py:
- Update
Visualizer.visualize to accept a pre-computed fit instead of calling fit_for_visualization internally. The quick-update branch already receives a FitImaging.
Phase 3 — Investigate 5s steady-state
- Profile with
jax.make_jaxpr(analysis.fit_from)(instance) to check if retracing occurs.
- Audit
fit_from → tracer_via_instance_from → FitImaging.__init__ for Python branching on traced values.
- Check
model.instance_from_vector(vector, xp=jnp) produces consistent pytree structure across calls.
- Check
register_model timing — must happen before first _jit_fit_from call.
Phase 4 — Warmup and wiring
- In
Fitness.__init__, after pytree registration, optionally warm up _jit_fit_from with a prior-medians instance so the first quick update pays zero compile cost.
- Add logging:
"Warming up visualization JIT..." before the warmup call.
- Update
autolens_profiling/quick_update/imaging.py to profile the unified path.
Key Files
| File |
Changes |
PyAutoFit/autofit/non_linear/fitness.py |
Add _jit_fit_from cached_property, update manage_quick_update |
PyAutoFit/autofit/non_linear/analysis/analysis.py |
Remove use_jax_for_visualization, simplify fit_for_visualization |
PyAutoFit/autofit/jax/pytrees.py |
Ensure register_model is called at the right time |
PyAutoGalaxy/autogalaxy/analysis/analysis/analysis.py |
Add visualize_quick_update(paths, fit) |
PyAutoLens/autolens/imaging/model/visualizer.py |
Accept pre-computed fit in quick-update path |
PyAutoLens/autolens/imaging/model/analysis.py |
Remove flag from AnalysisImaging.__init__ |
autolens_profiling/quick_update/imaging.py |
Extend profiling for unified JIT path |
Testing
pytest test_autofit/non_linear/ — existing tests must pass
autolens_workspace_test/scripts/imaging/modeling_visualization_jit.py — update to remove use_jax_for_visualization=True (now implicit)
autolens_profiling/quick_update/imaging.py — verify sub-second steady-state with the unified path
- Smoke tests across all workspaces
Original Prompt
Click to expand starting prompt
Unify JAX visualization with likelihood JIT path
Remove use_jax_for_visualization and have quick-update visualization
reuse the same JIT-compiled function as the likelihood evaluation. The
current architecture compiles visualization separately, paying a 20–30s
penalty on the first quick update that the user perceives as "the quick
update is slow."
Problem statement
When use_jax=True and use_jax_for_visualization=False (the current
default), the quick-update visualization path calls fit_for_visualization
which runs analysis.fit_from(instance) through plain Python. But
the profile methods (Sersic, Isothermal, etc.) still dispatch to JAX
internally because analysis._use_jax=True. Each profile evaluation
triggers its own small jax.jit compilation via the
@aa.grid_dec.transform / @aa.grid_dec.to_array decorator chain.
cProfile of the first model_data access (15k masked pixels, HST):
234 calls to jax._src.pjit.cache_miss → 12.97s
216 calls to backend_compile_and_load → 9.59s
This is 234 individual XLA compilations instead of one composed graph.
Subsequent quick updates are fast (~0.5s) because the per-function JIT
cache is warm — but only for the same parameter shapes. The first
quick update in every process pays the full 20s+ cost.
Setting use_jax_for_visualization=True compiles a single
jax.jit(analysis.fit_from) — better in principle, but:
- It creates its own JIT function, separate from the search's
Fitness._vmap = jax.vmap(jax.jit(self.call)). No cache sharing.
- First compile takes ~31s (one big graph vs 234 small ones).
- Subsequent calls are ~5s, not sub-second — suggests JIT cache misses
(possibly pytree structure changes between calls).
- Requires
register_model(model) from autofit.jax.pytrees before
use — the ModelInstance / Galaxy types must be pytree-registered
or jax.jit rejects them.
Profiling numbers (HST, 15k pixels, CPU, no GPU)
| Scenario |
First call |
Subsequent |
use_jax_for_visualization=False (current default) |
22s (234 JIT compiles) |
0.5s |
use_jax_for_visualization=True (separate JIT) |
31s (1 compile) |
5–6s |
| Target: reuse search's JIT |
0s (already compiled) |
<1s |
The 35s matplotlib rendering cost (subplot_fit with 12 panels) is a
separate issue being tracked independently.
🤖 Generated with Claude Code
Overview
Quick-update visualization currently compiles JAX separately from the search's likelihood function, paying a 20–30s penalty on the first quick update. cProfile shows 234 individual XLA compilations in the
model_dataaccess path — the profile methods dispatch to JAX individually via decorators rather than composing into one graph. Removinguse_jax_for_visualizationand having visualization reuse the search's cached JIT function should eliminate this cost entirely.Plan
use_jax_for_visualizationflag — visualization followsuse_jax. If the search uses JAX, visualization does too. One less knob for users.jax.jit(analysis.fit_from)on theFitnessinstance alongside the existing_vmap. The search uses_vmap(vector → scalar) for likelihood; the quick update uses the cached_jit_fit_from(instance → FitImaging) for visualization. Both share XLA's compilation cache.use_jax_for_visualization=True— subsequent calls take 5–6s instead of sub-second, suggesting JIT cache misses (pytree structure changes, Python-float vs jax.Array leaves, or side-effect branching infit_from).manage_quick_updateto pass a pre-computed FitImaging to a newvisualize_quick_update(paths, fit)method, separating "compute the fit" from "render the fit" so the Fitness owns the JIT-cached computation.autolens_profiling/quick_update/imaging.py) to cover the unified JIT path and verify sub-second steady-state.Profiling baselines (HST, 15k pixels, CPU)
use_jax_for_visualization=False(current)use_jax_for_visualization=True(separate JIT)The 35s matplotlib rendering cost (
subplot_fit12-panel figure) is separate and tracked independently.Detailed implementation plan
Affected Repositories
Fitness,Analysis,pytrees)perform_quick_update,AnalysisDataset)AnalysisImaging,Visualizer,fit_imaging_plots)modeling_visualization_jit.py)quick_update/imaging.py)Branch Survey
Suggested branch:
feature/unify-jax-visualizationImplementation Steps
Phase 1 — Remove
use_jax_for_visualizationPyAutoFit/autofit/non_linear/analysis/analysis.py:use_jax_for_visualizationfromAnalysis.__init__signature andself._use_jax_for_visualization.fit_for_visualization: whenself._use_jaxis True, use the JIT path; when False, use plainself.fit_from._jitted_fit_fromlazy cache fromfit_for_visualization(this moves to Fitness in Phase 2).PyAutoGalaxy/autogalaxy/analysis/analysis/analysis.py:use_jax_for_visualizationfrom any__init__signatures.perform_quick_updateif it references the flag.PyAutoLens/autolens/imaging/model/analysis.py:AnalysisImaging.__init__signature.All workspace scripts that pass
use_jax_for_visualization=True(grep for it in autolens_workspace_test, autogalaxy_workspace_test).Phase 2 — Cache
jax.jit(fit_from)on FitnessPyAutoFit/autofit/non_linear/fitness.py:@cached_property_jit_fit_fromthat wrapsjax.jit(self.analysis.fit_from), similar to the existing_vmap/_jit.register_model(self.model)(fromautofit.jax.pytrees) is called duringFitness.__init__whenuse_jax_vmaporuse_jax_jitis True — pytree registration must happen before the first JIT trace.manage_quick_update: whenself._xpis JAX, build the instance viamodel.instance_from_vector(vector, xp=jnp), callself._jit_fit_from(instance)to get aFitImaging, then pass it to a newanalysis.visualize_quick_update(paths, fit).PyAutoGalaxy/autogalaxy/analysis/analysis/analysis.py:visualize_quick_update(self, paths, fit)— takes a pre-computed fit object and renders it (critical curves + subplot_fit). Replaces the currentperform_quick_updateflow where the analysis both computes and renders.PyAutoLens/autolens/imaging/model/visualizer.py:Visualizer.visualizeto accept a pre-computed fit instead of callingfit_for_visualizationinternally. The quick-update branch already receives a FitImaging.Phase 3 — Investigate 5s steady-state
jax.make_jaxpr(analysis.fit_from)(instance)to check if retracing occurs.fit_from→tracer_via_instance_from→FitImaging.__init__for Python branching on traced values.model.instance_from_vector(vector, xp=jnp)produces consistent pytree structure across calls.register_modeltiming — must happen before first_jit_fit_fromcall.Phase 4 — Warmup and wiring
Fitness.__init__, after pytree registration, optionally warm up_jit_fit_fromwith a prior-medians instance so the first quick update pays zero compile cost."Warming up visualization JIT..."before the warmup call.autolens_profiling/quick_update/imaging.pyto profile the unified path.Key Files
PyAutoFit/autofit/non_linear/fitness.py_jit_fit_fromcached_property, updatemanage_quick_updatePyAutoFit/autofit/non_linear/analysis/analysis.pyuse_jax_for_visualization, simplifyfit_for_visualizationPyAutoFit/autofit/jax/pytrees.pyregister_modelis called at the right timePyAutoGalaxy/autogalaxy/analysis/analysis/analysis.pyvisualize_quick_update(paths, fit)PyAutoLens/autolens/imaging/model/visualizer.pyPyAutoLens/autolens/imaging/model/analysis.pyAnalysisImaging.__init__autolens_profiling/quick_update/imaging.pyTesting
pytest test_autofit/non_linear/— existing tests must passautolens_workspace_test/scripts/imaging/modeling_visualization_jit.py— update to removeuse_jax_for_visualization=True(now implicit)autolens_profiling/quick_update/imaging.py— verify sub-second steady-state with the unified pathOriginal Prompt
Click to expand starting prompt
Unify JAX visualization with likelihood JIT path
Remove
use_jax_for_visualizationand have quick-update visualizationreuse the same JIT-compiled function as the likelihood evaluation. The
current architecture compiles visualization separately, paying a 20–30s
penalty on the first quick update that the user perceives as "the quick
update is slow."
Problem statement
When
use_jax=Trueanduse_jax_for_visualization=False(the currentdefault), the quick-update visualization path calls
fit_for_visualizationwhich runs
analysis.fit_from(instance)through plain Python. Butthe profile methods (Sersic, Isothermal, etc.) still dispatch to JAX
internally because
analysis._use_jax=True. Each profile evaluationtriggers its own small
jax.jitcompilation via the@aa.grid_dec.transform/@aa.grid_dec.to_arraydecorator chain.cProfile of the first
model_dataaccess (15k masked pixels, HST):This is 234 individual XLA compilations instead of one composed graph.
Subsequent quick updates are fast (~0.5s) because the per-function JIT
cache is warm — but only for the same parameter shapes. The first
quick update in every process pays the full 20s+ cost.
Setting
use_jax_for_visualization=Truecompiles a singlejax.jit(analysis.fit_from)— better in principle, but:Fitness._vmap = jax.vmap(jax.jit(self.call)). No cache sharing.(possibly pytree structure changes between calls).
register_model(model)fromautofit.jax.pytreesbeforeuse — the
ModelInstance/Galaxytypes must be pytree-registeredor
jax.jitrejects them.Profiling numbers (HST, 15k pixels, CPU, no GPU)
use_jax_for_visualization=False(current default)use_jax_for_visualization=True(separate JIT)The 35s matplotlib rendering cost (
subplot_fitwith 12 panels) is aseparate issue being tracked independently.
🤖 Generated with Claude Code