-
Notifications
You must be signed in to change notification settings - Fork 586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix parameter-shift bug in JAX interface #2255
Conversation
Codecov Report
@@ Coverage Diff @@
## master #2255 +/- ##
=======================================
Coverage 99.27% 99.27%
=======================================
Files 232 232
Lines 18712 18713 +1
=======================================
+ Hits 18577 18578 +1
Misses 135 135
Continue to review full report at Codecov.
|
@antalszava I can't think of a good test to add to ensure this is fixed 🤔 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks 💯
@@ -427,6 +429,7 @@ def cost_fn(a, b): | |||
|
|||
expected = [np.sin(a) * np.sin(b), -np.cos(a) * np.cos(b)] | |||
assert np.allclose(res, expected, atol=0.1, rtol=0) | |||
assert all(not isinstance(p, jnp.ndarray) for p in spy.call_args[0][0][0].get_parameters()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@josh146 this would fail on current master 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice!! Thanks @antalszava, this is a smart way of testing it
[sc-15391] |
Context: Applying the parameter-shift rule with the JAX interface was failing, as VJP tapes sent for execution were not unwrapping their parameters.
Description of the Change: Fixes the bug by unwrapping VJP tapes prior to execution.
Benefits: Works as expected on devices that do not support
DeviceArray
objects.Possible Drawbacks: This will be a blocker to supporting higher-order derivatives with JAX, since we are converting device array's to NumPy arrays. Although I don't believe this was supported previously yet anyway.
Related GitHub Issues: Fixes #2242