-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RuleConfig and Zygote support #49
Conversation
This addresses the motivation behind #39. |
Codecov Report
@@ Coverage Diff @@
## master #49 +/- ##
==========================================
+ Coverage 83.22% 83.36% +0.14%
==========================================
Files 5 6 +1
Lines 465 475 +10
==========================================
+ Hits 387 396 +9
- Misses 78 79 +1
Continue to review full report at Codecov.
|
Seems that the Julia 1.0 tests are failing because ZygoteRuleConfig is only available in a recent version of Zygote which is not compatible with Julia 1.0 |
@odow is JuMP likely to drop Julia 1.0 support any time soon? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps add a patch version bump so we can release after this?
RuleConfig support is now available only in Julia 1.6+. I will go ahead and merge this then if 1.0 support is still required for RuleConfig, we can open another issue/PR discussing/implementing that. |
|
||
AD.@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...) | ||
return (vs) -> begin | ||
_, back = rrule_via_ad(ab.ruleconfig, f, xs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's reason for moving this in the function body? Isn't it better to obtain back
only once outside of the function instead of doing it at every invocation of the pullback function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good point.
AD.@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...) | ||
return (vs) -> begin | ||
_, back = rrule_via_ad(ab.ruleconfig, f, xs...) | ||
if vs isa Tuple && length(vs) === 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this compile away in the same way as, I would assume, if vs isa Tuple{Any}
?
return (vs) -> begin | ||
_, back = rrule_via_ad(ab.ruleconfig, f, xs...) | ||
if vs isa Tuple && length(vs) === 1 | ||
return Base.tail(back(vs[1])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems this rules out support for functors f
as it removes the derivatives wrt to f
? But maybe generally the design of AbstractDifferentiation does not support them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is so. AbstractDifferentiation supports a strictly smaller set of applications than, say, Zygote. It's all about scalars and arrays. And even there, I'm pretty sure results will be inconsistent if one tries to, say, take the gradient of a structured array. And functors are not supported.
This PR implements RuleConfig support via rrule_via_ad. I tested Zygote and tests pass. I tried Yota but tests didn't pass for some reason so I will open another PR. This is a very minimal PR to get things working but I didn't do the excellent benchmarks @sethaxen does in these kinds of PRs. I think optimisations can be added as a later PR.