-
Notifications
You must be signed in to change notification settings - Fork 4
Add custom rrule to speed up AD with mapped functions #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Speeds up AD of with_logabsdet_jacobian with mapped/broadcasted functions significantly.
Codecov Report
@@ Coverage Diff @@
## master #4 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 2 2
Lines 45 51 +6
=========================================
+ Hits 45 51 +6
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In principle, I think it would be fine for ChangesOfVariables to depend on ChainRulesCore. ChainRulesCore is designed as a very very lightweight dependency so it can be added to basically every other package without increasing compilation times and number of dependencies significantly.
However, I am not convinced yet that we should add the rule in this PR. In general, the function seems so simple that Zygote and other AD backends should be able to differentiate through it efficiently automatically. Therefore it's not clear to me if this is a Zygote specific issue, i.e., if e.g. rules for map
, broadcast
, or sum
are implemented inefficiently in Zygote and/or ChainRules.
Is the Zygote gradient with map
faster or is it a broadcast
-specific issue?
It also seems that the primal definition can be improved, maybe this already leads to better performance of the Zygote gradient. It would be good to check this first before adding a rule.
It's slow with |
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
1fd9f9e
to
40fb11e
Compare
40fb11e
to
d819f29
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ChainRulesTestutils issue should be fixed in CRTestutils 1.3.1 which was just released. I restarted CI.
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
The packages that currently depend on ChangesOfVariables already depend on ChainRulesCore directly or indirectly, so adding it as a dependency here doesn't affect their load time (just tested for LogExpFunctions and TransformVariables, just to make sure). |
@devmotion is this PR good to go from your side? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with adding ChainRulesCore as a dependency - it's very lightweight and most packages that depend on ChangesOfVariables will depend on ChainRulesCore anyway.
I only have two questions:
- Are the benchmarks in the PR still correct?
- Would it be useful to add a
frule
as well?
Looks like it: Current master:
This PR:
I don't think it would bring a speedup (at least not with ForwardDiff, which doesn't support |
AD of
_with_ladj_on_mapped
causes a significant overhead when using AD onwith_logabsdet_jacobian
with mapped/broadcasted functions. This PR adds a customChainRulesCore.rrule
for_with_ladj_on_mapped
that results in a siginificant speedup. GC overhead is also reduced.Pro: Very significant speed gain in the benchmark example below. Should be be similar for other lightweight transformations (example uses a
log
-trafo).Con: The price is adding ChainRulesCore as a dependency
This would make ChangesOfVariables itself less lightweight. On the other hand it's likely that packages that we'd want to depend on ChangesOfVariables will depend on ChainRulesCore already anyway (directly or indirectly).
Benchmark example:
Without this PR (no custom rrule):
With this PR (custom rrule for
_with_ladj_on_mapped
):