Skip to content

feat: multi-component density CDF warp for adaptive rectangular source mesh #322

@Jammy2211

Description

@Jammy2211

Overview

The existing JAX-compatible adaptive rectangular source-plane mesh in PyAutoArray
warps source-plane coordinates through a separable per-axis empirical CDF
(implemented as a degree-11 polynomial + cubic-Hermite spline inverse in
InterpolatorRectangularSpline). Scientifically this works, but recovers detailed
source structure only at high resolutions (~4000+ pixels). This task investigates
whether a multi-component, physically-motivated density field driving the same
CDF warp can match Delaunay-like effective resolution at 500–1500 pixels while
preserving fixed rectangular topology, fixed array shapes, full JAX/JIT
compatibility, and differentiability.

The hypothesis: the adaptive rectangular topology is not the problem — pixel
distribution is. Replacing the single density signal (point-density or
adapt-image weights) with a weighted sum of magnification, source-brightness,
residual-gradient, and caustic-proximity density bases should concentrate pixels
where they buy scientific information, without ever changing topology.

Plan

  • Audit the existing CDF implementation end-to-end and document what the current single-signal warp actually does, including the per-axis separability assumption and the bilinear interpolation operator's conditioning at low pixel counts.
  • Define a composable density-component interface so multiple weighted bases (floor + w1·rho_mag + w2·rho_brightness + w3·rho_residual + w4·rho_caustic) can drive the same CDF transform with no change to topology or array shapes.
  • Implement candidate density bases as JAX-pure functions: magnification density (from deflection-field Jacobian), source-brightness density (from current reconstruction), residual-gradient density (from data-model residuals), caustic-proximity density (from |μ⁻¹| → 0 surfaces).
  • Investigate the separability assumption — assess whether per-axis marginal CDFs of a multi-component density are sufficient, or whether low-rank / outer-product factorisations recover anisotropic concentration without breaking JAX.
  • Build a developer-side benchmark comparing reconstruction quality vs Delaunay at matched compute budgets across a small grid of pixel counts (250, 500, 1000, 1500, 4000).
  • Assess bottlenecks: bilinear operator conditioning at low N, sparsity / curvature-matrix structure, whether NNLS or positivity-constrained solves dominate, whether matrix-free CG becomes preferable.
  • Decide on a shipping artefact: if any combination wins, propose a new mesh class (e.g. RectangularMultiComponentAdapt) alongside the existing variants — do not modify the working spline classes.
Detailed implementation plan

Affected Repositories

  • PyAutoArray (primary, library)
  • (potential follow-up) autolens_workspace_developer — benchmark / developer script, separate task

Work Classification

Library (research + prototype). Workspace developer script is a follow-up task once a candidate basis combination demonstrates wins.

Branch Survey

Repository Current Branch Dirty?
./PyAutoArray main clean

Suggested branch: feature/rectangular-adapt-cdf
Worktree root: ~/Code/PyAutoLabs-wt/rectangular-adapt-cdf/ (created later by /start_library)

Phase 1 — Audit existing CDF (no code changes)

Document, in autoarray/inversion/mesh/interpolator/rectangular_spline.py-adjacent
notes (a markdown file under PyAutoArray/files/ is fine — do not pollute the
package), exactly:

  • How mesh_weight_map flows from RectangularSplineAdaptImage.mesh_weight_map_from
    through create_transforms_spline into the per-axis CDF (_build_inv_poly_*).
  • The two normalisation steps: (mean / min-std) in
    adaptive_rectangular_transformed_grid_from_spline and the inner unit-square
    rescale.
  • The bilinear scatter via (N-3) * transform(scaled) + 1 and the
    MeshGeometryRectangular pcolormesh path.
  • Confirm whether mesh_weight_map is genuinely treated as a per-point weight
    (the cumsum-after-sort in _build_inv_poly_jax_impl says yes).
  • Establish baseline measurements at 500 / 1000 / 4000 source pixels on one
    reference dataset (existing test_autolens / autolens_workspace example) — record
    reconstruction χ², log-evidence, peak per-pixel residual.

Phase 2 — Composable density-component interface

Add a small, pure-JAX module: autoarray/inversion/mesh/interpolator/density_components.py.

Each component is a callable (traced_points, context, xp) -> per_point_weight
where context carries whatever auxiliary fields the component needs
(e.g. magnification map, current reconstruction, residuals). Components must:

  • Return shape (N,), strictly positive, finite — xp.clip(w, eps, None).
  • Be JAX-traceable end-to-end; no scipy callbacks, no kNN, no dynamic shapes.
  • Be independently testable via a unit test in
    test_autoarray/inversion/pixelization/interpolator/test_density_components.py.

Define a composition helper:

def compose_density(components, weights, floor: float, xp=np):
    """rho = floor + sum_k weights[k] * components[k](traced_points, context, xp)."""

The composed weight feeds the existing mesh_weight_map slot in
create_transforms_spline — no change required to the spline machinery itself.

Phase 3 — Candidate density bases

  1. Magnification density (magnification_density_from): from |det(J)|⁻¹
    on the source plane; high-μ regions get higher weight. Requires the
    image-plane Jacobian which already exists in PyAutoGalaxy/PyAutoLens — for
    PyAutoArray-side testing we mock or accept a precomputed magnification map
    as part of context.
  2. Source-brightness density (brightness_density_from): the existing
    adapt-image path, refactored into a component so it composes uniformly.
  3. Residual-gradient density (residual_gradient_density_from): finite
    differences of (data − model) on the source plane; concentrates pixels where
    the current reconstruction is failing. Requires the previous-iteration
    reconstruction; first-iteration falls back to brightness or uniform.
  4. Caustic-proximity density (caustic_proximity_density_from): proximity
    to |μ⁻¹| → 0 surfaces; mathematically a smoothed inverse distance to the
    caustic curve. Requires the magnification field; same context dependency
    as (1).

For (1) and (4), avoid creating a hard cross-repo dependency — PyAutoArray must
not import autogalaxy/autolens. The component signature accepts the
magnification map as a precomputed input. The PyAutoGalaxy/Lens side will be
where this is wired in later (follow-up task).

Phase 4 — Investigate separability / low-rank factorisation

The existing implementation marginalises per-axis. For a single Gaussian-like
density this is fine; for a multi-component density with anisotropic structure
(e.g. caustic-proximity along a curve), separability may lose information.

Investigate:

  • Whether the marginal x and y CDFs of the composite density recover sufficient
    anisotropic concentration when validated against a non-separable reference
    (full 2D CDF computed offline).
  • Whether a rank-1 outer-product factorisation rho(x,y) ≈ a(x) * b(y) + low-rank correction
    is enough.
  • Whether iterative proportional fitting (1–2 passes only, JAX-compatible)
    closes the gap without breaking jit shapes.

This is a measurement phase — no permanent API yet.

Phase 5 — Benchmark vs Delaunay (developer script, follow-up task)

Defer the benchmark script itself to a follow-up task in
autolens_workspace_developer. This issue only carries the library-side
prototype and unit tests. Open the follow-up prompt once the density-component
API has stabilised.

Phase 6 — Decide shipping artefact

If a candidate composition demonstrably matches Delaunay at 1000 pixels:

  • Promote it to a new mesh class RectangularMultiComponentAdapt (alongside,
    never replacing, RectangularSplineAdaptImage).
  • Wire it into pixelization.py and the mesh __init__.py re-exports.
  • Add to the existing test_rectangular_spline.py parity test suite.

If results are inconclusive: ship the density-component framework as a library
primitive (it has standalone value for further experimentation) and document
the negative findings.

Constraints (reiterated for reviewers)

  • Fixed rectangular topology, fixed array shapes, JAX/JIT-compatible.
  • Differentiable where practical.
  • No Delaunay, no kNN/Wendland, no RBF, no scipy callbacks, no dynamic topology.
  • Library unit tests stay numpy-only (per feedback_no_jax_in_unit_tests
    cross-xp validation lives in workspace_test).

Key Files

  • PyAutoArray/autoarray/inversion/mesh/interpolator/rectangular_spline.py
    current CDF implementation; the audit anchor.
  • PyAutoArray/autoarray/inversion/mesh/interpolator/rectangular.py
    the linear-CDF baseline; useful for grad-jump comparisons.
  • PyAutoArray/autoarray/inversion/mesh/mesh/rectangular_spline_adapt_image.py
    and .../rectangular_spline_adapt_density.py — call sites for the
    composed-density mesh entry points.
  • (new) PyAutoArray/autoarray/inversion/mesh/interpolator/density_components.py
    — composable bases + composition helper.
  • (new) PyAutoArray/test_autoarray/inversion/pixelization/interpolator/test_density_components.py
    — unit tests for each component and the composer.
  • (potentially new, only if Phase 6 ships a class) .../mesh/mesh/rectangular_multi_component_adapt.py.

Out of scope

  • Any change to InterpolatorRectangularSpline arithmetic — it's working; we
    only feed it a richer weight map.
  • Any change to NNLS / curvature / regularisation code in this issue. If
    Phase 4 shows the bilinear operator becomes ill-conditioned at low N, file
    a separate prompt.
  • Workspace tutorial updates and the developer benchmark script — separate
    follow-up tasks.

Original Prompt

Click to expand starting prompt

We have an existing JAX-compatible adaptive rectangular source-plane implementation already in the codebase. The current implementation uses a CDF-style adaptive coordinate transform where rectangular pixels become progressively smaller in regions of interest while preserving a fixed rectangular topology. The implementation works scientifically, but currently requires relatively high source resolutions (~4000+ pixels) to recover detailed source structure.

Your task is NOT to redesign this from scratch. Instead:

  1. Inspect the existing implementation carefully.
  2. Understand exactly how the current adaptive coordinate transform works.
  3. Build on top of the existing approach to investigate more sophisticated adaptive-density formulations that retain:
    • fixed rectangular topology,
    • fixed array shapes,
    • JAX/JIT compatibility,
    • differentiability where possible,
    • no Delaunay triangulation,
    • no scipy spatial callbacks,
    • no dynamic topology changes.

The key conceptual direction is:

Instead of adapting mesh connectivity (like Delaunay), adapt the coordinate system itself via smooth density-driven coordinate warps.

We believe this may preserve many advantages of adaptive Delaunay source planes (high effective resolution in important regions with relatively few pixels) while remaining far more accelerator/JAX friendly.

The current implementation likely already resembles:

density -> cumulative distribution -> adaptive rectangular edges

We now want to generalize this.

Please investigate architectures where the adaptive density field is constructed from multiple weighted components, for example:

rho(x,y) =
floor

  • w1 * magnification_density
  • w2 * source_brightness_density
  • w3 * residual_gradient_density
  • w4 * caustic_proximity_density

Key goals:

  • Concentrate source-plane resolution where scientifically useful.
  • Keep total source pixel count relatively low (~500-1500 if possible).
  • Preserve full JAX compatibility.
  • Maintain smooth coordinate warps rather than topology changes.
  • Avoid the scientific/topological failure modes encountered with kNN/Wendland-style meshless interpolation.
  • Keep the implementation differentiable where practical.

Important:
The adaptive rectangular topology itself is NOT the problem. The likely problem is how intelligently the pixels are distributed.

Please specifically investigate:

  1. Whether multiple density bases can be combined cleanly.
  2. Whether separable x/y marginal CDFs are sufficient.
  3. Whether low-rank or separable adaptive density fields are viable.
  4. Whether bilinear interpolation on warped grids gives sufficiently smooth gradients.
  5. Whether the source-plane interpolation operator remains well-conditioned at low pixel counts.
  6. Whether adaptive rectangular grids can recover Delaunay-like effective resolution while remaining JAX-native.
  7. Whether gradients wrt adaptivity weights or lens parameters remain tractable.
  8. Whether the implementation can remain matrix-free or sparse-friendly.

Please also assess:

  • likely bottlenecks,
  • memory scaling,
  • sparsity structure,
  • curvature matrix structure,
  • whether NNLS or positivity-constrained solves become dominant,
  • and whether matrix-free iterative methods become preferable.

Do NOT spend time pursuing:

  • pure JAX Delaunay triangulation,
  • kNN interpolation variants,
  • RBF/Wendland meshless methods,
  • or dynamic topology approaches.

The current hypothesis is that:
"adaptive coordinates with fixed topology" may be the correct JAX-native formulation for adaptive source reconstruction.

Start by locating and understanding the existing adaptive rectangular implementation in detail before proposing modifications

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions