Skip to content

Fix JAX array leak from PointSolver into Grid2DIrregular#410

Merged
Jammy2211 merged 1 commit intomainfrom
feature/fix-jax-array-leak-point-solver
Apr 1, 2026
Merged

Fix JAX array leak from PointSolver into Grid2DIrregular#410
Jammy2211 merged 1 commit intomainfrom
feature/fix-jax-array-leak-point-solver

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

  • When PointSolver.solve() runs with a JAX backend, the boolean-indexed solution is a JAX DeviceArray. Wrapping it directly in aa.Grid2DIrregular produced a JAX-backed object that propagated through result.py into PositionsLH.positions and eventually into the visualizer.
  • plotter.py calls np.array(positions.array) which raised ValueError: object __array__ method not producing an array on certain JAX versions.
  • Fix: add np.asarray() before the final Grid2DIrregular construction in PointSolver.solve(). This is safe — solve() is never called inside jax.jit (variable-length boolean indexing prevents it), and all downstream consumers expect numpy-backed coordinates.
  • Also updates the Returns docstring to document the numpy-backed guarantee.

Test plan

  • All 254 existing tests pass (python -m pytest test_autolens/)
  • np.asarray() on a numpy array is a zero-copy no-op — no regression on the numpy path

🤖 Generated with Claude Code

When the solver uses a JAX backend, the final boolean-indexed solution
was a JAX DeviceArray. Wrapping it in Grid2DIrregular without conversion
caused downstream np.array() calls in the visualizer to raise:
  ValueError: object __array__ method not producing an array

Fix: convert solution to numpy via np.asarray() before constructing the
return Grid2DIrregular. Safe because solve() is never called inside
jax.jit (variable-length boolean indexing prevents it).

Also updates the Returns docstring to document the numpy-backed guarantee.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@Jammy2211 Jammy2211 merged commit 007f957 into main Apr 1, 2026
8 checks passed
@Jammy2211 Jammy2211 deleted the feature/fix-jax-array-leak-point-solver branch April 2, 2026 11:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant