Skip to content

fix: AdaptImages galaxy-identity mismatch across jax.jit boundary#370

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/adapt-images-pytree-fix
Apr 26, 2026
Merged

fix: AdaptImages galaxy-identity mismatch across jax.jit boundary#370
Jammy2211 merged 1 commit into
mainfrom
feature/adapt-images-pytree-fix

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Fixes AdaptImages lookups crashing across the jax.jit boundary for any model that uses adapt images (Adapt regularization, RectangularAdaptImage, Delaunay, Hilbert). After Model.instance_unflatten rebuilds galaxies via self.cls(*constructor_arguments) they get fresh .id values, so galaxy_image_dict[galaxy] (keyed by Galaxy instances, hashed via int(self.id)) misses and downstream mesh_weight_map_from(adapt_data=None) crashes with AttributeError.

The fix carries a path-tuple list (AdaptImages.galaxy_path_list) parallel to the analysis-time galaxies list. AdaptImages rides as aux through the FitImaging pytree, so the path list survives unflatten unchanged. New helpers image_for_galaxy(galaxy, galaxies) / image_plane_mesh_grid_for_galaxy try the existing by-instance fast path first, then fall back to identity-positional lookup against galaxy_path_list and galaxy_name_image_dict (already path-keyed, also aux-stable).

The previous one-element fallback at to_inversion.py:428-442 (mesh grid) is dropped — it was a workaround for the same root cause and only happened to cover single-pixelization fits.

API Changes

AdaptImages gained a galaxy_path_list attribute and two new lookup helpers (image_for_galaxy, image_plane_mesh_grid_for_galaxy). updated_via_instance_from accepts an optional galaxies arg used to align the path list with the analysis-time galaxy ordering. GalaxiesToInversion accepts an optional path_galaxies arg, defaulting to its own galaxies (the autogalaxy single-plane case stays unchanged). Analysis.adapt_images_via_instance_from forwards an optional galaxies arg. All new arguments are optional and default to None — existing callers and tests pass unchanged. See full details below.

Test Plan

  • Existing test_autogalaxy/ suite remains green (840 pass locally).
  • New test__image_for_galaxy__resolves_after_galaxy_identity_changes and test__image_plane_mesh_grid_for_galaxy__resolves_after_galaxy_identity_changes in test_autogalaxy/analysis/test_adapt_images.py exercise the post-unflatten lookup with fresh-Galaxy instances at the same paths.
  • Follow-up PR on autogalaxy_workspace_test re-enables rectangular.py (adapt variant), rectangular_mge.py, delaunay.py, delaunay_mge.py under jax_likelihood_functions/imaging/.
Full API Changes (for automation & release notes)

Added

  • AdaptImages.galaxy_path_list: Optional[List[str]] — path-tuple list parallel to the analysis-time galaxies list, populated by updated_via_instance_from.
  • AdaptImages.image_for_galaxy(galaxy, galaxies=None) -> Optional[Array2D] — JIT-safe lookup helper. Tries galaxy_image_dict[galaxy] first, falls back to galaxy_name_image_dict[path] via positional alignment.
  • AdaptImages.image_plane_mesh_grid_for_galaxy(galaxy, galaxies=None) -> Optional[Grid2DIrregular] — companion helper for the mesh-grid path.

Changed Signature

  • AdaptImages.__init__ — new optional kwarg galaxy_path_list: Optional[List[str]] = None.
  • AdaptImages.updated_via_instance_from(instance, mask=None) -> updated_via_instance_from(instance, mask=None, galaxies=None). When galaxies is provided, the resulting galaxy_path_list is aligned with that list.
  • GalaxiesToInversion.__init__ — new optional kwarg path_galaxies: Optional[List[Galaxy]] = None. Defaults to the constructor's galaxies argument; only autolens currently supplies a different value (the full tracer.galaxies list).
  • Analysis.adapt_images_via_instance_from(instance) -> adapt_images_via_instance_from(instance, galaxies=None).

Changed Behaviour

  • GalaxiesToInversion.image_plane_mesh_grid_list no longer falls back to the single-mesh-grid value when the by-instance lookup misses. Lookup now resolves correctly via image_plane_mesh_grid_for_galaxy for any number of pixelized galaxies.
  • GalaxiesToInversion.mapper_galaxy_dict adapt-image lookup uses image_for_galaxy instead of galaxy_image_dict[galaxy] directly.

Migration

  • Internal — no migration needed for downstream code. Analysis.fit_from is updated to pass the freshly-built galaxies list to adapt_images_via_instance_from. External callers that already pass instance continue to work.

🤖 Generated with Claude Code

Resolves dict-keyed-by-Galaxy-instance crash for adapt-image models after
jax.jit unflatten produces fresh Galaxy objects with new .id values.
Adds galaxy_path_list parallel to the analysis-time galaxies list and
two helpers (image_for_galaxy, image_plane_mesh_grid_for_galaxy) that
fall back to path-tuple keying via galaxy_name_image_dict. Drops the
single-mesh-grid fallback at to_inversion.py:428-442 — replaced by the
proper fix.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211
Copy link
Copy Markdown
Collaborator Author

Workspace PR: PyAutoLabs/autogalaxy_workspace_test#12

@Jammy2211 Jammy2211 merged commit 5d1f3d5 into main Apr 26, 2026
5 checks passed
@Jammy2211 Jammy2211 deleted the feature/adapt-images-pytree-fix branch April 26, 2026 19:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant