fix(interferometer): correct sparse curvature for Pmax > 1 (Delaunay)#316
Merged
Merged
Conversation
Replaces the InterferometerSparseOperator curvature path with one that mirrors ImagingSparseOperator.curvature_matrix_diag_from(rows, cols, vals, *, S) and uses extent-flat row indices for the W~ operator's (2*y_ext, 2*x_ext) grid. The previous path used native-flat indices from fft_index_for_masked_pixel, so most JAX scatter writes silently fell out-of-bounds for any real_space_mask with extent < native shape. This fixed both the 34% Frobenius gap on Delaunay (Pmax=3 barycentric, issue #314) and the previously-documented ~0.4% Pmax=1 "numerical reformulation" gap — both reduce to the same indexing bug. Removes the defensive NotImplementedError guard added in #315, and converts the raise-test into a sparse-vs-mapping parity assertion at rtol=1e-4. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Collaborator
Author
|
Workspace PR: PyAutoLabs/autolens_workspace_test#98 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes the silent ~34% Frobenius mismatch between the interferometer sparse-operator curvature matrix and the mapping path on Delaunay (Pmax = 3 barycentric) meshes (issue #314). Replaces the defensive
NotImplementedErrorguard from #315 with a real math fix. As a bonus, the same bug was producing a quieter ~0.4% gap on Pmax = 1 (rectangular) — both reduce to the same indexing mistake and are fixed together.API Changes
The interferometer sparse operator now exposes the same API surface as the imaging operator:
apply_operator(F)+curvature_matrix_diag_from(rows, cols, vals, *, S). The oldcurvature_matrix_via_sparse_operator_from(pix_indexes, pix_weights, pix_pixels, fft_index)shape-array signature is gone. TheNotImplementedErrorguard onInversionInterferometerSparse.curvature_matrix_diag(added in #315) is also gone. NewMask2D.extent_index_for_masked_pixelcached_property exposes the slim → unmasked-extent-flat index — the interferometer counterpart offft_index_for_masked_pixel.See full details below.
Test Plan
pytest test_autoarray/inversion/inversion/interferometer/test_interferometer.py— the newtest__curvature_matrix__interferometer_sparse_operator__delaunay__identical_to_mappingasserts sparse-vs-mapping parity at rtol=1e-4 on a Delaunay mapper.pytest test_autoarray/inversion/(full inversion suite) — 166 passed locally.pytest test_autoarray/mask/(mask suite) — 105 passed locally, covers the newextent_index_for_masked_pixelvia the integration parity test.autolens_workspace_test/scripts/jax_assertions/sparse_operators.pyand the now-closed-gap literal inrectangular_sparse.py.Root cause
The interferometer's W~ operator lives on the unmasked-extent rectangular grid:
M = shape_native_masked_pixels[0] * shape_native_masked_pixels[1]. The previous implementation usedmask.fft_index_for_masked_pixelas row indices for the(M, batch_size)scatter buffer — butfft_index_for_masked_pixelreturns native-flat indices in[0, native_y * native_x). For any real-space mask where the unmasked extent is smaller than the native shape (i.e. any non-trivial circular mask), most rows fell out of bounds; JAX's default.at[i, j].add(v)silently drops out-of-bounds writes, so most contributions vanished. The result was very-nearly-zero curvature for Delaunay (34% Frobenius from the correct value) and a quieter mostly-correct-but-not-quite curvature for Pmax = 1 (the documented 0.4% gap).The fix adds
Mask2D.extent_index_for_masked_pixel(slim → extent-flat), andInversionInterferometerSparse.curvature_matrix_diagplumbs it throughmapper_util.sparse_triplets_from(...)to produce extent-flat row indices that land in the operator's actual(M, B)scatter buffer. The sparse-operator class itself is then a near-clone ofImagingSparseOperator(same scatter /apply_operator/segment_sum/ width-mask /S_padshape), differing only in the actual W~ apply — FFT-conv withKhatfor interferometer, vsH^T N^{-1} Hfor imaging.Scripts Changed
None in this library PR. Workspace follow-up coming in a separate
/ship_workspacePR forautolens_workspace_test:scripts/jax_assertions/sparse_operators.py— switch the Pmax=1 call to the newcurvature_matrix_diag_from(rows, cols, vals, S)API.scripts/jax_likelihood_functions/interferometer/rectangular_sparse.py— Path A literal-3152.03184792→-3164.286252(now matches Path B and Path C exactly); the Path B docstring's "~0.4% numerical reformulation" comment is replaced with the actual behaviour (mathematically exact, agrees to ~1e-13).notebooks/jax_likelihood_functions/interferometer/rectangular_sparse.ipynb— corresponding literal update.Full API Changes (for automation & release notes)
Added
Mask2D.extent_index_for_masked_pixel: np.ndarray—cached_property. 1D int32 array of shape(N_unmasked,). Maps slim masked-pixel index to the flat row-major index on the unmasked-extent rectangular FFT grid (extent_y * extent_x). Use this for the interferometer sparse path. The existingfft_index_for_masked_pixel(native-flat) remains correct for the imaging path.InterferometerSparseOperator.apply_operator(Fbatch_flat) -> jax.Array— FFT-conv method (Re(IFFT(FFT(F_pad) * Khat))[:y, :x]). Previously inlined insidecurvature_matrix_via_sparse_operator_from.InterferometerSparseOperator.col_offsets: jax.Array—(batch_size,)int32. Width-mask helper, populated byfrom_nufft_precision_operator.InterferometerSparseOperator.curvature_matrix_diag_from(rows, cols, vals, *, S) -> jax.Array— replaces the old method; see Renamed / Signature Change.Removed
InterferometerSparseOperator.curvature_matrix_via_sparse_operator_from(pix_indexes_for_sub_slim_index, pix_weights_for_sub_slim_index, pix_pixels, fft_index_for_masked_pixel)— see Renamed / Signature Change.NotImplementedErrorguard onInversionInterferometerSparse.curvature_matrix_diag(and thefrom autoarray.inversion.mesh.mesh.delaunay import Delaunayimport that supported it) — guard is no longer needed; the math is now correct.Renamed / Signature Change
InterferometerSparseOperator.curvature_matrix_via_sparse_operator_from(pix_indexes_for_sub_slim_index, pix_weights_for_sub_slim_index, pix_pixels, fft_index_for_masked_pixel)→InterferometerSparseOperator.curvature_matrix_diag_from(rows, cols, vals, *, S). Same name and call shape asImagingSparseOperator.curvature_matrix_diag_from. Triplets must use extent-flat row indexing; callers should produce them viamapper_util.sparse_triplets_from(..., fft_index_for_masked_pixel=mask.extent_index_for_masked_pixel, return_rows_slim=False).Changed Behaviour
InversionInterferometerSparse.curvature_matrix_diagnow returns the correct curvature for arbitrary Pmax (including Delaunay barycentric). Previously raisedNotImplementedErroron Delaunay (PR fix(interferometer-sparse): guard against Delaunay mappers (issue #314) #315) and silently produced an indexed-out-of-bounds result with a quieter ~0.4% gap on Pmax = 1.Migration
If you held a direct reference to
InterferometerSparseOperator.curvature_matrix_via_sparse_operator_from, replace the call with:```python
Before:
op.curvature_matrix_via_sparse_operator_from(
pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index,
pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index,
pix_pixels=mapper.params,
fft_index_for_masked_pixel=mask.fft_index_for_masked_pixel,
)
After:
from autoarray.inversion.mappers import mapper_util
rows, cols, vals = mapper_util.sparse_triplets_from(
pix_indexes_for_sub=mapper.pix_indexes_for_sub_slim_index,
pix_weights_for_sub=mapper.pix_weights_for_sub_slim_index,
slim_index_for_sub=mapper.slim_index_for_sub_slim_index,
fft_index_for_masked_pixel=mask.extent_index_for_masked_pixel, # extent-flat, not native-flat
sub_fraction_slim=mapper.over_sampler.sub_fraction.array,
return_rows_slim=False,
)
op.curvature_matrix_diag_from(rows, cols, vals, S=mapper.params)
```
This pattern mirrors
ImagingSparseOperator.curvature_matrix_diag_from's call site — seeinversion/inversion/imaging/sparse.py:288for the imaging counterpart.Closes #314. Supersedes the defensive guard from #315.
🤖 Generated with Claude Code