Fix JAX jit boundary in LensCalc + document decorator/JAX patterns#291
Fix JAX jit boundary in LensCalc + document decorator/JAX patterns#291
Conversation
All six hessian-derived LensCalc methods now guard autoarray wrapping with `if xp is np:` so they return a raw jax.Array on the JAX path. This allows them to be called directly inside jax.jit without the TypeError that occurred when an ArrayIrregular or Array2D was returned as the JIT output. Methods fixed: - convergence_2d_via_hessian_from - shear_yx_2d_via_hessian_from - magnification_2d_via_hessian_from - magnification_2d_from - tangential_eigen_value_from (also adds jnp.sqrt for shear magnitudes) - radial_eigen_value_from (same) CLAUDE.md updated with decorator system overview and the if-xp-is-np guard pattern for functions at the jax.jit boundary. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Fixes JAX jax.jit boundary failures in LensCalc by avoiding returning non-pytree autoarray wrapper types on the JAX backend, and documents the project’s decorator/JAX patterns for future contributors.
Changes:
- Guard six Hessian-derived
LensCalcmethods so NumPy returns autoarray wrappers while JAX returns raw arrays. - Add guidance to
CLAUDE.mdon autoarray decorators,xpbackend usage, and theif xp is np:guard pattern.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 8 comments.
| File | Description |
|---|---|
| autogalaxy/operate/lens_calc.py | Adds if xp is np: guards to avoid returning autoarray wrapper types from jax.jit-compiled functions. |
| CLAUDE.md | Documents decorator stacking expectations and JAX/autoarray wrapper constraints at the jax.jit boundary. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if xp is np: | ||
| return aa.ArrayIrregular(values=1.0 / det_A) | ||
| return 1.0 / det_A |
There was a problem hiding this comment.
magnification_2d_via_hessian_from is annotated as returning aa.ArrayIrregular, but returns a raw array when xp is not np. Update the return type annotation (or add overloads) to match the conditional wrapper behavior.
| if xp is np: | ||
| return aa.Array2D( | ||
| values=1 - convergence - shear_yx.magnitudes, mask=grid.mask | ||
| ) | ||
| magnitudes = xp.sqrt(shear_yx[:, 0] ** 2 + shear_yx[:, 1] ** 2) | ||
| return 1 - convergence - magnitudes |
There was a problem hiding this comment.
This change is specifically to avoid TypeError at the jax.jit boundary, but there’s no regression test exercising these methods under jax.jit. Since JAX is already available in the test suite (test_autogalaxy/conftest.py imports jax.numpy), consider adding a small unit test that jits a lambda calling each updated method with xp=jnp and asserts it runs/returns a JAX array.
| Prefer simple shell commands. | ||
| Avoid chaining with && or pipes. No newline at end of file |
There was a problem hiding this comment.
The guidance says to avoid chaining with pipes, but the example immediately above uses a pipe (find ... | xargs ...). Either adjust the recommendation (e.g. avoid complex pipelines) or update the example so the instructions are internally consistent.
| Prefer simple shell commands. | |
| Avoid chaining with && or pipes. | |
| Prefer simple, readable shell commands. | |
| Avoid complex command chains with multiple && operators or many pipes; a single straightforward pipe (as above) is fine when it improves clarity. |
| if xp is np: | ||
| return aa.Array2D( | ||
| values=1 - convergence - shear_yx.magnitudes, mask=grid.mask | ||
| ) | ||
| magnitudes = xp.sqrt(shear_yx[:, 0] ** 2 + shear_yx[:, 1] ** 2) | ||
| return 1 - convergence - magnitudes |
There was a problem hiding this comment.
These methods now return a raw JAX array when xp is not np, but the return annotation still claims aa.Array2D. This mismatch can mislead callers and type checkers; consider updating the annotation to a Union[...] (or overloads) that reflects both the NumPy-wrapped and JAX-raw return types.
| if xp is np: | ||
| return aa.Array2D( | ||
| values=1 - convergence + shear.magnitudes, mask=grid.mask | ||
| ) | ||
| magnitudes = xp.sqrt(shear[:, 0] ** 2 + shear[:, 1] ** 2) | ||
| return 1 - convergence + magnitudes |
There was a problem hiding this comment.
Return type annotation is aa.Array2D, but this function now returns a raw array on the JAX path (xp is not np). Update the annotation (e.g., Union[aa.Array2D, <array type>] / overloads) so the public API matches actual behavior.
| if xp is np: | ||
| return aa.Array2D(values=1 / det_A, mask=grid.mask) | ||
| return 1 / det_A |
There was a problem hiding this comment.
magnification_2d_from is annotated to return aa.Array2D, but returns a raw array when xp is not np. Please update the return type annotation (and/or add overloads) to reflect the conditional wrapper behavior.
| if xp is np: | ||
| return aa.ArrayIrregular(values=convergence) | ||
| return convergence |
There was a problem hiding this comment.
convergence_2d_via_hessian_from is annotated as returning aa.ArrayIrregular, but returns a raw array on the JAX path. Adjust the return type annotation (e.g., Union[aa.ArrayIrregular, <array type>] / overloads) so callers can rely on the signature.
| if xp is np: | ||
| return ShearYX2DIrregular(values=shear_yx_2d, grid=grid) | ||
| return shear_yx_2d |
There was a problem hiding this comment.
shear_yx_2d_via_hessian_from is annotated to return ShearYX2DIrregular, but returns a raw (N, 2) array when xp is not np. Please update the return annotation (or add overloads) to reflect the JAX-compatible return type.
Summary
TypeErrorthat occurred when any of the six hessian-derivedLensCalcmethods were called directly insidejax.jitCLAUDE.mdwith decorator system overview and theif xp is np:guard patternProblem
Autoarray types (
ArrayIrregular,ShearYX2DIrregular,Array2D) are not registered as JAX pytrees. The following methods unconditionally wrapped their return values in these types, causingTypeError: ... is not a valid JAX typewhen used as the output of ajax.jit-compiled function:convergence_2d_via_hessian_fromshear_yx_2d_via_hessian_frommagnification_2d_via_hessian_frommagnification_2d_fromtangential_eigen_value_fromradial_eigen_value_fromFix
Each method now guards autoarray wrapping with
if xp is np:, returning the wrapper on the numpy path and a rawjax.Arrayon the JAX path.tangential_eigen_value_fromandradial_eigen_value_fromalso compute shear magnitudes viaxp.sqrton the JAX path (sinceShearYX2DIrregular.magnitudesusesnp.sqrtdirectly and is not usable on a raw jax array).🤖 Generated with Claude Code