Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JVP/VJP support #98

Merged
merged 121 commits into from May 18, 2023
Merged

Add JVP/VJP support #98

merged 121 commits into from May 18, 2023

Conversation

grwlf
Copy link
Contributor

@grwlf grwlf commented Apr 21, 2023

[sc-37338]

In this PR we add support for JVP and VJP operations to the Grad MLIR dialect.

  • Add 1-part JVP and VJP operations at the MLIR level.
  • Implement lowering to grad.grad and linalg.generic operations.
  • Implement einsumLinalgGeneric helper function in MLIR.
  • Provide the required Python JAX API extensions.

Note that this PR includes a workaround for #111 . Otherwise it would fail some tests.

Review process (summary on the time-demanding work):

  • Move MLIR patterns to separate files, adjust headers, etc.
  • Make tests non-parameterized
  • Split the JVP/VJP callee arguments from the (co-)tangent arguments
  • Split the JVP/VJP callee results from the jvp/vjp results
  • Unify the MLIR symbol usage verifiers for grad/jvp/vjp
  • Rebase and test the PR against the Temporary fix for memory issues. #121 workaround

@codecov
Copy link

codecov bot commented Apr 21, 2023

Codecov Report

Merging #98 (34477ab) into main (e40a94e) will increase coverage by 0.08%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main      #98      +/-   ##
==========================================
+ Coverage   98.85%   98.93%   +0.08%     
==========================================
  Files          37       37              
  Lines        6215     6310      +95     
  Branches      317      328      +11     
==========================================
+ Hits         6144     6243      +99     
+ Misses         40       38       -2     
+ Partials       31       29       -2     
Impacted Files Coverage Δ
frontend/catalyst/__init__.py 90.00% <100.00%> (ø)
frontend/catalyst/jax_primitives.py 96.42% <100.00%> (+0.30%) ⬆️
frontend/catalyst/pennylane_extensions.py 97.42% <100.00%> (+1.31%) ⬆️
frontend/catalyst/utils/calculate_grad_shape.py 92.10% <100.00%> (ø)

@grwlf grwlf changed the title Add JVP jax extension Add JVP/VJP support Apr 21, 2023
@grwlf grwlf force-pushed the jvp_vjp branch 2 times, most recently from 6e80e6b to d2c5813 Compare May 10, 2023 08:45
@grwlf grwlf requested review from dime10 and erick-xanadu May 10, 2023 10:40
@grwlf grwlf marked this pull request as ready for review May 10, 2023 10:40
Copy link
Contributor

@erick-xanadu erick-xanadu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Some changes, some comments. I haven't yet looked to deeply into the JVP/VJP algorithm. That will take me a bit longer.

frontend/catalyst/compiler.py Outdated Show resolved Hide resolved
frontend/catalyst/jax_primitives.py Outdated Show resolved Hide resolved
frontend/catalyst/jax_primitives.py Outdated Show resolved Hide resolved
frontend/catalyst/jax_primitives.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_jvpvjp.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_jvpvjp.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_jvpvjp.py Outdated Show resolved Hide resolved
mlir/lib/Gradient/Transforms/jvp_vjp_lowering.cpp Outdated Show resolved Hide resolved
mlir/test/Gradient/JVPTest.mlir Show resolved Hide resolved
Copy link
Collaborator

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @grwlf, the PR looks great so far! I haven't gone through all of it yet, but I have a few suggestions around the MLIR ops for you.

It might be easier to review once all the debugging code is cleaned up.

frontend/test/pytest/test_jvpvjp.py Outdated Show resolved Hide resolved
mlir/include/Gradient/IR/GradientOps.td Outdated Show resolved Hide resolved
mlir/include/Gradient/IR/GradientOps.td Outdated Show resolved Hide resolved
mlir/lib/Gradient/IR/GradientOps.cpp Outdated Show resolved Hide resolved
mlir/lib/Gradient/IR/GradientOps.cpp Outdated Show resolved Hide resolved
mlir/lib/Gradient/Transforms/jvp_vjp_lowering.cpp Outdated Show resolved Hide resolved
mlir/lib/Gradient/Transforms/jvp_vjp_lowering.cpp Outdated Show resolved Hide resolved
mlir/include/Gradient/IR/GradientOps.td Outdated Show resolved Hide resolved
mlir/include/Gradient/IR/GradientOps.td Outdated Show resolved Hide resolved
mlir/lib/Gradient/Transforms/jvp_vjp_lowering.cpp Outdated Show resolved Hide resolved
@grwlf
Copy link
Contributor Author

grwlf commented May 11, 2023

FYI: I found problems with the handling of scalar arguments, now I'm working on the correction. The main problem here is how to complete the einsum definition if some or all arguments are scalars.

UPDATE 1: I fixed the scalar problem quickly, but after that I faced another problem with numeric differences between PL and C quantum program gradients.

UPDATE 2: The numeric problem was due to the incorrect usage of PL/JAX interface. Should be fine now.

frontend/test/pytest/test_jvpvjp.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_jvpvjp.py Outdated Show resolved Hide resolved
frontend/catalyst/compiler.py Outdated Show resolved Hide resolved
runtime/lib/capi/RuntimeCAPI.cpp Outdated Show resolved Hide resolved
mlir/include/Gradient/IR/GradientOps.td Outdated Show resolved Hide resolved
Copy link
Collaborator

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had some time to look at the rest of the code, nice contribution @grwlf ! 🙂

mlir/lib/Gradient/Transforms/jvpvjp_lowering.cpp Outdated Show resolved Hide resolved
mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp Outdated Show resolved Hide resolved
mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp Outdated Show resolved Hide resolved
mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
frontend/catalyst/pennylane_extensions.py Outdated Show resolved Hide resolved
@PennyLaneAI PennyLaneAI deleted a comment from grwlf May 15, 2023
Copy link
Collaborator

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @grwlf, this PR should be ready to merge as soon as the last remaining comments are resolved and all debug/commented out code is removed :)

mlir/lib/Gradient/IR/GradientOps.cpp Outdated Show resolved Hide resolved
mlir/include/Gradient/IR/GradientOps.td Outdated Show resolved Hide resolved
mlir/include/Gradient/IR/GradientOps.td Outdated Show resolved Hide resolved
mlir/lib/Gradient/Transforms/gradient_lowering.cpp Outdated Show resolved Hide resolved
@grwlf grwlf force-pushed the jvp_vjp branch 3 times, most recently from 30e592c to 8bf4a38 Compare May 17, 2023 10:59
@grwlf grwlf requested review from dime10 and josh146 May 17, 2023 11:26
doc/changelog.md Outdated Show resolved Hide resolved
Copy link
Collaborator

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @grwlf 💯

Looks good to me!

doc/changelog.md Outdated Show resolved Hide resolved
doc/changelog.md Outdated Show resolved Hide resolved
grwlf and others added 27 commits May 18, 2023 15:55
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
Co-authored-by: Josh Izaac <josh146@gmail.com>
@grwlf grwlf merged commit 4721af3 into main May 18, 2023
15 checks passed
@grwlf grwlf deleted the jvp_vjp branch May 18, 2023 16:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants