Skip to content

Feature/deflections operate jax#285

Merged
Jammy2211 merged 9 commits intomainfrom
feature/deflections_operate_jax
Mar 4, 2026
Merged

Feature/deflections operate jax#285
Jammy2211 merged 9 commits intomainfrom
feature/deflections_operate_jax

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

This pull request refactors the lensing Jacobian and Hessian computation logic in autogalaxy/operate/deflections.py to make Hessian-based calculations the default, unify the API for NumPy and JAX backends, and remove legacy Jacobian-based methods. It also updates documentation to clarify JAX usage and workspace script style. The changes improve flexibility, performance, and code clarity, especially for JAX integration.

Lensing Jacobian & Hessian Refactor

  • Replaced all Jacobian-based calculations with Hessian-based equivalents, making Hessian the default for convergence, shear, and magnification computations. Removed legacy Jacobian methods and the precompute_jacobian decorator. (autogalaxy/operate/deflections.py) [1] [2] [3] [4]
  • Unified the API for Hessian and Jacobian methods to support both NumPy and JAX via the xp parameter. JAX is now only imported locally when requested, and Hessian computation uses either finite-difference (NumPy) or auto-differentiation (jax.jacfwd) for JAX. (autogalaxy/operate/deflections.py) [1] [2] [3] [4]

JAX Integration & Documentation

  • Updated documentation in CLAUDE.md to clarify that NumPy is the default backend and JAX is opt-in, imported only within functions when needed. Provided clear instructions for adding JAX support to new functions and described the testing strategy for JAX vs. NumPy paths. (CLAUDE.md)
  • Added documentation about workspace script style, emphasizing the use of docstring blocks for commentary and section headers. (CLAUDE.md)

Miscellaneous

  • Updated shear field imports to remove unused legacy imports. (autogalaxy/operate/deflections.py)
  • Simplified test code by removing unnecessary buffer parameter for Hessian-based convergence tests. (test_autogalaxy/operate/test_deflections.py)

Jammy2211 and others added 2 commits March 2, 2026 11:47
Replaces the finite-difference hessian_from with a dual-path implementation:
- xp=np (default): delegates to _hessian_via_finite_difference, no JAX import
- xp=jnp: delegates to _hessian_via_jax which uses jax.jacfwd on a new
  deflections_yx_scalar helper, supporting both Grid2D and Grid2DIrregular

Also fixes shear_yx_2d_via_hessian_from to use grid.shape[0] instead of
grid.shape_slim (incompatible with Grid2DIrregular), and makes
magnification_2d_via_hessian_from return a raw jax.Array when xp=jnp
to avoid wrapping a traced value in ArrayIrregular.

Updates CLAUDE.md to document the NumPy-default / JAX-opt-in design pattern.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Remove precompute_jacobian decorator, jacobian_from (np.gradient), convergence_2d_via_jacobian_from,
and shear_yx_2d_via_jacobian_from. All were redundant with the hessian path since A = I - H.

Rewire tangential_eigen_value_from, radial_eigen_value_from, and magnification_2d_from
to call hessian_from directly.

Restore jacobian_from as a thin public wrapper over hessian_from that returns
[[1-hxx, -hxy], [-hyx, 1-hyy]], supporting both xp=np and xp=jnp.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request refactors OperateDeflections to make Hessian-based derivatives the primary path for lensing quantities, introduces a JAX auto-diff Hessian path selected via an xp backend parameter, and removes legacy Jacobian-precompute plumbing. It also updates repository guidance around JAX usage and workspace scripting style, and adjusts tests to match the new API.

Changes:

  • Replaced Jacobian-based convergence/shear computations with Hessian-based equivalents; removed precompute_jacobian and legacy Jacobian-via-gradient methods.
  • Added a JAX Hessian computation path using jax.jacfwd, selected via xp, and introduced jacobian_from computed as A = I - H.
  • Updated docs (CLAUDE.md) about the NumPy-default / JAX-opt-in xp pattern and workspace script prose style; updated/trimmed deflections tests accordingly.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 7 comments.

File Description
autogalaxy/operate/deflections.py Refactors Hessian/Jacobian computation, adds JAX Hessian path via xp, removes legacy Jacobian utilities.
test_autogalaxy/operate/test_deflections.py Updates tests to reflect Hessian-first API and revised Jacobian behavior.
CLAUDE.md Clarifies JAX opt-in conventions (xp pattern) and workspace script commenting style.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 389 to 391
hessian_yy, hessian_xy, hessian_yx, hessian_xx = self.hessian_from(grid=grid)

return aa.ArrayIrregular(values=0.5 * (hessian_yy + hessian_xx))
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convergence_2d_via_hessian_from no longer accepts an xp parameter and always calls hessian_from with its default xp=np. This prevents convergence from being computed via the JAX Hessian path, which is inconsistent with hessian_from(..., xp=...) / magnification_2d_via_hessian_from(..., xp=...). Consider adding xp=np here and passing it through to hessian_from.

Copilot uses AI. Check for mistakes.
Comment on lines 421 to 427
hessian_yy, hessian_xy, hessian_yx, hessian_xx = self.hessian_from(grid=grid)

gamma_1 = 0.5 * (hessian_xx - hessian_yy)
gamma_2 = hessian_xy

shear_yx_2d = np.zeros(shape=(grid.shape_slim, 2))
shear_yx_2d = np.zeros(shape=(grid.shape[0], 2))

Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shear_yx_2d_via_hessian_from always uses NumPy (np.zeros) and calls hessian_from without passing xp, so it cannot use the JAX derivative path even if the caller is operating with xp=jnp. If JAX support is intended here, add an xp=np parameter, pass it through to hessian_from, and allocate the output with xp.zeros (or otherwise ensure JAX-compatible arrays).

Copilot uses AI. Check for mistakes.
Comment on lines 393 to 396
def shear_yx_2d_via_hessian_from(
self, grid, buffer: float = 0.01
self, grid
) -> ShearYX2DIrregular:
"""
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring for shear_yx_2d_via_hessian_from describes returning a ShearYX2D structure, but the function is annotated and implemented to return ShearYX2DIrregular. Update the docstring (and/or return type) so the documented API matches the implementation.

Copilot uses AI. Check for mistakes.
Comment on lines +461 to +462
if xp is not np:
return xp.array(1.0 / det_A)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

magnification_2d_via_hessian_from returns different types depending on xp: an aa.ArrayIrregular for NumPy but a raw xp.array for JAX. This makes the method’s return type inconsistent and contradicts the -> aa.ArrayIrregular annotation. Either always wrap in the same AutoArray structure (with JAX values inside), or update the annotation / docstring to reflect a union return type and ensure downstream callers handle both.

Suggested change
if xp is not np:
return xp.array(1.0 / det_A)

Copilot uses AI. Check for mistakes.
Comment on lines 172 to 187
@@ -193,105 +180,153 @@ def tangential_eigen_value_from(self, grid, jacobian=None) -> aa.Array2D:
grid
The 2D grid of (y,x) arc-second coordinates the deflection angles and tangential eigen values are computed
on.
jacobian
A precomputed lensing jacobian, which is passed throughout the `CalcLens` functions for efficiency.
"""
convergence = self.convergence_2d_via_jacobian_from(
grid=grid, jacobian=jacobian
)

shear_yx = self.shear_yx_2d_via_jacobian_from(grid=grid, jacobian=jacobian)
convergence = self.convergence_2d_via_hessian_from(grid=grid)
shear_yx = self.shear_yx_2d_via_hessian_from(grid=grid)

return aa.Array2D(values=1 - convergence - shear_yx.magnitudes, mask=grid.mask)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These eigenvalue helpers no longer accept xp, so they always go through the NumPy path (via convergence_2d_via_hessian_from / shear_yx_2d_via_hessian_from). If JAX support via the xp pattern is intended for eigenvalue / critical-curve calculations, consider adding xp=np here and passing it through to the underlying convergence/shear (and ultimately hessian_from).

Copilot uses AI. Check for mistakes.
Comment on lines 189 to 203
def radial_eigen_value_from(self, grid) -> aa.Array2D:
"""
Returns the radial eigen values of lensing jacobian, which are given by the expression:

radial_eigen_value = 1 - convergence + shear
`radial_eigen_value = 1 - convergence + shear`

Parameters
----------
grid
The 2D grid of (y,x) arc-second coordinates the deflection angles and radial eigen values are computed on.
jacobian
A precomputed lensing jacobian, which is passed throughout the `CalcLens` functions for efficiency.
"""
convergence = self.convergence_2d_via_jacobian_from(
grid=grid, jacobian=jacobian
)

shear = self.shear_yx_2d_via_jacobian_from(grid=grid, jacobian=jacobian)
convergence = self.convergence_2d_via_hessian_from(grid=grid)
shear = self.shear_yx_2d_via_hessian_from(grid=grid)

return aa.Array2D(values=1 - convergence + shear.magnitudes, mask=grid.mask)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

radial_eigen_value_from also lacks an xp parameter and therefore can’t take advantage of the JAX Hessian path. If the intention of this refactor is that Hessian-based calculations are JAX-acceleratable via xp, it would be more consistent to thread xp through here as well.

Copilot uses AI. Check for mistakes.
Comment on lines +291 to +298
def _hessian_single(y_scalar, x_scalar):
return jnp.stack(
jax.jacfwd(self.deflections_yx_scalar, argnums=(0, 1))(
y_scalar, x_scalar, pixel_scales
)
)

h = jnp.vectorize(_hessian_single, signature="(),()->(i,i)")(y, x)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The JAX Hessian path uses jnp.vectorize over per-point jax.jacfwd calls. In JAX this is typically a Python-level loop and can be very slow / inhibit jit performance. Prefer a jax.vmap-based implementation (e.g., vmap a jacfwd’d scalar function, or jacfwd a batched function) to get compiled vectorization.

Suggested change
def _hessian_single(y_scalar, x_scalar):
return jnp.stack(
jax.jacfwd(self.deflections_yx_scalar, argnums=(0, 1))(
y_scalar, x_scalar, pixel_scales
)
)
h = jnp.vectorize(_hessian_single, signature="(),()->(i,i)")(y, x)
jac_fn = jax.jacfwd(self.deflections_yx_scalar, argnums=(0, 1))
def _hessian_single(y_scalar, x_scalar):
return jnp.stack(jac_fn(y_scalar, x_scalar, pixel_scales))
# Use jax.vmap for efficient batched evaluation over the grid
h = jax.vmap(_hessian_single, in_axes=(0, 0))(y, x)

Copilot uses AI. Check for mistakes.
Jammy2211 and others added 7 commits March 2, 2026 14:46
…_2d hessian methods

Thread xp=np through convergence_2d_via_hessian_from, shear_yx_2d_via_hessian_from,
tangential_eigen_value_from, radial_eigen_value_from, and magnification_2d_from so all
hessian-derived quantities support the JAX path consistently.

For each method, xp is passed through to hessian_from; when xp is not numpy the result
is returned as a raw jax.Array rather than an autoarray wrapper (which cannot be
constructed during a jax.jit trace).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…o LensCalc

- Rename operate/deflections.py -> operate/lens_calc.py
- Rename class OperateDeflections -> LensCalc throughout
- LensCalc.__init__ now accepts optional potential_2d_from callable
- from_mass_obj and from_tracer capture potential_2d_from automatically
- fermat_potential_from moved into LensCalc (removed from MassProfile, Galaxy, Galaxies)
- Update all imports, call sites, tests, and docs in autogalaxy

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@Jammy2211 Jammy2211 merged commit d348d13 into main Mar 4, 2026
8 checks passed
@Jammy2211 Jammy2211 deleted the feature/deflections_operate_jax branch April 2, 2026 11:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants