Skip to content

feat: wire AnalysisEllipse for JAX via fit_from + pytree registration #411

@Jammy2211

Description

@Jammy2211

Overview

The keystone of the ellipse_fitting_jax feature (step 7 of 7). Prompts 4-6 made every piece JAX-traceable; this prompt wires AnalysisEllipse so jax.jit(analysis.fit_from)(instance) works end to end. Adds use_jax: bool = True flag, a fit_from method, and a _register_fit_ellipse_pytrees() helper modelled on AnalysisImaging. Flips the prompt-2 workspace_test scripts (scripts/jax_likelihood_functions/ellipse/{fit.py,multipoles.py}) from TODO placeholders to real JIT round-trip + fitness._vmap batch checks against the locked-in numpy reference numbers (rtol=1e-4).

Plan

  • Library (PyAutoGalaxy):
    • AnalysisEllipse.__init__ gains use_jax: bool = True (default True, mirroring AnalysisImaging); passed through super().__init__(use_jax=use_jax).
    • Add a fit_from(instance: af.ModelInstance) -> FitEllipse method mirroring AnalysisImaging.fit_from. Calls _register_fit_ellipse_pytrees() once when self._use_jax. Returns a single FitEllipse that wraps the sum across the instance's ellipse list (the current fit_list_from builds one FitEllipse per ellipse).
    • log_likelihood_function delegates to self.fit_from(instance).figure_of_merit. The existing fit_list_from stays — it's used by VisualizerEllipse.visualize.
    • _register_fit_ellipse_pytrees(): register FitEllipse (no_flatten=("dataset",)), Ellipse, EllipseMultipole, EllipseMultipoleScaled via register_instance_pytree. Idempotent (registry-guard like the imaging analysis).
  • Workspace (autogalaxy_workspace_test):
    • Replace the # TODO(7_analysis_ellipse_jax.md) placeholders in scripts/jax_likelihood_functions/ellipse/{fit.py,multipoles.py} with the JIT round-trip + fitness._vmap blocks. Assert agreement to rtol=1e-4 against the numpy references locked in by prompt 2.
  • Unit test (PyAutoGalaxy): one numpy-only test asserting AnalysisEllipse(dataset, use_jax=False).log_likelihood_function(instance) is unchanged. JAX parity is verified at the workspace_test level per CLAUDE.md rule.
Detailed implementation plan

Affected Repositories

  • PyAutoGalaxy (primary — library changes)
  • autogalaxy_workspace_test (workspace follow-up — flip prompt-2 scripts to JIT)

Work Classification

Both (library first via /ship_library, then workspace via /ship_workspace).

Branch Survey

Repository Current Branch Dirty?
./PyAutoGalaxy main clean
./autogalaxy_workspace_test main clean

Suggested branch: feature/analysis-ellipse-jax (same name in both repos)
Worktree root: ~/Code/PyAutoLabs-wt/analysis-ellipse-jax/ (created later by /start_library)

Implementation Steps

Library phase (PyAutoGalaxy):

  1. autogalaxy/ellipse/model/analysis.py:

    • Add use_jax: bool = True to AnalysisEllipse.__init__. Pass through super().__init__(use_jax=use_jax).
    • Add fit_from(self, instance: af.ModelInstance) -> FitEllipse. Build a single FitEllipse per ellipse via fit_list_from's pattern, BUT collapse to a single object whose .figure_of_merit equals sum(f.figure_of_merit for f in fit_list). Look at how AnalysisImaging.fit_from (line 127-144) handles this — it returns a single FitImaging, not a list, because there's only one dataset. For ellipse, there can be multiple ellipses per instance — the cleanest approach is probably to add a small FitEllipseSummed class in fit_ellipse.py that holds the list and exposes .figure_of_merit, .log_likelihood, .chi_squared as sums. Register THAT class as the pytree (with no_flatten=("dataset",)).
    • log_likelihood_function(instance) delegates to self.fit_from(instance).figure_of_merit. The existing fit_list_from keeps its old behaviour and is still called by VisualizerEllipse.visualize.
  2. _register_fit_ellipse_pytrees() static method on AnalysisEllipse:

    • Use a module-level set guard _REGISTERED = False (or check via a try/except registry probe) for idempotency.
    • Register:
      • FitEllipse with no_flatten=("dataset",)dataset carries the interpolator state and shouldn't be flattened across JIT calls.
      • Ellipse via generic register_instance_pytree(Ellipse).
      • EllipseMultipole and EllipseMultipoleScaled via generic registration.
      • FitEllipseSummed (or whatever the aggregate type is called) with no_flatten=("dataset",).
    • Optional helper: add register_ellipses_pytree() in autogalaxy/analysis/jax_pytrees.py mirroring register_galaxies_pytree. Decide based on whether the generic registration suffices.
  3. test_autogalaxy/ellipse/test_analysis.py: add ONE numpy-only test:

    • Construct a tiny imaging_7x7-style fixture (re-use existing) + a fixed Ellipse model.
    • analysis = ag.AnalysisEllipse(dataset=dataset, use_jax=False).
    • instance = model.instance_from_prior_medians().
    • Compute lh = analysis.log_likelihood_function(instance).
    • Assert against a hard-coded reference value (capture via print(repr(lh)) once, paste in).
    • No JAX imports.

Workspace phase (autogalaxy_workspace_test):

  1. scripts/jax_likelihood_functions/ellipse/fit.py:

    • Below the existing numpy reference printing, add the JIT round-trip block modelled on scripts/jax_likelihood_functions/imaging/lp.py:107-129:
      from autofit.jax.pytrees import enable_pytrees, register_model
      enable_pytrees()
      register_model(model)
      
      analysis_jit = ag.AnalysisEllipse(dataset=dataset, use_jax=True)
      fit_jit_fn = jax.jit(analysis_jit.fit_from)
      fit = fit_jit_fn(instance)
      
      assert isinstance(fit.log_likelihood, jnp.ndarray), \
          f"expected jax.Array, got {type(fit.log_likelihood)}"
      np.testing.assert_allclose(
          float(fit.log_likelihood), total_log_likelihood, rtol=1e-4
      )
      print("PASS: jit(fit_from) round-trip matches NumPy reference.")
    • Add a fitness._vmap block mirroring imaging/lp.py:74-98 — wrap autofit's Fitness in vmap with batch_size=50, evaluate, print VRAM/timing.
  2. scripts/jax_likelihood_functions/ellipse/multipoles.py: same treatment as fit.py.

Key Files

  • PyAutoGalaxy/autogalaxy/ellipse/model/analysis.pyuse_jax plumbing + fit_from + _register_fit_ellipse_pytrees.
  • PyAutoGalaxy/autogalaxy/ellipse/fit_ellipse.py — add FitEllipseSummed aggregate (or equivalent).
  • PyAutoGalaxy/autogalaxy/analysis/jax_pytrees.py — optional register_ellipses_pytree shim.
  • PyAutoGalaxy/test_autogalaxy/ellipse/test_analysis.py — one new numpy-only test.
  • autogalaxy_workspace_test/scripts/jax_likelihood_functions/ellipse/fit.py — flip from TODO to real JIT round-trip + vmap.
  • autogalaxy_workspace_test/scripts/jax_likelihood_functions/ellipse/multipoles.py — same.

Testing Approach

  • pytest test_autogalaxy/ -v — 870/870 (was 869, +1 new analysis test).
  • Workspace_test scripts: python scripts/jax_likelihood_functions/ellipse/{fit,multipoles}.py — JIT round-trip prints PASS, vmap block produces sensible timing.
  • bash run_all_scripts.sh from autogalaxy_workspace_test — full smoke green.
  • Library-first merge gate: library PR merges first, workspace PR follows.

Original Prompt

Click to expand starting prompt

Step 7 of the ellipse-JAX series — the keystone. Prompts 4-6 made every piece JAX-traceable; this prompt wires AnalysisEllipse so jax.jit(analysis.fit_from)(instance) works end to end. The template is AnalysisImaging in @PyAutoGalaxy/autogalaxy/imaging/model/analysis.py:30-187, which has all the moving parts (use_jax: bool = True, _register_fit_imaging_pytrees(), super().__init__(use_jax=use_jax)).

Please:

  1. In @PyAutoGalaxy/autogalaxy/ellipse/model/analysis.py:

    • Add use_jax: bool = True to AnalysisEllipse.__init__ and pass it through super().__init__(use_jax=use_jax). Default True matches AnalysisImaging.
    • Add a fit_from(instance: af.ModelInstance) -> FitEllipse method (today only fit_list_from exists). It should mirror AnalysisImaging.fit_from: build the FitEllipse (or list of FitEllipse collapsed into a sum-figure-of-merit wrapper), call _register_fit_ellipse_pytrees() once when self._use_jax, return the resulting FitEllipse.
    • Update log_likelihood_function to call self.fit_from(instance).figure_of_merit (or sum the list, matching the existing logic). The existing fit_list_from stays — it's used by VisualizerEllipse.visualize in @PyAutoGalaxy/autogalaxy/ellipse/model/visualizer.py:64.
  2. Implement _register_fit_ellipse_pytrees() modelled on AnalysisImaging._register_fit_imaging_pytrees() (lines 168-187). Register:

    • FitEllipse with no_flatten=("dataset",). The interp cached property reconstructs from dataset so it's safe to skip flattening.
    • Ellipse (generic flatten via register_instance_pytree).
    • EllipseMultipole and EllipseMultipoleScaled (generic flatten).
    • Reuse the helper from autoarray.abstract_ndarray.register_instance_pytree. Make the function idempotent — match the registry-guard pattern in the imaging analysis.
    • Place a thin shim @PyAutoGalaxy/autogalaxy/analysis/jax_pytrees.py::register_ellipses_pytree() if useful to mirror register_galaxies_pytree, but it's optional — generic registration may be enough since Ellipses are stored on instance.ellipses as a list, not a custom container.
  3. Flip the workspace_test scripts from prompt 2 (@autogalaxy_workspace_test/scripts/jax_likelihood_functions/ellipse/{fit.py, multipoles.py}) to exercise the JIT path:

    • Replace the # TODO(7_analysis_ellipse_jax.md) placeholder with the actual JIT round-trip block, modelled on @autogalaxy_workspace_test/scripts/jax_likelihood_functions/imaging/lp.py:107-129analysis_jit = ag.AnalysisEllipse(dataset=dataset, use_jax=True); fit_jit_fn = jax.jit(analysis_jit.fit_from); fit = fit_jit_fn(instance).
    • Assert np.testing.assert_allclose(float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4) against the numpy reference computed earlier in the script.
    • Assert isinstance(fit.log_likelihood, jnp.ndarray).
    • Add a fitness._vmap batch-evaluation block too, mirroring imaging/lp.py:74-98. This catches issues that only surface under jax.vmap.
  4. Add a unit test in @PyAutoGalaxy/test_autogalaxy/ellipse/test_analysis.py that constructs AnalysisEllipse(dataset, use_jax=False) and asserts the existing numpy log_likelihood_function value is unchanged for a known instance. The JAX-path checks live in the workspace_test scripts — @PyAutoGalaxy/CLAUDE.md "Never use JAX in unit tests".

  5. Test bar:

    • python -m pytest test_autogalaxy/ -v passes (no regressions in the imaging/interferometer paths).
    • The two workspace_test scripts run cleanly and the JIT path matches the numpy reference to rtol=1e-4.
    • bash run_all_scripts.sh from @autogalaxy_workspace_test/ is green.

After this lands, ellipse modeling can run inside Drawer / Nautilus / any other JAX-compatible search the same way AnalysisImaging does today. Note that Drawer itself still needs a small fix to pass use_jax_jit=True through to Fitness (out of scope for this series — see the z_features/ellipse_fitting_jax.md "see also" note).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions