feat: register pytrees for AnalysisInterferometer#376
Merged
Conversation
Mirror AnalysisImaging's pytree registration on the interferometer side so jax.jit(fit_from) can flatten its FitInterferometer return value. Extract the Galaxies flatten/unflatten block (~12 lines, identical across analyses) into autogalaxy.analysis.jax_pytrees.register_galaxies_pytree() so imaging and interferometer share the non-trivial logic without duplication. End-to-end JIT verification (jax.jit(analysis.fit_from) round-trip with NumPy parity) will land in the downstream autogalaxy_workspace_test_jax_likelihood_interferometer task, which is explicitly gated on this PR. Refs #375 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This was referenced May 8, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds JAX pytree registration for
AnalysisInterferometersojax.jit(analysis.fit_from)can flatten itsFitInterferometerreturn value — mirrors the existingAnalysisImagingpattern shipped in #364.The
Galaxiesflatten/unflatten block (~12 lines, identical between imaging and interferometer) is lifted intoautogalaxy.analysis.jax_pytrees.register_galaxies_pytree()so both analyses share the non-trivial logic without duplication. Imaging is refactored to call the shared helper; the per-analysisregister_instance_pytree(Fit*, ...)andregister_instance_pytree(DatasetModel)lines stay inline at each call site so it's still obvious from each method what is being registered.AnalysisQuantityandAnalysisEllipseare out of scope — quantity has no autolens JAX-likelihood equivalent yet (verification path needs separate design) and ellipse is structurally different (returnsList[FitEllipse]with noGalaxiesaggregate, inheritsaf.Analysisdirectly). Both deferred to follow-up issues.End-to-end JIT verification (
jax.jit(analysis.fit_from)round-trip with NumPy parity) lands in the downstreamautogalaxy_workspace_test_jax_likelihood_interferometertask, which is explicitly gated on this PR.Closes #375
API Changes
autogalaxy.analysis.jax_pytrees.register_galaxies_pytree()— shared helper that registersGalaxies(alistsubclass) as a JAX pytree with custom flatten/unflatten. Idempotent.AnalysisInterferometer._register_fit_interferometer_pytrees()— static method registeringFitInterferometer,DatasetModel, andGalaxies. Called fromfit_fromunder the existingself._use_jaxgate.AnalysisImaging._register_fit_imaging_pytrees— body collapsed from 41 to 11 lines by delegating theGalaxiesregistration to the shared helper. Behaviour unchanged.See full details below.
Test Plan
pytest test_autogalaxy/imaging/model/test_analysis_imaging.py— passes.pytest test_autogalaxy/interferometer/model/test_analysis_interferometer.py— passes._register_fit_interferometer_pytrees()runs without error, is idempotent on repeated calls, and ends withFitInterferometer,DatasetModel, andGalaxiesin_pytree_registered_classes.autogalaxy_workspace_test_jax_likelihood_interferometertask, which is gated on this PR).Full API Changes (for automation & release notes)
Added
autogalaxy.analysis.jax_pytrees(new module) exposingregister_galaxies_pytree() -> None. RegistersGalaxiesas a JAX pytree with custom flatten/unflatten. Idempotent via_pytree_registered_classes.autogalaxy.interferometer.model.analysis.AnalysisInterferometer._register_fit_interferometer_pytrees()— staticmethod registeringFitInterferometer (no_flatten=("dataset", "adapt_images", "settings")),DatasetModel, andGalaxies(via the shared helper).AnalysisInterferometer.fit_fromnow calls_register_fit_interferometer_pytrees()before constructing the fit, gated onself._use_jax(mirrors imaging line 146-147).Changed Behaviour
AnalysisImaging._register_fit_imaging_pytrees— internal refactor only. Body collapsed from 41 to 11 lines; theGalaxiesregistration block now delegates toregister_galaxies_pytree(). Net behaviour identical.Migration
use_jax=TrueonAnalysisInterferometer(default). Existing NumPy callers are unaffected.🤖 Generated with Claude Code