Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autoarray/inversion/regularization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from .gaussian_kernel import GaussianKernel
from .exponential_kernel import ExponentialKernel
from .matern_kernel import MaternKernel
from .matern_adaptive_brightness_kernel import MaternAdaptiveBrightnessKernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from __future__ import annotations
import numpy as np
from typing import TYPE_CHECKING

from autoarray.inversion.regularization.matern_kernel import MaternKernel

if TYPE_CHECKING:
from autoarray.inversion.linear_obj.linear_obj import LinearObj

from autoarray.inversion.regularization.matern_kernel import matern_kernel


def matern_cov_matrix_from(
scale: float,
nu: float,
pixel_points,
weights=None,
xp=np,
):
"""
Construct the regularization covariance matrix (N x N) using a Matérn kernel,
optionally modulated by per-pixel weights.

If `weights` is provided (shape [N]), the covariance is:
C_ij = K(d_ij; scale, nu) * w_i * w_j
with a small diagonal jitter added for numerical stability.

Parameters
----------
scale
Typical correlation length of the Matérn kernel.
nu
Smoothness parameter of the Matérn kernel.
pixel_points
Array-like of shape [N, 2] with (y, x) coordinates (or any 2D coords; only distances matter).
weights
Optional array-like of shape [N]. If None, treated as all ones.
xp
Backend (numpy or jax.numpy).

Returns
-------
covariance_matrix
Array of shape [N, N].
"""

# --------------------------------
# Pairwise distances (broadcasted)
# --------------------------------
diff = pixel_points[:, None, :] - pixel_points[None, :, :] # (N, N, 2)
d_ij = xp.sqrt(diff[..., 0] ** 2 + diff[..., 1] ** 2) # (N, N)

# --------------------------------
# Base Matérn covariance
# --------------------------------
covariance_matrix = matern_kernel(d_ij, l=scale, v=nu, xp=xp) # (N, N)

# --------------------------------
# Apply weights: C_ij *= w_i * w_j
# (broadcasted outer product, JAX-safe)
# --------------------------------
if weights is not None:
w = xp.asarray(weights)
# Ensure shape (N,) -> outer product (N,1)*(1,N) -> (N,N)
covariance_matrix = covariance_matrix * (w[:, None] * w[None, :])

# --------------------------------
# Add diagonal jitter (JAX-safe)
# --------------------------------
pixels = pixel_points.shape[0]
covariance_matrix = covariance_matrix + 1e-8 * xp.eye(pixels)

return covariance_matrix


class MaternAdaptiveBrightnessKernel(MaternKernel):
def __init__(
self,
coefficient: float = 1.0,
scale: float = 1.0,
nu: float = 0.5,
rho: float = 1.0,
):
"""
Regularization which uses a Matern smoothing kernel to regularize the solution with regularization weights
that adapt to the brightness of the source being reconstructed.

For this regularization scheme, every pixel is regularized with every other pixel. This contrasts many other
schemes, where regularization is based on neighboring (e.g. do the pixels share a Delaunay edge?) or computing
derivatives around the center of the pixel (where nearby pixels are regularization locally in similar ways).

This makes the regularization matrix fully dense and therefore may change the run times of the solution.
It also leads to more overall smoothing which can lead to more stable linear inversions.

For the weighted regularization scheme, each pixel is given an 'effective regularization weight', which is
applied when each set of pixel neighbors are regularized with one another. The motivation of this is that
different regions of a pixelization's mesh require different levels of regularization (e.g., high smoothing where the
no signal is present and less smoothing where it is, see (Nightingale, Dye and Massey 2018)).

This scheme is not used by Vernardos et al. (2022): https://arxiv.org/abs/2202.09378, but it follows
a similar approach.

A full description of regularization and this matrix can be found in the parent `AbstractRegularization` class.

Parameters
----------
coefficient
The regularization coefficient which controls the degree of smooth of the inversion reconstruction.
scale
The typical scale (correlation length) of the Matérn regularization kernel.
nu
Controls the smoothness (differentiability) of the Matérn kernel; ``nu=0.5`` corresponds to an
exponential (Ornstein–Uhlenbeck) kernel, while a Gaussian covariance is obtained in the limit
as ``nu`` approaches infinity.
rho
Controls how strongly the kernel weights adapt to pixel brightness. Larger values make bright pixels
receive significantly higher weights (and faint pixels lower weights), while smaller values produce a
more uniform weighting. Typical values are of order unity (e.g. 0.5–2.0).
"""
super().__init__(coefficient=coefficient, scale=scale, nu=nu)
self.rho = rho

def covariance_kernel_weights_from(
self, linear_obj: LinearObj, xp=np
) -> np.ndarray:
"""
Returns per-pixel kernel weights that adapt to the reconstructed pixel brightness.
"""
# Assumes linear_obj.pixel_signals_from is xp-aware elsewhere in the codebase.
pixel_signals = linear_obj.pixel_signals_from(signal_scale=1.0, xp=xp)

max_signal = xp.max(pixel_signals)
max_signal = xp.maximum(max_signal, 1e-8) # avoid divide-by-zero (JAX-safe)

return xp.exp(-self.rho * (1.0 - pixel_signals / max_signal))

def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray:
kernel_weights = self.covariance_kernel_weights_from(
linear_obj=linear_obj, xp=xp
)

# Follow the xp pattern used in the Matérn kernel module (often `.array` for grids).
pixel_points = linear_obj.source_plane_mesh_grid.array

covariance_matrix = matern_cov_matrix_from(
scale=self.scale,
pixel_points=pixel_points,
nu=self.nu,
weights=kernel_weights,
xp=xp,
)

return self.coefficient * xp.linalg.inv(covariance_matrix)

def regularization_weights_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray:
"""
Returns the regularization weights of this regularization scheme.
"""
return 1.0 / self.covariance_kernel_weights_from(linear_obj=linear_obj, xp=xp)
20 changes: 17 additions & 3 deletions autoarray/inversion/regularization/matern_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,20 @@ def kv_xp(v, z, xp=np):
)


def gamma_xp(x, xp=np):
"""
XP-compatible Gamma(x).
"""
if xp is np:
import scipy.special as sc

return sc.gamma(x)
else:
import jax.scipy.special as jsp

return jsp.gamma(x)


def matern_kernel(r, l: float = 1.0, v: float = 0.5, xp=np):
"""
XP-compatible Matérn kernel.
Expand All @@ -55,7 +69,7 @@ def matern_kernel(r, l: float = 1.0, v: float = 0.5, xp=np):

z = xp.sqrt(2.0 * v) * r / l

part1 = 2.0 ** (1.0 - v) / math.gamma(v) # scalar constant
part1 = 2.0 ** (1.0 - v) / gamma_xp(v, xp) # scalar constant
part2 = z**v
part3 = kv_xp(v, z, xp)

Expand Down Expand Up @@ -141,8 +155,8 @@ def __init__(self, coefficient: float = 1.0, scale: float = 1.0, nu: float = 0.5
"""

self.coefficient = coefficient
self.scale = float(scale)
self.nu = float(nu)
self.scale = scale
self.nu = nu
super().__init__()

def regularization_weights_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray:
Expand Down
12 changes: 12 additions & 0 deletions autoarray/preloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
linear_light_profile_blurred_mapping_matrix=None,
use_voronoi_areas: bool = True,
areas_factor: float = 0.5,
skip_areas: bool = False,
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The newly added skip_areas parameter is not documented in the Parameters section of the Preloads docstring. Please add documentation explaining what this parameter controls, when it should be set to True, and how it affects the Delaunay triangulation computation (specifically that it skips Voronoi area calculations and split point computations).

