Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Take nothing seriously #967

Closed
wants to merge 7 commits into from
Closed

Take nothing seriously #967

wants to merge 7 commits into from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented May 7, 2021

ChainRules has a method rrule(_...) = nothing to indicate that no rule has been defined, which is the signal for Zygote to keep looking inside the function. At present, this is done by checking that this exact method would be the one called, rather than allowing other user-defined rrule methods which return nothing to play the same role. This fixes that, by checking the return type not just the method's identity.

I'm told that compile time might be a concern. This appears to be unchanged locally.

All but 3 of the tests of ChainRules integration pass if I try them individually, but many seem to fail when running Pkg.test, I'm not sure why. All other tests pass.

@mcabbott mcabbott marked this pull request as draft May 8, 2021 02:00
@oxinabox
Copy link
Member

oxinabox commented May 9, 2021

I will review in next few days.
There is a benchmark suite on the original ChainRules PR.
You might want to run it.

@mcabbott
Copy link
Member Author

mcabbott commented May 9, 2021

Digging in old issues:

  • Around here Add ChainRules #366 (comment) you discuss explicitly testing return === nothing. This inserts function calls which ought to be trivial into the hot code, and might cause inference to give up. The present implementation avoids this.
  • Further down, you discuss what happens if rules change. I think this PR requires Zygote.refresh() in order to see that an rrule now returns nothing, or now doesn't. Could this be improved? I'm not sure I follow what return false, m.instance is doing, but it's doing it for methods without a rule.
  • I don't see the benchmarks mentioned.

@mzgubic
Copy link
Collaborator

mzgubic commented May 10, 2021

What's the intended use case for this?

@willtebbutt
Copy link
Member

willtebbutt commented May 10, 2021

The ability to write

rrule(::typeof(foo), ::AbstractArray) = <some_generic_code>
rrule(::typeof(foo), ::SubtypeOfAbstractArray) = nothing

to say "don't use the generic rrule for this subtype, try and AD through it".

Comment on lines 14 to 19
if m.method === chainrules_fallback
# no rule exists
return false, m.instance
elseif Core.Compiler.return_type(rrule, Tuple{T.parameters...}) === Nothing
Core.println(" has_chain_rrule got Nothing")
# or we hit a rule telling us to keep digging
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the chainrules_fallback is just one rule that matches this pattern so we can get rid of the clause mentioning it explictly.
and can delete the const on line 1 of this fule

Suggested change
if m.method === chainrules_fallback
# no rule exists
return false, m.instance
elseif Core.Compiler.return_type(rrule, Tuple{T.parameters...}) === Nothing
Core.println(" has_chain_rrule got Nothing")
# or we hit a rule telling us to keep digging
if Core.Compiler.return_type(rrule, Tuple{T.parameters...}) === Nothing
# no applicable rule

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is excluding this likely to affect compile time? IIUC @mcabbott 's reason for keeping this branch was to avoid having to do type inference some of the time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was initially thinking not to call Compiler.return_type more times than needed, but without thinking much about whether that actually matters. This does seem to produce specialisations, so I guess there may be some cost, even if all of them are very simple?

julia> f(x) = nothing; f(1); f(1.0); f(1+im);

julia> first(methods(f)).specializations
svec(MethodInstance for f(::Int64), MethodInstance for f(::Float64), MethodInstance for f(::Complex{Int64}), #undef, #undef, #undef, #undef, #undef)

julia> g(x) = nothing; Core.Compiler.return_type(g, Tuple{Int}); Core.Compiler.return_type(g, Tuple{Float64}); Core.Compiler.return_type(g, Tuple{String});

julia> first(methods(g)).specializations
svec(MethodInstance for g(::Int64), MethodInstance for g(::Float64), MethodInstance for g(::String), #undef, #undef, #undef, #undef, #undef)

Copy link
Member

@oxinabox oxinabox May 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure that it will actually make a difference in this case though, I think we might already have had the specialization made to the the m.instance made.

On that basis, I wonder if there is a place in IRTools that we should be asking for the return type?
Or if we should be passing m.instance to something we can ccall?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IRTools I don't know.

Is there some subtlety of testing the negative if m.method !== chainrules_fallback (before this PR) which I'm overlooking?

Copy link
Member Author

@mcabbott mcabbott Jun 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fbb3fcf makes a change like the suggestion. See what CI thinks of it. This seems to let tests pass without refresh, which is strange but welcome.

@@ -41,7 +41,7 @@ end
end

@testset "ChainRules" begin
include("chainrules.jl")
# include("chainrules.jl")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget to

Suggested change
# include("chainrules.jl")
include("chainrules.jl")

@oxinabox
Copy link
Member

oxinabox commented May 10, 2021

Around here #366 (comment) you discuss explicitly testing return === nothing. This inserts function calls which ought to be trivial into the hot code, and might cause inference to give up. The present implementation avoids this.

Yeah, i think this is the difference.
That proposal was inserting a "runtime" check that should optimize away, this is better since it is running at compile-time even without the optimizer.
Core.return_type can't be trusted in general (since it is allowed to return any thing as long as the true returned value is <: it), but i think can be trusted for methods that are simple.

Further down, you discuss what happens if rules change. I think this PR requires Zygote.refresh() in order to see that an rrule now returns nothing, or now doesn't.

This is a bug in Zygote right now, that needs to be tracked down
#718

Could this be improved? I'm not sure I follow what return false, m.instance is doing, but it's doing it for methods without a rule.

It is attaching a backedge to the method instance it hit (not the method, but the exact specialization for that input), so that if it gets invalidated this code (in theory) get invalidated too.
Before it was the chainrules_fallback now it will be whatever is hit instead that returned nothing.
It should still work (but probably won't work since something was going wrong without this, thus #718)

I don't see the benchmarks mentioned.

they are here
#366 (comment)

@mcabbott
Copy link
Member Author

Core.return_type can't be trusted in general (since it is allowed to return any thing as long as the true returned value is <: it), but i think can be trusted for methods that are simple.

Agree. Perhaps the documentation ought to stress that rrule(::typeof(f), ::MyType) = nothing is fine, but trying to do istriu(x) || return nothing is not something that's intended to work.

It also might be worth adding a friendly error, since no method matching iterate(::Nothing) isn't so helpful. This should be cheap, right? Just checking === nothing within the function chain_rrule, no extra calls.

This is a bug in Zygote right now, that needs to be tracked down #718

OK, that's good to hear. It is possible that this case was behaving a little worse than other rules, for me, but it is also possible that I got confused between all the branches & restarts & @evaling stuff.

@oxinabox
Copy link
Member

oxinabox commented May 14, 2021

It is weird that we need more refresh now.

Anyway, running those benchmarks would be nice still

@mcabbott
Copy link
Member Author

mcabbott commented May 14, 2021

Benchmarks from this post, #366:

julia> using Zygote

julia> gradient(sqrt, 1)
(0.5,)

julia> @time sin'(0.5);
  0.000000 seconds    # Zygote v0.6.10
  0.000000 seconds    # this PR

julia> @btime sin'(0.5);
  1.583 ns (0 allocations: 0 bytes)    # tagged
  2.125 ns (0 allocations: 0 bytes)    # this PR -- noise, 1.5ns on other runs

julia> f(x) = 4x^3 + 3x^2 + 2x + 1;  # typo: original had a y

julia> @time f'(π);
  0.000020 seconds (1 allocation: 16 bytes)    # tagged
  0.000012 seconds (1 allocation: 16 bytes)    # this PR

julia> @btime f'(π);
  13.096 ns (1 allocation: 16 bytes)    # tagged
  10.052 ns (1 allocation: 16 bytes)    # this PR

julia> function foo_inner(x)
           a = max(x, x^2)*x
           if a > 0
               return prod([a, 2a, 3a])
           else
               return sum([a, 2a, 3a])
           end
       end;

julia> foo(x) = foo_inner(x) + foo_inner(-x);

julia> @time foo'(0.5);
  0.093714 seconds (354.55 k allocations: 22.030 MiB, 98.03% compilation time)
  0.093125 seconds (351.82 k allocations: 21.836 MiB, 98.13% compilation time)    # this PR

julia> @btime foo'(0.5);
  14.042 μs (154 allocations: 5.25 KiB)
  13.666 μs (154 allocations: 5.11 KiB)    # this PR

julia> @time sin''(0.5);
  0.919726 seconds (2.62 M allocations: 157.897 MiB, 3.20% gc time, 99.51% compilation time)
  0.691970 seconds (2.49 M allocations: 149.152 MiB, 3.20% gc time, 99.47% compilation time)    # this PR

julia> @time sin''(0.5);
  0.000097 seconds (154 allocations: 36.391 KiB)
  0.000092 seconds (142 allocations: 15.031 KiB)    # this PR

julia> @time sin'''(0.5);
 98.225836 seconds (93.46 M allocations: 5.276 GiB, 1.39% gc time, 96.83% compilation time)
 54.666285 seconds (80.86 M allocations: 4.510 GiB, 1.86% gc time, 95.29% compilation time)    # this PR

julia> @time sin'''(0.5);
  0.003831 seconds (5.99 k allocations: 3.511 MiB)
  0.003073 seconds (5.42 k allocations: 1.709 MiB)    # this PR

julia> versioninfo()  # official 1.6, rosetta
Julia Version 1.6.0
Commit f9720dc2eb (2021-03-24 12:55 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin19.6.0)
  CPU: Apple M1

(To also run the version before ChainRules, I think you need Zygote@0.4.20, NNlib@0.6.6, and I can't seem to make those work locally, Julia 1.4 or 1.6.)

Second run, same steps after re-starting Julia, but only showing a few:

julia> @time sin''(0.5);
  0.908306 seconds (2.65 M allocations: 159.441 MiB, 2.37% gc time, 99.52% compilation time)    # Zygote v0.6.10
  0.686682 seconds (2.49 M allocations: 149.080 MiB, 2.85% gc time, 99.46% compilation time)    # this PR

julia> @time sin''(0.5);
  0.000110 seconds (154 allocations: 36.391 KiB)
  0.000097 seconds (142 allocations: 15.031 KiB)    # this PR

julia> @time sin'''(0.5);
 93.237707 seconds (93.44 M allocations: 5.275 GiB, 1.42% gc time, 97.09% compilation time)
 50.542420 seconds (80.91 M allocations: 4.511 GiB, 2.00% gc time, 95.27% compilation time)    # this PR

julia> @time sin'''(0.5);
  0.003729 seconds (5.99 k allocations: 3.511 MiB)
  0.003005 seconds (5.42 k allocations: 1.709 MiB)    # this PR

julia> @btime sin'''(0.5); # was @time above, but I suspect @btime was intended
  1.907 ms (5989 allocations: 3.51 MiB)
  1.341 ms (5422 allocations: 1.71 MiB)    # this PR

So I think the speedup for sin'''(0.5) is real. While welcome it's slightly disturbing, as what this is doing ought not to have changed; I wonder if this is related to the need for refresh(). I see that I did not remove the Core.println statement in the new branch when running this, and nothing was printed.

Edit -- after d015df3 this is back to the first (slower) set of times above. The second (faster) set were from 388bd0f which needed many refresh statements.

@mcabbott mcabbott marked this pull request as ready for review June 5, 2021 18:49
@mcabbott
Copy link
Member Author

Status here is that the change not to use if m.method !== chainrules_fallback seems to have fixed the need for Zygote.refresh(). And also reverted the speedup seen at first in the above benchmarks. This still seems pretty spooky, but CI at least seems to think the PR has no downsides.

Perhaps a little more testing by hand would not be a bad idea. I mean to check that new rules do take effect as you'd expect.

@oxinabox
Copy link
Member

Status here is that the change not to use if m.method !== chainrules_fallback

hmmm, in #990 doesn't use the if m.method !== chainrules_fallback and it does seem to still need the Zygote.refresh()
I am going to be investigating more closely tomorrow though.

@oxinabox
Copy link
Member

@vtjnash says that #990 (comment)

Calling Core.Compiler.return_type inside @generated should be entirely prohibited, as it also risks crashing the compiler...

So I guess we can't do this

@mcabbott
Copy link
Member Author

Replaced by #1035, right?

@mcabbott mcabbott closed this Jul 27, 2021
@oxinabox
Copy link
Member

Right, it would have been nice to do it this way instead, but it has proved impossible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants