From bf900af4b36217c01fca96bbb001c0fbc50c242b Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sat, 18 Apr 2026 20:48:58 +0100 Subject: [PATCH] Rewrap point fit data + magnifications for JAX backend parity Two fixes in AbstractFitPoint so point-source fits work under use_jax=True: 1. __init__ rewraps observed positions (`data`) with the analysis backend when use_jax=True. Datasets are loaded as numpy-backed Grid2DIrregular from JSON, so without rewrapping the fit's data._xp stays np even when xp=jnp is passed in, and downstream deflection-grid propagation fails. 2. magnifications_at_positions rewraps the raw jax.Array returned by LensCalc.magnification_2d_via_hessian_from (which skips ArrayIrregular wrapping on the JAX path) so callers can use .array uniformly across backends without hitting AttributeError on jax tracers. Together with PyAutoArray Grid2DIrregular xp propagation, this unblocks full JIT tracing of FitPositionsSource. Co-Authored-By: Claude Opus 4.7 --- autolens/point/fit/abstract.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/autolens/point/fit/abstract.py b/autolens/point/fit/abstract.py index 35f26078c..281d197f8 100644 --- a/autolens/point/fit/abstract.py +++ b/autolens/point/fit/abstract.py @@ -74,6 +74,8 @@ def __init__( """ self.name = name + if xp is not np and data._xp is not xp: + data = aa.Grid2DIrregular(values=data.array, xp=xp) self._data = data self._noise_map = noise_map self.tracer = tracer @@ -131,9 +133,12 @@ def magnifications_at_positions(self) -> aa.ArrayIrregular: use_multi_plane=use_multi_plane, plane_j=plane_j, ) - return abs( - od.magnification_2d_via_hessian_from(grid=self.positions, xp=self._xp) + magnifications = od.magnification_2d_via_hessian_from( + grid=self.positions, xp=self._xp ) + if self.use_jax: + magnifications = aa.ArrayIrregular(values=magnifications) + return abs(magnifications) @property def source_plane_coordinate(self) -> Tuple[float, float]: