-
-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.numpy.unique: "Abstract tracer value encountered where concrete value is expected." error #89
Comments
This is essentially the same problem as #68 and #43. It seems like removing these capabilities is an ongoing Correction: these exact changes are just throwing the underlying errors sooner than later, which is a good thing, but the changes related to #43 involve potential losses in functionality/capabilities. It's not clear whether or not this case actually involves the latter, though. |
Looks like the relevant change was this simple guard. |
OK, the underlying issue I'm seeing here can be demonstrated with the following MWE: import numpy as np
import jax
from functools import partial
a = np.arange(6).reshape((3, 2))
def unique(x):
return jax.numpy.unique(x)
jax.jit(partial(unique, a))() ---------------------------------------------------------------------------
FilteredStackTrace Traceback (most recent call last)
<ipython-input-1-1016b14aa37a> in <module>
12
---> 13 jax.jit(partial(unique, a))()
<ipython-input-1-1016b14aa37a> in unique(x)
9 def unique(x):
---> 10 return jax.numpy.unique(x)
11
~/apps/anaconda3/envs/theano-3.7/lib/python3.7/site-packages/jax/numpy/lax_numpy.py in unique(ar, return_index, return_inverse, return_counts, axis)
3863 if axis is None:
-> 3864 ret = _unique1d(ar, return_index, return_inverse, return_counts)
3865 if len(ret) == 1:
~/apps/anaconda3/envs/theano-3.7/lib/python3.7/site-packages/jax/numpy/lax_numpy.py in _unique1d(ar, return_index, return_inverse, return_counts)
3840
-> 3841 ret = (aux[mask],)
3842 if return_index:
~/apps/anaconda3/envs/theano-3.7/lib/python3.7/site-packages/jax/numpy/lax_numpy.py in _rewriting_take(arr, idx)
3879 arr = asarray(arr)
-> 3880 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
3881 return _gather(arr, treedef, static_idx, dynamic_idx)
~/apps/anaconda3/envs/theano-3.7/lib/python3.7/site-packages/jax/numpy/lax_numpy.py in _split_index_for_jit(idx)
3939 # indexing logic to handle them.
-> 3940 idx = _expand_bool_indices(idx)
3941
~/apps/anaconda3/envs/theano-3.7/lib/python3.7/site-packages/jax/numpy/lax_numpy.py in _expand_bool_indices(idx)
4198 # concrete
-> 4199 raise IndexError("Array boolean indices must be concrete.")
4200 else:
FilteredStackTrace: IndexError: Array boolean indices must be concrete. The input is entirely concrete, even inside Furthermore, we can be direct and reproduce the error with just Also, as I understand it, |
The most recent
jax
/jaxlib
update (i.e. 0.2.1) has removed "symbolic" inputs support forjax.numpy.unique
. We need to—at the very least—fix the tests for the correspondingOp
.The text was updated successfully, but these errors were encountered: