-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
VJP/JVP support pytree #501
Conversation
[sc-55113] |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #501 +/- ##
=======================================
Coverage 99.55% 99.55%
=======================================
Files 43 43
Lines 7786 7802 +16
Branches 540 542 +2
=======================================
+ Hits 7751 7767 +16
Misses 18 18
Partials 17 17 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not too familiar with JVPs not VJPs, so some more comments would be nice. But the code looks good! I.e., why is there a midpoint in one of the options, and why is the shape of the VJP the same as the parameters?
@erick-xanadu It is because the JVP have the same shape as the returns, where VJP have the same shape as the parameters. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does the behaviour match JAX, is it 1-1 or are there certain deviations?
else: | ||
func_res = results[: len(jaxpr.out_avals)] | ||
vjps = results[len(jaxpr.out_avals) :] | ||
results = tuple([*func_res, tuple(vjps)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This structure seems a bit strange no? The function results are expanded but the vjps are in another tuple
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the general question: our vjp is very different from Jax vjp https://jax.readthedocs.io/en/latest/_autosummary/jax.vjp.html where they return.
res, f_vjp = tuple(res, f_vjp)
Here res are unflatten, after that you need to use the function to get the vjps
vjps = f_vjp(cot)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true about the vjp, I was mainly thinking of the PyTree behaviour for inputs, outputs, tangents, cotangents, and gradients. Those should ideally all match JAX's version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
About the vjp difference, I think we should still return a tuple of (results, gradients)
just like for the jvp, because like you say we use the same function style for both.
Context:
Following #500, we aim to add support for arbitrary return of functions for VJP and JVP.
Description of the Change: