Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.numpy.unique: "Abstract tracer value encountered where concrete value is expected." error #89

Closed
brandonwillard opened this issue Oct 7, 2020 · 3 comments · Fixed by #90
Labels
bug Something isn't working important JAX Involves JAX transpilation

Comments

@brandonwillard
Copy link
Member

The most recent jax/jaxlib update (i.e. 0.2.1) has removed "symbolic" inputs support for jax.numpy.unique. We need to—at the very least—fix the tests for the corresponding Op.

@brandonwillard brandonwillard added bug Something isn't working JAX Involves JAX transpilation important labels Oct 7, 2020
@brandonwillard
Copy link
Member Author

brandonwillard commented Oct 7, 2020

This is essentially the same problem as #68 and #43. It seems like removing these capabilities is an ongoing jax process, so we probably need to start considering a different approach.

Correction: these exact changes are just throwing the underlying errors sooner than later, which is a good thing, but the changes related to #43 involve potential losses in functionality/capabilities. It's not clear whether or not this case actually involves the latter, though.

@brandonwillard
Copy link
Member Author

brandonwillard commented Oct 7, 2020

Looks like the relevant change was this simple guard.

@brandonwillard
Copy link
Member Author

OK, the underlying issue I'm seeing here can be demonstrated with the following MWE:

import numpy as np
import jax

from functools import partial


a = np.arange(6).reshape((3, 2))

def unique(x):
    return jax.numpy.unique(x)


jax.jit(partial(unique, a))()
---------------------------------------------------------------------------
FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-1-1016b14aa37a> in <module>
     12 
---> 13 jax.jit(partial(unique, a))()

<ipython-input-1-1016b14aa37a> in unique(x)
      9 def unique(x):
---> 10     return jax.numpy.unique(x)
     11 

~/apps/anaconda3/envs/theano-3.7/lib/python3.7/site-packages/jax/numpy/lax_numpy.py in unique(ar, return_index, return_inverse, return_counts, axis)
   3863   if axis is None:
-> 3864     ret = _unique1d(ar, return_index, return_inverse, return_counts)
   3865     if len(ret) == 1:

~/apps/anaconda3/envs/theano-3.7/lib/python3.7/site-packages/jax/numpy/lax_numpy.py in _unique1d(ar, return_index, return_inverse, return_counts)
   3840 
-> 3841   ret = (aux[mask],)
   3842   if return_index:

~/apps/anaconda3/envs/theano-3.7/lib/python3.7/site-packages/jax/numpy/lax_numpy.py in _rewriting_take(arr, idx)
   3879   arr = asarray(arr)
-> 3880   treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
   3881   return _gather(arr, treedef, static_idx, dynamic_idx)

~/apps/anaconda3/envs/theano-3.7/lib/python3.7/site-packages/jax/numpy/lax_numpy.py in _split_index_for_jit(idx)
   3939   # indexing logic to handle them.
-> 3940   idx = _expand_bool_indices(idx)
   3941 

~/apps/anaconda3/envs/theano-3.7/lib/python3.7/site-packages/jax/numpy/lax_numpy.py in _expand_bool_indices(idx)
   4198         # concrete
-> 4199         raise IndexError("Array boolean indices must be concrete.")
   4200       else:

FilteredStackTrace: IndexError: Array boolean indices must be concrete.

The input is entirely concrete, even inside jax.numpy.unique, so this is very confusing.

Furthermore, we can be direct and reproduce the error with just jax.jit(partial(jax.numpy.unique, a))(). The only way I can get it to work is from a non-jitted call like jax.numpy.unique(a), which would imply that jax.numpy.unique is about as useful as numpy.unique itself.

Also, as I understand it, static_argnums won't help here, because there simply are no arguments to the JITed function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working important JAX Involves JAX transpilation
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant