Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ class InterferometerSparseOperator:
batch_size: int
w_dtype: "jax.numpy.dtype"
Khat: "jax.Array" # (2y, 2x), complex
col_offsets: "jax.Array" # (batch_size,) int32
"""
Cached FFT operator state for fast interferometer curvature-matrix assembly.

Expand Down Expand Up @@ -672,168 +673,120 @@ def from_nufft_precision_operator(
batch_size=int(batch_size),
w_dtype=nufft_precision_operator.dtype,
Khat=Khat,
col_offsets=jnp.arange(int(batch_size), dtype=jnp.int32),
)

def curvature_matrix_via_sparse_operator_from(
self,
pix_indexes_for_sub_slim_index: np.ndarray,
pix_weights_for_sub_slim_index: np.ndarray,
pix_pixels: int,
fft_index_for_masked_pixel: np.ndarray,
):
def apply_operator(self, Fbatch_flat):
"""
Assemble the curvature matrix C = Aᵀ W A using sparse triplets and the FFT W~ operator.

This method computes the mapper (pixelization) curvature matrix without
forming a dense mapping matrix. Instead, it uses fixed-length mapping
arrays (pixel indexes + weights per masked pixel) which define a sparse
mapping operator A in COO-like form.

Algorithm outline
-----------------
Let S be the number of source pixels and M be the number of rectangular
real-space pixels.

1) Build a fixed-length COO stream from the mapping arrays:
rows_rect[k] : rectangular pixel index (0..M-1)
cols[k] : source pixel index (0..S-1)
vals[k] : mapping weight
Invalid mappings (cols < 0 or cols >= S) are masked out.

2) Process source-pixel columns in blocks of width `batch_size`:
- Scatter the block’s source columns into a dense (M, batch_size) array F.
- Apply the W~ operator by FFT:
G = apply_W(F)
- Project back with Aᵀ via segmented reductions:
C[:, start:start+B] = Aᵀ G

3) Symmetrize the result:
C <- 0.5 * (C + Cᵀ)
Apply the interferometer W~ operator to a batch of vectors.

Given an input matrix of shape (M, B) on the rectangular real-space
grid (M = y_shape * x_shape), this method computes

G = W~ Fbatch_flat

via FFT-based convolution with the cached `Khat` kernel:

apply_W(F) = Re( IFFT( FFT(F_pad) * Khat ) )[:y, :x]

where `F_pad` is the (2y, 2x) zero-padded version of `F`.

Parameters
----------
pix_indexes_for_sub_slim_index
Integer array of shape (M_masked, Pmax).
For each masked (slim) image pixel, stores the source-pixel indices
involved in the interpolation / mapping stencil. Invalid entries
should be set to -1.
pix_weights_for_sub_slim_index
Floating array of shape (M_masked, Pmax).
Weights corresponding to `pix_indexes_for_sub_slim_index`.
These should already include any oversampling normalisation (e.g.
sub-pixel fractions) required by the mapper.
pix_pixels
Number of source pixels, S.
fft_index_for_masked_pixel
Integer array of shape (M_masked,).
Maps each masked (slim) image pixel index to its corresponding
rectangular-grid flat index (0..M-1). This embeds the masked pixel
ordering into the FFT-friendly rectangular grid.
Fbatch_flat
Array of shape (M, B) representing B vectors on the rectangular grid.

Returns
-------
jax.Array
Curvature matrix of shape (S, S), symmetric.
ndarray
Array of shape (M, B) equal to W~ applied to the batch.
"""
import jax.numpy as jnp

y_shape, x_shape = self.y_shape, self.x_shape
M = y_shape * x_shape
Khat = self.Khat

Notes
-----
- The inner computation is written in JAX and is intended to be jitted.
For best performance, keep `batch_size` fixed (static) across calls.
- Choosing `batch_size` as a divisor of S avoids a smaller tail block,
but correctness does not require that if the implementation masks the tail.
- This method uses FFTs on padded (2y, 2x) arrays; memory use scales with
batch_size and grid size.
B = Fbatch_flat.shape[1]
F_img = Fbatch_flat.T.reshape((B, y_shape, x_shape))
F_pad = jnp.pad(F_img, ((0, 0), (0, y_shape), (0, x_shape)))
Fhat = jnp.fft.fft2(F_pad)
Ghat = Fhat * Khat[None, :, :]
G_pad = jnp.fft.ifft2(Ghat)
G = jnp.real(G_pad[:, :y_shape, :x_shape])
return G.reshape((B, M)).T

def curvature_matrix_diag_from(self, rows, cols, vals, *, S: int):
"""
Compute the diagonal (mapper-mapper) curvature matrix block F = Aᵀ W~ A.

This method mirrors `ImagingSparseOperator.curvature_matrix_diag_from`
and is the structural counterpart for the interferometer W~ operator.

Given a sparse mapping operator A in COO triplet form (rows, cols, vals)
with `S` source pixels, it computes

F = Aᵀ W~ A

in column blocks of width `batch_size`:

1) Assemble Fbatch = A[:, start:start+B] on the rectangular grid via scatter-add.
2) Apply W~ to the block via FFT: Gbatch = W~(Fbatch).
3) Project back with Aᵀ via segment_sum over `cols`.

Parameters
----------
rows, cols, vals
COO triplets encoding the sparse mapping operator A.
- `rows`: rectangular-grid pixel indices (flat) in [0, M), shape (nnz,)
- `cols`: source pixel indices in [0, S), shape (nnz,)
- `vals`: mapping weights (interpolation + any sub-fraction normalisation),
shape (nnz,)
These should already be produced by `mapper.sparse_triplets_curvature`.
S
Number of source pixels / parameters for this mapper.

Returns
-------
ndarray
Curvature matrix of shape (S, S), symmetric.
"""
import jax.numpy as jnp
from jax import lax
from jax.ops import segment_sum

# -------------------------
# Pull static quantities from state
# -------------------------
y_shape = self.y_shape
x_shape = self.x_shape
rows = jnp.asarray(rows, dtype=jnp.int32)
cols = jnp.asarray(cols, dtype=jnp.int32)
vals = jnp.asarray(vals, dtype=jnp.float64)

M = self.M
batch_size = self.batch_size
Khat = self.Khat
w_dtype = self.w_dtype

# -------------------------
# Basic shape checks (NumPy side, safe)
# -------------------------
M_masked, Pmax = pix_indexes_for_sub_slim_index.shape
S = int(pix_pixels)

# -------------------------
# JAX core (unchanged COO logic)
# -------------------------
def _curvature_rect_jax(
pix_idx: jnp.ndarray, # (M_masked, Pmax)
pix_wts: jnp.ndarray, # (M_masked, Pmax)
rect_map: jnp.ndarray, # (M_masked,)
) -> jnp.ndarray:
rect_map = jnp.asarray(rect_map)

nnz_full = M_masked * Pmax

# Flatten mapping arrays into a fixed-length COO stream
rows_mask = jnp.repeat(
jnp.arange(M_masked, dtype=jnp.int32), Pmax
) # (nnz_full,)
cols = pix_idx.reshape((nnz_full,)).astype(jnp.int32)
vals = pix_wts.reshape((nnz_full,)).astype(w_dtype)

# Validity mask
valid = (cols >= 0) & (cols < S)

# Embed masked rows into rectangular rows
rows_rect = rect_map[rows_mask].astype(jnp.int32)

# Make cols / vals safe
cols_safe = jnp.where(valid, cols, 0)
vals_safe = jnp.where(valid, vals, 0.0)

def apply_operator_fft_batch(Fbatch_flat: jnp.ndarray) -> jnp.ndarray:
B = Fbatch_flat.shape[1]
F_img = Fbatch_flat.T.reshape((B, y_shape, x_shape))
F_pad = jnp.pad(
F_img, ((0, 0), (0, y_shape), (0, x_shape))
) # (B,2y,2x)
Fhat = jnp.fft.fft2(F_pad)
Ghat = Fhat * Khat[None, :, :]
G_pad = jnp.fft.ifft2(Ghat)
G = jnp.real(G_pad[:, :y_shape, :x_shape])
return G.reshape((B, M)).T # (M,B)

def compute_block(start_col: int) -> jnp.ndarray:
in_block = (cols_safe >= start_col) & (
cols_safe < start_col + batch_size
)
in_use = valid & in_block
B = self.batch_size

bc = jnp.where(in_use, cols_safe - start_col, 0).astype(jnp.int32)
v = jnp.where(in_use, vals_safe, 0.0)
n_blocks = (S + B - 1) // B
S_pad = n_blocks * B

Fbatch = jnp.zeros((M, batch_size), dtype=w_dtype)
Fbatch = Fbatch.at[rows_rect, bc].add(v)
C0 = jnp.zeros((S, S_pad), dtype=jnp.float64)

Gbatch = apply_operator_fft_batch(Fbatch)
G_at_rows = Gbatch[rows_rect, :]
def body(block_i, C):
start = block_i * B

contrib = vals_safe[:, None] * G_at_rows
return segment_sum(contrib, cols_safe, num_segments=S)
in_block = (cols >= start) & (cols < (start + B))
bc = jnp.where(in_block, cols - start, 0).astype(jnp.int32)
v = jnp.where(in_block, vals, 0.0)

# Assemble curvature
C = jnp.zeros((S, S), dtype=w_dtype)
for start in range(0, S, batch_size):
Cblock = compute_block(start)
width = min(batch_size, S - start)
C = C.at[:, start : start + width].set(Cblock[:, :width])
F = jnp.zeros((M, B), dtype=jnp.float64)
F = F.at[rows, bc].add(v)

return 0.5 * (C + C.T)
G = self.apply_operator(F) # (M, B)

return _curvature_rect_jax(
pix_indexes_for_sub_slim_index,
pix_weights_for_sub_slim_index,
fft_index_for_masked_pixel,
)
contrib = vals[:, None] * G[rows, :]
Cblock = segment_sum(contrib, cols, num_segments=S) # (S, B)

width = jnp.minimum(B, jnp.maximum(0, S - start))
Cblock = Cblock * (self.col_offsets < width)[None, :]

return lax.dynamic_update_slice(C, Cblock, (0, start))

C_pad = lax.fori_loop(0, n_blocks, body, C0)
C = C_pad[:, :S]
return 0.5 * (C + C.T)
48 changes: 19 additions & 29 deletions autoarray/inversion/inversion/interferometer/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
AbstractInversionInterferometer,
)
from autoarray.inversion.linear_obj.linear_obj import LinearObj
from autoarray.inversion.mesh.mesh.delaunay import Delaunay
from autoarray.inversion.mappers import mapper_util
from autoarray.settings import Settings
from autoarray.inversion.mappers.abstract import Mapper
from autoarray.structures.visibilities import Visibilities

from autoarray.inversion.inversion.interferometer import inversion_interferometer_util


class InversionInterferometerSparse(AbstractInversionInterferometer):
def __init__(
Expand Down Expand Up @@ -99,35 +97,27 @@ def curvature_matrix_diag(self) -> np.ndarray:

This function computes the diagonal terms of F using the sparse linear algebra formalism.
"""

mapper = self.cls_list_from(cls=Mapper)[0]

# The interferometer sparse-operator curvature path
# (``InterferometerSparseOperator.curvature_matrix_via_sparse_operator_from``)
# has only been validated against ``Rectangular*`` meshes (single source
# pixel per image pixel, weight=1). When given a ``Delaunay`` mapper
# (three source pixels per image pixel via barycentric interpolation,
# weights summing to 1) the returned curvature matrix disagrees with the
# mapping path by ~34% Frobenius and the regularized matrix loses
# positive-definiteness, raising a numpy ``LinAlgError`` at the Cholesky
# call site in ``Inversion.log_det_curvature_reg_matrix_term``. Guard
# rather than silently mis-computing.
if isinstance(mapper.mesh, Delaunay):
raise NotImplementedError(
"Interferometer.apply_sparse_operator() is not implemented for "
"Delaunay-mesh pixelizations: the sparse curvature math has only "
"been validated against Rectangular meshes (Pmax=1, weight=1) "
"and is structurally wrong for barycentric-interpolated mappers "
"(Pmax=3). For Delaunay interferometer fits, use the plain DFT "
"or NUFFT path (i.e. omit the apply_sparse_operator step). "
"Tracking issue: https://github.com/PyAutoLabs/PyAutoArray/issues/314"
)
# The interferometer W~ operator lives on the unmasked-extent rectangular
# grid (shape_native_masked_pixels), not the full native grid used by
# the imaging path. Build sparse triplets with extent-flat row indices
# so they match the operator's (M = extent_y * extent_x, B) scatter buffer.
rows, cols, vals = mapper_util.sparse_triplets_from(
pix_indexes_for_sub=mapper.pix_indexes_for_sub_slim_index,
pix_weights_for_sub=mapper.pix_weights_for_sub_slim_index,
slim_index_for_sub=mapper.slim_index_for_sub_slim_index,
fft_index_for_masked_pixel=self.mask.extent_index_for_masked_pixel,
sub_fraction_slim=mapper.over_sampler.sub_fraction.array,
return_rows_slim=False,
xp=self._xp,
)

return self.dataset.sparse_operator.curvature_matrix_via_sparse_operator_from(
pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index,
pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index,
pix_pixels=self.linear_obj_list[0].params,
fft_index_for_masked_pixel=self.mask.fft_index_for_masked_pixel,
return self.dataset.sparse_operator.curvature_matrix_diag_from(
rows=rows,
cols=cols,
vals=vals,
S=mapper.params,
)

@property
Expand Down
31 changes: 31 additions & 0 deletions autoarray/mask/mask_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,37 @@ def fft_index_for_masked_pixel(self) -> np.ndarray:
# Convert (y, x) coordinates to flat row-major indices
return (ys * width + xs).astype(np.int32)

@cached_property
def extent_index_for_masked_pixel(self) -> np.ndarray:
"""
Return a mapping from masked-pixel (slim) indices to flat indices on
the *unmasked-extent* rectangular FFT grid.

The unmasked extent is the bounding box of unmasked pixels
(``shape_native_masked_pixels``). This index is the interferometer
counterpart of `fft_index_for_masked_pixel`, which uses the full native
grid: the interferometer W~ kernel is computed on the (extent_y,
extent_x) grid because it is translation-invariant and only the offsets
between pairs of unmasked pixels matter — the surrounding masked region
contributes nothing.

Returns
-------
np.ndarray
A 1D array of shape (N_unmasked,) of int32 values in
``[0, extent_y * extent_x)``, suitable as row indices into the
(extent_y * extent_x, batch) scatter buffer used by
``InterferometerSparseOperator.curvature_matrix_diag_from``.
"""
ys, xs = np.where(~self)
if ys.size == 0:
return np.zeros((0,), dtype=np.int32)

y0, x0 = int(np.min(ys)), int(np.min(xs))
extent_y, extent_x = self.shape_native_masked_pixels
width = int(extent_x)
return ((ys - y0) * width + (xs - x0)).astype(np.int32)

def trimmed_array_from(self, padded_array, image_shape) -> Array2D:
"""
Map a padded 1D array of values to its original 2D array, trimming all edge values.
Expand Down
Loading
Loading