Overview
The existing JAX-compatible adaptive rectangular source-plane mesh in PyAutoArray
warps source-plane coordinates through a separable per-axis empirical CDF
(implemented as a degree-11 polynomial + cubic-Hermite spline inverse in
InterpolatorRectangularSpline). Scientifically this works, but recovers detailed
source structure only at high resolutions (~4000+ pixels). This task investigates
whether a multi-component, physically-motivated density field driving the same
CDF warp can match Delaunay-like effective resolution at 500–1500 pixels while
preserving fixed rectangular topology, fixed array shapes, full JAX/JIT
compatibility, and differentiability.
The hypothesis: the adaptive rectangular topology is not the problem — pixel
distribution is. Replacing the single density signal (point-density or
adapt-image weights) with a weighted sum of magnification, source-brightness,
residual-gradient, and caustic-proximity density bases should concentrate pixels
where they buy scientific information, without ever changing topology.
Plan
- Audit the existing CDF implementation end-to-end and document what the current single-signal warp actually does, including the per-axis separability assumption and the bilinear interpolation operator's conditioning at low pixel counts.
- Define a composable density-component interface so multiple weighted bases (
floor + w1·rho_mag + w2·rho_brightness + w3·rho_residual + w4·rho_caustic) can drive the same CDF transform with no change to topology or array shapes.
- Implement candidate density bases as JAX-pure functions: magnification density (from deflection-field Jacobian), source-brightness density (from current reconstruction), residual-gradient density (from data-model residuals), caustic-proximity density (from |μ⁻¹| → 0 surfaces).
- Investigate the separability assumption — assess whether per-axis marginal CDFs of a multi-component density are sufficient, or whether low-rank / outer-product factorisations recover anisotropic concentration without breaking JAX.
- Build a developer-side benchmark comparing reconstruction quality vs Delaunay at matched compute budgets across a small grid of pixel counts (250, 500, 1000, 1500, 4000).
- Assess bottlenecks: bilinear operator conditioning at low N, sparsity / curvature-matrix structure, whether NNLS or positivity-constrained solves dominate, whether matrix-free CG becomes preferable.
- Decide on a shipping artefact: if any combination wins, propose a new mesh class (e.g.
RectangularMultiComponentAdapt) alongside the existing variants — do not modify the working spline classes.
Detailed implementation plan
Affected Repositories
- PyAutoArray (primary, library)
- (potential follow-up) autolens_workspace_developer — benchmark / developer script, separate task
Work Classification
Library (research + prototype). Workspace developer script is a follow-up task once a candidate basis combination demonstrates wins.
Branch Survey
| Repository |
Current Branch |
Dirty? |
| ./PyAutoArray |
main |
clean |
Suggested branch: feature/rectangular-adapt-cdf
Worktree root: ~/Code/PyAutoLabs-wt/rectangular-adapt-cdf/ (created later by /start_library)
Phase 1 — Audit existing CDF (no code changes)
Document, in autoarray/inversion/mesh/interpolator/rectangular_spline.py-adjacent
notes (a markdown file under PyAutoArray/files/ is fine — do not pollute the
package), exactly:
- How
mesh_weight_map flows from RectangularSplineAdaptImage.mesh_weight_map_from
through create_transforms_spline into the per-axis CDF (_build_inv_poly_*).
- The two normalisation steps: (mean / min-std) in
adaptive_rectangular_transformed_grid_from_spline and the inner unit-square
rescale.
- The bilinear scatter via
(N-3) * transform(scaled) + 1 and the
MeshGeometryRectangular pcolormesh path.
- Confirm whether
mesh_weight_map is genuinely treated as a per-point weight
(the cumsum-after-sort in _build_inv_poly_jax_impl says yes).
- Establish baseline measurements at 500 / 1000 / 4000 source pixels on one
reference dataset (existing test_autolens / autolens_workspace example) — record
reconstruction χ², log-evidence, peak per-pixel residual.
Phase 2 — Composable density-component interface
Add a small, pure-JAX module: autoarray/inversion/mesh/interpolator/density_components.py.
Each component is a callable (traced_points, context, xp) -> per_point_weight
where context carries whatever auxiliary fields the component needs
(e.g. magnification map, current reconstruction, residuals). Components must:
- Return shape
(N,), strictly positive, finite — xp.clip(w, eps, None).
- Be JAX-traceable end-to-end; no scipy callbacks, no kNN, no dynamic shapes.
- Be independently testable via a unit test in
test_autoarray/inversion/pixelization/interpolator/test_density_components.py.
Define a composition helper:
def compose_density(components, weights, floor: float, xp=np):
"""rho = floor + sum_k weights[k] * components[k](traced_points, context, xp)."""
The composed weight feeds the existing mesh_weight_map slot in
create_transforms_spline — no change required to the spline machinery itself.
Phase 3 — Candidate density bases
- Magnification density (
magnification_density_from): from |det(J)|⁻¹
on the source plane; high-μ regions get higher weight. Requires the
image-plane Jacobian which already exists in PyAutoGalaxy/PyAutoLens — for
PyAutoArray-side testing we mock or accept a precomputed magnification map
as part of context.
- Source-brightness density (
brightness_density_from): the existing
adapt-image path, refactored into a component so it composes uniformly.
- Residual-gradient density (
residual_gradient_density_from): finite
differences of (data − model) on the source plane; concentrates pixels where
the current reconstruction is failing. Requires the previous-iteration
reconstruction; first-iteration falls back to brightness or uniform.
- Caustic-proximity density (
caustic_proximity_density_from): proximity
to |μ⁻¹| → 0 surfaces; mathematically a smoothed inverse distance to the
caustic curve. Requires the magnification field; same context dependency
as (1).
For (1) and (4), avoid creating a hard cross-repo dependency — PyAutoArray must
not import autogalaxy/autolens. The component signature accepts the
magnification map as a precomputed input. The PyAutoGalaxy/Lens side will be
where this is wired in later (follow-up task).
Phase 4 — Investigate separability / low-rank factorisation
The existing implementation marginalises per-axis. For a single Gaussian-like
density this is fine; for a multi-component density with anisotropic structure
(e.g. caustic-proximity along a curve), separability may lose information.
Investigate:
- Whether the marginal x and y CDFs of the composite density recover sufficient
anisotropic concentration when validated against a non-separable reference
(full 2D CDF computed offline).
- Whether a rank-1 outer-product factorisation
rho(x,y) ≈ a(x) * b(y) + low-rank correction
is enough.
- Whether iterative proportional fitting (1–2 passes only, JAX-compatible)
closes the gap without breaking jit shapes.
This is a measurement phase — no permanent API yet.
Phase 5 — Benchmark vs Delaunay (developer script, follow-up task)
Defer the benchmark script itself to a follow-up task in
autolens_workspace_developer. This issue only carries the library-side
prototype and unit tests. Open the follow-up prompt once the density-component
API has stabilised.
Phase 6 — Decide shipping artefact
If a candidate composition demonstrably matches Delaunay at 1000 pixels:
- Promote it to a new mesh class
RectangularMultiComponentAdapt (alongside,
never replacing, RectangularSplineAdaptImage).
- Wire it into
pixelization.py and the mesh __init__.py re-exports.
- Add to the existing
test_rectangular_spline.py parity test suite.
If results are inconclusive: ship the density-component framework as a library
primitive (it has standalone value for further experimentation) and document
the negative findings.
Constraints (reiterated for reviewers)
- Fixed rectangular topology, fixed array shapes, JAX/JIT-compatible.
- Differentiable where practical.
- No Delaunay, no kNN/Wendland, no RBF, no scipy callbacks, no dynamic topology.
- Library unit tests stay numpy-only (per
feedback_no_jax_in_unit_tests —
cross-xp validation lives in workspace_test).
Key Files
PyAutoArray/autoarray/inversion/mesh/interpolator/rectangular_spline.py —
current CDF implementation; the audit anchor.
PyAutoArray/autoarray/inversion/mesh/interpolator/rectangular.py —
the linear-CDF baseline; useful for grad-jump comparisons.
PyAutoArray/autoarray/inversion/mesh/mesh/rectangular_spline_adapt_image.py
and .../rectangular_spline_adapt_density.py — call sites for the
composed-density mesh entry points.
- (new)
PyAutoArray/autoarray/inversion/mesh/interpolator/density_components.py
— composable bases + composition helper.
- (new)
PyAutoArray/test_autoarray/inversion/pixelization/interpolator/test_density_components.py
— unit tests for each component and the composer.
- (potentially new, only if Phase 6 ships a class)
.../mesh/mesh/rectangular_multi_component_adapt.py.
Out of scope
- Any change to
InterpolatorRectangularSpline arithmetic — it's working; we
only feed it a richer weight map.
- Any change to NNLS / curvature / regularisation code in this issue. If
Phase 4 shows the bilinear operator becomes ill-conditioned at low N, file
a separate prompt.
- Workspace tutorial updates and the developer benchmark script — separate
follow-up tasks.
Original Prompt
Click to expand starting prompt
We have an existing JAX-compatible adaptive rectangular source-plane implementation already in the codebase. The current implementation uses a CDF-style adaptive coordinate transform where rectangular pixels become progressively smaller in regions of interest while preserving a fixed rectangular topology. The implementation works scientifically, but currently requires relatively high source resolutions (~4000+ pixels) to recover detailed source structure.
Your task is NOT to redesign this from scratch. Instead:
- Inspect the existing implementation carefully.
- Understand exactly how the current adaptive coordinate transform works.
- Build on top of the existing approach to investigate more sophisticated adaptive-density formulations that retain:
- fixed rectangular topology,
- fixed array shapes,
- JAX/JIT compatibility,
- differentiability where possible,
- no Delaunay triangulation,
- no scipy spatial callbacks,
- no dynamic topology changes.
The key conceptual direction is:
Instead of adapting mesh connectivity (like Delaunay), adapt the coordinate system itself via smooth density-driven coordinate warps.
We believe this may preserve many advantages of adaptive Delaunay source planes (high effective resolution in important regions with relatively few pixels) while remaining far more accelerator/JAX friendly.
The current implementation likely already resembles:
density -> cumulative distribution -> adaptive rectangular edges
We now want to generalize this.
Please investigate architectures where the adaptive density field is constructed from multiple weighted components, for example:
rho(x,y) =
floor
- w1 * magnification_density
- w2 * source_brightness_density
- w3 * residual_gradient_density
- w4 * caustic_proximity_density
Key goals:
- Concentrate source-plane resolution where scientifically useful.
- Keep total source pixel count relatively low (~500-1500 if possible).
- Preserve full JAX compatibility.
- Maintain smooth coordinate warps rather than topology changes.
- Avoid the scientific/topological failure modes encountered with kNN/Wendland-style meshless interpolation.
- Keep the implementation differentiable where practical.
Important:
The adaptive rectangular topology itself is NOT the problem. The likely problem is how intelligently the pixels are distributed.
Please specifically investigate:
- Whether multiple density bases can be combined cleanly.
- Whether separable x/y marginal CDFs are sufficient.
- Whether low-rank or separable adaptive density fields are viable.
- Whether bilinear interpolation on warped grids gives sufficiently smooth gradients.
- Whether the source-plane interpolation operator remains well-conditioned at low pixel counts.
- Whether adaptive rectangular grids can recover Delaunay-like effective resolution while remaining JAX-native.
- Whether gradients wrt adaptivity weights or lens parameters remain tractable.
- Whether the implementation can remain matrix-free or sparse-friendly.
Please also assess:
- likely bottlenecks,
- memory scaling,
- sparsity structure,
- curvature matrix structure,
- whether NNLS or positivity-constrained solves become dominant,
- and whether matrix-free iterative methods become preferable.
Do NOT spend time pursuing:
- pure JAX Delaunay triangulation,
- kNN interpolation variants,
- RBF/Wendland meshless methods,
- or dynamic topology approaches.
The current hypothesis is that:
"adaptive coordinates with fixed topology" may be the correct JAX-native formulation for adaptive source reconstruction.
Start by locating and understanding the existing adaptive rectangular implementation in detail before proposing modifications
Overview
The existing JAX-compatible adaptive rectangular source-plane mesh in PyAutoArray
warps source-plane coordinates through a separable per-axis empirical CDF
(implemented as a degree-11 polynomial + cubic-Hermite spline inverse in
InterpolatorRectangularSpline). Scientifically this works, but recovers detailedsource structure only at high resolutions (~4000+ pixels). This task investigates
whether a multi-component, physically-motivated density field driving the same
CDF warp can match Delaunay-like effective resolution at 500–1500 pixels while
preserving fixed rectangular topology, fixed array shapes, full JAX/JIT
compatibility, and differentiability.
The hypothesis: the adaptive rectangular topology is not the problem — pixel
distribution is. Replacing the single density signal (point-density or
adapt-image weights) with a weighted sum of magnification, source-brightness,
residual-gradient, and caustic-proximity density bases should concentrate pixels
where they buy scientific information, without ever changing topology.
Plan
floor + w1·rho_mag + w2·rho_brightness + w3·rho_residual + w4·rho_caustic) can drive the same CDF transform with no change to topology or array shapes.RectangularMultiComponentAdapt) alongside the existing variants — do not modify the working spline classes.Detailed implementation plan
Affected Repositories
Work Classification
Library (research + prototype). Workspace developer script is a follow-up task once a candidate basis combination demonstrates wins.
Branch Survey
Suggested branch:
feature/rectangular-adapt-cdfWorktree root:
~/Code/PyAutoLabs-wt/rectangular-adapt-cdf/(created later by/start_library)Phase 1 — Audit existing CDF (no code changes)
Document, in
autoarray/inversion/mesh/interpolator/rectangular_spline.py-adjacentnotes (a markdown file under
PyAutoArray/files/is fine — do not pollute thepackage), exactly:
mesh_weight_mapflows fromRectangularSplineAdaptImage.mesh_weight_map_fromthrough
create_transforms_splineinto the per-axis CDF (_build_inv_poly_*).adaptive_rectangular_transformed_grid_from_splineand the inner unit-squarerescale.
(N-3) * transform(scaled) + 1and theMeshGeometryRectangularpcolormesh path.mesh_weight_mapis genuinely treated as a per-point weight(the
cumsum-after-sort in_build_inv_poly_jax_implsays yes).reference dataset (existing test_autolens / autolens_workspace example) — record
reconstruction χ², log-evidence, peak per-pixel residual.
Phase 2 — Composable density-component interface
Add a small, pure-JAX module:
autoarray/inversion/mesh/interpolator/density_components.py.Each component is a callable
(traced_points, context, xp) -> per_point_weightwhere
contextcarries whatever auxiliary fields the component needs(e.g. magnification map, current reconstruction, residuals). Components must:
(N,), strictly positive, finite —xp.clip(w, eps, None).test_autoarray/inversion/pixelization/interpolator/test_density_components.py.Define a composition helper:
The composed weight feeds the existing
mesh_weight_mapslot increate_transforms_spline— no change required to the spline machinery itself.Phase 3 — Candidate density bases
magnification_density_from): from|det(J)|⁻¹on the source plane; high-μ regions get higher weight. Requires the
image-plane Jacobian which already exists in PyAutoGalaxy/PyAutoLens — for
PyAutoArray-side testing we mock or accept a precomputed magnification map
as part of
context.brightness_density_from): the existingadapt-image path, refactored into a component so it composes uniformly.
residual_gradient_density_from): finitedifferences of (data − model) on the source plane; concentrates pixels where
the current reconstruction is failing. Requires the previous-iteration
reconstruction; first-iteration falls back to brightness or uniform.
caustic_proximity_density_from): proximityto
|μ⁻¹| → 0surfaces; mathematically a smoothed inverse distance to thecaustic curve. Requires the magnification field; same context dependency
as (1).
For (1) and (4), avoid creating a hard cross-repo dependency — PyAutoArray must
not import autogalaxy/autolens. The component signature accepts the
magnification map as a precomputed input. The PyAutoGalaxy/Lens side will be
where this is wired in later (follow-up task).
Phase 4 — Investigate separability / low-rank factorisation
The existing implementation marginalises per-axis. For a single Gaussian-like
density this is fine; for a multi-component density with anisotropic structure
(e.g. caustic-proximity along a curve), separability may lose information.
Investigate:
anisotropic concentration when validated against a non-separable reference
(full 2D CDF computed offline).
rho(x,y) ≈ a(x) * b(y) + low-rank correctionis enough.
closes the gap without breaking jit shapes.
This is a measurement phase — no permanent API yet.
Phase 5 — Benchmark vs Delaunay (developer script, follow-up task)
Defer the benchmark script itself to a follow-up task in
autolens_workspace_developer. This issue only carries the library-sideprototype and unit tests. Open the follow-up prompt once the density-component
API has stabilised.
Phase 6 — Decide shipping artefact
If a candidate composition demonstrably matches Delaunay at 1000 pixels:
RectangularMultiComponentAdapt(alongside,never replacing,
RectangularSplineAdaptImage).pixelization.pyand the mesh__init__.pyre-exports.test_rectangular_spline.pyparity test suite.If results are inconclusive: ship the density-component framework as a library
primitive (it has standalone value for further experimentation) and document
the negative findings.
Constraints (reiterated for reviewers)
feedback_no_jax_in_unit_tests—cross-xp validation lives in workspace_test).
Key Files
PyAutoArray/autoarray/inversion/mesh/interpolator/rectangular_spline.py—current CDF implementation; the audit anchor.
PyAutoArray/autoarray/inversion/mesh/interpolator/rectangular.py—the linear-CDF baseline; useful for grad-jump comparisons.
PyAutoArray/autoarray/inversion/mesh/mesh/rectangular_spline_adapt_image.pyand
.../rectangular_spline_adapt_density.py— call sites for thecomposed-density mesh entry points.
PyAutoArray/autoarray/inversion/mesh/interpolator/density_components.py— composable bases + composition helper.
PyAutoArray/test_autoarray/inversion/pixelization/interpolator/test_density_components.py— unit tests for each component and the composer.
.../mesh/mesh/rectangular_multi_component_adapt.py.Out of scope
InterpolatorRectangularSplinearithmetic — it's working; weonly feed it a richer weight map.
Phase 4 shows the bilinear operator becomes ill-conditioned at low N, file
a separate prompt.
follow-up tasks.
Original Prompt
Click to expand starting prompt
We have an existing JAX-compatible adaptive rectangular source-plane implementation already in the codebase. The current implementation uses a CDF-style adaptive coordinate transform where rectangular pixels become progressively smaller in regions of interest while preserving a fixed rectangular topology. The implementation works scientifically, but currently requires relatively high source resolutions (~4000+ pixels) to recover detailed source structure.
Your task is NOT to redesign this from scratch. Instead:
The key conceptual direction is:
Instead of adapting mesh connectivity (like Delaunay), adapt the coordinate system itself via smooth density-driven coordinate warps.
We believe this may preserve many advantages of adaptive Delaunay source planes (high effective resolution in important regions with relatively few pixels) while remaining far more accelerator/JAX friendly.
The current implementation likely already resembles:
density -> cumulative distribution -> adaptive rectangular edges
We now want to generalize this.
Please investigate architectures where the adaptive density field is constructed from multiple weighted components, for example:
rho(x,y) =
floor
Key goals:
Important:
The adaptive rectangular topology itself is NOT the problem. The likely problem is how intelligently the pixels are distributed.
Please specifically investigate:
Please also assess:
Do NOT spend time pursuing:
The current hypothesis is that:
"adaptive coordinates with fixed topology" may be the correct JAX-native formulation for adaptive source reconstruction.
Start by locating and understanding the existing adaptive rectangular implementation in detail before proposing modifications