Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax-jit can use device vjp #4935

Merged
merged 15 commits into from Dec 18, 2023
Merged

Jax-jit can use device vjp #4935

merged 15 commits into from Dec 18, 2023

Conversation

albi3ro
Copy link
Contributor

@albi3ro albi3ro commented Dec 11, 2023

[sc-50418]

This PR binds device-provided VJPs to the JAX-JIT interface. It also switches the jax-jit interface to the "Jacobian Product Calculator" system of defining how to take Jacobians/ JVPs/ VJPs.

Since we need to place a pure-callback around the jacobian calculation, but we want to calculate results and jacobians at the same time when feasible, this PR adds the JacobianProductCalculator.execute_and_compute_jacobian method. This has the benefit over calling execute_fn and compute_jacobian separately, as it is always grad_on_execution=True when actually taking derivatives, and always grad_on_execution=False when purely calculating results.

Copy link

codecov bot commented Dec 11, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (a5c5ab9) 99.50% compared to head (c34e338) 99.49%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #4935      +/-   ##
==========================================
- Coverage   99.50%   99.49%   -0.01%     
==========================================
  Files         390      390              
  Lines       35538    35228     -310     
==========================================
- Hits        35362    35051     -311     
- Misses        176      177       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@albi3ro albi3ro marked this pull request as ready for review December 13, 2023 16:11
@albi3ro albi3ro requested a review from rmoyard December 13, 2023 19:02
Copy link
Contributor

@dwierichs dwierichs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only had a few minor comments, otherwise looks good to me!
Again, amazing to see the cleaned up logic! 💯

Those test changes though 😅

Do we want a changelog entry remarking that gradients are not precalculated if not requested (c.f. changed test in test_jax_jit.py)?

pennylane/interfaces/jacobian_products.py Show resolved Hide resolved
pennylane/interfaces/jacobian_products.py Outdated Show resolved Hide resolved
pennylane/interfaces/jacobian_products.py Show resolved Hide resolved
pennylane/interfaces/torch.py Show resolved Hide resolved
pennylane/interfaces/jax_jit.py Outdated Show resolved Hide resolved
tests/interfaces/test_jacobian_products.py Outdated Show resolved Hide resolved
tests/interfaces/test_jacobian_products.py Outdated Show resolved Hide resolved
tests/interfaces/test_jacobian_products.py Outdated Show resolved Hide resolved
tests/interfaces/test_jax_jit.py Outdated Show resolved Hide resolved
albi3ro and others added 2 commits December 14, 2023 09:41
Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
@albi3ro
Copy link
Contributor Author

albi3ro commented Dec 14, 2023

Ran some performance numbers 🚀

Screenshot 2023-12-14 at 9 50 52 AM

Shows similar behaviour to the timings for torch and autograd.

Interestingly enough, we have support for jacobians with jit, even though we don't with non-jit. I think this is because we aren't setting vectorized=True in the pure-callback around the vjp calculation. Something to revisit once the device derivative supports a broadcast dimension.

Copy link
Contributor

@timmysilv timmysilv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really bothered by this addition of a device kwarg for some reason... because we have DeviceDerivatives but it's something else? idk. anyway this is really great, will be so close after this!

Copy link
Contributor

@timmysilv timmysilv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one more questio, but assuming it isn't a problem, sounds good!

pennylane/interfaces/execution.py Show resolved Hide resolved
Copy link
Contributor

@dwierichs dwierichs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🎉

@albi3ro albi3ro merged commit f496a63 into master Dec 18, 2023
35 checks passed
@albi3ro albi3ro deleted the jax-jit-device-vjp branch December 18, 2023 14:42
mudit2812 pushed a commit that referenced this pull request Jan 19, 2024
[sc-50418]

This PR binds device-provided VJPs to the JAX-JIT interface. It also
switches the jax-jit interface to the "Jacobian Product Calculator"
system of defining how to take Jacobians/ JVPs/ VJPs.

Since we need to place a pure-callback around the jacobian calculation,
but we want to calculate results and jacobians at the same time when
feasible, this PR adds the
`JacobianProductCalculator.execute_and_compute_jacobian` method. This
has the benefit over calling `execute_fn` and `compute_jacobian`
separately, as it is always `grad_on_execution=True` when actually
taking derivatives, and always `grad_on_execution=False` when purely
calculating results.

---------

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants