diff --git a/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py b/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py index e2c8b27d..f228d863 100644 --- a/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py +++ b/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py @@ -554,6 +554,7 @@ class InterferometerSparseOperator: batch_size: int w_dtype: "jax.numpy.dtype" Khat: "jax.Array" # (2y, 2x), complex + col_offsets: "jax.Array" # (batch_size,) int32 """ Cached FFT operator state for fast interferometer curvature-matrix assembly. @@ -672,168 +673,120 @@ def from_nufft_precision_operator( batch_size=int(batch_size), w_dtype=nufft_precision_operator.dtype, Khat=Khat, + col_offsets=jnp.arange(int(batch_size), dtype=jnp.int32), ) - def curvature_matrix_via_sparse_operator_from( - self, - pix_indexes_for_sub_slim_index: np.ndarray, - pix_weights_for_sub_slim_index: np.ndarray, - pix_pixels: int, - fft_index_for_masked_pixel: np.ndarray, - ): + def apply_operator(self, Fbatch_flat): """ - Assemble the curvature matrix C = Aᵀ W A using sparse triplets and the FFT W~ operator. - - This method computes the mapper (pixelization) curvature matrix without - forming a dense mapping matrix. Instead, it uses fixed-length mapping - arrays (pixel indexes + weights per masked pixel) which define a sparse - mapping operator A in COO-like form. - - Algorithm outline - ----------------- - Let S be the number of source pixels and M be the number of rectangular - real-space pixels. - - 1) Build a fixed-length COO stream from the mapping arrays: - rows_rect[k] : rectangular pixel index (0..M-1) - cols[k] : source pixel index (0..S-1) - vals[k] : mapping weight - Invalid mappings (cols < 0 or cols >= S) are masked out. - - 2) Process source-pixel columns in blocks of width `batch_size`: - - Scatter the block’s source columns into a dense (M, batch_size) array F. - - Apply the W~ operator by FFT: - G = apply_W(F) - - Project back with Aᵀ via segmented reductions: - C[:, start:start+B] = Aᵀ G - - 3) Symmetrize the result: - C <- 0.5 * (C + Cᵀ) + Apply the interferometer W~ operator to a batch of vectors. + + Given an input matrix of shape (M, B) on the rectangular real-space + grid (M = y_shape * x_shape), this method computes + + G = W~ Fbatch_flat + + via FFT-based convolution with the cached `Khat` kernel: + + apply_W(F) = Re( IFFT( FFT(F_pad) * Khat ) )[:y, :x] + + where `F_pad` is the (2y, 2x) zero-padded version of `F`. Parameters ---------- - pix_indexes_for_sub_slim_index - Integer array of shape (M_masked, Pmax). - For each masked (slim) image pixel, stores the source-pixel indices - involved in the interpolation / mapping stencil. Invalid entries - should be set to -1. - pix_weights_for_sub_slim_index - Floating array of shape (M_masked, Pmax). - Weights corresponding to `pix_indexes_for_sub_slim_index`. - These should already include any oversampling normalisation (e.g. - sub-pixel fractions) required by the mapper. - pix_pixels - Number of source pixels, S. - fft_index_for_masked_pixel - Integer array of shape (M_masked,). - Maps each masked (slim) image pixel index to its corresponding - rectangular-grid flat index (0..M-1). This embeds the masked pixel - ordering into the FFT-friendly rectangular grid. + Fbatch_flat + Array of shape (M, B) representing B vectors on the rectangular grid. Returns ------- - jax.Array - Curvature matrix of shape (S, S), symmetric. + ndarray + Array of shape (M, B) equal to W~ applied to the batch. + """ + import jax.numpy as jnp + + y_shape, x_shape = self.y_shape, self.x_shape + M = y_shape * x_shape + Khat = self.Khat - Notes - ----- - - The inner computation is written in JAX and is intended to be jitted. - For best performance, keep `batch_size` fixed (static) across calls. - - Choosing `batch_size` as a divisor of S avoids a smaller tail block, - but correctness does not require that if the implementation masks the tail. - - This method uses FFTs on padded (2y, 2x) arrays; memory use scales with - batch_size and grid size. + B = Fbatch_flat.shape[1] + F_img = Fbatch_flat.T.reshape((B, y_shape, x_shape)) + F_pad = jnp.pad(F_img, ((0, 0), (0, y_shape), (0, x_shape))) + Fhat = jnp.fft.fft2(F_pad) + Ghat = Fhat * Khat[None, :, :] + G_pad = jnp.fft.ifft2(Ghat) + G = jnp.real(G_pad[:, :y_shape, :x_shape]) + return G.reshape((B, M)).T + + def curvature_matrix_diag_from(self, rows, cols, vals, *, S: int): """ + Compute the diagonal (mapper-mapper) curvature matrix block F = Aᵀ W~ A. + + This method mirrors `ImagingSparseOperator.curvature_matrix_diag_from` + and is the structural counterpart for the interferometer W~ operator. + + Given a sparse mapping operator A in COO triplet form (rows, cols, vals) + with `S` source pixels, it computes + + F = Aᵀ W~ A + in column blocks of width `batch_size`: + + 1) Assemble Fbatch = A[:, start:start+B] on the rectangular grid via scatter-add. + 2) Apply W~ to the block via FFT: Gbatch = W~(Fbatch). + 3) Project back with Aᵀ via segment_sum over `cols`. + + Parameters + ---------- + rows, cols, vals + COO triplets encoding the sparse mapping operator A. + - `rows`: rectangular-grid pixel indices (flat) in [0, M), shape (nnz,) + - `cols`: source pixel indices in [0, S), shape (nnz,) + - `vals`: mapping weights (interpolation + any sub-fraction normalisation), + shape (nnz,) + These should already be produced by `mapper.sparse_triplets_curvature`. + S + Number of source pixels / parameters for this mapper. + + Returns + ------- + ndarray + Curvature matrix of shape (S, S), symmetric. + """ import jax.numpy as jnp + from jax import lax from jax.ops import segment_sum - # ------------------------- - # Pull static quantities from state - # ------------------------- - y_shape = self.y_shape - x_shape = self.x_shape + rows = jnp.asarray(rows, dtype=jnp.int32) + cols = jnp.asarray(cols, dtype=jnp.int32) + vals = jnp.asarray(vals, dtype=jnp.float64) + M = self.M - batch_size = self.batch_size - Khat = self.Khat - w_dtype = self.w_dtype - - # ------------------------- - # Basic shape checks (NumPy side, safe) - # ------------------------- - M_masked, Pmax = pix_indexes_for_sub_slim_index.shape - S = int(pix_pixels) - - # ------------------------- - # JAX core (unchanged COO logic) - # ------------------------- - def _curvature_rect_jax( - pix_idx: jnp.ndarray, # (M_masked, Pmax) - pix_wts: jnp.ndarray, # (M_masked, Pmax) - rect_map: jnp.ndarray, # (M_masked,) - ) -> jnp.ndarray: - rect_map = jnp.asarray(rect_map) - - nnz_full = M_masked * Pmax - - # Flatten mapping arrays into a fixed-length COO stream - rows_mask = jnp.repeat( - jnp.arange(M_masked, dtype=jnp.int32), Pmax - ) # (nnz_full,) - cols = pix_idx.reshape((nnz_full,)).astype(jnp.int32) - vals = pix_wts.reshape((nnz_full,)).astype(w_dtype) - - # Validity mask - valid = (cols >= 0) & (cols < S) - - # Embed masked rows into rectangular rows - rows_rect = rect_map[rows_mask].astype(jnp.int32) - - # Make cols / vals safe - cols_safe = jnp.where(valid, cols, 0) - vals_safe = jnp.where(valid, vals, 0.0) - - def apply_operator_fft_batch(Fbatch_flat: jnp.ndarray) -> jnp.ndarray: - B = Fbatch_flat.shape[1] - F_img = Fbatch_flat.T.reshape((B, y_shape, x_shape)) - F_pad = jnp.pad( - F_img, ((0, 0), (0, y_shape), (0, x_shape)) - ) # (B,2y,2x) - Fhat = jnp.fft.fft2(F_pad) - Ghat = Fhat * Khat[None, :, :] - G_pad = jnp.fft.ifft2(Ghat) - G = jnp.real(G_pad[:, :y_shape, :x_shape]) - return G.reshape((B, M)).T # (M,B) - - def compute_block(start_col: int) -> jnp.ndarray: - in_block = (cols_safe >= start_col) & ( - cols_safe < start_col + batch_size - ) - in_use = valid & in_block + B = self.batch_size - bc = jnp.where(in_use, cols_safe - start_col, 0).astype(jnp.int32) - v = jnp.where(in_use, vals_safe, 0.0) + n_blocks = (S + B - 1) // B + S_pad = n_blocks * B - Fbatch = jnp.zeros((M, batch_size), dtype=w_dtype) - Fbatch = Fbatch.at[rows_rect, bc].add(v) + C0 = jnp.zeros((S, S_pad), dtype=jnp.float64) - Gbatch = apply_operator_fft_batch(Fbatch) - G_at_rows = Gbatch[rows_rect, :] + def body(block_i, C): + start = block_i * B - contrib = vals_safe[:, None] * G_at_rows - return segment_sum(contrib, cols_safe, num_segments=S) + in_block = (cols >= start) & (cols < (start + B)) + bc = jnp.where(in_block, cols - start, 0).astype(jnp.int32) + v = jnp.where(in_block, vals, 0.0) - # Assemble curvature - C = jnp.zeros((S, S), dtype=w_dtype) - for start in range(0, S, batch_size): - Cblock = compute_block(start) - width = min(batch_size, S - start) - C = C.at[:, start : start + width].set(Cblock[:, :width]) + F = jnp.zeros((M, B), dtype=jnp.float64) + F = F.at[rows, bc].add(v) - return 0.5 * (C + C.T) + G = self.apply_operator(F) # (M, B) - return _curvature_rect_jax( - pix_indexes_for_sub_slim_index, - pix_weights_for_sub_slim_index, - fft_index_for_masked_pixel, - ) + contrib = vals[:, None] * G[rows, :] + Cblock = segment_sum(contrib, cols, num_segments=S) # (S, B) + + width = jnp.minimum(B, jnp.maximum(0, S - start)) + Cblock = Cblock * (self.col_offsets < width)[None, :] + + return lax.dynamic_update_slice(C, Cblock, (0, start)) + + C_pad = lax.fori_loop(0, n_blocks, body, C0) + C = C_pad[:, :S] + return 0.5 * (C + C.T) diff --git a/autoarray/inversion/inversion/interferometer/sparse.py b/autoarray/inversion/inversion/interferometer/sparse.py index 4d5c66c9..5a4a0b45 100644 --- a/autoarray/inversion/inversion/interferometer/sparse.py +++ b/autoarray/inversion/inversion/interferometer/sparse.py @@ -7,13 +7,11 @@ AbstractInversionInterferometer, ) from autoarray.inversion.linear_obj.linear_obj import LinearObj -from autoarray.inversion.mesh.mesh.delaunay import Delaunay +from autoarray.inversion.mappers import mapper_util from autoarray.settings import Settings from autoarray.inversion.mappers.abstract import Mapper from autoarray.structures.visibilities import Visibilities -from autoarray.inversion.inversion.interferometer import inversion_interferometer_util - class InversionInterferometerSparse(AbstractInversionInterferometer): def __init__( @@ -99,35 +97,27 @@ def curvature_matrix_diag(self) -> np.ndarray: This function computes the diagonal terms of F using the sparse linear algebra formalism. """ - mapper = self.cls_list_from(cls=Mapper)[0] - # The interferometer sparse-operator curvature path - # (``InterferometerSparseOperator.curvature_matrix_via_sparse_operator_from``) - # has only been validated against ``Rectangular*`` meshes (single source - # pixel per image pixel, weight=1). When given a ``Delaunay`` mapper - # (three source pixels per image pixel via barycentric interpolation, - # weights summing to 1) the returned curvature matrix disagrees with the - # mapping path by ~34% Frobenius and the regularized matrix loses - # positive-definiteness, raising a numpy ``LinAlgError`` at the Cholesky - # call site in ``Inversion.log_det_curvature_reg_matrix_term``. Guard - # rather than silently mis-computing. - if isinstance(mapper.mesh, Delaunay): - raise NotImplementedError( - "Interferometer.apply_sparse_operator() is not implemented for " - "Delaunay-mesh pixelizations: the sparse curvature math has only " - "been validated against Rectangular meshes (Pmax=1, weight=1) " - "and is structurally wrong for barycentric-interpolated mappers " - "(Pmax=3). For Delaunay interferometer fits, use the plain DFT " - "or NUFFT path (i.e. omit the apply_sparse_operator step). " - "Tracking issue: https://github.com/PyAutoLabs/PyAutoArray/issues/314" - ) + # The interferometer W~ operator lives on the unmasked-extent rectangular + # grid (shape_native_masked_pixels), not the full native grid used by + # the imaging path. Build sparse triplets with extent-flat row indices + # so they match the operator's (M = extent_y * extent_x, B) scatter buffer. + rows, cols, vals = mapper_util.sparse_triplets_from( + pix_indexes_for_sub=mapper.pix_indexes_for_sub_slim_index, + pix_weights_for_sub=mapper.pix_weights_for_sub_slim_index, + slim_index_for_sub=mapper.slim_index_for_sub_slim_index, + fft_index_for_masked_pixel=self.mask.extent_index_for_masked_pixel, + sub_fraction_slim=mapper.over_sampler.sub_fraction.array, + return_rows_slim=False, + xp=self._xp, + ) - return self.dataset.sparse_operator.curvature_matrix_via_sparse_operator_from( - pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index, - pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index, - pix_pixels=self.linear_obj_list[0].params, - fft_index_for_masked_pixel=self.mask.fft_index_for_masked_pixel, + return self.dataset.sparse_operator.curvature_matrix_diag_from( + rows=rows, + cols=cols, + vals=vals, + S=mapper.params, ) @property diff --git a/autoarray/mask/mask_2d.py b/autoarray/mask/mask_2d.py index 66c08749..59c77204 100644 --- a/autoarray/mask/mask_2d.py +++ b/autoarray/mask/mask_2d.py @@ -712,6 +712,37 @@ def fft_index_for_masked_pixel(self) -> np.ndarray: # Convert (y, x) coordinates to flat row-major indices return (ys * width + xs).astype(np.int32) + @cached_property + def extent_index_for_masked_pixel(self) -> np.ndarray: + """ + Return a mapping from masked-pixel (slim) indices to flat indices on + the *unmasked-extent* rectangular FFT grid. + + The unmasked extent is the bounding box of unmasked pixels + (``shape_native_masked_pixels``). This index is the interferometer + counterpart of `fft_index_for_masked_pixel`, which uses the full native + grid: the interferometer W~ kernel is computed on the (extent_y, + extent_x) grid because it is translation-invariant and only the offsets + between pairs of unmasked pixels matter — the surrounding masked region + contributes nothing. + + Returns + ------- + np.ndarray + A 1D array of shape (N_unmasked,) of int32 values in + ``[0, extent_y * extent_x)``, suitable as row indices into the + (extent_y * extent_x, batch) scatter buffer used by + ``InterferometerSparseOperator.curvature_matrix_diag_from``. + """ + ys, xs = np.where(~self) + if ys.size == 0: + return np.zeros((0,), dtype=np.int32) + + y0, x0 = int(np.min(ys)), int(np.min(xs)) + extent_y, extent_x = self.shape_native_masked_pixels + width = int(extent_x) + return ((ys - y0) * width + (xs - x0)).astype(np.int32) + def trimmed_array_from(self, padded_array, image_shape) -> Array2D: """ Map a padded 1D array of values to its original 2D array, trimming all edge values. diff --git a/test_autoarray/inversion/inversion/interferometer/test_interferometer.py b/test_autoarray/inversion/inversion/interferometer/test_interferometer.py index d91fb07e..e9f08c8b 100644 --- a/test_autoarray/inversion/inversion/interferometer/test_interferometer.py +++ b/test_autoarray/inversion/inversion/interferometer/test_interferometer.py @@ -64,21 +64,7 @@ def test__fast_chi_squared( assert inversion.fast_chi_squared == pytest.approx(chi_squared, 1.0e-4) -def test__apply_sparse_operator__delaunay_mapper__raises_not_implemented(): - """``InversionInterferometerSparse.curvature_matrix_diag`` must raise a - ``NotImplementedError`` rather than silently mis-computing when paired - with a Delaunay mapper. - - The interferometer sparse-operator curvature path - (``InterferometerSparseOperator.curvature_matrix_via_sparse_operator_from``) - was only validated against ``Rectangular*`` meshes (single source pixel per - image pixel, weight 1). On a Delaunay mapper (three source pixels per image - pixel via barycentric interpolation, weights summing to 1) the returned - curvature matrix disagrees with the mapping path by ~34% Frobenius norm and - the regularized matrix loses positive-definiteness, raising a numpy - ``LinAlgError`` deep inside ``Inversion.log_det_curvature_reg_matrix_term``. - The guard at the entry point catches the mis-use early with a clear message. - """ +def test__curvature_matrix__interferometer_sparse_operator__delaunay__identical_to_mapping(): mask = aa.Mask2D( mask=[ [True, True, True, True, True, True, True], @@ -123,12 +109,19 @@ def test__apply_sparse_operator__delaunay_mapper__raises_not_implemented(): real_space_mask=mask, transformer_class=aa.TransformerDFT, ) + dataset_sparse = dataset.apply_sparse_operator(use_jax=False) - inversion = aa.Inversion( + inversion_sparse = aa.Inversion( dataset=dataset_sparse, linear_obj_list=[mapper], ) - with pytest.raises(NotImplementedError, match=r"Delaunay"): - _ = inversion.curvature_matrix + inversion_mapping = aa.Inversion( + dataset=dataset, + linear_obj_list=[mapper], + ) + + assert inversion_sparse.curvature_matrix == pytest.approx( + inversion_mapping.curvature_matrix, 1.0e-4 + )