Copilot uses AI. Check for mistakes.
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback

):
"""
Stores preloaded arrays and matrices used during pixelized linear inversions, improving both performance
Expand Down Expand Up @@ -81,6 +82,16 @@ def __init__(
inversion, with the other component being the pixelization's pixels. These are fixed when the lens light
is fixed to the maximum likelihood solution, allowing the blurred mapping matrix to be preloaded, but
the intensity values will still be solved for during the inversion.
use_voronoi_areas
Whether to use Voronoi areas during Delaunay triangulation. When True, computes areas for each Voronoi
region which can be used in certain regularization schemes. Default is True.
areas_factor
Factor used to scale the Voronoi areas during split point computation. Default is 0.5.
skip_areas
Whether to skip Voronoi area calculations and split point computations during Delaunay triangulation.
When True, the Delaunay interface returns only the minimal set of outputs (points, simplices, mappings)
without computing split_points or splitted_mappings. This optimization is useful for regularization
schemes like Matérn kernels that don't require area-based calculations. Default is False.
"""
self.mapper_indices = None
self.source_pixel_zeroed_indices = None
Expand Down Expand Up @@ -123,3 +134,4 @@ def __init__(

self.use_voronoi_areas = use_voronoi_areas
self.areas_factor = areas_factor
self.skip_areas = skip_areas
2 changes: 1 addition & 1 deletion autoarray/structures/arrays/kernel_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ def convolved_image_from(
image,
blurring_image,
jax_method="direct",
use_mixed_precision : bool = False,
use_mixed_precision: bool = False,
xp=np,
):
"""
Expand Down
114 changes: 101 additions & 13 deletions autoarray/structures/mesh/delaunay_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,67 @@ def pix_indexes_for_sub_slim_index_delaunay_from(
return out


def scipy_delaunay_matern(points_np, query_points_np):
"""
Minimal SciPy Delaunay callback for Matérn regularization.

Returns only what’s needed for mapping:
- points (tri.points)
- simplices_padded
- mappings: integer array of pixel indices for each query point,
typically of shape (Q, 3), where each row gives the indices of the
Delaunay mesh vertices ("pixels") associated with that query point.
"""

max_simplices = 2 * points_np.shape[0]

# --- Delaunay mesh ---
tri = Delaunay(points_np)

points = tri.points.astype(points_np.dtype)
simplices = tri.simplices.astype(np.int32)

# --- Pad simplices to fixed shape for JAX ---
simplices_padded = -np.ones((max_simplices, 3), dtype=np.int32)
simplices_padded[: simplices.shape[0]] = simplices

# --- find_simplex for query points ---
simplex_idx = tri.find_simplex(query_points_np).astype(np.int32) # (Q,)

mappings = pix_indexes_for_sub_slim_index_delaunay_from(
source_plane_data_grid=query_points_np,
simplex_index_for_sub_slim_index=simplex_idx,
pix_indexes_for_simplex_index=simplices,
delaunay_points=points_np,
)

return points, simplices_padded, mappings


def jax_delaunay_matern(points, query_points):
"""
JAX wrapper using pure_callback to run SciPy Delaunay on CPU,
returning only the minimal outputs needed for Matérn usage.
"""
import jax
import jax.numpy as jnp

N = points.shape[0]
Q = query_points.shape[0]
max_simplices = 2 * N

points_shape = jax.ShapeDtypeStruct((N, 2), points.dtype)
simplices_padded_shape = jax.ShapeDtypeStruct((max_simplices, 3), jnp.int32)
mappings_shape = jax.ShapeDtypeStruct((Q, 3), jnp.int32)

return jax.pure_callback(
lambda pts, qpts: scipy_delaunay_matern(np.asarray(pts), np.asarray(qpts)),
(points_shape, simplices_padded_shape, mappings_shape),
points,
query_points,
)


class DelaunayInterface:

def __init__(
Expand Down Expand Up @@ -466,33 +527,60 @@ def delaunay(self) -> "scipy.spatial.Delaunay":

use_voronoi_areas = self.preloads.use_voronoi_areas
areas_factor = self.preloads.areas_factor
skip_areas = self.preloads.skip_areas

else:

use_voronoi_areas = True
areas_factor = 0.5
skip_areas = False

if self._xp.__name__.startswith("jax"):
if not skip_areas:

import jax.numpy as jnp
if self._xp.__name__.startswith("jax"):

points, simplices, mappings, split_points, splitted_mappings = jax_delaunay(
points=self.mesh_grid_xy,
query_points=self._source_plane_data_grid_over_sampled,
use_voronoi_areas=use_voronoi_areas,
areas_factor=areas_factor,
)
import jax.numpy as jnp

points, simplices, mappings, split_points, splitted_mappings = (
jax_delaunay(
points=self.mesh_grid_xy,
query_points=self._source_plane_data_grid_over_sampled,
use_voronoi_areas=use_voronoi_areas,
areas_factor=areas_factor,
)
)

else:

points, simplices, mappings, split_points, splitted_mappings = (
scipy_delaunay(
points_np=self.mesh_grid_xy,
query_points_np=self._source_plane_data_grid_over_sampled,
use_voronoi_areas=use_voronoi_areas,
areas_factor=areas_factor,
)
)

else:

points, simplices, mappings, split_points, splitted_mappings = (
scipy_delaunay(
if self._xp.__name__.startswith("jax"):

import jax.numpy as jnp

points, simplices, mappings = jax_delaunay_matern(
points=self.mesh_grid_xy,
query_points=self._source_plane_data_grid_over_sampled,
)

else:

points, simplices, mappings = scipy_delaunay_matern(
points_np=self.mesh_grid_xy,
query_points_np=self._source_plane_data_grid_over_sampled,
use_voronoi_areas=use_voronoi_areas,
areas_factor=areas_factor,
)
)

split_points = None
splitted_mappings = None

return DelaunayInterface(
points=points,
Expand Down
Loading
Loading