-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify VJP/JVP/Grad/Jacobian frontend #508
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #508 +/- ##
==========================================
- Coverage 99.55% 99.55% -0.01%
==========================================
Files 43 43
Lines 7824 7821 -3
Branches 541 537 -4
==========================================
- Hits 7789 7786 -3
Misses 18 18
Partials 17 17 ☔ View full report in Codecov by Sentry. |
This will be really nice to have in 😎 @rmoyard I think a lot of docstrings/codeblocks (especially the quickstart, sharp bits page, etc.) will likely also need to be updated 🤔 |
@josh146 I am double checking the docs |
… into clean_vjp_jvp
I have found some changes in the quickstart and sharp bits, and updated them |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks for going over those PRs again 👍 I think we should update the docstrings of the functions as well, what do you think?
Additionally, were the original PRs introduced as a breaking change in the changelog? (This PR also needs to be added to those changelog entries.)
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
@dime10 I have added a note in each function that we support function that returns any pytree-like shape |
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
… into clean_vjp_jvp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good on my side 👍
>>> workflow(jnp.array([2.0, 1.0])) | ||
array([[-1.32116540e-07, 1.33781874e-07], | ||
[-4.20735506e-01, 4.20735506e-01]]) | ||
array([[ 1.74393425e-16 4.54648713e-01] | ||
[-1.74393425e-16 -4.54648713e-01]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we expect the gradient to change 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have copied the wrong one normally there is just a transposition, sorry for that. And this was introduced in a previous PR. Open a Pr shortly.
[[ 3.48786850e-16 -4.20735492e-01]
[-8.71967125e-17 4.20735492e-01]]
Description of the Change:
Simplify the unflatten of VJP/JVP/grad/Jacobian and covers more edge cases.