New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add JVP/VJP support #98
Conversation
Codecov Report
@@ Coverage Diff @@
## main #98 +/- ##
==========================================
+ Coverage 98.85% 98.93% +0.08%
==========================================
Files 37 37
Lines 6215 6310 +95
Branches 317 328 +11
==========================================
+ Hits 6144 6243 +99
+ Misses 40 38 -2
+ Partials 31 29 -2
|
6e80e6b
to
d2c5813
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Some changes, some comments. I haven't yet looked to deeply into the JVP/VJP algorithm. That will take me a bit longer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @grwlf, the PR looks great so far! I haven't gone through all of it yet, but I have a few suggestions around the MLIR ops for you.
It might be easier to review once all the debugging code is cleaned up.
FYI: I found problems with the handling of scalar arguments, now I'm working on the correction. The main problem here is how to complete the UPDATE 1: I fixed the scalar problem quickly, but after that I faced another problem with numeric differences between PL and C quantum program gradients. UPDATE 2: The numeric problem was due to the incorrect usage of PL/JAX interface. Should be fine now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had some time to look at the rest of the code, nice contribution @grwlf ! 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @grwlf, this PR should be ready to merge as soon as the last remaining comments are resolved and all debug/commented out code is removed :)
30e592c
to
8bf4a38
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @grwlf 💯
Looks good to me!
[sc-37338]
In this PR we add support for JVP and VJP operations to the Grad MLIR dialect.
grad.grad
andlinalg.generic
operations.einsumLinalgGeneric
helper function in MLIR.Note that this PR includes a workaround for #111 . Otherwise it would fail some tests.
Review process (summary on the time-demanding work):