Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions autoarray/inversion/pixelization/mappers/mapper_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax.numpy as jnp
import numpy as np
from scipy.spatial import cKDTree
from typing import Tuple
Expand Down Expand Up @@ -144,6 +145,97 @@ def data_slim_to_pixelization_unique_from(
return data_to_pix_unique, data_weights, pix_lengths


def rectangular_mappings_weights_via_interpolation_from(
shape_native: Tuple[int, int],
source_plane_data_grid: jnp.ndarray,
source_plane_mesh_grid: jnp.ndarray,
):
"""
Compute bilinear interpolation weights and corresponding rectangular mesh indices for an irregular grid.

Given a flattened regular rectangular mesh grid and an irregular grid of data points, this function
determines for each irregular point:
- the indices of the 4 nearest rectangular mesh pixels (top-left, top-right, bottom-left, bottom-right), and
- the bilinear interpolation weights with respect to those pixels.

The function supports JAX and is compatible with JIT compilation.

Parameters
----------
shape_native
The shape (Ny, Nx) of the original rectangular mesh grid before flattening.
source_plane_data_grid
The irregular grid of (y, x) points to interpolate.
source_plane_mesh_grid
The flattened regular rectangular mesh grid of (y, x) coordinates.

Returns
-------
mappings : jnp.ndarray of shape (N, 4)
Indices of the four nearest rectangular mesh pixels in the flattened mesh grid.
Order is: top-left, top-right, bottom-left, bottom-right.
weights : jnp.ndarray of shape (N, 4)
Bilinear interpolation weights corresponding to the four nearest mesh pixels.

Notes
-----
- Assumes the mesh grid is uniformly spaced.
- The weights sum to 1 for each irregular point.
- Uses bilinear interpolation in the (y, x) coordinate system.
"""
source_plane_mesh_grid = source_plane_mesh_grid.reshape(*shape_native, 2)

# Assume mesh is shaped (Ny, Nx, 2)
Ny, Nx = source_plane_mesh_grid.shape[:2]

# Get mesh spacings and lower corner
y_coords = source_plane_mesh_grid[:, 0, 0] # shape (Ny,)
x_coords = source_plane_mesh_grid[0, :, 1] # shape (Nx,)

dy = y_coords[1] - y_coords[0]
dx = x_coords[1] - x_coords[0]

y_min = y_coords[0]
x_min = x_coords[0]

# shape (N_irregular, 2)
irregular = source_plane_data_grid

# Compute normalized mesh coordinates (floating indices)
fy = (irregular[:, 0] - y_min) / dy
fx = (irregular[:, 1] - x_min) / dx

# Integer indices of top-left corners
ix = jnp.floor(fx).astype(jnp.int32)
iy = jnp.floor(fy).astype(jnp.int32)

# Clip to stay within bounds
ix = jnp.clip(ix, 0, Nx - 2)
iy = jnp.clip(iy, 0, Ny - 2)

# Local coordinates inside the cell (0 <= tx, ty <= 1)
tx = fx - ix
ty = fy - iy

# Bilinear weights
w00 = (1 - tx) * (1 - ty)
w10 = tx * (1 - ty)
w01 = (1 - tx) * ty
w11 = tx * ty

weights = jnp.stack([w00, w10, w01, w11], axis=1) # shape (N_irregular, 4)

# Compute indices of 4 surrounding pixels in the flattened mesh
i00 = iy * Nx + ix
i10 = iy * Nx + (ix + 1)
i01 = (iy + 1) * Nx + ix
i11 = (iy + 1) * Nx + (ix + 1)

mappings = jnp.stack([i00, i10, i01, i11], axis=1) # shape (N_irregular, 4)

return mappings, weights


@numba_util.jit()
def pix_indexes_for_sub_slim_index_delaunay_from(
source_plane_data_grid,
Expand Down
28 changes: 15 additions & 13 deletions autoarray/inversion/pixelization/mappers/rectangular.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import jax.numpy as jnp
import numpy as np
from typing import Tuple

from autoconf import cached_property

from autoarray.structures.grids.irregular_2d import Grid2DIrregular
from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper
from autoarray.inversion.pixelization.mappers.abstract import PixSubWeights

from autoarray.geometry import geometry_util
from autoarray.inversion.pixelization.mappers import mapper_util


class MapperRectangular(AbstractMapper):
Expand Down Expand Up @@ -95,19 +97,19 @@ def pix_sub_weights(self) -> PixSubWeights:
dimension of the array `pix_indexes_for_sub_slim_index` 1 and all entries in `pix_weights_for_sub_slim_index`
are equal to 1.0.
"""
mappings = geometry_util.grid_pixel_indexes_2d_slim_from(
grid_scaled_2d_slim=np.array(self.source_plane_data_grid.over_sampled),
shape_native=self.source_plane_mesh_grid.shape_native,
pixel_scales=self.source_plane_mesh_grid.pixel_scales,
origin=self.source_plane_mesh_grid.origin,
).astype("int")

mappings = mappings.reshape((len(mappings), 1))
mappings, weights = (
mapper_util.rectangular_mappings_weights_via_interpolation_from(
shape_native=self.shape_native,
source_plane_mesh_grid=self.source_plane_mesh_grid.array,
source_plane_data_grid=Grid2DIrregular(
self.source_plane_data_grid.over_sampled
).array,
)
)

return PixSubWeights(
mappings=mappings,
sizes=np.ones(len(mappings), dtype="int"),
weights=np.ones(
(len(self.source_plane_data_grid.over_sampled), 1), dtype="int"
),
mappings=np.array(mappings),
sizes=4 * np.ones(len(mappings), dtype="int"),
weights=np.array(weights),
)
8 changes: 6 additions & 2 deletions autoarray/operators/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,12 @@ def hull(
# cast JAX arrays to base numpy arrays
grid_convex = np.zeros((len(self.grid), 2))

grid_convex[:, 0] = np.array(self.grid[:, 1])
grid_convex[:, 1] = np.array(self.grid[:, 0])
try:
grid_convex[:, 0] = np.array(self.grid.array[:, 1])
grid_convex[:, 1] = np.array(self.grid.array[:, 0])
except AttributeError:
grid_convex[:, 0] = np.array(self.grid[:, 1])
grid_convex[:, 1] = np.array(self.grid[:, 0])

try:
hull = ConvexHull(grid_convex)
Expand Down
10 changes: 2 additions & 8 deletions test_autoarray/inversion/inversion/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,14 +487,8 @@ def test__inversion_matrices__x2_mappers(
settings=aa.SettingsInversion(use_positive_only_solver=True),
)

assert (
inversion.operated_mapping_matrix[0:9, 0:9]
== rectangular_mapper_7x7_3x3.mapping_matrix
).all()
assert (
inversion.operated_mapping_matrix[0:9, 9:18]
== delaunay_mapper_9_3x3.mapping_matrix
).all()
assert inversion.operated_mapping_matrix[0:9, 0:9] == pytest.approx(rectangular_mapper_7x7_3x3.mapping_matrix, abs=1.0e-4)
assert inversion.operated_mapping_matrix[0:9, 9:18] == pytest.approx(delaunay_mapper_9_3x3.mapping_matrix, abs=1.0e-4)

operated_mapping_matrix = np.hstack(
[
Expand Down
20 changes: 10 additions & 10 deletions test_autoarray/inversion/pixelization/mappers/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,18 @@ def test__rectangular_mapper():
(5.0, 5.0), 1.0e-4
)
assert mapper.source_plane_mesh_grid.origin == pytest.approx((0.5, 0.5), 1.0e-4)
assert (
mapper.mapping_matrix
== np.array(
assert mapper.mapping_matrix == pytest.approx(
np.array(
[
[0.0, 0.75, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25],
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
[0.0675, 0.5775, 0.18, 0.0075, -0.065, -0.1425, 0.0, 0.0375, 0.3375],
[0.18, -0.03, 0.0, 0.84, -0.14, 0.0, 0.18, -0.03, 0.0],
[0.0225, 0.105, 0.0225, 0.105, 0.49, 0.105, 0.0225, 0.105, 0.0225],
[0.0, -0.03, 0.18, 0.0, -0.14, 0.84, 0.0, -0.03, 0.18],
[0.0, 0.0, 0.0, -0.03, -0.14, -0.03, 0.18, 0.84, 0.18],
]
)
).all()
),
1.0e-4,
)
assert mapper.shape_native == (3, 3)


Expand Down
24 changes: 11 additions & 13 deletions test_autoarray/inversion/pixelization/mappers/test_rectangular.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,18 @@ def test__pix_indexes_for_sub_slim_index__matches_util():

mapper = aa.Mapper(mapper_grids=mapper_grids, regularization=None)

pix_indexes_for_sub_slim_index_util = np.array(
[
aa.util.geometry.grid_pixel_indexes_2d_slim_from(
grid_scaled_2d_slim=np.array(grid.over_sampled),
shape_native=mesh_grid.shape_native,
pixel_scales=mesh_grid.pixel_scales,
origin=mesh_grid.origin,
).astype("int")
]
).T
mappings, weights = (
aa.util.mapper.rectangular_mappings_weights_via_interpolation_from(
shape_native=(3, 3),
source_plane_mesh_grid=mesh_grid.array,
source_plane_data_grid=aa.Grid2DIrregular(
mapper_grids.source_plane_data_grid.over_sampled
).array,
)
)

assert (
mapper.pix_indexes_for_sub_slim_index == pix_indexes_for_sub_slim_index_util
).all()
assert (mapper.pix_sub_weights.mappings == mappings).all()
assert (mapper.pix_sub_weights.weights == weights).all()


def test__pixel_signals_from__matches_util(grid_2d_sub_1_7x7, image_7x7):
Expand Down
Loading