Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evaluating rules is unreasonably slow on first call #5

Closed
ararslan opened this issue May 27, 2019 · 8 comments
Closed

Evaluating rules is unreasonably slow on first call #5

ararslan opened this issue May 27, 2019 · 8 comments

Comments

@ararslan
Copy link
Member

Currently, evaluating rules can be pretty slow (and in some cases extremely slow) on the first call. It seems to be due to JIT compiling the anonymous functions, the use of which is a central design point of ChainRules.

Here's a basic example, where it takes 2.6 seconds(!!!) to evaluate the derivative for svd. Note that the ChainRules version is a simple port of the Nabla version, so the underlying code that does the computation is nearly identical.

julia> using ChainRules, LinearAlgebra

julia> F, dX = rrule(svd, randn(4, 4));

julia> nt = (U=F.U, S=F.S, V=F.V);

julia> @time dX(nt)
  2.644154 seconds (11.32 M allocations: 589.599 MiB, 6.32% gc time)
4×4 Array{Float64,2}:
  0.189976   0.714807   0.456714    2.18944 
 -0.313175  -0.760914  -0.0347824   0.465126
 -0.53937    1.4229    -1.20063    -1.46798 
 -0.517822   1.03534   -1.66866     0.263543

julia> @time dX(nt)
  0.017788 seconds (2.15 k allocations: 146.201 KiB)
4×4 Array{Float64,2}:
  0.189976   0.714807   0.456714    2.18944 
 -0.313175  -0.760914  -0.0347824   0.465126
 -0.53937    1.4229    -1.20063    -1.46798 
 -0.517822   1.03534   -1.66866     0.263543

Compare this to Nabla:

julia> using Nabla, LinearAlgebra

julia> X = randn(4, 4); F = svd(X); nt = (U=F.U, S=F.S, V=F.V);

julia> @time ∇(svd, Arg{1}, (), F, nt, X)
  0.631797 seconds (2.37 M allocations: 114.680 MiB, 3.77% gc time)
4×4 Array{Float64,2}:
 -1.09959    1.16123    2.27612    -0.97398 
  1.42299   -0.17832   -2.52324     1.48229 
 -3.32982   -0.746237   1.1226      2.98457 
  0.346265   0.402947  -0.0502055   0.661208

julia> @time dX(nt)
  0.006206 seconds (2.15 k allocations: 146.201 KiB)
4×4 Array{Float64,2}:
  0.189976   0.714807   0.456714    2.18944 
 -0.313175  -0.760914  -0.0347824   0.465126
 -0.53937    1.4229    -1.20063    -1.46798 
 -0.517822   1.03534   -1.66866     0.263543

We should find some way(s) to mitigate this so that AD systems which switch to using ChainRules underneath won't take an enormous performance hit by doing so.

@oxinabox
Copy link
Member

oxinabox commented May 28, 2019

What happens if we do a top level @nospecialiize on the whole package?

Answer: helps a bit but not much

@iamed2
Copy link

iamed2 commented May 28, 2019

What happens if we do a top level @nospecialiize on the whole package?

Is there still a large overhead between first and second call? Or are they about the same?

@oxinabox
Copy link
Member

It was below the noise level for both.
Some minor different in allocations but not much.

@iamed2
Copy link

iamed2 commented May 28, 2019

You're probably just getting interpreter run time then

@oxinabox oxinabox transferred this issue from JuliaDiff/ChainRules.jl Aug 2, 2019
@oxinabox
Copy link
Member

oxinabox commented Aug 7, 2019

This feels a lot better when I was trying it after the breakup into two packages,
and when working in julia 1.3,
but I've not done new timing.
I suspect breaking into two packages has made precompilation happier.

@oxinabox
Copy link
Member

#35 does indeed fix this.

Test Script

using Pkg: @pkg_str
pkg"activate /Users/oxinabox/JuliaEnvs/ChainRulesWorld/"
using ChainRules
using LinearAlgebra

function main()
    x = [1 2; 3 4]
    @time f, res = rrule(svd, x)
    r = (U = f.U, S=f.S, V=f.V)
    @time res(r)
end
main()

Timing Results

Current Master

rrule: 0.043273 seconds (29.90 k allocations: 1.602 MiB, 14.56% gc time)
res : 0.029336 seconds (59.08 k allocations: 3.668 MiB)

No cassette branch #35

rrule: 0.040967 seconds (29.90 k allocations: 1.602 MiB)
res: 0.000014 seconds (19 allocations: 1.891 KiB)

@nickrobinson251
Copy link
Contributor

pkg"activate ChainRulesWorld"

😆

@oxinabox
Copy link
Member

I have a Manifest.toml there
that has all the right branchs of ChainRules, ChainRulesCore, Zygote and FiniteDifferences etc
checked out for what ever I am working on.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants