Conversation
…r jax.jit Adds a positional fallback in ``GalaxiesToInversion.image_plane_mesh_grid_list`` for the case where ``adapt_images.galaxy_image_plane_mesh_grid_dict[galaxy]`` returns None because the Galaxy instance is a fresh object produced by a pytree unflatten cycle rather than the one stored as a dict key. When the dict contains exactly one entry, that entry is returned by insertion order — correct for the one-pixelised-source case (Delaunay / Hilbert image-mesh fits) which is what the fit-imaging-pytree-delaunay PoC exercises. This is a deliberately narrow workaround; the principled fix is to attach a pytree_token to Galaxy (mirroring LightProfileLinear) so __hash__/__eq__ survive pytree round-trips. That fix is tracked separately in admin_jammy/prompt/autolens/galaxy_pytree_token.md. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2 tasks
Collaborator
Author
|
Workspace PR: PyAutoLabs/autolens_workspace_test#38 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Enables
jax.jit(analysis.fit_from)to succeed for models with aDelaunaypixelization source +
Hilbertimage-mesh +AdaptSplitregularization. Withoutthis change the JIT path raises
AttributeError: 'NoneType' object has no attribute 'array'insideGalaxiesToInversion.image_plane_mesh_grid_list, because theGalaxykeys inadapt_images.galaxy_image_plane_mesh_grid_dictare theoriginal Python objects and the traced
Galaxyis a fresh instance produced bypytree unflatten — identity-keyed dict lookup fails.
API Changes
Behaviour-only change in
GalaxiesToInversion.image_plane_mesh_grid_list: whenthe galaxy-identity-keyed mesh-grid dict lookup returns
Noneand the dictcontains exactly one entry, that entry is now used by insertion order.
No public signature changes. NumPy path is unaffected (it never hit the fallback).
See full details below.
Test Plan
pytest test_autogalaxy/ -q --ignore=test_autogalaxy/aggregator— 836 passedscripts/jax_likelihood_functions/imaging/delaunay_pytree.pyPASSes —NumPy and JIT log_likelihoods agree to rtol ~4e-15
Upstream PR
None — this ships as a standalone library change.
Follow-up
The one-entry fallback is a deliberately narrow workaround. It is correct for the
current PoC (single pixelised source) but silently picks the wrong grid when there
are two or more pixelised sources. The principled fix — attach
pytree_tokentoGalaxymirroringLightProfileLinearso identity survives pytree round-trips —is filed at
admin_jammy/prompt/autolens/galaxy_pytree_token.mdin the workspacerepo and should land before any multi-source-pixelization JIT variant is added to
the queue (e.g.
fit_imaging_pytree_delaunay_mge.md— item 10 in the queue).Full API Changes (for automation & release notes)
Changed Behaviour
GalaxiesToInversion.image_plane_mesh_grid_list—when
adapt_images.galaxy_image_plane_mesh_grid_dict[galaxy]returnsNone(because the
Galaxyinstance has been reconstructed by a pytree unflattencycle and no longer has identity-equality with the dict's keys), the method
now falls back to the single value in the dict when the dict has length 1.
Previously this case appended
Noneto the result list, leading to a crashdownstream when the mesh grid was accessed.
Rationale
jax.jit(analysis.fit_from).galaxy_pytree_token.mdfollow-up prompt.🤖 Generated with Claude Code