Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions autoarray/operators/convolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ class determines how masked real-space data are embedded into a padded array,

self.fft_kernel = np.fft.rfft2(self.kernel.native.array, s=self.fft_shape)
self.fft_kernel_mapping = np.expand_dims(self.fft_kernel, 2)
# Pre-cached complex64 view for the use_mixed_precision=True path of
# convolved_image_from. Cast once here so the FFT branch does not
# repeat the astype per JIT trace — it would otherwise produce a fresh
# numpy buffer each call, which on CPU costs more than the fp32 FFT
# saves. convolved_mapping_matrix_from intentionally does NOT use a
# complex64 kernel — see that method's body for why.
self.fft_kernel_c64 = self.fft_kernel.astype(np.complex64)


class Convolver:
Expand Down Expand Up @@ -532,17 +539,23 @@ def convolved_image_from(

state = self.state_from(mask=image.mask)

# When use_mixed_precision is on, the FFT runs in complex64 end-to-end:
# the input cube is allocated as float32, rfft2 emits complex64, the
# precomputed (complex128) kernel is cast on the fly, and irfft2
# returns float32 natively. No trailing astype is needed.
real_dtype = jnp.float32 if use_mixed_precision else jnp.float64

# Build combined native image in the FFT dtype
image_both_native = xp.zeros(state.fft_shape, dtype=jnp.float64)
image_both_native = xp.zeros(state.fft_shape, dtype=real_dtype)

image_both_native = image_both_native.at[state.mask.slim_to_native_tuple].set(
jnp.asarray(image.array, dtype=jnp.float64)
jnp.asarray(image.array, dtype=real_dtype)
)

if blurring_image is not None:
image_both_native = image_both_native.at[
state.blurring_mask.slim_to_native_tuple
].set(jnp.asarray(blurring_image.array, dtype=jnp.float64))
].set(jnp.asarray(blurring_image.array, dtype=real_dtype))
else:
warnings.warn(
"No blurring_image provided. Only the direct image will be convolved. "
Expand All @@ -554,9 +567,14 @@ def convolved_image_from(
image_both_native, s=state.fft_shape, axes=(0, 1)
)

# Pick the precomputed kernel matching the FFT dtype. ConvolverState
# caches both complex128 (default) and complex64 (mixed precision) at
# init time, so this is a constant lookup rather than a per-call cast.
fft_kernel = state.fft_kernel_c64 if use_mixed_precision else state.fft_kernel

# Multiply by PSF in Fourier space and invert
blurred_image_full = xp.fft.irfft2(
state.fft_kernel * fft_image_native, s=state.fft_shape, axes=(0, 1)
fft_kernel * fft_image_native, s=state.fft_shape, axes=(0, 1)
)
ky, kx = self.kernel.shape_native # (21, 21)
off_y = (ky - 1) // 2
Expand All @@ -572,15 +590,11 @@ def convolved_image_from(
blurred_image_full, start_indices, state.fft_shape
)

# Return slim form; optionally cast for downstream stability
# Return slim form; dtype already matches use_mixed_precision via the
# FFT path, so no explicit downcast.
blurred_slim = blurred_image_native[state.mask.slim_to_native_tuple]

blurred_image = Array2D(values=blurred_slim, mask=image.mask)

if use_mixed_precision:
blurred_image = blurred_image.astype(jnp.float32)

return blurred_image
return Array2D(values=blurred_slim, mask=image.mask)

def convolved_mapping_matrix_from(
self,
Expand Down Expand Up @@ -677,7 +691,19 @@ def convolved_mapping_matrix_from(
# -------------------------------------------------------------------------
# Mixed precision handling
# -------------------------------------------------------------------------
fft_complex_dtype = jnp.complex64 if use_mixed_precision else jnp.complex128
# mapping_matrix_native_from honors use_mixed_precision and produces a
# fp32 native cube. rfft2 of that cube emits complex64. We deliberately
# multiply by the complex128 precomputed kernel below, which upcasts
# the product back to complex128 so the irfft2 returns float64. This
# asymmetry is intentional: pixelization meshes with K >> 40 source
# pixels accumulate enough fp32 round-off through the NNLS active-set
# / log-determinant that the figure_of_merit drifts by O(1) units
# (verified on the delaunay_mge regression). The fp32 input cube and
# complex64 forward FFT still buy us a faster scatter and slightly
# cheaper rfft2; keeping the kernel multiply in complex128 preserves
# the precision the downstream linear algebra needs.
# convolved_image_from (used by light profiles) takes the full fp32
# path because its 40-column linear systems are well-conditioned.

# -------------------------------------------------------------------------
# Build native cube on the *native mask grid*
Expand Down
51 changes: 46 additions & 5 deletions autoarray/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,52 @@ def __init__(
Parameters
----------
use_mixed_precision
If `True`, the linear algebra calculations of the inversion are performed using single precision on a
targeted subset of functions which provide significant speed up when using a GPU (x4), reduces VRAM
use and are expected to have minimal impact on the accuracy of the results. If `False`, all linear algebra
calculations are performed using double precision, which is the default and is more accurate but
slower on a GPU.
If `True`, a targeted subset of the inversion's linear algebra runs in single precision (float32 /
complex64) instead of double precision (float64 / complex128). This is intended to reduce VRAM use and
speed up the FFT-heavy and bandwidth-bound steps on GPU and CPU; only the JAX (`xp=jnp`) paths honor
the flag — the NumPy backend always runs in fp64.

Paths that honor the flag:

- PSF FFT convolution in :meth:`Convolver.convolved_image_from` (the light-profile blurring path,
used by linear MGE bases and similar): the input image, kernel multiply and inverse FFT all run in
complex64 / float32 end to end. This is the headline GPU win for MGE imaging pipelines.
- PSF FFT convolution in :meth:`Convolver.convolved_mapping_matrix_from` (the pixelization mapping
matrix path): the input cube is fp32 and the forward ``rfft2`` runs in complex64, but the kernel
multiply intentionally upcasts back to complex128 so the inverse FFT and downstream linear
algebra stay fp64. Pixelization meshes with K ≫ 40 source pixels accumulate enough fp32
round-off through NNLS / log-determinant to shift ``figure_of_merit`` by O(1) units; the upcast
preserves precision while the cheaper fp32 scatter and forward FFT are kept.
- The mapping matrix native cube allocation in
:func:`autoarray.inversion.mappers.mapper_util.mapping_matrix_from` — output dtype becomes fp32.
- The internal compute dtype of the curvature matrix accumulation in
:func:`autoarray.inversion.inversion.inversion_util.curvature_matrix_via_mapping_matrix_from` —
the noise-weighted ``A.T @ A`` is formed in fp32 then cast to fp64 for downstream stability.

Empirical platform notes:

- **GPU**: full pipeline single-JIT roughly matches the fp64 baseline; vmap-batched evaluation
(the production sampler hot path) shows 25–30% speedup on RTX 2060-class hardware.
- **CPU**: the per-call FFT itself is ~1.6× faster in fp32, but JAX/XLA's CPU FFT lowering does
not always re-compose well across ~40-call MGE-basis pipelines, so the single-JIT measurement
can be neutral or slightly slower than fp64. vmap remains comparable to or slightly faster than
fp64. The flag is most beneficial for GPU users.

Paths that intentionally stay in fp64:

- The NNLS reconstruction (jaxnnls / Cholesky factor + cho_solve) in
:func:`autoarray.inversion.inversion.inversion_util.reconstruction_positive_only_from`. Active-set
and PDIP solvers are sensitive to fp32 noise on ill-conditioned source meshes.
- The log-determinant of the curvature regularization matrix used by ``figure_of_merit``: condition
numbers can exceed 1e6 on fine pixelizations and fp32 silently loses 1+ digit there.
- Light profile evaluation on the (over-)sampled grid; only the resulting mapping matrix is downcast.

Empirical numerical impact on the MGE imaging regression (HST-shaped, 15k masked pixels, 40 linear
Gaussians): Δlog-likelihood ≈ 1e-4 absolute at log-likelihood ≈ 27,400. Well below the natural χ²
sampling noise floor (σ ≈ √(2N) ≈ 175). Pixelization paths with K ≫ 40 source pixels are more
sensitive — verify on representative integration tests before turning on for production fits.

If `False` (default), all paths run in fp64.
use_positive_only_solver
Whether to use a positive-only linear system solver, which requires that every reconstructed value is
positive but is computationally much slower than the default solver (which allows for positive and
Expand Down
Loading