Propagate xp through Grid2DIrregular.grid_2d_via_deflection_grid_from#287
Merged
Propagate xp through Grid2DIrregular.grid_2d_via_deflection_grid_from#287
Conversation
Previously constructed the new Grid2DIrregular without passing xp, so the resulting grid defaulted to _xp=np even when called on a JAX-backed receiver. Downstream calls to xp.square on the values (which were JAX tracers under JIT) raised TracerArrayConversionError. Pass xp=self._xp so the result inherits the receiver's backend. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This was referenced Apr 18, 2026
Collaborator
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Grid2DIrregular.grid_2d_via_deflection_grid_frompreviously constructed the result grid without propagatingxp=self._xp, so a JAX-backed receiver produced a numpy-backed result whose values were still JAX tracers. Downstreamself._xp.square(...)then callednp.squareon a tracer and raisedTracerArrayConversionError, blocking JIT of point-source pipelines (see #286).This one-line change passes
xp=self._xpso the new grid inherits the receiver's backend, and adds a unit test covering both the np and jnp round-trips.API Changes
None — internal change only. Public signature unchanged; behaviour is only affected when the receiver's
_xpis not numpy (in which case previously the result silently downgraded to numpy).See full details below.
Test Plan
test__grid_2d_via_deflection_grid_from__propagates_xpcovers numpy + JAX round-tripsTracerArrayConversionErrorbefore, see stacked PR on PyAutoLens)Full API Changes (for automation & release notes)
Removed
None.
Added
None (unit test added, no public API additions).
Renamed
None.
Changed Signature
None.
Changed Behaviour
Grid2DIrregular.grid_2d_via_deflection_grid_from— the returned grid now has_xpmatching the receiver's_xp. Previously always defaulted to numpy, silently downgrading JAX-backed callers.Migration
None — callers that were numpy-backed are unchanged. Callers that were JAX-backed and worked around the downgrade (e.g. by manually re-wrapping the result with
xp=jnp) can drop the workaround.Follows up: #286
🤖 Generated with Claude Code