Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuleConfig and Zygote support #49

Merged
merged 8 commits into from
Feb 8, 2022
Merged

RuleConfig and Zygote support #49

merged 8 commits into from
Feb 8, 2022

Conversation

mohamed82008
Copy link
Member

This PR implements RuleConfig support via rrule_via_ad. I tested Zygote and tests pass. I tried Yota but tests didn't pass for some reason so I will open another PR. This is a very minimal PR to get things working but I didn't do the excellent benchmarks @sethaxen does in these kinds of PRs. I think optimisations can be added as a later PR.

@mohamed82008
Copy link
Member Author

This addresses the motivation behind #39.

@mohamed82008 mohamed82008 mentioned this pull request Feb 8, 2022
@codecov-commenter
Copy link

codecov-commenter commented Feb 8, 2022

Codecov Report

Merging #49 (6207218) into master (9a3b564) will increase coverage by 0.14%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #49      +/-   ##
==========================================
+ Coverage   83.22%   83.36%   +0.14%     
==========================================
  Files           5        6       +1     
  Lines         465      475      +10     
==========================================
+ Hits          387      396       +9     
- Misses         78       79       +1     
Impacted Files Coverage Δ
src/AbstractDifferentiation.jl 78.41% <100.00%> (-0.10%) ⬇️
src/ruleconfig.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9a3b564...6207218. Read the comment docs.

@mohamed82008
Copy link
Member Author

Seems that the Julia 1.0 tests are failing because ZygoteRuleConfig is only available in a recent version of Zygote which is not compatible with Julia 1.0

@mohamed82008
Copy link
Member Author

mohamed82008 commented Feb 8, 2022

@odow is JuMP likely to drop Julia 1.0 support any time soon?

Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps add a patch version bump so we can release after this?

src/ruleconfig.jl Outdated Show resolved Hide resolved
@mohamed82008
Copy link
Member Author

RuleConfig support is now available only in Julia 1.6+. I will go ahead and merge this then if 1.0 support is still required for RuleConfig, we can open another issue/PR discussing/implementing that.

@mohamed82008 mohamed82008 merged commit 8f0d6db into master Feb 8, 2022

AD.@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...)
return (vs) -> begin
_, back = rrule_via_ad(ab.ruleconfig, f, xs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's reason for moving this in the function body? Isn't it better to obtain back only once outside of the function instead of doing it at every invocation of the pullback function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point.

AD.@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...)
return (vs) -> begin
_, back = rrule_via_ad(ab.ruleconfig, f, xs...)
if vs isa Tuple && length(vs) === 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this compile away in the same way as, I would assume, if vs isa Tuple{Any}?

return (vs) -> begin
_, back = rrule_via_ad(ab.ruleconfig, f, xs...)
if vs isa Tuple && length(vs) === 1
return Base.tail(back(vs[1]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this rules out support for functors f as it removes the derivatives wrt to f? But maybe generally the design of AbstractDifferentiation does not support them?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is so. AbstractDifferentiation supports a strictly smaller set of applications than, say, Zygote. It's all about scalars and arrays. And even there, I'm pretty sure results will be inconsistent if one tries to, say, take the gradient of a structured array. And functors are not supported.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants