Overview
The keystone of the ellipse_fitting_jax feature (step 7 of 7). Prompts 4-6 made every piece JAX-traceable; this prompt wires AnalysisEllipse so jax.jit(analysis.fit_from)(instance) works end to end. Adds use_jax: bool = True flag, a fit_from method, and a _register_fit_ellipse_pytrees() helper modelled on AnalysisImaging. Flips the prompt-2 workspace_test scripts (scripts/jax_likelihood_functions/ellipse/{fit.py,multipoles.py}) from TODO placeholders to real JIT round-trip + fitness._vmap batch checks against the locked-in numpy reference numbers (rtol=1e-4).
Plan
- Library (PyAutoGalaxy):
AnalysisEllipse.__init__ gains use_jax: bool = True (default True, mirroring AnalysisImaging); passed through super().__init__(use_jax=use_jax).
- Add a
fit_from(instance: af.ModelInstance) -> FitEllipse method mirroring AnalysisImaging.fit_from. Calls _register_fit_ellipse_pytrees() once when self._use_jax. Returns a single FitEllipse that wraps the sum across the instance's ellipse list (the current fit_list_from builds one FitEllipse per ellipse).
log_likelihood_function delegates to self.fit_from(instance).figure_of_merit. The existing fit_list_from stays — it's used by VisualizerEllipse.visualize.
_register_fit_ellipse_pytrees(): register FitEllipse (no_flatten=("dataset",)), Ellipse, EllipseMultipole, EllipseMultipoleScaled via register_instance_pytree. Idempotent (registry-guard like the imaging analysis).
- Workspace (autogalaxy_workspace_test):
- Replace the
# TODO(7_analysis_ellipse_jax.md) placeholders in scripts/jax_likelihood_functions/ellipse/{fit.py,multipoles.py} with the JIT round-trip + fitness._vmap blocks. Assert agreement to rtol=1e-4 against the numpy references locked in by prompt 2.
- Unit test (PyAutoGalaxy): one numpy-only test asserting
AnalysisEllipse(dataset, use_jax=False).log_likelihood_function(instance) is unchanged. JAX parity is verified at the workspace_test level per CLAUDE.md rule.
Detailed implementation plan
Affected Repositories
- PyAutoGalaxy (primary — library changes)
- autogalaxy_workspace_test (workspace follow-up — flip prompt-2 scripts to JIT)
Work Classification
Both (library first via /ship_library, then workspace via /ship_workspace).
Branch Survey
| Repository |
Current Branch |
Dirty? |
| ./PyAutoGalaxy |
main |
clean |
| ./autogalaxy_workspace_test |
main |
clean |
Suggested branch: feature/analysis-ellipse-jax (same name in both repos)
Worktree root: ~/Code/PyAutoLabs-wt/analysis-ellipse-jax/ (created later by /start_library)
Implementation Steps
Library phase (PyAutoGalaxy):
-
autogalaxy/ellipse/model/analysis.py:
- Add
use_jax: bool = True to AnalysisEllipse.__init__. Pass through super().__init__(use_jax=use_jax).
- Add
fit_from(self, instance: af.ModelInstance) -> FitEllipse. Build a single FitEllipse per ellipse via fit_list_from's pattern, BUT collapse to a single object whose .figure_of_merit equals sum(f.figure_of_merit for f in fit_list). Look at how AnalysisImaging.fit_from (line 127-144) handles this — it returns a single FitImaging, not a list, because there's only one dataset. For ellipse, there can be multiple ellipses per instance — the cleanest approach is probably to add a small FitEllipseSummed class in fit_ellipse.py that holds the list and exposes .figure_of_merit, .log_likelihood, .chi_squared as sums. Register THAT class as the pytree (with no_flatten=("dataset",)).
log_likelihood_function(instance) delegates to self.fit_from(instance).figure_of_merit. The existing fit_list_from keeps its old behaviour and is still called by VisualizerEllipse.visualize.
-
_register_fit_ellipse_pytrees() static method on AnalysisEllipse:
- Use a module-level set guard
_REGISTERED = False (or check via a try/except registry probe) for idempotency.
- Register:
FitEllipse with no_flatten=("dataset",) — dataset carries the interpolator state and shouldn't be flattened across JIT calls.
Ellipse via generic register_instance_pytree(Ellipse).
EllipseMultipole and EllipseMultipoleScaled via generic registration.
FitEllipseSummed (or whatever the aggregate type is called) with no_flatten=("dataset",).
- Optional helper: add
register_ellipses_pytree() in autogalaxy/analysis/jax_pytrees.py mirroring register_galaxies_pytree. Decide based on whether the generic registration suffices.
-
test_autogalaxy/ellipse/test_analysis.py: add ONE numpy-only test:
- Construct a tiny
imaging_7x7-style fixture (re-use existing) + a fixed Ellipse model.
analysis = ag.AnalysisEllipse(dataset=dataset, use_jax=False).
instance = model.instance_from_prior_medians().
- Compute
lh = analysis.log_likelihood_function(instance).
- Assert against a hard-coded reference value (capture via
print(repr(lh)) once, paste in).
- No JAX imports.
Workspace phase (autogalaxy_workspace_test):
-
scripts/jax_likelihood_functions/ellipse/fit.py:
- Below the existing numpy reference printing, add the JIT round-trip block modelled on
scripts/jax_likelihood_functions/imaging/lp.py:107-129:
from autofit.jax.pytrees import enable_pytrees, register_model
enable_pytrees()
register_model(model)
analysis_jit = ag.AnalysisEllipse(dataset=dataset, use_jax=True)
fit_jit_fn = jax.jit(analysis_jit.fit_from)
fit = fit_jit_fn(instance)
assert isinstance(fit.log_likelihood, jnp.ndarray), \
f"expected jax.Array, got {type(fit.log_likelihood)}"
np.testing.assert_allclose(
float(fit.log_likelihood), total_log_likelihood, rtol=1e-4
)
print("PASS: jit(fit_from) round-trip matches NumPy reference.")
- Add a
fitness._vmap block mirroring imaging/lp.py:74-98 — wrap autofit's Fitness in vmap with batch_size=50, evaluate, print VRAM/timing.
-
scripts/jax_likelihood_functions/ellipse/multipoles.py: same treatment as fit.py.
Key Files
PyAutoGalaxy/autogalaxy/ellipse/model/analysis.py — use_jax plumbing + fit_from + _register_fit_ellipse_pytrees.
PyAutoGalaxy/autogalaxy/ellipse/fit_ellipse.py — add FitEllipseSummed aggregate (or equivalent).
PyAutoGalaxy/autogalaxy/analysis/jax_pytrees.py — optional register_ellipses_pytree shim.
PyAutoGalaxy/test_autogalaxy/ellipse/test_analysis.py — one new numpy-only test.
autogalaxy_workspace_test/scripts/jax_likelihood_functions/ellipse/fit.py — flip from TODO to real JIT round-trip + vmap.
autogalaxy_workspace_test/scripts/jax_likelihood_functions/ellipse/multipoles.py — same.
Testing Approach
pytest test_autogalaxy/ -v — 870/870 (was 869, +1 new analysis test).
- Workspace_test scripts:
python scripts/jax_likelihood_functions/ellipse/{fit,multipoles}.py — JIT round-trip prints PASS, vmap block produces sensible timing.
bash run_all_scripts.sh from autogalaxy_workspace_test — full smoke green.
- Library-first merge gate: library PR merges first, workspace PR follows.
Original Prompt
Click to expand starting prompt
Step 7 of the ellipse-JAX series — the keystone. Prompts 4-6 made every piece JAX-traceable; this prompt wires AnalysisEllipse so jax.jit(analysis.fit_from)(instance) works end to end. The template is AnalysisImaging in @PyAutoGalaxy/autogalaxy/imaging/model/analysis.py:30-187, which has all the moving parts (use_jax: bool = True, _register_fit_imaging_pytrees(), super().__init__(use_jax=use_jax)).
Please:
-
In @PyAutoGalaxy/autogalaxy/ellipse/model/analysis.py:
- Add
use_jax: bool = True to AnalysisEllipse.__init__ and pass it through super().__init__(use_jax=use_jax). Default True matches AnalysisImaging.
- Add a
fit_from(instance: af.ModelInstance) -> FitEllipse method (today only fit_list_from exists). It should mirror AnalysisImaging.fit_from: build the FitEllipse (or list of FitEllipse collapsed into a sum-figure-of-merit wrapper), call _register_fit_ellipse_pytrees() once when self._use_jax, return the resulting FitEllipse.
- Update
log_likelihood_function to call self.fit_from(instance).figure_of_merit (or sum the list, matching the existing logic). The existing fit_list_from stays — it's used by VisualizerEllipse.visualize in @PyAutoGalaxy/autogalaxy/ellipse/model/visualizer.py:64.
-
Implement _register_fit_ellipse_pytrees() modelled on AnalysisImaging._register_fit_imaging_pytrees() (lines 168-187). Register:
FitEllipse with no_flatten=("dataset",). The interp cached property reconstructs from dataset so it's safe to skip flattening.
Ellipse (generic flatten via register_instance_pytree).
EllipseMultipole and EllipseMultipoleScaled (generic flatten).
- Reuse the helper from
autoarray.abstract_ndarray.register_instance_pytree. Make the function idempotent — match the registry-guard pattern in the imaging analysis.
- Place a thin shim
@PyAutoGalaxy/autogalaxy/analysis/jax_pytrees.py::register_ellipses_pytree() if useful to mirror register_galaxies_pytree, but it's optional — generic registration may be enough since Ellipses are stored on instance.ellipses as a list, not a custom container.
-
Flip the workspace_test scripts from prompt 2 (@autogalaxy_workspace_test/scripts/jax_likelihood_functions/ellipse/{fit.py, multipoles.py}) to exercise the JIT path:
- Replace the
# TODO(7_analysis_ellipse_jax.md) placeholder with the actual JIT round-trip block, modelled on @autogalaxy_workspace_test/scripts/jax_likelihood_functions/imaging/lp.py:107-129 — analysis_jit = ag.AnalysisEllipse(dataset=dataset, use_jax=True); fit_jit_fn = jax.jit(analysis_jit.fit_from); fit = fit_jit_fn(instance).
- Assert
np.testing.assert_allclose(float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4) against the numpy reference computed earlier in the script.
- Assert
isinstance(fit.log_likelihood, jnp.ndarray).
- Add a
fitness._vmap batch-evaluation block too, mirroring imaging/lp.py:74-98. This catches issues that only surface under jax.vmap.
-
Add a unit test in @PyAutoGalaxy/test_autogalaxy/ellipse/test_analysis.py that constructs AnalysisEllipse(dataset, use_jax=False) and asserts the existing numpy log_likelihood_function value is unchanged for a known instance. The JAX-path checks live in the workspace_test scripts — @PyAutoGalaxy/CLAUDE.md "Never use JAX in unit tests".
-
Test bar:
python -m pytest test_autogalaxy/ -v passes (no regressions in the imaging/interferometer paths).
- The two workspace_test scripts run cleanly and the JIT path matches the numpy reference to
rtol=1e-4.
bash run_all_scripts.sh from @autogalaxy_workspace_test/ is green.
After this lands, ellipse modeling can run inside Drawer / Nautilus / any other JAX-compatible search the same way AnalysisImaging does today. Note that Drawer itself still needs a small fix to pass use_jax_jit=True through to Fitness (out of scope for this series — see the z_features/ellipse_fitting_jax.md "see also" note).
Overview
The keystone of the
ellipse_fitting_jaxfeature (step 7 of 7). Prompts 4-6 made every piece JAX-traceable; this prompt wiresAnalysisEllipsesojax.jit(analysis.fit_from)(instance)works end to end. Addsuse_jax: bool = Trueflag, afit_frommethod, and a_register_fit_ellipse_pytrees()helper modelled onAnalysisImaging. Flips the prompt-2 workspace_test scripts (scripts/jax_likelihood_functions/ellipse/{fit.py,multipoles.py}) from TODO placeholders to real JIT round-trip +fitness._vmapbatch checks against the locked-in numpy reference numbers (rtol=1e-4).Plan
AnalysisEllipse.__init__gainsuse_jax: bool = True(default True, mirroringAnalysisImaging); passed throughsuper().__init__(use_jax=use_jax).fit_from(instance: af.ModelInstance) -> FitEllipsemethod mirroringAnalysisImaging.fit_from. Calls_register_fit_ellipse_pytrees()once whenself._use_jax. Returns a singleFitEllipsethat wraps the sum across the instance's ellipse list (the currentfit_list_frombuilds oneFitEllipseper ellipse).log_likelihood_functiondelegates toself.fit_from(instance).figure_of_merit. The existingfit_list_fromstays — it's used byVisualizerEllipse.visualize._register_fit_ellipse_pytrees(): registerFitEllipse(no_flatten=("dataset",)),Ellipse,EllipseMultipole,EllipseMultipoleScaledviaregister_instance_pytree. Idempotent (registry-guard like the imaging analysis).# TODO(7_analysis_ellipse_jax.md)placeholders inscripts/jax_likelihood_functions/ellipse/{fit.py,multipoles.py}with the JIT round-trip +fitness._vmapblocks. Assert agreement tortol=1e-4against the numpy references locked in by prompt 2.AnalysisEllipse(dataset, use_jax=False).log_likelihood_function(instance)is unchanged. JAX parity is verified at the workspace_test level perCLAUDE.mdrule.Detailed implementation plan
Affected Repositories
Work Classification
Both (library first via
/ship_library, then workspace via/ship_workspace).Branch Survey
Suggested branch:
feature/analysis-ellipse-jax(same name in both repos)Worktree root:
~/Code/PyAutoLabs-wt/analysis-ellipse-jax/(created later by/start_library)Implementation Steps
Library phase (PyAutoGalaxy):
autogalaxy/ellipse/model/analysis.py:use_jax: bool = TruetoAnalysisEllipse.__init__. Pass throughsuper().__init__(use_jax=use_jax).fit_from(self, instance: af.ModelInstance) -> FitEllipse. Build a singleFitEllipseper ellipse viafit_list_from's pattern, BUT collapse to a single object whose.figure_of_meritequalssum(f.figure_of_merit for f in fit_list). Look at howAnalysisImaging.fit_from(line 127-144) handles this — it returns a singleFitImaging, not a list, because there's only one dataset. For ellipse, there can be multiple ellipses per instance — the cleanest approach is probably to add a smallFitEllipseSummedclass infit_ellipse.pythat holds the list and exposes.figure_of_merit,.log_likelihood,.chi_squaredas sums. Register THAT class as the pytree (withno_flatten=("dataset",)).log_likelihood_function(instance)delegates toself.fit_from(instance).figure_of_merit. The existingfit_list_fromkeeps its old behaviour and is still called byVisualizerEllipse.visualize._register_fit_ellipse_pytrees()static method onAnalysisEllipse:_REGISTERED = False(or check via a try/except registry probe) for idempotency.FitEllipsewithno_flatten=("dataset",)—datasetcarries the interpolator state and shouldn't be flattened across JIT calls.Ellipsevia genericregister_instance_pytree(Ellipse).EllipseMultipoleandEllipseMultipoleScaledvia generic registration.FitEllipseSummed(or whatever the aggregate type is called) withno_flatten=("dataset",).register_ellipses_pytree()inautogalaxy/analysis/jax_pytrees.pymirroringregister_galaxies_pytree. Decide based on whether the generic registration suffices.test_autogalaxy/ellipse/test_analysis.py: add ONE numpy-only test:imaging_7x7-style fixture (re-use existing) + a fixedEllipsemodel.analysis = ag.AnalysisEllipse(dataset=dataset, use_jax=False).instance = model.instance_from_prior_medians().lh = analysis.log_likelihood_function(instance).print(repr(lh))once, paste in).Workspace phase (autogalaxy_workspace_test):
scripts/jax_likelihood_functions/ellipse/fit.py:scripts/jax_likelihood_functions/imaging/lp.py:107-129:fitness._vmapblock mirroringimaging/lp.py:74-98— wrap autofit'sFitnessin vmap withbatch_size=50, evaluate, print VRAM/timing.scripts/jax_likelihood_functions/ellipse/multipoles.py: same treatment as fit.py.Key Files
PyAutoGalaxy/autogalaxy/ellipse/model/analysis.py—use_jaxplumbing +fit_from+_register_fit_ellipse_pytrees.PyAutoGalaxy/autogalaxy/ellipse/fit_ellipse.py— addFitEllipseSummedaggregate (or equivalent).PyAutoGalaxy/autogalaxy/analysis/jax_pytrees.py— optionalregister_ellipses_pytreeshim.PyAutoGalaxy/test_autogalaxy/ellipse/test_analysis.py— one new numpy-only test.autogalaxy_workspace_test/scripts/jax_likelihood_functions/ellipse/fit.py— flip from TODO to real JIT round-trip + vmap.autogalaxy_workspace_test/scripts/jax_likelihood_functions/ellipse/multipoles.py— same.Testing Approach
pytest test_autogalaxy/ -v— 870/870 (was 869, +1 new analysis test).python scripts/jax_likelihood_functions/ellipse/{fit,multipoles}.py— JIT round-trip prints PASS, vmap block produces sensible timing.bash run_all_scripts.shfrom autogalaxy_workspace_test — full smoke green.Original Prompt
Click to expand starting prompt
Step 7 of the ellipse-JAX series — the keystone. Prompts 4-6 made every piece JAX-traceable; this prompt wires
AnalysisEllipsesojax.jit(analysis.fit_from)(instance)works end to end. The template isAnalysisImagingin@PyAutoGalaxy/autogalaxy/imaging/model/analysis.py:30-187, which has all the moving parts (use_jax: bool = True,_register_fit_imaging_pytrees(),super().__init__(use_jax=use_jax)).Please:
In
@PyAutoGalaxy/autogalaxy/ellipse/model/analysis.py:use_jax: bool = TruetoAnalysisEllipse.__init__and pass it throughsuper().__init__(use_jax=use_jax). DefaultTruematchesAnalysisImaging.fit_from(instance: af.ModelInstance) -> FitEllipsemethod (today onlyfit_list_fromexists). It should mirrorAnalysisImaging.fit_from: build theFitEllipse(or list ofFitEllipsecollapsed into a sum-figure-of-merit wrapper), call_register_fit_ellipse_pytrees()once whenself._use_jax, return the resultingFitEllipse.log_likelihood_functionto callself.fit_from(instance).figure_of_merit(or sum the list, matching the existing logic). The existingfit_list_fromstays — it's used byVisualizerEllipse.visualizein@PyAutoGalaxy/autogalaxy/ellipse/model/visualizer.py:64.Implement
_register_fit_ellipse_pytrees()modelled onAnalysisImaging._register_fit_imaging_pytrees()(lines 168-187). Register:FitEllipsewithno_flatten=("dataset",). Theinterpcached property reconstructs fromdatasetso it's safe to skip flattening.Ellipse(generic flatten viaregister_instance_pytree).EllipseMultipoleandEllipseMultipoleScaled(generic flatten).autoarray.abstract_ndarray.register_instance_pytree. Make the function idempotent — match the registry-guard pattern in the imaging analysis.@PyAutoGalaxy/autogalaxy/analysis/jax_pytrees.py::register_ellipses_pytree()if useful to mirrorregister_galaxies_pytree, but it's optional — generic registration may be enough sinceEllipses are stored oninstance.ellipsesas a list, not a custom container.Flip the workspace_test scripts from prompt 2 (
@autogalaxy_workspace_test/scripts/jax_likelihood_functions/ellipse/{fit.py, multipoles.py}) to exercise the JIT path:# TODO(7_analysis_ellipse_jax.md)placeholder with the actual JIT round-trip block, modelled on@autogalaxy_workspace_test/scripts/jax_likelihood_functions/imaging/lp.py:107-129—analysis_jit = ag.AnalysisEllipse(dataset=dataset, use_jax=True); fit_jit_fn = jax.jit(analysis_jit.fit_from); fit = fit_jit_fn(instance).np.testing.assert_allclose(float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4)against the numpy reference computed earlier in the script.isinstance(fit.log_likelihood, jnp.ndarray).fitness._vmapbatch-evaluation block too, mirroringimaging/lp.py:74-98. This catches issues that only surface underjax.vmap.Add a unit test in
@PyAutoGalaxy/test_autogalaxy/ellipse/test_analysis.pythat constructsAnalysisEllipse(dataset, use_jax=False)and asserts the existing numpylog_likelihood_functionvalue is unchanged for a known instance. The JAX-path checks live in the workspace_test scripts —@PyAutoGalaxy/CLAUDE.md"Never use JAX in unit tests".Test bar:
python -m pytest test_autogalaxy/ -vpasses (no regressions in the imaging/interferometer paths).rtol=1e-4.bash run_all_scripts.shfrom@autogalaxy_workspace_test/is green.After this lands, ellipse modeling can run inside
Drawer/Nautilus/ any other JAX-compatible search the same wayAnalysisImagingdoes today. Note thatDraweritself still needs a small fix to passuse_jax_jit=Truethrough toFitness(out of scope for this series — see thez_features/ellipse_fitting_jax.md"see also" note).