New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cannot run demo, possible incompatibility with latest Jax #12
Comments
I now realise that your pip installation asks for a specific jax version, which is a bit problematic for me, given that I am running on a M1 and installed jax via condaforge, I am not sure I can match to a compatible version, I will try and let you know if I succeed. |
I managed to downgrade jax, but there is no jaxlib 0.1.60 available in condaforge, seems like it could be the source of this bug I get when trying to load objax 1.31: TypeError Traceback (most recent call last) File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/init.py:1, in File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/kernels.py:1, in File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/objax/init.py:17, in File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/objax/_patch_jax.py:20, in File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/init.py:93, in File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/image/init.py:18, in File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/_src/image/scale.py:20, in File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/lax/init.py:324, in File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/_src/lax/fft.py:87, in File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/api.py:184, in jit(fun, static_argnums, device, backend, donate_argnums) File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/api.py:370, in cpp_jit(fun, static_argnums, device, backend, donate_argnums) TypeError: jit(): incompatible function arguments. The following argument types are supported: Invoked with: <function _rfft_transpose at 0x7f93913f20d0>, <function _cpp_jit..cache_miss at 0x7f93913f2160>, <function _cpp_jit..get_device_info at 0x7f93913f21f0>, <function _cpp_jit..get_jax_enable_x64 at 0x7f93913f2280>, <function _cpp_jit..get_jax_disable_jit_flag at 0x7f93913f2310>, (0, 2) |
Hi, Sorry for the slow response and apologies that you've been having issues with the package versions. This is indeed frustrating. I would love to update the package to work with the most recent versions, but I don't currently have the spare time. I am using an M1 mac and things are working OK for me, but I'm not using condaforge. The |
This issue should be fixed here: #18 (updated to current release of jax) |
Dear all,
I am trying to run the demo examples, but I run in the following error
ImportError Traceback (most recent call last)
Input In [22], in <cell line: 1>()
----> 1 import bayesnewton
2 import objax
3 import numpy as np
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/init.py:1, in
----> 1 from . import (
2 kernels,
3 utils,
4 ops,
5 likelihoods,
6 models,
7 basemodels,
8 inference,
9 cubature
10 )
13 def build_model(model, inf, name='GPModel'):
14 return type(name, (inf, model), {})
File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/kernels.py:5, in
3 import jax.numpy as np
4 from jax.scipy.linalg import cho_factor, cho_solve, block_diag, expm
----> 5 from jax.ops import index_add, index
6 from .utils import scaled_squared_euclid_dist, softplus, softplus_inv, rotation_matrix
7 from warnings import warn
ImportError: cannot import name 'index_add' from 'jax.ops' (/Users/Daniel/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/ops/init.py)
I think its related to this from the Jax website:
The text was updated successfully, but these errors were encountered: