Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parameter-shift bug in JAX interface #2255

Merged
merged 6 commits into from
Mar 2, 2022
Merged

Fix parameter-shift bug in JAX interface #2255

merged 6 commits into from
Mar 2, 2022

Conversation

josh146
Copy link
Member

@josh146 josh146 commented Mar 1, 2022

Context: Applying the parameter-shift rule with the JAX interface was failing, as VJP tapes sent for execution were not unwrapping their parameters.

Description of the Change: Fixes the bug by unwrapping VJP tapes prior to execution.

Benefits: Works as expected on devices that do not support DeviceArray objects.

Possible Drawbacks: This will be a blocker to supporting higher-order derivatives with JAX, since we are converting device array's to NumPy arrays. Although I don't believe this was supported previously yet anyway.

Related GitHub Issues: Fixes #2242

@josh146 josh146 requested a review from antalszava March 1, 2022 06:33
@josh146 josh146 added the bug 🐛 Something isn't working label Mar 1, 2022
@codecov
Copy link

codecov bot commented Mar 1, 2022

Codecov Report

Merging #2255 (2437ee2) into master (f400217) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master    #2255   +/-   ##
=======================================
  Coverage   99.27%   99.27%           
=======================================
  Files         232      232           
  Lines       18712    18713    +1     
=======================================
+ Hits        18577    18578    +1     
  Misses        135      135           
Impacted Files Coverage Δ
pennylane/interfaces/batch/jax.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f400217...2437ee2. Read the comment docs.

@josh146
Copy link
Member Author

josh146 commented Mar 1, 2022

@antalszava I can't think of a good test to add to ensure this is fixed 🤔

Copy link
Contributor

@antalszava antalszava left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks 💯

@@ -427,6 +429,7 @@ def cost_fn(a, b):

expected = [np.sin(a) * np.sin(b), -np.cos(a) * np.cos(b)]
assert np.allclose(res, expected, atol=0.1, rtol=0)
assert all(not isinstance(p, jnp.ndarray) for p in spy.call_args[0][0][0].get_parameters())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@josh146 this would fail on current master 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice!! Thanks @antalszava, this is a smart way of testing it

@antalszava
Copy link
Contributor

[sc-15391]

@josh146 josh146 merged commit cd4cd43 into master Mar 2, 2022
@josh146 josh146 deleted the fix-jax-python branch March 2, 2022 10:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] Taking gradients with interface='jax' or interface='jax-python' and backend=`qiskit.aer´ fails
2 participants