Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for JAX 0.2.24 breaking changes #1769

Merged
merged 2 commits into from
Oct 19, 2021
Merged

Fix for JAX 0.2.24 breaking changes #1769

merged 2 commits into from
Oct 19, 2021

Conversation

josh146
Copy link
Member

@josh146 josh146 commented Oct 19, 2021

Context: JAX 0.2.24 was just released, that has the following breaking change:

jax.numpy.take and jax.numpy.take_along_axis now require array-like inputs

Description of the Change: Modifies qml.math.take when using JAX to take this into account.

Benefits: JAX 0.2.24 works with PL

Possible Drawbacks: None.

Related GitHub Issues: n/a

@josh146 josh146 added the bug 🐛 Something isn't working label Oct 19, 2021
@github-actions
Copy link
Contributor

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@@ -385,7 +385,9 @@ def _to_numpy_jax(x):
ar.register_function(
"jax",
"take",
lambda x, indices, axis=None: _i("jax").numpy.take(x, indices, axis=axis, mode="wrap"),
lambda x, indices, axis=None: _i("jax").numpy.take(
x, np.array(indices), axis=axis, mode="wrap"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@josh146 I'm wondering if this is the spot to tackle the failing tests, or explicitly convert the lists to arrays in the tests themselves.

With the change here, would that not be an issue that we deviate from how jnp.take works starting from 0.2.24? I.e., we accept list indices, whereas jnp.take doesn't:

In [2]: from jax import numpy as jnp

In [3]: arr = jnp.array([0,1,2])
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

In [4]: jnp.take(arr, [0])
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-4-3d2d7ab5c969> in <module>
----> 1 jnp.take(arr, [0])

~/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in take(a, indices, axis, out, mode)
   5419 @_wraps(np.take, skip_params=['out'])
   5420 def take(a, indices, axis: Optional[int] = None, out=None, mode=None):
-> 5421   return _take(a, indices, None if axis is None else operator.index(axis), out,
   5422                mode)
   5423 

    [... skipping hidden 13 frame]

~/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _take(a, indices, axis, out, mode)
   5426   if out is not None:
   5427     raise NotImplementedError("The 'out' argument to jnp.take is not supported.")
-> 5428   _check_arraylike("take", a, indices)
   5429   a = asarray(a)
   5430   indices = asarray(indices)

~/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _check_arraylike(fun_name, *args)
    561                     if not _arraylike(arg))
    562     msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 563     raise TypeError(msg.format(fun_name, type(arg), pos))
    564 
    565 def _check_no_float0s(fun_name, *args):

TypeError: take requires ndarray or scalar arguments, got <class 'list'> at position 1.

In [5]: jnp.take(arr, jnp.array([0]))
Out[5]: DeviceArray([0], dtype=int32)

Copy link
Member Author

@josh146 josh146 Oct 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is intentional! The reason behind qml.math is to ensure that the API is standardized between all autodiff frameworks.

Since

qml.math.take(x: tensor_like, indices: list or tensor_like)

is currently supported in autograd/tf/torch, we have to ensure that it also works for JAX.

Otherwise, if qml.math.take behaved differently for the different frameworks, it wouldn't serve its purpose!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that makes sense! Thanks :)

@codecov
Copy link

codecov bot commented Oct 19, 2021

Codecov Report

Merging #1769 (74ec867) into master (5aa93dd) will increase coverage by 2.06%.
The diff coverage is n/a.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1769      +/-   ##
==========================================
+ Coverage   96.84%   98.90%   +2.06%     
==========================================
  Files         206      206              
  Lines       15388    15388              
==========================================
+ Hits        14902    15219     +317     
+ Misses        486      169     -317     
Impacted Files Coverage Δ
pennylane/math/single_dispatch.py 99.48% <ø> (+11.73%) ⬆️
pennylane/interfaces/batch/__init__.py 100.00% <0.00%> (+0.96%) ⬆️
pennylane/devices/default_qubit.py 100.00% <0.00%> (+1.22%) ⬆️
pennylane/beta/devices/default_tensor.py 96.93% <0.00%> (+1.70%) ⬆️
pennylane/interfaces/batch/tensorflow.py 100.00% <0.00%> (+2.22%) ⬆️
pennylane/beta/devices/default_tensor_tf.py 90.62% <0.00%> (+3.12%) ⬆️
pennylane/interfaces/batch/torch.py 100.00% <0.00%> (+3.27%) ⬆️
pennylane/interfaces/torch.py 100.00% <0.00%> (+3.29%) ⬆️
pennylane/devices/default_qubit_tf.py 90.00% <0.00%> (+3.33%) ⬆️
pennylane/collections/qnode_collection.py 100.00% <0.00%> (+3.50%) ⬆️
... and 14 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5aa93dd...74ec867. Read the comment docs.

Copy link
Contributor

@antalszava antalszava left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! 🎉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants