Skip to content

Rewrap point fit data + magnifications for JAX backend parity#442

Merged
Jammy2211 merged 1 commit intomainfrom
feature/grid-irregular-xp-propagation
Apr 18, 2026
Merged

Rewrap point fit data + magnifications for JAX backend parity#442
Jammy2211 merged 1 commit intomainfrom
feature/grid-irregular-xp-propagation

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Two fixes in AbstractFitPoint so point-source fits work end-to-end under use_jax=True:

  1. __init__ rewraps observed positions (data) with the analysis backend when use_jax=True. Datasets are loaded as numpy-backed Grid2DIrregular from JSON, so without rewrapping the fit's data._xp stays np even when xp=jnp is passed in, and downstream deflection-grid propagation fails.
  2. magnifications_at_positions rewraps the raw jax.Array returned by LensCalc.magnification_2d_via_hessian_from (which skips ArrayIrregular wrapping on the JAX path) so callers can use .array uniformly across backends without hitting AttributeError on jax tracers.

Together with PyAutoLabs/PyAutoArray#287, this unblocks full JIT tracing of FitPositionsSource.

Upstream PR

API Changes

None — internal changes to AbstractFitPoint.__init__ and AbstractFitPoint.magnifications_at_positions. External behaviour in the numpy path is unchanged; only the JAX path now returns consistent wrapped types and works end-to-end.
See full details below.

Test Plan

  • Full PyAutoLens test suite passes (246 passed)
  • test_autolens/point/fit/ passes (22 passed)
  • autolens_workspace_developer/jax_profiling/point_source/source_plane.py full pipeline now JITs end-to-end
Full API Changes (for automation & release notes)

Removed

None.

Added

None.

Renamed

None.

Changed Signature

None.

Changed Behaviour

  • AbstractFitPoint.__init__ — when xp is not numpy and the passed data._xp does not match, data is re-wrapped in a new Grid2DIrregular with the requested xp. Numpy callers unaffected.
  • AbstractFitPoint.magnifications_at_positions — on the JAX path the raw jax.Array returned by LensCalc.magnification_2d_via_hessian_from is now wrapped in ArrayIrregular before taking abs(...), so the property consistently returns an ArrayIrregular across backends.

Migration

None — callers that were numpy-backed are unchanged.

Follows up: PyAutoLabs/PyAutoArray#286

🤖 Generated with Claude Code

Two fixes in AbstractFitPoint so point-source fits work under use_jax=True:

1. __init__ rewraps observed positions (`data`) with the analysis backend
   when use_jax=True. Datasets are loaded as numpy-backed Grid2DIrregular
   from JSON, so without rewrapping the fit's data._xp stays np even when
   xp=jnp is passed in, and downstream deflection-grid propagation fails.

2. magnifications_at_positions rewraps the raw jax.Array returned by
   LensCalc.magnification_2d_via_hessian_from (which skips ArrayIrregular
   wrapping on the JAX path) so callers can use .array uniformly across
   backends without hitting AttributeError on jax tracers.

Together with PyAutoArray Grid2DIrregular xp propagation, this unblocks
full JIT tracing of FitPositionsSource.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@Jammy2211
Copy link
Copy Markdown
Collaborator Author

@Jammy2211 Jammy2211 merged commit c355660 into main Apr 18, 2026
5 checks passed
@Jammy2211 Jammy2211 deleted the feature/grid-irregular-xp-propagation branch April 18, 2026 20:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release Tracked for next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant