From cab85c972b7f724168a4aee622045950379289ac Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 10 May 2026 14:01:51 +0100 Subject: [PATCH] perf: batched transform_mapping_matrix in TransformerNUFFT (single nufft2d2 call) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous implementation looped over each source-pixel column in Python, scattering each into a (N_y, N_x) image and calling _forward_native separately. Under jax.jit this fully unrolled into n_src distinct nufft2d2 invocations in one trace, ballooning the JIT graph for pixelization-heavy fits — most visibly the rectangular_dspl JAX-likelihood script, where two source planes with mesh_shape=(30,30) gave 1800 inlined NUFFTs and caused 11+ minute slow-compile warnings (and OOM on 16 GB hosts). The replacement scatters all columns into a single (n_src, N_y, N_x) batched array and calls nufft2d2 once with batched f. nufftax natively supports the batched form. Numerical results unchanged (existing test_transformer.py passes); the JIT graph drops from O(n_src) NUFFTs to a single batched NUFFT call. Co-Authored-By: Claude Opus 4.7 (1M context) --- autoarray/operators/transformer.py | 53 +++++++++++++++++++----------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/autoarray/operators/transformer.py b/autoarray/operators/transformer.py index 3e9e80cd..f161d326 100644 --- a/autoarray/operators/transformer.py +++ b/autoarray/operators/transformer.py @@ -678,28 +678,43 @@ def transform_mapping_matrix(self, mapping_matrix, xp=np): """ Apply the forward NUFFT to each column of a mapping matrix. - Each column is scattered back to the native 2D image grid using the - mask's `slim_to_native_tuple`, then passed through `_forward_native`. + All columns are scattered into a single batched native-shape image + of shape ``(n_src, N_y, N_x)`` and passed through nufft2d2 in one + call (nufft2d2 supports batched ``f``). This avoids the + per-column Python loop that, under ``jax.jit``, would unroll into + ``n_src`` separate NUFFT invocations and blow up the JIT graph + for pixelization-heavy fits (notably double-source-plane). """ - n_uv = self.uv_wavelengths.shape[0] n_src = mapping_matrix.shape[1] - slim_to_native = self.real_space_mask.slim_to_native_tuple - native_shape = self.real_space_mask.shape_native + rows, cols = self.real_space_mask.slim_to_native_tuple + n_y, n_x = self.real_space_mask.shape_native if xp.__name__.startswith("jax"): import jax.numpy as jnp - out = jnp.zeros((n_uv, n_src), dtype=jnp.complex128) - for k in range(n_src): - image_2d = jnp.zeros(native_shape, dtype=mapping_matrix.dtype) - image_2d = image_2d.at[slim_to_native].set(mapping_matrix[:, k]) - vis = self._forward_native(image_2d, xp=xp) - out = out.at[:, k].set(vis) - return out - - out = np.zeros((n_uv, n_src), dtype=np.complex128) - for k in range(n_src): - image_2d = np.zeros(native_shape, dtype=mapping_matrix.dtype) - image_2d[slim_to_native] = mapping_matrix[:, k] - out[:, k] = self._forward_native(image_2d, xp=xp) - return out + mm_T = jnp.asarray(mapping_matrix).T.astype(jnp.complex128) + source_images = jnp.zeros((n_src, n_y, n_x), dtype=jnp.complex128) + source_images = source_images.at[ + jnp.arange(n_src)[:, None], + jnp.asarray(rows)[None, :], + jnp.asarray(cols)[None, :], + ].set(mm_T) + flipped = source_images[:, ::-1, :] + x = jnp.asarray(self._x) + y = jnp.asarray(self._y) + shift = jnp.asarray(self._shift) + # nufft2d2 returns shape (n_trans, M); transpose to (M, n_src). + vis_batched = ( + _nufftax.nufft2d2(x, y, flipped, self.eps, -1) * shift[None, :] + ) + return vis_batched.T + + mm_T = np.asarray(mapping_matrix).T.astype(np.complex128) + source_images = np.zeros((n_src, n_y, n_x), dtype=np.complex128) + source_images[np.arange(n_src)[:, None], rows[None, :], cols[None, :]] = mm_T + flipped = source_images[:, ::-1, :] + vis_batched = ( + _nufftax.nufft2d2(self._x, self._y, flipped, self.eps, -1) + * self._shift[None, :] + ) + return np.array(np.asarray(vis_batched).T)