Rewrap point fit data + magnifications for JAX backend parity#442
Merged
Rewrap point fit data + magnifications for JAX backend parity#442
Conversation
Two fixes in AbstractFitPoint so point-source fits work under use_jax=True: 1. __init__ rewraps observed positions (`data`) with the analysis backend when use_jax=True. Datasets are loaded as numpy-backed Grid2DIrregular from JSON, so without rewrapping the fit's data._xp stays np even when xp=jnp is passed in, and downstream deflection-grid propagation fails. 2. magnifications_at_positions rewraps the raw jax.Array returned by LensCalc.magnification_2d_via_hessian_from (which skips ArrayIrregular wrapping on the JAX path) so callers can use .array uniformly across backends without hitting AttributeError on jax tracers. Together with PyAutoArray Grid2DIrregular xp propagation, this unblocks full JIT tracing of FitPositionsSource. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Collaborator
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Two fixes in
AbstractFitPointso point-source fits work end-to-end underuse_jax=True:__init__rewraps observed positions (data) with the analysis backend whenuse_jax=True. Datasets are loaded as numpy-backedGrid2DIrregularfrom JSON, so without rewrapping the fit'sdata._xpstaysnpeven whenxp=jnpis passed in, and downstream deflection-grid propagation fails.magnifications_at_positionsrewraps the rawjax.Arrayreturned byLensCalc.magnification_2d_via_hessian_from(which skipsArrayIrregularwrapping on the JAX path) so callers can use.arrayuniformly across backends without hittingAttributeErroron jax tracers.Together with PyAutoLabs/PyAutoArray#287, this unblocks full JIT tracing of
FitPositionsSource.Upstream PR
API Changes
None — internal changes to
AbstractFitPoint.__init__andAbstractFitPoint.magnifications_at_positions. External behaviour in the numpy path is unchanged; only the JAX path now returns consistent wrapped types and works end-to-end.See full details below.
Test Plan
test_autolens/point/fit/passes (22 passed)autolens_workspace_developer/jax_profiling/point_source/source_plane.pyfull pipeline now JITs end-to-endFull API Changes (for automation & release notes)
Removed
None.
Added
None.
Renamed
None.
Changed Signature
None.
Changed Behaviour
AbstractFitPoint.__init__— whenxpis not numpy and the passeddata._xpdoes not match,datais re-wrapped in a newGrid2DIrregularwith the requestedxp. Numpy callers unaffected.AbstractFitPoint.magnifications_at_positions— on the JAX path the rawjax.Arrayreturned byLensCalc.magnification_2d_via_hessian_fromis now wrapped inArrayIrregularbefore takingabs(...), so the property consistently returns anArrayIrregularacross backends.Migration
None — callers that were numpy-backed are unchanged.
Follows up: PyAutoLabs/PyAutoArray#286
🤖 Generated with Claude Code