In #56, the version number of jax was pinned to avoid some errors coming from distrax. We should remove that pinning once distrax works with the latest jax. Alternatively, we could remove the use of distrax