Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot run demo, possible incompatibility with latest Jax #12

Closed
daniel-trejobanos opened this issue Apr 20, 2022 · 4 comments
Closed

Comments

@daniel-trejobanos
Copy link

Dear all,

I am trying to run the demo examples, but I run in the following error


ImportError Traceback (most recent call last)
Input In [22], in <cell line: 1>()
----> 1 import bayesnewton
2 import objax
3 import numpy as np

File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/init.py:1, in
----> 1 from . import (
2 kernels,
3 utils,
4 ops,
5 likelihoods,
6 models,
7 basemodels,
8 inference,
9 cubature
10 )
13 def build_model(model, inf, name='GPModel'):
14 return type(name, (inf, model), {})

File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/kernels.py:5, in
3 import jax.numpy as np
4 from jax.scipy.linalg import cho_factor, cho_solve, block_diag, expm
----> 5 from jax.ops import index_add, index
6 from .utils import scaled_squared_euclid_dist, softplus, softplus_inv, rotation_matrix
7 from warnings import warn

ImportError: cannot import name 'index_add' from 'jax.ops' (/Users/Daniel/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/ops/init.py)

I think its related to this from the Jax website:

The functions jax.ops.index_update, jax.ops.index_add, etc., which were deprecated in JAX 0.2.22, have been removed. Please use the jax.numpy.ndarray.at property on JAX arrays instead.

@daniel-trejobanos
Copy link
Author

I now realise that your pip installation asks for a specific jax version, which is a bit problematic for me, given that I am running on a M1 and installed jax via condaforge, I am not sure I can match to a compatible version, I will try and let you know if I succeed.

@daniel-trejobanos
Copy link
Author

I managed to downgrade jax, but there is no jaxlib 0.1.60 available in condaforge, seems like it could be the source of this bug I get when trying to load objax 1.31:


TypeError Traceback (most recent call last)
Input In [22], in <cell line: 1>()
----> 1 import bayesnewton
2 import objax
3 import numpy as np

File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/init.py:1, in
----> 1 from . import (
2 kernels,
3 utils,
4 ops,
5 likelihoods,
6 models,
7 basemodels,
8 inference,
9 cubature
10 )
13 def build_model(model, inf, name='GPModel'):
14 return type(name, (inf, model), {})

File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/kernels.py:1, in
----> 1 import objax
2 from jax import vmap
3 import jax.numpy as np

File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/objax/init.py:17, in
1 # Copyright 2020 Google LLC
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
15 import sys
---> 17 from ._patch_jax import *
19 pass # To avoid reordering imports from above
21 from . import functional

File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/objax/_patch_jax.py:20, in
16 all = []
18 from typing import Union, Sequence, Tuple, Callable, Optional
---> 20 import jax.numpy as jn
22 from .typing import JaxArray
23 from .util import re_sign

File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/init.py:93, in
89 from .version import version
91 # These submodules are separate because they are in an import cycle with
92 # jax and rely on the names imported above.
---> 93 from . import image
94 from . import lax
95 from . import nn

File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/image/init.py:18, in
15 """Common functions for neural network libraries."""
17 # flake8: noqa: F401
---> 18 from jax._src.image.scale import (
19 resize,
20 ResizeMethod,
21 scale_and_translate,
22 )

File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/_src/image/scale.py:20, in
17 from typing import Callable, Sequence, Union
19 from jax import jit
---> 20 from jax import lax
21 from jax import numpy as jnp
22 import numpy as np

File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/lax/init.py:324, in
291 from jax._src.lax.lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
292 _reduce_and, _reduce_window_sum, _reduce_window_max,
293 _reduce_window_min, _reduce_window_prod,
(...)
298 _upcast_fp16_for_computation, _broadcasting_shape_rule,
299 _eye, _tri, _delta, _ones, _zeros, _dilate_shape)
300 from jax._src.lax.control_flow import (
301 associative_scan,
302 cond,
(...)
322 while_p,
323 )
--> 324 from jax._src.lax.fft import (
325 fft,
326 fft_p,
327 )
328 from jax._src.lax.parallel import (
329 all_gather,
330 all_to_all,
(...)
346 xeinsum,
347 )
348 from jax._src.lax.other import (
349 conv_general_dilated_patches
350 )

File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/_src/lax/fft.py:87, in
83 n = fft_lengths[-1]
84 return y[..., : n//2 + 1]
86 @partial(jit, static_argnums=1)
---> 87 def _rfft_transpose(t, fft_lengths):
88 # The transpose of RFFT can't be expressed only in terms of irfft. Instead of
89 # manually building up larger twiddle matrices (which would increase the
90 # asymptotic complexity and is also rather complicated), we rely JAX to
91 # transpose a naive RFFT implementation.
92 dummy_shape = t.shape[:-len(fft_lengths)] + fft_lengths
93 dummy_primal = ShapeDtypeStruct(dummy_shape, _real_dtype(t.dtype))

File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/api.py:184, in jit(fun, static_argnums, device, backend, donate_argnums)
129 """Sets up fun for just-in-time compilation with XLA.
130
131 Args:
(...)
181 -0.85743 -0.78232 0.76827 0.59566 ]
182 """
183 if FLAGS.experimental_cpp_jit and config.omnistaging_enabled:
--> 184 return _cpp_jit(fun, static_argnums, device, backend, donate_argnums)
185 else:
186 return _python_jit(fun, static_argnums, device, backend, donate_argnums)

File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/api.py:370, in cpp_jit(fun, static_argnums, device, backend, donate_argnums)
367 return config.read("jax_disable_jit")
369 static_argnums
= (0,) + tuple(i + 1 for i in static_argnums)
--> 370 cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
371 get_jax_enable_x64, get_jax_disable_jit_flag,
372 static_argnums_)
374 # TODO(mattjj): make cpp callable follow descriptor protocol for bound methods
375 @wraps(fun)
376 @api_boundary
377 def f_jitted(*args, **kwargs):

TypeError: jit(): incompatible function arguments. The following argument types are supported:
1. (fun: function, cache_miss: function, get_device: function, static_argnums: List[int], static_argnames: List[str] = [], donate_argnums: List[int] = [], cache: jaxlib.xla_extension.CompiledFunctionCache = None) -> object

Invoked with: <function _rfft_transpose at 0x7f93913f20d0>, <function _cpp_jit..cache_miss at 0x7f93913f2160>, <function _cpp_jit..get_device_info at 0x7f93913f21f0>, <function _cpp_jit..get_jax_enable_x64 at 0x7f93913f2280>, <function _cpp_jit..get_jax_disable_jit_flag at 0x7f93913f2310>, (0, 2)

@wil-j-wil
Copy link
Collaborator

Hi,

Sorry for the slow response and apologies that you've been having issues with the package versions. This is indeed frustrating. I would love to update the package to work with the most recent versions, but I don't currently have the spare time.

I am using an M1 mac and things are working OK for me, but I'm not using condaforge.

The index_update issue should be fairly easy to fix. However, I recall seeing some performance issues when I tried updating objax in the past, and I never managed to debug the issue. I hope to get around to this at some point in the future.

@asolin
Copy link
Contributor

asolin commented Feb 17, 2023

This issue should be fixed here: #18 (updated to current release of jax)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants