-
Notifications
You must be signed in to change notification settings - Fork 575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixing JAX interface for adjoint diff #1349
Conversation
Hello. You may have forgotten to update the changelog!
|
Codecov Report
@@ Coverage Diff @@
## master #1349 +/- ##
=======================================
Coverage 98.16% 98.16%
=======================================
Files 154 154
Lines 11544 11548 +4
=======================================
+ Hits 11332 11336 +4
Misses 212 212
Continue to review full report at Codecov.
|
pennylane/qnode.py
Outdated
jnp.array(0.1, dtype=jnp.float64) | ||
except UserWarning: | ||
return False | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know how to test this line, given float64 mode must be configured on startup.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching this @albi3ro! Just two suggestions re: improving the precision check and moving the location of the warning logic, but once addressed this should be good for merging.
pennylane/qnode.py
Outdated
import jax.numpy as jnp | ||
|
||
with warnings.catch_warnings(): | ||
warnings.filterwarnings("error") | ||
try: | ||
jnp.array(0.1, dtype=jnp.float64) | ||
except UserWarning: | ||
return False | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice workaround, but it feels a bit too reliant on implementation details of JAX and how it raises warnings.
I would suggest instead simply querying the JAX configuration:
>>> from jax.config import config
>>> config.x64_enabled
False
>>> config.update("jax_enable_x64", True)
>>> config.x64_enabled
True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I would recommend moving this to interfaces/jax.py
, so that all JAX imports are isolated in a single place. The warning can then be raised by the tape mixin class querying jacobian_options
dictionary
pennylane/qnode.py
Outdated
if interface == "jax": | ||
if not _jax_float64_support(): | ||
warnings.warn( | ||
"float64 support not enabled for jax. " | ||
"May cause inaccuracies for `diff_method='finite-diff'`", | ||
UserWarning, | ||
) | ||
|
||
return JacobianTape, interface, device, {"method": "numeric"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I would say this definitely feels like something that lives on the interface rather than the QNode :) Especially because it simply raises a warning, rather than changing the diff method.
JacobianTape
copy duplicates jacobian_options
.github/CHANGELOG.md
Outdated
* Fixes drawing QNodes that contain multiple measurements on a single wire. | ||
[(#1353)](https://github.com/PennyLaneAI/pennylane/pull/1353) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pennylane/tape/jacobian_tape.py
Outdated
@@ -533,7 +538,8 @@ def jacobian(self, device, params=None, **options): | |||
# First order (forward) finite-difference will be performed. | |||
# Compute the value of the tape at the current parameters here. This ensures | |||
# this computation is only performed once, for all parameters. | |||
options["y0"] = np.asarray(self.execute_device(params, device)) | |||
params_f64 = np.array(params, dtype=np.float64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One of the jax device tests doesn't seem to like this line 🤔
https://github.com/PennyLaneAI/pennylane/pull/1349/checks?check_run_id=2697038221#step:10:63
On another note: wouldn't this casting be a bit too specific at a general place? This would potentially cast the parameters all the time regardless of the interface. Would we only like to do it for the jax interface?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albi3ro I think I agree with @antalszava in this case. While there is a good argument for enforcing finite difference to always use float64, I feel like this should be an interface level option/default, rather than enforced in the Jacobian method
@@ -55,6 +55,23 @@ def test_parameter_info(self, make_tape): | |||
} | |||
|
|||
|
|||
class TestTapeCopying: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth adding an integration test similar to the original example in the #1345 issue to ensure that the results using the jax
interface are correct.
Hi @albi3ro, it must have been quite some digging to work this out, nice one! 🥇 One device test seems to be failing and also wondering if as Josh suggested, it could be good to add the fix specifically to the jax interface code. |
The one test is failing because of an issue I just raised, Issue #1369 . When a parameter is a matrix instead of float, While this problem arose with the jax interface, it will actually occur any time we use finite differences with
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @albi3ro!
When a parameter is a matrix instead of float, params is a deprecated ragged array and can't be cast to float64.
I think this might be a bug with the test not specifying requires_grad=False
on the matrix, as opposed to a bug in the code, see
#1369 (comment)
I have two requests:
-
I feel like the general
float64
fix within theJacobianTape
is out of scope for this PR, and should be considered separately. Instead, I think it should be sufficient to just make the JAX interface usefloat64
in the finite-diff tests (with maybe small modifications tojax_ops.py
). -
Does this mean that there are no adjoint tests for the JAX interface? We should probably add them now to ensure that this bug is fixed :)
pennylane/tape/jacobian_tape.py
Outdated
@@ -533,7 +538,8 @@ def jacobian(self, device, params=None, **options): | |||
# First order (forward) finite-difference will be performed. | |||
# Compute the value of the tape at the current parameters here. This ensures | |||
# this computation is only performed once, for all parameters. | |||
options["y0"] = np.asarray(self.execute_device(params, device)) | |||
params_f64 = np.array(params, dtype=np.float64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albi3ro I think I agree with @antalszava in this case. While there is a good argument for enforcing finite difference to always use float64, I feel like this should be an interface level option/default, rather than enforced in the Jacobian method
I added the finite-diff change to this PR because I thought it was going to be easy. But there needs to better a better solution than just the easy one, so that is getting moved to it's own PR. I had the adjoint test before, but it got removed when I fire and brimstoned the PR after a bad merge that messed everything up. It's back. Everything should be good to go now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good @albi3ro 🙂
@@ -119,7 +119,7 @@ def get_parameters(self, trainable_only=True, return_arraybox=False): | |||
qml.RY(0.543, wires=0) | |||
qml.CNOT(wires=[0, 'a']) | |||
qml.RX(0.133, wires='a') | |||
expval(qml.PauliZ(wires=[0])) | |||
qml.expval(qml.PauliZ(wires=[0])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, good catch
grad_adjoint = jax.grad(qnode_adjoint)(params1, params2) | ||
grad_backprop = jax.grad(qnode_backprop)(params1, params2) | ||
|
||
assert np.allclose(grad_adjoint, grad_backprop) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to have this test in 💪
Closes Issue #1345
Currently, the jax interface fails to return correct results with the adjoint differentiation method. The Jax interface makes a copy of the current tape to take it's derivative, but the tape copy didn't copy over differentiation information.
I have attempted to rectify this by improving the tape copy method, but now trying to take the derivative of the circuit causes my kernel to end.