New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax-jit can use device vjp #4935
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #4935 +/- ##
==========================================
- Coverage 99.50% 99.49% -0.01%
==========================================
Files 390 390
Lines 35538 35228 -310
==========================================
- Hits 35362 35051 -311
- Misses 176 177 +1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only had a few minor comments, otherwise looks good to me!
Again, amazing to see the cleaned up logic! 💯
Those test changes though 😅
Do we want a changelog entry remarking that gradients are not precalculated if not requested (c.f. changed test in test_jax_jit.py
)?
tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py
Outdated
Show resolved
Hide resolved
Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm really bothered by this addition of a device kwarg for some reason... because we have DeviceDerivatives but it's something else? idk. anyway this is really great, will be so close after this!
tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one more questio, but assuming it isn't a problem, sounds good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 🎉
[sc-50418] This PR binds device-provided VJPs to the JAX-JIT interface. It also switches the jax-jit interface to the "Jacobian Product Calculator" system of defining how to take Jacobians/ JVPs/ VJPs. Since we need to place a pure-callback around the jacobian calculation, but we want to calculate results and jacobians at the same time when feasible, this PR adds the `JacobianProductCalculator.execute_and_compute_jacobian` method. This has the benefit over calling `execute_fn` and `compute_jacobian` separately, as it is always `grad_on_execution=True` when actually taking derivatives, and always `grad_on_execution=False` when purely calculating results. --------- Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
[sc-50418]
This PR binds device-provided VJPs to the JAX-JIT interface. It also switches the jax-jit interface to the "Jacobian Product Calculator" system of defining how to take Jacobians/ JVPs/ VJPs.
Since we need to place a pure-callback around the jacobian calculation, but we want to calculate results and jacobians at the same time when feasible, this PR adds the
JacobianProductCalculator.execute_and_compute_jacobian
method. This has the benefit over callingexecute_fn
andcompute_jacobian
separately, as it is alwaysgrad_on_execution=True
when actually taking derivatives, and alwaysgrad_on_execution=False
when purely calculating results.