diff --git a/autolens/point/fit/abstract.py b/autolens/point/fit/abstract.py index 35f26078c..281d197f8 100644 --- a/autolens/point/fit/abstract.py +++ b/autolens/point/fit/abstract.py @@ -74,6 +74,8 @@ def __init__( """ self.name = name + if xp is not np and data._xp is not xp: + data = aa.Grid2DIrregular(values=data.array, xp=xp) self._data = data self._noise_map = noise_map self.tracer = tracer @@ -131,9 +133,12 @@ def magnifications_at_positions(self) -> aa.ArrayIrregular: use_multi_plane=use_multi_plane, plane_j=plane_j, ) - return abs( - od.magnification_2d_via_hessian_from(grid=self.positions, xp=self._xp) + magnifications = od.magnification_2d_via_hessian_from( + grid=self.positions, xp=self._xp ) + if self.use_jax: + magnifications = aa.ArrayIrregular(values=magnifications) + return abs(magnifications) @property def source_plane_coordinate(self) -> Tuple[float, float]: