Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChainRules adjoints #106

Merged
merged 13 commits into from
Aug 30, 2021
Merged

Add ChainRules adjoints #106

merged 13 commits into from
Aug 30, 2021

Conversation

devmotion
Copy link
Member

Currently, DistributionsAD contains custom adjoints for some functions from StatsFuns. However, as many other parts of DistributionsAD, this is type piracy and arguably these definitions should be part of StatsFuns, following the general policy of ChainRules, or of another more lightweight package without all the dependencies and other parts of DistributionsAD.

This PR moves the StatsFuns-specific adjoints from DistributionsAD to StatsFuns and extends the implementation of the adjoints for binomlogpdf and poislogpdf. The PR adds a dependency on the lightweight package ChainRulesCore (only dependencies are LinearAlgebra, SparseArrays, and MuladdMacro which does not have any dependencies).

I am not sure if the maintainers are willing to take an additional dependency on ChainRulesCore but actually StatsFuns already depends on ChainRulesCore indirectly via SpecialFunctions. The corresponding PR was received with mixed feelings but merged and released in the end.

@nalimilan
Copy link
Member

ChainRulesCore seems like a much heavier dependency that we would be willing to take for such a low-level package as StatsFuns. I'm not convinced SpecialFunctions should have accepted that. I wish we had a broader discussion to define a common policy before that happened. Why can't these definitions live in another package? Or can't ChainRulesCore be more lightweight?

@devmotion
Copy link
Member Author

To me it seems there are two main alternatives if one wants AD to work automatically without users having to load additional packages: one can move the rules to StatsFuns or to a glue package that is loaded conditionally by ChainRules, similar to the proposal in JuliaDiff/ChainRules.jl#280 (users are not supposed to load ChainRules but ChainRules-compatible AD backends will load it). As far as I can see, the main advantage of the rules living in StatsFuns would be that there is no additional package needed which has to be maintained and whose compat bounds have to be updated in ChainRules, and that maybe one can reuse code when caching intermediate steps in the forward pass (this is not done in this initial PR but I assume it would be a natural improvement). The main advantage of the rules being in a separate glue package seems to be that StatsFuns does not have to depend on and to load ChainRulesCore, but since SpecialFunctions already depends on it the argument might not be completely valid anymore.

IMO keeping the rules in DistributionsAD is not a good long-term solution since DistributionsAD is a too heavy dependency if one is interested only in the rules for StatsFuns. Additionally, I think AD should really "just work" since it does not scale if users have to load manually an AD package with rules for every package they use.

To me it seems that ChainRulesCore is a very lightweight interface package since it only depends on two standard libraries and a standalone package without any dependencies. Unfortunately, I am not familiar with the code in ChainRulesCore, so I don't know if it would be possible to drop even some of these dependencies. Arguably, ChainsRulesCore is also quite lightweight compared with the Rmath dependency of StatsFuns - at least Rmath seems to be a reason for the maintainers of NNlib to not use StatsFuns but to reimplement logsumexp, logistic, softmax, etc.

Of course, it's completely fine if you disagree with these points and rather want these rules to live in a separate glue package. I also think it's fine if different packages and maintainers handle it differently, e.g., SpecialFunctions and apparently in the future AbstractFFTs include ChainRules rules whereas Distances rather not it seems. Since packages have different existing dependencies and maintainers put different focus on loading times and AD, I am not sure if a common policy can be achieved or would even be desirable.

@kleinschmidt
Copy link
Member

What's the actual cost associate with taking a "heavy" dependency...additional precompile/load times? Do we have any way of measuring how bad that cost is? I feel like a lot of these discussions devolve into FUD about the heaviness of dependencies, which is not to say that the concern isn't valid, just that there's a lack of data on something seems measurable...

@devmotion
Copy link
Member Author

With the latest StatsFuns release (Julia 1.5.3) I get

julia> @time using StatsFuns
  1.324432 seconds (2.62 M allocations: 141.903 MiB, 2.22% gc time)

and with this PR

julia> @time using StatsFuns
  1.439281 seconds (2.80 M allocations: 150.985 MiB, 1.69% gc time)

But even with these numbers, it is unclear to me if this difference is small or large, and if this means that ChainRulesCore is a too heavy dependency (similar to the discussion in JuliaStats/Distances.jl#172).

@kleinschmidt
Copy link
Member

I guess it might be hard to tell given that SpecialFunctions already has it as a dependency right?

@nalimilan
Copy link
Member

Other interesting criteria are precompilation times and the number of method invalidations.

There's also the less measurable maintenance cost associated to the risk of breakage, and of version conflicts in case one of the dependencies isn't ported quickly to a new major release of its own dependencies (this happens quite often).

@devmotion
Copy link
Member Author

of version conflicts in case one of the dependencies isn't ported quickly to a new major release of its own dependencies (this happens quite often)

I am pretty sure that this is not an issue here since MuladdMacro (the only non-stdlib dependency of ChainRulesCore which itself does not depend on any package) is completely stable and there are no plans to extend or change it in any way (it was written > 2 years ago and had only one commit (a dependency was removed) since then: https://github.com/SciML/MuladdMacro.jl/commits/master/src/MuladdMacro.jl).

@tpapp
Copy link
Contributor

tpapp commented Jan 6, 2021

I agree that in the long run this should be handled by a feature like JuliaLang/Pkg.jl#1285, but in the meantime, just depending on ChainRulesCore is a reasonable workaround.

It may make sense to include all adjoints in a separate single file so that they are easier to extract later on into a conditionally loaded package (also tests). But the organization of this PR is also OK with me.

@devmotion
Copy link
Member Author

devmotion commented Jan 6, 2021

It might even be possible to remove the MuladdMacro dependency completely: JuliaDiff/ChainRulesCore.jl#272

@devmotion
Copy link
Member Author

The PR was merged and released, so now ChainRulesCore only depends on LinearAlgebra and SparseArrays anymore. As mentioned above, it is unclear how reasonable it is to compare timings given that StatsFuns already indirectly depends on ChainRulesCore. Nevertheless, I ran julia --project=. --compiled-modules=no --startup-file=no -e '@time using StatsFuns;' on Julia 1.6 (beta 1) and got

master:

  7.424053 seconds (12.32 M allocations: 754.750 MiB, 3.42% gc time, 81.15% compilation time)

PR:

  7.541786 seconds (12.61 M allocations: 773.573 MiB, 3.28% gc time, 81.06% compilation time)

The timings fluctuate a bit and are not competely stable though (e.g. between 7.37 and 7.57 on master).

@codecov-io
Copy link

codecov-io commented Mar 31, 2021

Codecov Report

Merging #106 (f1e92f0) into master (bc45e18) will increase coverage by 0.86%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #106      +/-   ##
==========================================
+ Coverage   39.74%   40.61%   +0.86%     
==========================================
  Files          12       13       +1     
  Lines         478      485       +7     
==========================================
+ Hits          190      197       +7     
  Misses        288      288              
Impacted Files Coverage Δ
src/distrs/pois.jl 100.00% <ø> (ø)
src/chainrules.jl 100.00% <100.00%> (ø)
src/distrs/chisq.jl 100.00% <100.00%> (ø)
src/distrs/tdist.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update bc45e18...f1e92f0. Read the comment docs.

@stevengj
Copy link

stevengj commented Apr 3, 2021

My feeling is that differentiability is more and more going to be a basic expectation of library functions in numerical computations, and that merging it directly into the source package is generally the right thing to do given that this is such a lightweight addition.

@devmotion
Copy link
Member Author

I updated the PR to ChainRulesCore 1. Below you can find some updated timings, it seems the difference to master decreased compared with #106 (comment) (possibly due to an updated Julia version and the performance improvements in ChainRulesCore 1).

master

(StatsFuns) pkg> precompile
Precompiling project...
  1 dependency successfully precompiled in 4 seconds (16 already precompiled)

julia> @time using StatsFuns
  1.616403 seconds (2.05 M allocations: 118.309 MiB, 1.80% gc time)

this PR

(StatsFuns) pkg> precompile
Precompiling project...
  1 dependency successfully precompiled in 4 seconds (16 already precompiled)

julia> @time using StatsFuns
  1.620797 seconds (2.05 M allocations: 118.359 MiB, 0.72% gc time)

system

julia> versioninfo()
Julia Version 1.6.2
Commit 1b93d53fc4 (2021-07-14 15:36 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-7500U CPU @ 2.70GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, skylake)
Environment:
  JULIA_NUM_THREADS = 4 

@oxinabox
Copy link

oxinabox commented Aug 25, 2021

In 0.10 we removed the code that run in __init__ which was a surprising slow-down. (and also was not stable enough to keep in 1.0 that code was scary), and the dependency on MulAddMacros.
And in 1.0 we also removed basically all the invalidations.
(There is about 1 left but talked to Tim Holy about that, and he said it wasn't one that should really matter much)

We investigated making it even lighter, and you can find the issues with the plots, but concluded the additional development challenges that comes from splitting the code into extra packages was not worth it when the load time was already down to 0.05-0.07 seconds on most computers.
(JuliaDiff/ChainRulesCore.jl#413)

We have done all we can do to make this as ok to add to something like StatsFuns is possible.

  • The load time increase is literally a blink of the eye, next to the current 1.5 seconds.
  • The API is stabilized with the 1.0 release.
  • It's no longer something used just by one AD. It is reasonable to say it is the standard API. Is used by 5 separate AD systems, which is most of the ones in active use. (not all, but over half)

There were valid objections on these basis before, but I think now those are all resolved.

@matbesancon
Copy link
Member

I would be in favour of adding ChainRulesCore as a dependency here. It is now stable and the compilation time has been brought down.
The PR should be updated to admit only CRC 1.x now I guess

@devmotion
Copy link
Member Author

The PR should be updated to admit only CRC 1.x now I guess

Seems reasonable, I updated it 👍

@codecov-commenter
Copy link

codecov-commenter commented Aug 25, 2021

Codecov Report

Merging #106 (81ee960) into master (7a6dc09) will increase coverage by 1.31%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #106      +/-   ##
==========================================
+ Coverage   28.91%   30.23%   +1.31%     
==========================================
  Files          11       12       +1     
  Lines         370      377       +7     
==========================================
+ Hits          107      114       +7     
  Misses        263      263              
Impacted Files Coverage Δ
src/distrs/pois.jl 100.00% <ø> (ø)
src/chainrules.jl 100.00% <100.00%> (ø)
src/distrs/chisq.jl 100.00% <100.00%> (ø)
src/distrs/tdist.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 7a6dc09...81ee960. Read the comment docs.

@kleinschmidt
Copy link
Member

I don't have any objections at this point!

@devmotion
Copy link
Member Author

It seems all recent comments are in favour of this PR. I'll wait until the beginning of next week and merge if there are no objections.

/cc @andreasnoack @nalimilan

@nalimilan
Copy link
Member

Fine with me.

Regarding the implementation, wouldn't it make sense to move each method to the file corresponding to the relevant distribution?

@devmotion
Copy link
Member Author

Regarding the implementation, wouldn't it make sense to move each method to the file corresponding to the relevant distribution?

I don't mind, I don't have a strong opinion. Initially I had defined them next to the primal function definition. Above @tpapp suggested to move all definitions to a single file to make it easier to extract them and/or load them conditionally if Pkg supports conditional dependencies, so I updated the PR accordingly. SpecialFunctions also contains a single file with all ChainRules definitions but, of course, we can do it differently here.

@nalimilan
Copy link
Member

OK, as you prefer.

@devmotion
Copy link
Member Author

I'll keep them separate for now, I think it's maybe a bit easier to move them to the primal definitions if desired at some point instead of extracting them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants