Overview
Port autogalaxy/profiles/mass/abstract/cse.py to support JAX by threading the xp=np parameter through all methods, mirroring the MGE module (mge.py) which is already JAX-ready. Currently the CSE module uses pure NumPy and scipy.linalg.lstsq, blocking JAX JIT compilation for any profile that uses CSE decomposition. This is Phase 2 of the mass profiles refactoring epic (#445).
Plan
- Thread
xp=np parameter through all CSE static and instance methods
- Replace hardcoded
np.sqrt, np.vstack, np.logspace, np.zeros, np.log10 with xp.*
- Add
xp is not np branch in _decompose_convergence_via_cse_from using jnp.linalg.lstsq (keep scipy.linalg.lstsq for NumPy path)
- Thread
xp=xp through all caller profiles that inherit MassProfileCSE (NFW and other dark matter profiles)
- Run existing unit tests to verify no regressions on NumPy path
- Add CSE-based profiles to
autolens_workspace_test/scripts/profiles_jit.py JAX three-step tests
Detailed implementation plan
Affected Repositories
- PyAutoGalaxy (primary)
- autolens_workspace_test (test additions — follow-up)
Work Classification
Library (then workspace test follow-up)
Branch Survey
| Repository |
Current Branch |
Dirty? |
| PyAutoGalaxy |
main |
modified CLAUDE.md only |
Suggested branch: feature/cse-jax-port
Worktree root: ~/Code/PyAutoLabs-wt/cse-jax-port/
Implementation Steps
-
autogalaxy/profiles/mass/abstract/cse.py — the core change:
convergence_cse_1d_from(grid_radii, core_radius) → add xp=np (pure arithmetic, no np calls to replace)
deflections_via_cse_from(...) → add xp=np, replace np.sqrt→xp.sqrt, np.vstack→xp.vstack
_convergence_2d_via_cse_from(grid_radii, **kwargs) → thread xp to convergence_cse_1d_from
_deflections_2d_via_cse_from(grid, **kwargs) → thread xp to deflections_via_cse_from, replace np.* with xp.* for grid operations
_decompose_convergence_via_cse_from(func, ...) → add xp is not np branch with jnp.linalg.lstsq, keep scipy for NumPy path, replace np.logspace/zeros/log10 with xp.*
-
Callers in autogalaxy/profiles/mass/dark/:
nfw.py — NFW uses CSE for deflections via _deflections_2d_via_cse_from; thread xp=xp
- Any other dark profiles mixing in
MassProfileCSE — thread xp=xp
-
Unit tests — run pytest test_autogalaxy/profiles/mass/ to verify NumPy path unchanged
-
Workspace test additions (follow-up PR on autolens_workspace_test):
- Add NFW CSE path to
scripts/profiles_jit.py JAX three-step pattern
- Run Phase 1 self-consistency tests to confirm no regressions
Key Files
autogalaxy/profiles/mass/abstract/cse.py — CSE mixin (6 methods to port)
autogalaxy/profiles/mass/dark/nfw.py — primary CSE caller
autogalaxy/profiles/mass/dark/abstract.py — DarkProfile base
test_autogalaxy/profiles/mass/dark/test_nfw.py — existing NFW tests
Key Constraint
The CSE decomposition (_decompose_convergence_via_cse_from) is a one-time setup computation, not part of the JIT-traced forward pass. It must NOT be called inside jax.jit. The forward methods that consume cached decomposition results ARE traced and must be pure xp code.
Original Prompt
Click to expand starting prompt
Port the CSE (Cored Steep Ellipsoid) module in PyAutoGalaxy to support JAX.
Make autogalaxy/profiles/mass/abstract/cse.py JAX-compatible by threading the xp=np parameter through all methods, mirroring how the MGE module (mge.py) already supports both NumPy and JAX backends. Replace np.* calls with xp.*, add jnp.linalg.lstsq branch for the decomposition solver, and thread xp=xp through all callers (NFW and other dark matter profiles).
Overview
Port
autogalaxy/profiles/mass/abstract/cse.pyto support JAX by threading thexp=npparameter through all methods, mirroring the MGE module (mge.py) which is already JAX-ready. Currently the CSE module uses pure NumPy andscipy.linalg.lstsq, blocking JAX JIT compilation for any profile that uses CSE decomposition. This is Phase 2 of the mass profiles refactoring epic (#445).Plan
xp=npparameter through all CSE static and instance methodsnp.sqrt,np.vstack,np.logspace,np.zeros,np.log10withxp.*xp is not npbranch in_decompose_convergence_via_cse_fromusingjnp.linalg.lstsq(keepscipy.linalg.lstsqfor NumPy path)xp=xpthrough all caller profiles that inheritMassProfileCSE(NFW and other dark matter profiles)autolens_workspace_test/scripts/profiles_jit.pyJAX three-step testsDetailed implementation plan
Affected Repositories
Work Classification
Library (then workspace test follow-up)
Branch Survey
Suggested branch:
feature/cse-jax-portWorktree root:
~/Code/PyAutoLabs-wt/cse-jax-port/Implementation Steps
autogalaxy/profiles/mass/abstract/cse.py— the core change:convergence_cse_1d_from(grid_radii, core_radius)→ addxp=np(pure arithmetic, no np calls to replace)deflections_via_cse_from(...)→ addxp=np, replacenp.sqrt→xp.sqrt,np.vstack→xp.vstack_convergence_2d_via_cse_from(grid_radii, **kwargs)→ threadxptoconvergence_cse_1d_from_deflections_2d_via_cse_from(grid, **kwargs)→ threadxptodeflections_via_cse_from, replacenp.*withxp.*for grid operations_decompose_convergence_via_cse_from(func, ...)→ addxp is not npbranch withjnp.linalg.lstsq, keep scipy for NumPy path, replacenp.logspace/zeros/log10withxp.*Callers in
autogalaxy/profiles/mass/dark/:nfw.py— NFW uses CSE for deflections via_deflections_2d_via_cse_from; threadxp=xpMassProfileCSE— threadxp=xpUnit tests — run
pytest test_autogalaxy/profiles/mass/to verify NumPy path unchangedWorkspace test additions (follow-up PR on autolens_workspace_test):
scripts/profiles_jit.pyJAX three-step patternKey Files
autogalaxy/profiles/mass/abstract/cse.py— CSE mixin (6 methods to port)autogalaxy/profiles/mass/dark/nfw.py— primary CSE callerautogalaxy/profiles/mass/dark/abstract.py— DarkProfile basetest_autogalaxy/profiles/mass/dark/test_nfw.py— existing NFW testsKey Constraint
The CSE decomposition (
_decompose_convergence_via_cse_from) is a one-time setup computation, not part of the JIT-traced forward pass. It must NOT be called insidejax.jit. The forward methods that consume cached decomposition results ARE traced and must be purexpcode.Original Prompt
Click to expand starting prompt
Port the CSE (Cored Steep Ellipsoid) module in PyAutoGalaxy to support JAX.
Make
autogalaxy/profiles/mass/abstract/cse.pyJAX-compatible by threading thexp=npparameter through all methods, mirroring how the MGE module (mge.py) already supports both NumPy and JAX backends. Replacenp.*calls withxp.*, addjnp.linalg.lstsqbranch for the decomposition solver, and threadxp=xpthrough all callers (NFW and other dark matter profiles).