Overview
Today, enabling JAX on a point-source Analysis requires two separate user actions: passing xp=jnp to PointSolver.for_grid AND passing use_jax=True to AnalysisPoint. This is inconsistent with AnalysisImaging / AnalysisInterferometer, where use_jax=True on the Analysis is the only knob the user touches.
The fix refactors PointSolver to be stateless with respect to the array module. xp flows as a per-call argument to .solve(), owned by AnalysisPoint.fit_from — matching the AnalysisImaging → FitImaging pattern. Users write PointSolver.for_grid(grid, ...) without xp= and AnalysisPoint(..., use_jax=True) handles the rest.
Design note (scope expansion from original)
An earlier, simpler plan ("flip solver.use_jax = True post-hoc in AnalysisPoint.__init__") was ruled out during code review. AbstractSolver.initial_triangles is constructed eagerly with CoordinateArrayTriangles (JAX) or CoordinateArrayTrianglesNp (numpy) depending on xp; flipping the flag post-construction would leave the solver holding numpy triangles while _xp returned jnp, silently corrupting state. Making the solver genuinely stateless wrt xp is the only correctness-preserving design.
Plan
PointSolver stops storing initial_triangles. Instead it stores the geometric limits (y_min, y_max, x_min, x_max, scale) as primitives.
- At
.solve(xp=...) time, the correct triangle class is built from those limits. xp threads through _filter_low_magnification, tracer.deflections_between_planes_from, and the final aa.Grid2DIrregular construction.
AnalysisPoint.fit_from passes xp=self._xp into the solver via the existing FitPointDataset → FitPositions* chain.
PointSolver.for_grid(..., xp=) drops the xp parameter entirely. Not backwards-compatible.
- All ~38 workspace call sites that pass
xp=jnp are cleaned up in the same coordinated release.
Detailed implementation plan
Affected Repositories
- Jammy2211/PyAutoLens — library refactor
- Jammy2211/autolens_workspace — workspace cleanup (~N call sites)
- PyAutoLabs/autolens_workspace_developer — workspace cleanup (~N call sites)
- Jammy2211/autolens_workspace_test — workspace cleanup (~N call sites)
All four repos are on feature/point-solver-auto-jax in the task worktree at ~/Code/PyAutoLabs-wt/point-solver-auto-jax/.
Work Classification
Library + workspace — coordinated ship. Library PR gates workspace PRs.
Implementation Steps
1. autolens/point/solver/shape_solver.py — AbstractSolver
__init__ accepts y_min, y_max, x_min, x_max, scale, pixel_scale_precision, magnification_threshold, neighbor_degree (drop initial_triangles and xp params).
- Store geometry primitives on
self. Remove use_jax / _xp attributes entirely.
- Add
_initial_triangles(xp) helper that builds the correct triangle class on demand.
for_grid / for_limits_and_scale factories drop the xp kwarg.
2. AbstractSolver.solve_triangles and PointSolver.solve (point_solver.py)
- Accept required
xp param.
- Call
self._initial_triangles(xp) at the top of the solve loop.
- Thread
xp through _filter_low_magnification, tracer.deflections_between_planes_from(xp=...), and the final aa.Grid2DIrregular(..., xp=xp).
3. autolens/point/fit/positions/image/abstract.py and autolens/point/fit/positions/source/*
FitPositionsImagePairAll.model_data, FitPositionsImagePairRepeat.model_data, FitPositionsSource.model_data — each calls self.solver.solve(xp=self._xp).
4. autolens/point/fit/dataset.py — FitPointDataset
- Ensure
xp threads from constructor into FitPositions* children.
5. autolens/point/model/analysis.py — AnalysisPoint
- Verify
fit_from already passes xp=self._xp to FitPointDataset.
6. Workspace cleanup (autolens_workspace, autolens_workspace_developer, autolens_workspace_test)
- Remove
xp=jnp from every PointSolver.for_grid(...) call site.
- Remove any redundant
import jax.numpy as jnp that was only there for the solver.
- Keep
use_jax=True on AnalysisPoint.
Tests
test_autolens/point/triangles/test_solver.py — the same PointSolver instance called as .solve(xp=np) and .solve(xp=jnp) yields equivalent magnification / positions within tolerance.
test_autolens/point/triangles/test_jax.py — update to use call-time xp.
test_autolens/point/model/test_analysis_point.py — regression: AnalysisPoint(dataset, solver, use_jax=True) with solver = PointSolver.for_grid(grid, ...) (no xp=) produces a JAX-traceable fit. Builds a real PointSolver (not MockPointSolver) for this one test.
Key Files
autolens/point/solver/shape_solver.py
autolens/point/solver/point_solver.py
autolens/point/fit/positions/image/abstract.py
autolens/point/fit/positions/source/*
autolens/point/fit/dataset.py
test_autolens/point/triangles/test_solver.py
test_autolens/point/model/test_analysis_point.py
- Workspace call sites across all three workspace repos (~38)
Out of Scope
- Pytree registration of
FitPointDataset / FitPositions* — tracked by fit_point_pytree.md.
- Source-plane
Grid2DIrregular.grid_2d_via_deflection_grid_from xp-propagation blocker.
Risk / trade-offs considered
- Rebuilding triangles per
.solve() call is O(n_coords); negligible relative to ray-tracing.
@cached_property on CoordinateArrayTrianglesNp.triangles is lost — it only saved one access to initial triangles anyway.
- No backwards compatibility: library + all workspace PRs must ship together.
Original Prompt
Click to expand starting prompt
Narrow subtask of admin_jammy/prompt/issued/fit_point_pytree.md:
The clunk in the point-source JAX API is that the user has to pass xp=jnp to PointSolver.for_grid AND pass use_jax=True to AnalysisPoint. Other Analysis classes only require use_jax=True. Example of the clunk today:
solver_jax = al.PointSolver.for_grid(
grid=grid,
pixel_scale_precision=0.001,
magnification_threshold=0.1,
xp=jnp,
)
analysis = al.AnalysisPoint(
dataset=dataset,
solver=solver_jax,
use_jax=True,
)
Fix the API so passing use_jax=True to AnalysisPoint is sufficient — the solver is auto-configured. No user import of jnp required.
Overview
Today, enabling JAX on a point-source Analysis requires two separate user actions: passing
xp=jnptoPointSolver.for_gridAND passinguse_jax=TruetoAnalysisPoint. This is inconsistent withAnalysisImaging/AnalysisInterferometer, whereuse_jax=Trueon the Analysis is the only knob the user touches.The fix refactors
PointSolverto be stateless with respect to the array module.xpflows as a per-call argument to.solve(), owned byAnalysisPoint.fit_from— matching theAnalysisImaging → FitImagingpattern. Users writePointSolver.for_grid(grid, ...)withoutxp=andAnalysisPoint(..., use_jax=True)handles the rest.Design note (scope expansion from original)
An earlier, simpler plan ("flip
solver.use_jax = Truepost-hoc inAnalysisPoint.__init__") was ruled out during code review.AbstractSolver.initial_trianglesis constructed eagerly withCoordinateArrayTriangles(JAX) orCoordinateArrayTrianglesNp(numpy) depending onxp; flipping the flag post-construction would leave the solver holding numpy triangles while_xpreturnedjnp, silently corrupting state. Making the solver genuinely stateless wrtxpis the only correctness-preserving design.Plan
PointSolverstops storinginitial_triangles. Instead it stores the geometric limits(y_min, y_max, x_min, x_max, scale)as primitives..solve(xp=...)time, the correct triangle class is built from those limits.xpthreads through_filter_low_magnification,tracer.deflections_between_planes_from, and the finalaa.Grid2DIrregularconstruction.AnalysisPoint.fit_frompassesxp=self._xpinto the solver via the existingFitPointDataset→FitPositions*chain.PointSolver.for_grid(..., xp=)drops thexpparameter entirely. Not backwards-compatible.xp=jnpare cleaned up in the same coordinated release.Detailed implementation plan
Affected Repositories
All four repos are on
feature/point-solver-auto-jaxin the task worktree at~/Code/PyAutoLabs-wt/point-solver-auto-jax/.Work Classification
Library + workspace — coordinated ship. Library PR gates workspace PRs.
Implementation Steps
1.
autolens/point/solver/shape_solver.py—AbstractSolver__init__acceptsy_min, y_max, x_min, x_max, scale, pixel_scale_precision, magnification_threshold, neighbor_degree(dropinitial_trianglesandxpparams).self. Removeuse_jax/_xpattributes entirely._initial_triangles(xp)helper that builds the correct triangle class on demand.for_grid/for_limits_and_scalefactories drop thexpkwarg.2.
AbstractSolver.solve_trianglesandPointSolver.solve(point_solver.py)xpparam.self._initial_triangles(xp)at the top of the solve loop.xpthrough_filter_low_magnification,tracer.deflections_between_planes_from(xp=...), and the finalaa.Grid2DIrregular(..., xp=xp).3.
autolens/point/fit/positions/image/abstract.pyandautolens/point/fit/positions/source/*FitPositionsImagePairAll.model_data,FitPositionsImagePairRepeat.model_data,FitPositionsSource.model_data— each callsself.solver.solve(xp=self._xp).4.
autolens/point/fit/dataset.py—FitPointDatasetxpthreads from constructor intoFitPositions*children.5.
autolens/point/model/analysis.py—AnalysisPointfit_fromalready passesxp=self._xptoFitPointDataset.6. Workspace cleanup (autolens_workspace, autolens_workspace_developer, autolens_workspace_test)
xp=jnpfrom everyPointSolver.for_grid(...)call site.import jax.numpy as jnpthat was only there for the solver.use_jax=TrueonAnalysisPoint.Tests
test_autolens/point/triangles/test_solver.py— the samePointSolverinstance called as.solve(xp=np)and.solve(xp=jnp)yields equivalent magnification / positions within tolerance.test_autolens/point/triangles/test_jax.py— update to use call-timexp.test_autolens/point/model/test_analysis_point.py— regression:AnalysisPoint(dataset, solver, use_jax=True)withsolver = PointSolver.for_grid(grid, ...)(noxp=) produces a JAX-traceable fit. Builds a realPointSolver(notMockPointSolver) for this one test.Key Files
autolens/point/solver/shape_solver.pyautolens/point/solver/point_solver.pyautolens/point/fit/positions/image/abstract.pyautolens/point/fit/positions/source/*autolens/point/fit/dataset.pytest_autolens/point/triangles/test_solver.pytest_autolens/point/model/test_analysis_point.pyOut of Scope
FitPointDataset/FitPositions*— tracked byfit_point_pytree.md.Grid2DIrregular.grid_2d_via_deflection_grid_fromxp-propagation blocker.Risk / trade-offs considered
.solve()call is O(n_coords); negligible relative to ray-tracing.@cached_propertyonCoordinateArrayTrianglesNp.trianglesis lost — it only saved one access to initial triangles anyway.Original Prompt
Click to expand starting prompt
Narrow subtask of
admin_jammy/prompt/issued/fit_point_pytree.md:The clunk in the point-source JAX API is that the user has to pass
xp=jnptoPointSolver.for_gridAND passuse_jax=TruetoAnalysisPoint. Other Analysis classes only requireuse_jax=True. Example of the clunk today:Fix the API so passing
use_jax=TruetoAnalysisPointis sufficient — the solver is auto-configured. No user import ofjnprequired.