Skip to content

Conversation

@willGraham01
Copy link
Collaborator

@willGraham01 willGraham01 commented Sep 25, 2025

Resolves #44 |

Switches numpy to jax.numpy everywhere in the codebase. We keep numpy.typing.ArrayLike as a typehint because it's still applicable to jax arrays anyway (and jax doesn't have a good typehinting library of it's own).

This was largely a like-for-like switch of np -> jnp everywhere, save for in test_parameter_node.py where I think we had both a bug and something incompatible with jnp.allclose.

  • Bug: Previously we were doing allclose( node.sample(...)[0], [0.3]*10), however this only compares the first element of our generated samples to 10 values in the generated list. The reaplcement removes the 0-index fetch from the samples, effectively now doing allclose( node.sample(...), [0.3]*10).
  • Incompatible: jnp.allclose can't handle when one of the containers is a list, so swapped out [0.3]*10 for jnp.full((10,), 0.3) which builds an identically shaped and filled array.

There's also the usual ruff formatting that it wants to perform now that imports have changed.

@willGraham01 willGraham01 merged commit 16afa77 into main Sep 29, 2025
5 checks passed
@willGraham01 willGraham01 deleted the wgraham/jax-numpy-everywhere branch September 29, 2025 07:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

numpy to jax.numpy everywhere

3 participants