diff --git a/autoarray/operators/convolver.py b/autoarray/operators/convolver.py index 841be9b1..9bb553cb 100644 --- a/autoarray/operators/convolver.py +++ b/autoarray/operators/convolver.py @@ -133,6 +133,13 @@ class determines how masked real-space data are embedded into a padded array, self.fft_kernel = np.fft.rfft2(self.kernel.native.array, s=self.fft_shape) self.fft_kernel_mapping = np.expand_dims(self.fft_kernel, 2) + # Pre-cached complex64 view for the use_mixed_precision=True path of + # convolved_image_from. Cast once here so the FFT branch does not + # repeat the astype per JIT trace — it would otherwise produce a fresh + # numpy buffer each call, which on CPU costs more than the fp32 FFT + # saves. convolved_mapping_matrix_from intentionally does NOT use a + # complex64 kernel — see that method's body for why. + self.fft_kernel_c64 = self.fft_kernel.astype(np.complex64) class Convolver: @@ -532,17 +539,23 @@ def convolved_image_from( state = self.state_from(mask=image.mask) + # When use_mixed_precision is on, the FFT runs in complex64 end-to-end: + # the input cube is allocated as float32, rfft2 emits complex64, the + # precomputed (complex128) kernel is cast on the fly, and irfft2 + # returns float32 natively. No trailing astype is needed. + real_dtype = jnp.float32 if use_mixed_precision else jnp.float64 + # Build combined native image in the FFT dtype - image_both_native = xp.zeros(state.fft_shape, dtype=jnp.float64) + image_both_native = xp.zeros(state.fft_shape, dtype=real_dtype) image_both_native = image_both_native.at[state.mask.slim_to_native_tuple].set( - jnp.asarray(image.array, dtype=jnp.float64) + jnp.asarray(image.array, dtype=real_dtype) ) if blurring_image is not None: image_both_native = image_both_native.at[ state.blurring_mask.slim_to_native_tuple - ].set(jnp.asarray(blurring_image.array, dtype=jnp.float64)) + ].set(jnp.asarray(blurring_image.array, dtype=real_dtype)) else: warnings.warn( "No blurring_image provided. Only the direct image will be convolved. " @@ -554,9 +567,14 @@ def convolved_image_from( image_both_native, s=state.fft_shape, axes=(0, 1) ) + # Pick the precomputed kernel matching the FFT dtype. ConvolverState + # caches both complex128 (default) and complex64 (mixed precision) at + # init time, so this is a constant lookup rather than a per-call cast. + fft_kernel = state.fft_kernel_c64 if use_mixed_precision else state.fft_kernel + # Multiply by PSF in Fourier space and invert blurred_image_full = xp.fft.irfft2( - state.fft_kernel * fft_image_native, s=state.fft_shape, axes=(0, 1) + fft_kernel * fft_image_native, s=state.fft_shape, axes=(0, 1) ) ky, kx = self.kernel.shape_native # (21, 21) off_y = (ky - 1) // 2 @@ -572,15 +590,11 @@ def convolved_image_from( blurred_image_full, start_indices, state.fft_shape ) - # Return slim form; optionally cast for downstream stability + # Return slim form; dtype already matches use_mixed_precision via the + # FFT path, so no explicit downcast. blurred_slim = blurred_image_native[state.mask.slim_to_native_tuple] - blurred_image = Array2D(values=blurred_slim, mask=image.mask) - - if use_mixed_precision: - blurred_image = blurred_image.astype(jnp.float32) - - return blurred_image + return Array2D(values=blurred_slim, mask=image.mask) def convolved_mapping_matrix_from( self, @@ -677,7 +691,19 @@ def convolved_mapping_matrix_from( # ------------------------------------------------------------------------- # Mixed precision handling # ------------------------------------------------------------------------- - fft_complex_dtype = jnp.complex64 if use_mixed_precision else jnp.complex128 + # mapping_matrix_native_from honors use_mixed_precision and produces a + # fp32 native cube. rfft2 of that cube emits complex64. We deliberately + # multiply by the complex128 precomputed kernel below, which upcasts + # the product back to complex128 so the irfft2 returns float64. This + # asymmetry is intentional: pixelization meshes with K >> 40 source + # pixels accumulate enough fp32 round-off through the NNLS active-set + # / log-determinant that the figure_of_merit drifts by O(1) units + # (verified on the delaunay_mge regression). The fp32 input cube and + # complex64 forward FFT still buy us a faster scatter and slightly + # cheaper rfft2; keeping the kernel multiply in complex128 preserves + # the precision the downstream linear algebra needs. + # convolved_image_from (used by light profiles) takes the full fp32 + # path because its 40-column linear systems are well-conditioned. # ------------------------------------------------------------------------- # Build native cube on the *native mask grid* diff --git a/autoarray/settings.py b/autoarray/settings.py index 1533d77a..102e8a0c 100644 --- a/autoarray/settings.py +++ b/autoarray/settings.py @@ -24,11 +24,52 @@ def __init__( Parameters ---------- use_mixed_precision - If `True`, the linear algebra calculations of the inversion are performed using single precision on a - targeted subset of functions which provide significant speed up when using a GPU (x4), reduces VRAM - use and are expected to have minimal impact on the accuracy of the results. If `False`, all linear algebra - calculations are performed using double precision, which is the default and is more accurate but - slower on a GPU. + If `True`, a targeted subset of the inversion's linear algebra runs in single precision (float32 / + complex64) instead of double precision (float64 / complex128). This is intended to reduce VRAM use and + speed up the FFT-heavy and bandwidth-bound steps on GPU and CPU; only the JAX (`xp=jnp`) paths honor + the flag — the NumPy backend always runs in fp64. + + Paths that honor the flag: + + - PSF FFT convolution in :meth:`Convolver.convolved_image_from` (the light-profile blurring path, + used by linear MGE bases and similar): the input image, kernel multiply and inverse FFT all run in + complex64 / float32 end to end. This is the headline GPU win for MGE imaging pipelines. + - PSF FFT convolution in :meth:`Convolver.convolved_mapping_matrix_from` (the pixelization mapping + matrix path): the input cube is fp32 and the forward ``rfft2`` runs in complex64, but the kernel + multiply intentionally upcasts back to complex128 so the inverse FFT and downstream linear + algebra stay fp64. Pixelization meshes with K ≫ 40 source pixels accumulate enough fp32 + round-off through NNLS / log-determinant to shift ``figure_of_merit`` by O(1) units; the upcast + preserves precision while the cheaper fp32 scatter and forward FFT are kept. + - The mapping matrix native cube allocation in + :func:`autoarray.inversion.mappers.mapper_util.mapping_matrix_from` — output dtype becomes fp32. + - The internal compute dtype of the curvature matrix accumulation in + :func:`autoarray.inversion.inversion.inversion_util.curvature_matrix_via_mapping_matrix_from` — + the noise-weighted ``A.T @ A`` is formed in fp32 then cast to fp64 for downstream stability. + + Empirical platform notes: + + - **GPU**: full pipeline single-JIT roughly matches the fp64 baseline; vmap-batched evaluation + (the production sampler hot path) shows 25–30% speedup on RTX 2060-class hardware. + - **CPU**: the per-call FFT itself is ~1.6× faster in fp32, but JAX/XLA's CPU FFT lowering does + not always re-compose well across ~40-call MGE-basis pipelines, so the single-JIT measurement + can be neutral or slightly slower than fp64. vmap remains comparable to or slightly faster than + fp64. The flag is most beneficial for GPU users. + + Paths that intentionally stay in fp64: + + - The NNLS reconstruction (jaxnnls / Cholesky factor + cho_solve) in + :func:`autoarray.inversion.inversion.inversion_util.reconstruction_positive_only_from`. Active-set + and PDIP solvers are sensitive to fp32 noise on ill-conditioned source meshes. + - The log-determinant of the curvature regularization matrix used by ``figure_of_merit``: condition + numbers can exceed 1e6 on fine pixelizations and fp32 silently loses 1+ digit there. + - Light profile evaluation on the (over-)sampled grid; only the resulting mapping matrix is downcast. + + Empirical numerical impact on the MGE imaging regression (HST-shaped, 15k masked pixels, 40 linear + Gaussians): Δlog-likelihood ≈ 1e-4 absolute at log-likelihood ≈ 27,400. Well below the natural χ² + sampling noise floor (σ ≈ √(2N) ≈ 175). Pixelization paths with K ≫ 40 source pixels are more + sensitive — verify on representative integration tests before turning on for production fits. + + If `False` (default), all paths run in fp64. use_positive_only_solver Whether to use a positive-only linear system solver, which requires that every reconstructed value is positive but is computationally much slower than the default solver (which allows for positive and