-
Notifications
You must be signed in to change notification settings - Fork 575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix for JAX 0.2.24 breaking changes #1769
Conversation
Hello. You may have forgotten to update the changelog!
|
@@ -385,7 +385,9 @@ def _to_numpy_jax(x): | |||
ar.register_function( | |||
"jax", | |||
"take", | |||
lambda x, indices, axis=None: _i("jax").numpy.take(x, indices, axis=axis, mode="wrap"), | |||
lambda x, indices, axis=None: _i("jax").numpy.take( | |||
x, np.array(indices), axis=axis, mode="wrap" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@josh146 I'm wondering if this is the spot to tackle the failing tests, or explicitly convert the lists to arrays in the tests themselves.
With the change here, would that not be an issue that we deviate from how jnp.take
works starting from 0.2.24
? I.e., we accept list indices, whereas jnp.take
doesn't:
In [2]: from jax import numpy as jnp
In [3]: arr = jnp.array([0,1,2])
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
In [4]: jnp.take(arr, [0])
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-4-3d2d7ab5c969> in <module>
----> 1 jnp.take(arr, [0])
~/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in take(a, indices, axis, out, mode)
5419 @_wraps(np.take, skip_params=['out'])
5420 def take(a, indices, axis: Optional[int] = None, out=None, mode=None):
-> 5421 return _take(a, indices, None if axis is None else operator.index(axis), out,
5422 mode)
5423
[... skipping hidden 13 frame]
~/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _take(a, indices, axis, out, mode)
5426 if out is not None:
5427 raise NotImplementedError("The 'out' argument to jnp.take is not supported.")
-> 5428 _check_arraylike("take", a, indices)
5429 a = asarray(a)
5430 indices = asarray(indices)
~/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _check_arraylike(fun_name, *args)
561 if not _arraylike(arg))
562 msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 563 raise TypeError(msg.format(fun_name, type(arg), pos))
564
565 def _check_no_float0s(fun_name, *args):
TypeError: take requires ndarray or scalar arguments, got <class 'list'> at position 1.
In [5]: jnp.take(arr, jnp.array([0]))
Out[5]: DeviceArray([0], dtype=int32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, this is intentional! The reason behind qml.math
is to ensure that the API is standardized between all autodiff frameworks.
Since
qml.math.take(x: tensor_like, indices: list or tensor_like)
is currently supported in autograd/tf/torch, we have to ensure that it also works for JAX.
Otherwise, if qml.math.take
behaved differently for the different frameworks, it wouldn't serve its purpose!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, that makes sense! Thanks :)
Codecov Report
@@ Coverage Diff @@
## master #1769 +/- ##
==========================================
+ Coverage 96.84% 98.90% +2.06%
==========================================
Files 206 206
Lines 15388 15388
==========================================
+ Hits 14902 15219 +317
+ Misses 486 169 -317
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! 🎉
Context: JAX 0.2.24 was just released, that has the following breaking change:
Description of the Change: Modifies
qml.math.take
when using JAX to take this into account.Benefits: JAX 0.2.24 works with PL
Possible Drawbacks: None.
Related GitHub Issues: n/a