Skip to content

feat: pilot JAX-jitted visualization via use_jax_for_visualization flag #1227

@Jammy2211

Description

@Jammy2211

Overview

Currently, PyAutoFit only jits the log_likelihood_function via Fitness._jit /
_vmap; the visualization path that runs through abstract_search.perform_visualization
analysis.fit_from(instance=...) is pure NumPy. This task pilots JAX-jitted
visualization behind a new explicit use_jax_for_visualization flag on Analysis, starting
with a parametric MGE source in autolens_workspace_test to surface the failure modes
(pytree registration gaps, JAX ↔ matplotlib boundary, tracer_linear_light_profiles_to_light_profiles).
The endgame is for visualization to jit automatically when use_jax=True is set on the
Analysis, but the explicit flag keeps the pilot scoped.

Plan

  • Write a pilot visualization_jax.py in autolens_workspace_test (MGE parametric source
    only) that calls VisualizerImaging.visualize through a jitted fit, and collect the
    failure surface.
  • Add use_jax_for_visualization flag to Analysis.__init__; reconcile with the
    existing supports_jax_visualization capability property.
  • Route perform_visualization / SearchUpdater.visualize through a jitted
    fit_from when the flag is set.
  • Iterate through the JAX errors one-by-one (pytree registration, xp is np guards,
    matplotlib boundary), building on the recent pytree registration PR.
  • Validate the jitted output matches the numpy baseline for the MGE parametric source.
Detailed implementation plan

Affected Repositories

  • rhayes777/PyAutoFit (primary)
  • Jammy2211/PyAutoLens
  • Jammy2211/autolens_workspace_test

Work Classification

Library (with workspace follow-up in autolens_workspace_test)

Branch Survey

Repository Current Branch Dirty?
./PyAutoFit main clean
./PyAutoLens main clean
./autolens_workspace_test main clean

Suggested branch: feature/jax-visualization
Worktree root: ~/Code/PyAutoLabs-wt/jax-visualization/ (created later by /start_library)

Implementation Steps

Phase 1 — pilot (autolens_workspace_test)

  1. Add scripts/imaging/visualization_jax.py — structurally mirrors visualization.py
    but parametric MGE only, constructs AnalysisImaging with use_jax=True, use_jax_for_visualization=True, calls VisualizerImaging.visualize through the
    jitted path. Use al.model_util.mge_model_from (see
    scripts/jax_likelihood_functions/imaging/mge.py).

Phase 2 — interface polish (PyAutoFit)
2. autofit/non_linear/analysis/analysis.py — add use_jax_for_visualization: bool = False
kwarg to Analysis.__init__, store as self._use_jax_for_visualization.
3. Audit supports_jax_visualization property (currently a class-level capability flag):
decide whether it becomes derived from _use_jax_for_visualization or stays as a
separate "this analysis CAN do jitted viz" flag. Document the distinction.
4. autofit/non_linear/search/abstract_search.py — in perform_visualization (~line 976)
and the force_visualize_overwrite branch (~line 753), detect the new flag and
wrap the fit construction in jax.jit.
5. autofit/non_linear/fitness.py lines 138–141 — keep convert_jax logic consistent
with the new flag.

Phase 3 — iterate on JAX errors (PyAutoLens / PyAutoArray / PyAutoGalaxy)
6. Run the pilot, collect tracebacks. Expected first failure:
tracer = fit.tracer_linear_light_profiles_to_light_profiles at
autolens/imaging/model/visualizer.py:98 — derived profiles aren't pytree-registered.
7. Work through the failures one-by-one: pytree registration on the missing types,
if xp is np: guards where arrays flow out of JIT, and np.asarray at the
matplotlib boundary.

Phase 4 — validation
8. Confirm the jitted visualization_jax.py produces fit.png + tracer.png that
match the numpy baseline numerically. Add assertions analogous to the existing
visualization.py.

Key Files

  • PyAutoFit/autofit/non_linear/analysis/analysis.py — add flag
  • PyAutoFit/autofit/non_linear/search/abstract_search.py — wire JIT into viz path
  • PyAutoFit/autofit/non_linear/fitness.py — keep convert_jax consistent
  • PyAutoLens/autolens/imaging/model/visualizer.py — survey what needs JIT support
  • autolens_workspace_test/scripts/imaging/visualization_jax.py — new pilot script

Open Questions

  • Matplotlib at JIT boundary: do any plot calls happen inside traced code? If so,
    structural change required.
  • Should supports_jax_visualization survive or be replaced by the new flag?
  • Final form: auto-derive from use_jax=True, or keep as an explicit opt-in?

Original Prompt

Click to expand starting prompt

Visualization is performed during and after a model-fit in the method @PyAutoFit/autofit/non_linear/search/abstract_search.py:

if self.force_visualize_overwrite:
    self.perform_visualization(
        model=model,
        analysis=analysis,
        samples_summary=samples_summary,
        during_analysis=False,
    )

Currently, visualization does not use JAX, and does not use JAX jit to speed up calculations.

In @PyAutoLens/autolens/imaging/model/visualizer.py, we can see an example of how visualization is performed.
In particular, the method fit = analysis.fit_from(instance=instance) is called, which does not use
JAX because only the log_likelihood_function is jitted in @PyAutoFit/autofit/non_linear/fitness.py

I want autofit to support JAX jitted visualization, and for this to be done when use_jax is passed to Analysis,
but for now lets use a use_jax_for_visualization flag to make it explicit that we are only using JAX for visualization.

Recent updatrs havr added pytrees registration to autofit and the source code, look up autofits recent PR on this
and the examples in @autolens_workspace_test/scripts/jax_likelihood_functions.py imaging.

Therefore, can you assess how feasible this is and in @autolens_workspace_test/scripts/imaging, read visualization.py
and produce an example visualization_jax.py which tries to achieve this, calling only the
VisualizerImaging's visualize method for now. Lets only do this for a MGE parametric source, for simplicitiy,
I expect we'll first see some JAX issues due to certain internal calls not support JAX (e.g. tracer = fit.tracer_linear_light_profiles_to_light_profiles).
this require pytree registration.

That is fine, I want us to get to this point and then we can start to work through the issues one by one,
and make sure that the visualization is working with JAX.

Do you foresee an issue with the combination of JAX and matplotlib?

We should also in this plan make sure we are fully confident of the interface between PyAutoFit JAX, visualization
in the search and this layer, if this can be polished before doing a lot of work please do.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions