-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient/Jacobians with pytrees #500
Conversation
[sc-55112] |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #500 +/- ##
=======================================
Coverage 99.54% 99.55%
=======================================
Files 43 43
Lines 7775 7786 +11
Branches 536 540 +4
=======================================
+ Hits 7740 7751 +11
Misses 18 18
Partials 17 17 ☔ View full report in Codecov by Sentry. |
…lyst into gradient_pytree
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🥳
**Context:** Following #500, we aim to add support for arbitrary return of functions for VJP and JVP. **Description of the Change:** - JVP and VJP are updated to support pytree as return. - Clean the tests.
@@ -522,6 +523,27 @@ def _check_grad_params( | |||
return GradParams(method, scalar_out, h, argnum) | |||
|
|||
|
|||
def _unflatten_derivatives(results, out_tree, argnum, num_results): | |||
"""Unflatten the flat list of derivatives results given the out tree.""" | |||
num_trainable_params = len(argnum) if isinstance(argnum, list) else 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checking for a list is unnecessarily restrictive, maybe better to check for any iterable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should include tests for Pytree arguments as well, unless we already have those?
results = tuple( | ||
tuple(results[i * num_trainable_params : i * num_trainable_params + num_trainable_params]) | ||
for i in range(0, num_results) | ||
) | ||
# In jax, argnums=[int] wraps single derivatives in a tuple | ||
if ( | ||
(isinstance(argnum, int) or argnum is None) | ||
and num_trainable_params == 1 | ||
and num_results != 1 | ||
): | ||
results = tuple(r[0] if len(r) == 1 else r for r in results) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find these lines very difficult to follow, maybe there is a way to refactor them or explain what they are doing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is doing:
flat list of derivatives -> structure of the jacobian (tuple of tuple if multiple return and multiple parameters)
if out_tree.children() != []: | ||
results = tree_unflatten(out_tree, results) | ||
else: | ||
results = results[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this sort of special case really required? I would have hoped that unflattening the tree would take care of it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes me too, but in some cases there is an error internal to Jax.
Context:
No matter what is the structure of the return, when taking derivatives we also return the flatten derivatives.
Description of the Change:
Add support for any return when taking the derivatives and keep the structure. (Pytree support)