Skip to content

Add custom rrule to speed up AD with mapped functions #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 10, 2021

Conversation

oschulz
Copy link
Collaborator

@oschulz oschulz commented Nov 8, 2021

AD of _with_ladj_on_mapped causes a significant overhead when using AD on with_logabsdet_jacobian with mapped/broadcasted functions. This PR adds a custom ChainRulesCore.rrule for _with_ladj_on_mapped that results in a siginificant speedup. GC overhead is also reduced.

Pro: Very significant speed gain in the benchmark example below. Should be be similar for other lightweight transformations (example uses a log-trafo).

Con: The price is adding ChainRulesCore as a dependency

This would make ChangesOfVariables itself less lightweight. On the other hand it's likely that packages that we'd want to depend on ChangesOfVariables will depend on ChainRulesCore already anyway (directly or indirectly).

Benchmark example:

using ChangesOfVariables, LinearAlgebra, Zygote, BenchmarkTools

function foo(xs)
    ys, ladj = with_logabsdet_jacobian(Base.Fix1(broadcast, log), xs)
    dot(ys, ys) + ladj
end

grad_foo(xs) = Zygote.gradient(foo, xs)[1]

xs = rand(10^3);
grad_foo(xs);

@benchmark grad_foo($xs)

Without this PR (no custom rrule):

julia> @benchmark grad_foo($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  30.507 μs    2.065 ms  ┊ GC (min  max):  0.00%  95.86%
 Time  (median):     46.969 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   58.705 μs ± 122.055 μs  ┊ GC (mean ± σ):  15.00% ±  7.01%

         ▁▆█▇▄▃▃▂▁▁                                            ▁
  ▅▄▆██▆▇█████████████▇▆▅▅▄▅▄▆▆▆▆▆▅▅▆▆▆▇▆▇▆▅▆▆▆▆▆▆▅▄▅▅▅▃▅▃▄▄▄▃ █
  30.5 μs       Histogram: log(frequency) by time       130 μs <

 Memory estimate: 248.20 KiB, allocs estimate: 99.

With this PR (custom rrule for _with_ladj_on_mapped):

julia> @benchmark grad_foo($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  19.175 μs   2.898 ms  ┊ GC (min  max):  0.00%  93.90%
 Time  (median):     27.598 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):   34.333 μs ± 81.148 μs  ┊ GC (mean ± σ):  12.31% ±  5.25%

         ▁▄▇█▇▅▄▃▃▂▁▁▁ ▁▁▂▃▃▃▂▂▁                              ▂
  ▄▆▅▃▃▅▇█████████████████████████▇▇▇▇▇▆▆▅▆▅▅▆▄▅▄▄▅▄▅▄▃▄▄▄▅▄▄ █
  19.2 μs      Histogram: log(frequency) by time      64.4 μs <

 Memory estimate: 136.44 KiB, allocs estimate: 50.

Speeds up AD of with_logabsdet_jacobian with mapped/broadcasted
functions significantly.
@codecov
Copy link

codecov bot commented Nov 8, 2021

Codecov Report

Merging #4 (ff72a98) into master (9d2a665) will not change coverage.
The diff coverage is 100.00%.

❗ Current head ff72a98 differs from pull request most recent head c143c49. Consider uploading reports for the commit c143c49 to get more accurate results
Impacted file tree graph

@@            Coverage Diff            @@
##            master        #4   +/-   ##
=========================================
  Coverage   100.00%   100.00%           
=========================================
  Files            2         2           
  Lines           45        51    +6     
=========================================
+ Hits            45        51    +6     
Impacted Files Coverage Δ
src/with_ladj.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9d2a665...c143c49. Read the comment docs.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, I think it would be fine for ChangesOfVariables to depend on ChainRulesCore. ChainRulesCore is designed as a very very lightweight dependency so it can be added to basically every other package without increasing compilation times and number of dependencies significantly.

However, I am not convinced yet that we should add the rule in this PR. In general, the function seems so simple that Zygote and other AD backends should be able to differentiate through it efficiently automatically. Therefore it's not clear to me if this is a Zygote specific issue, i.e., if e.g. rules for map, broadcast, or sum are implemented inefficiently in Zygote and/or ChainRules.

Is the Zygote gradient with map faster or is it a broadcast-specific issue?

It also seems that the primal definition can be improved, maybe this already leads to better performance of the Zygote gradient. It would be good to check this first before adding a rule.

@oschulz
Copy link
Collaborator Author

oschulz commented Nov 8, 2021

Is the Zygote gradient with map faster or is it a broadcast-specific issue?

It's slow with sum(map(last, y_with_ladj)), sum(broadcast(last, y_with_ladj)) and, sum(last, y_with_ladj). :-(

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
@oschulz oschulz force-pushed the bcased-trafo-rrules branch 2 times, most recently from 1fd9f9e to 40fb11e Compare November 9, 2021 12:07
@oschulz oschulz force-pushed the bcased-trafo-rrules branch from 40fb11e to d819f29 Compare November 9, 2021 12:08
@oschulz oschulz changed the title [RFQ] Add custom rrule to speed up AD with mapped functions Add custom rrule to speed up AD with mapped functions Nov 9, 2021
Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ChainRulesTestutils issue should be fixed in CRTestutils 1.3.1 which was just released. I restarted CI.

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
@oschulz
Copy link
Collaborator Author

oschulz commented Dec 9, 2021

The packages that currently depend on ChangesOfVariables already depend on ChainRulesCore directly or indirectly, so adding it as a dependency here doesn't affect their load time (just tested for LogExpFunctions and TransformVariables, just to make sure).

@oschulz
Copy link
Collaborator Author

oschulz commented Dec 9, 2021

@devmotion is this PR good to go from your side?

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with adding ChainRulesCore as a dependency - it's very lightweight and most packages that depend on ChangesOfVariables will depend on ChainRulesCore anyway.

I only have two questions:

  • Are the benchmarks in the PR still correct?
  • Would it be useful to add a frule as well?

@oschulz
Copy link
Collaborator Author

oschulz commented Dec 10, 2021

  • Are the benchmarks in the PR still correct?

Looks like it:

Current master:

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  31.854 μs …   2.385 ms  ┊ GC (min … max):  0.00% … 94.80%
 Time  (median):     43.974 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   59.229 μs ± 146.713 μs  ┊ GC (mean ± σ):  18.31% ±  7.22%

This PR:

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  18.304 μs …   2.671 ms  ┊ GC (min … max):  0.00% … 97.44%
 Time  (median):     25.303 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   34.384 μs ± 105.128 μs  ┊ GC (mean ± σ):  16.69% ±  5.42%
  • Would it be useful to add a frule as well?

I don't think it would bring a speedup (at least not with ForwardDiff, which doesn't support frule anyway, we'd need a Dual rule and depend on it). Maybe hold off on that until we can benchmark it (when a major AD framework supports fruless)? We can always add an frule later.

@oschulz oschulz merged commit c143c49 into master Dec 10, 2021
@oschulz oschulz deleted the bcased-trafo-rrules branch December 10, 2021 10:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants