Skip to content

refactor: AnalysisPoint auto-configures PointSolver when use_jax=True #466

@Jammy2211

Description

@Jammy2211

Overview

Today, enabling JAX on a point-source Analysis requires two separate user actions: passing xp=jnp to PointSolver.for_grid AND passing use_jax=True to AnalysisPoint. This is inconsistent with AnalysisImaging / AnalysisInterferometer, where use_jax=True on the Analysis is the only knob the user touches.

The fix refactors PointSolver to be stateless with respect to the array module. xp flows as a per-call argument to .solve(), owned by AnalysisPoint.fit_from — matching the AnalysisImaging → FitImaging pattern. Users write PointSolver.for_grid(grid, ...) without xp= and AnalysisPoint(..., use_jax=True) handles the rest.

Design note (scope expansion from original)

An earlier, simpler plan ("flip solver.use_jax = True post-hoc in AnalysisPoint.__init__") was ruled out during code review. AbstractSolver.initial_triangles is constructed eagerly with CoordinateArrayTriangles (JAX) or CoordinateArrayTrianglesNp (numpy) depending on xp; flipping the flag post-construction would leave the solver holding numpy triangles while _xp returned jnp, silently corrupting state. Making the solver genuinely stateless wrt xp is the only correctness-preserving design.

Plan

  • PointSolver stops storing initial_triangles. Instead it stores the geometric limits (y_min, y_max, x_min, x_max, scale) as primitives.
  • At .solve(xp=...) time, the correct triangle class is built from those limits. xp threads through _filter_low_magnification, tracer.deflections_between_planes_from, and the final aa.Grid2DIrregular construction.
  • AnalysisPoint.fit_from passes xp=self._xp into the solver via the existing FitPointDatasetFitPositions* chain.
  • PointSolver.for_grid(..., xp=) drops the xp parameter entirely. Not backwards-compatible.
  • All ~38 workspace call sites that pass xp=jnp are cleaned up in the same coordinated release.
Detailed implementation plan

Affected Repositories

  • Jammy2211/PyAutoLens — library refactor
  • Jammy2211/autolens_workspace — workspace cleanup (~N call sites)
  • PyAutoLabs/autolens_workspace_developer — workspace cleanup (~N call sites)
  • Jammy2211/autolens_workspace_test — workspace cleanup (~N call sites)

All four repos are on feature/point-solver-auto-jax in the task worktree at ~/Code/PyAutoLabs-wt/point-solver-auto-jax/.

Work Classification

Library + workspace — coordinated ship. Library PR gates workspace PRs.

Implementation Steps

1. autolens/point/solver/shape_solver.pyAbstractSolver

  • __init__ accepts y_min, y_max, x_min, x_max, scale, pixel_scale_precision, magnification_threshold, neighbor_degree (drop initial_triangles and xp params).
  • Store geometry primitives on self. Remove use_jax / _xp attributes entirely.
  • Add _initial_triangles(xp) helper that builds the correct triangle class on demand.
  • for_grid / for_limits_and_scale factories drop the xp kwarg.

2. AbstractSolver.solve_triangles and PointSolver.solve (point_solver.py)

  • Accept required xp param.
  • Call self._initial_triangles(xp) at the top of the solve loop.
  • Thread xp through _filter_low_magnification, tracer.deflections_between_planes_from(xp=...), and the final aa.Grid2DIrregular(..., xp=xp).

3. autolens/point/fit/positions/image/abstract.py and autolens/point/fit/positions/source/*

  • FitPositionsImagePairAll.model_data, FitPositionsImagePairRepeat.model_data, FitPositionsSource.model_data — each calls self.solver.solve(xp=self._xp).

4. autolens/point/fit/dataset.pyFitPointDataset

  • Ensure xp threads from constructor into FitPositions* children.

5. autolens/point/model/analysis.pyAnalysisPoint

  • Verify fit_from already passes xp=self._xp to FitPointDataset.

6. Workspace cleanup (autolens_workspace, autolens_workspace_developer, autolens_workspace_test)

  • Remove xp=jnp from every PointSolver.for_grid(...) call site.
  • Remove any redundant import jax.numpy as jnp that was only there for the solver.
  • Keep use_jax=True on AnalysisPoint.

Tests

  • test_autolens/point/triangles/test_solver.py — the same PointSolver instance called as .solve(xp=np) and .solve(xp=jnp) yields equivalent magnification / positions within tolerance.
  • test_autolens/point/triangles/test_jax.py — update to use call-time xp.
  • test_autolens/point/model/test_analysis_point.py — regression: AnalysisPoint(dataset, solver, use_jax=True) with solver = PointSolver.for_grid(grid, ...) (no xp=) produces a JAX-traceable fit. Builds a real PointSolver (not MockPointSolver) for this one test.

Key Files

  • autolens/point/solver/shape_solver.py
  • autolens/point/solver/point_solver.py
  • autolens/point/fit/positions/image/abstract.py
  • autolens/point/fit/positions/source/*
  • autolens/point/fit/dataset.py
  • test_autolens/point/triangles/test_solver.py
  • test_autolens/point/model/test_analysis_point.py
  • Workspace call sites across all three workspace repos (~38)

Out of Scope

  • Pytree registration of FitPointDataset / FitPositions* — tracked by fit_point_pytree.md.
  • Source-plane Grid2DIrregular.grid_2d_via_deflection_grid_from xp-propagation blocker.

Risk / trade-offs considered

  • Rebuilding triangles per .solve() call is O(n_coords); negligible relative to ray-tracing.
  • @cached_property on CoordinateArrayTrianglesNp.triangles is lost — it only saved one access to initial triangles anyway.
  • No backwards compatibility: library + all workspace PRs must ship together.

Original Prompt

Click to expand starting prompt

Narrow subtask of admin_jammy/prompt/issued/fit_point_pytree.md:

The clunk in the point-source JAX API is that the user has to pass xp=jnp to PointSolver.for_grid AND pass use_jax=True to AnalysisPoint. Other Analysis classes only require use_jax=True. Example of the clunk today:

solver_jax = al.PointSolver.for_grid(
    grid=grid,
    pixel_scale_precision=0.001,
    magnification_threshold=0.1,
    xp=jnp,
)
analysis = al.AnalysisPoint(
    dataset=dataset,
    solver=solver_jax,
    use_jax=True,
)

Fix the API so passing use_jax=True to AnalysisPoint is sufficient — the solver is auto-configured. No user import of jnp required.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions