Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configured rule for maximum(f, xs) #490

Closed
wants to merge 3 commits into from
Closed

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Aug 1, 2021

This uses the RuleConfig{>:HasReverseMode} story to call back into AD to write a rule for maximum(f, xs).

It's much simplified from the first attempt:

  • On julia 1.7+, for a total reduction, it calls i = findmax(f, xs), and then uses rrule_via_ad(f, xs[i]).
  • Otherwise, it just calls broadcasting.

Fast case, before & after:

julia> @btime gradient(x -> sum(maximum(sqrt, x)), $(rand(30,30)));
  min 2.908 ms, mean 4.031 ms (21816 allocations, 9.29 MiB)  # before
  min 14.875 μs, mean 17.546 μs (52 allocations, 8.92 KiB)  # after

julia> @btime gradient(x -> sum(maximum(sqrt.(x))), $(rand(30,30)));
  min 17.500 μs, mean 25.453 μs (46 allocations, 36.83 KiB)  # just broadcasting, to compare

Before this PR, gradient(x -> sum(maximum(sqrt, x, dims=1)), (rand(30,30))) gives an error with Zygote. After, it is the same speed as broadcasting.

What doesn't seem easy now is testing the broadcast path.

First attempt

However, it only needs one such call, rather than one for every element. That means it ends up calling f say N^2 + 1 times for a matrix (or N^2 + N with dims). This is much more efficient than calling it via AD all N^2 times, saving the pullbacks somewhere, and calling just one. Not always faster than Zygote's current broadcasting (which uses ForwardDiff), but much less memory:

julia> @btime gradient(x -> sum(maximum(sqrt, x)), $(rand(30,30)));
  9.625 μs (73 allocations: 9.11 KiB)   # this PR
  9.333 μs (66 allocations: 8.95 KiB)   # this PR, with rrule instead of rrule_via_ad
  
julia> @btime gradient(x -> sum(maximum(sqrt, x, dims=1)), $(rand(30,30)));
  10.125 μs (34 allocations: 13.92 KiB)  # this PR, take 1
  15.208 μs (33 allocations: 29.31 KiB)  # this PR, with mask allowing multiple maxima
  17.166 μs (33 allocations: 29.31 KiB)  # with rrule instead of rrule_via_ad

julia> @btime gradient(x -> sum(maximum(sqrt.(x))), $(rand(30,30)));
  8.833 μs (48 allocations: 36.98 KiB)  # broadcasting with Duals

julia> @btime maximum(sqrt, $(rand(30,30)));  # forward pass
  1.438 μs (0 allocations: 0 bytes)

If this is OK, then perhaps the sum(f, x) rule from #441 should also consider calling f more times. There's a commit here doing that, which cuts the memory use by quite a bit. Perhaps there are functions f for which calling twice would be slower? Perhaps writing sum(f, x) vs. sum(f.(x)) is how you emphasise that you care more about memory? (It may make sense to remove this & discuss sum in another thread.) [Now removed here.]

julia> @btime gradient(x -> sum(sqrt, x), $(rand(30,30)));
  4.173 μs (16 allocations: 50.02 KiB)  # before
  1.954 μs (2 allocations: 7.20 KiB)    # after

julia> @btime gradient(x -> sum(sum(sqrt, x, dims=1)), $(rand(30,30)));
  10.625 μs (42 allocations: 51.47 KiB)  # before
  2.704 μs (18 allocations: 8.20 KiB)    # after

# Compare broadcasting:

julia> @btime gradient(x -> sum(sqrt.(x)), $(rand(30,30)));
  2.616 μs (10 allocations: 28.70 KiB)

julia> @btime gradient(x -> sum(sum(sqrt.(x), dims=1)), $(rand(30,30)));
  3.542 μs (26 allocations: 36.81 KiB)

# Forward only:

julia> @btime sum(sqrt, x) setup=(x=$(rand(30,30)));
  833.333 ns (0 allocations: 0 bytes)
  
julia> @btime sum(sqrt.(x)) setup=(x=$(rand(30,30)));
  873.544 ns (1 allocation: 7.19 KiB)

All WIP, needs more careful testing, etc.

@oxinabox oxinabox changed the title Callback rule for maximum(f, xs) Configured rule for maximum(f, xs) Aug 5, 2021
@mcabbott
Copy link
Member Author

mcabbott commented Aug 5, 2021

First attempt

With a more expensive function:

julia> @btime gradient(x -> sum(maximum(log∘exp, x)), $(rand(30,30)));
  34.791 μs (162 allocations: 11.11 KiB)

julia> @btime gradient(x -> sum(maximum(log∘exp, x, dims=1)), $(rand(30,30)));
  326.292 μs (2615 allocations: 87.55 KiB)

julia> @btime gradient(x -> sum(maximum((log∘exp).(x))), $(rand(30,30)));
  22.333 μs (48 allocations: 36.86 KiB)

julia> @btime gradient(x -> sum(maximum((log∘exp).(x), dims=1)), $(rand(30,30)));
  16.250 μs (13 allocations: 36.72 KiB)

# without AD:

julia> @btime maximum(log∘exp, $(rand(30,30)));
  13.000 μs (0 allocations: 0 bytes)

julia> @btime maximum(log∘exp, $(rand(30,30)), dims=1);
  15.500 μs (4 allocations: 416 bytes)

julia> @btime findmax(log∘exp, $(rand(30,30)));
  15.334 μs (0 allocations: 0 bytes)

The dims=1 case is very slow, because (1) it's taking a second complete (N^2) pass to find the indices at which this attains the maximum, since there is no findmax(sqrt, rand(3,3), dims=1), and (2) it needs N calls to rrule_via_ad, and this doesn't infer for log∘exp, like Zygote's generic broadcasting.

The broadcasted one uses dual numbers, which is much quicker. Note BTW that there is no chunk mode in play here -- it always evaluates f exactly 900 times.

I'm not so sure why the complete reduction is slower than broadcasting here, but it's much closer, and 3x less memory.

Diffractor, BTW, does not see this rule. It does see #480, but broadcast times are variable:

julia> @btime Diffractor.gradient(x -> maximum(sqrt, x), $(rand(30,30)));
ERROR: TypeError: in typeassert, expected Int64, got a value of type Nothing
...
  [8] (::Diffractor.∂⃖recurse{1})(::typeof(Base._mapreduce), ::typeof(sqrt), ::typeof(max), ::IndexLinear, ::Matrix{Float64})

julia> @btime gradient(x -> maximum(sqrt.(x)), $(rand(30,30)));
  11.417 μs (12 allocations: 64.33 KiB)  # Zygote 8.833 μs (48 allocations: 36.98 KiB)

julia> @btime gradient(x -> maximum((log∘exp).(x)), $(rand(30,30)));
  2.155 ms (17143 allocations: 586.41 KiB)  # Zygote 22.333 μs (48 allocations: 36.86 KiB)

@mcabbott
Copy link
Member Author

mcabbott commented Nov 24, 2021

This has been much simplified. For the case of a complete reduction only, maximum(f, x), this saves the position of the maximum, and calls rrule_via_ad(f, x[i]) once. This saves memory compared to broadcasting, but in the end not much time -- might still not be worth the complication:

julia> @btime gradient(x -> sum(maximum(sqrt, x)), $(rand(30,30)));  # this PR + Zygote + Julia 1.8
  min 8.625 μs, mean 10.906 μs (52 allocations, 8.92 KiB. GC mean 13.94%)

julia> @btime gradient(x -> sum(maximum(sqrt.(x))), $(rand(30,30)));
  min 10.041 μs, mean 16.087 μs (49 allocations, 36.88 KiB. GC mean 20.75%)

julia> @btime gradient(x -> sum(maximum(logexp, x)), $(rand(30,30)));  # with a more expensive function:
  min 20.208 μs, mean 22.335 μs (116 allocations, 10.88 KiB. GC mean 5.22%)

julia> @btime gradient(x -> sum(maximum((logexp).(x))), $(rand(30,30)));
  min 19.291 μs, mean 25.757 μs (49 allocations, 36.88 KiB. GC mean 13.03%)

julia> @btime maximum(logexp, $(rand(30,30)));
  min 8.958 μs, mean 9.128 μs (0 allocations)

That means it calls f in total N+1 times. If f is stateful, then as far as I know the result of maximum(f, x) is already ill-defined, no order is guaranteed. If f closes over something, that will get a gradient contribution only from one entry, should be fine.

Instead of using rrule_via_ad, this would be a good use case for derivatives_given_output when that's defined.

For cases with dims, it just calls broadcasting. Earlier commits tried to handle this, but it gets complicated, and the saving is less clear. This case is not so easy to test.

On Julia 1.6 and below, the method findmax(f, x) which the fast path needs doesn't exist, so it always calls broadcasting.

@mcabbott
Copy link
Member Author

Status here is as in (edited) first message above.

Perhaps the broadcast path can be easily tested using JuliaDiff/ChainRulesTestUtils.jl#243 once that's available.

Copy link
Member

@mzgubic mzgubic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few questions, generally looks good. Do you plan to extend the tests?

src/rulesets/Base/mapreduce.jl Outdated Show resolved Hide resolved
src/rulesets/Base/mapreduce.jl Outdated Show resolved Hide resolved
src/rulesets/Base/mapreduce.jl Outdated Show resolved Hide resolved
src/rulesets/Base/mapreduce.jl Outdated Show resolved Hide resolved
src/rulesets/Base/mapreduce.jl Show resolved Hide resolved
src/rulesets/Base/mapreduce.jl Show resolved Hide resolved
Comment on lines 134 to 135
@test_skip test_rrule(maximum, sqrt, Float64[1 2; 3 4], fkwargs=(; dims = 1), check_inferred=false)
@test_skip test_rrule(minimum, abs, randn(3,3), fkwargs=(; dims = 2), check_inferred=false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought these needed JuliaDiff/ChainRulesTestUtils.jl#243 : with dims it always calls broadcast.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, they do need JuliaDiff/ChainRulesTestUtils.jl#243 (now merged), but also JuliaDiff/FiniteDifferences.jl#203 to get around to_vecing InplaceableThunks correctly (tested locally)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But where do InplaceableThunks come from? This path of this rule doesn't make them.

I do still get an error with only CRTU update:

julia> test_rrule(maximum, sqrt, Float64[1 2; 3 4], fkwargs=(; dims = 1), check_inferred=false)
test_rrule: maximum on typeof(sqrt),Matrix{Float64}: Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/fCvaU/src/testers.jl:193
  Got exception outside of a @test
  DimensionMismatch("second dimension of A, 4, does not match length of x, 7")
  Stacktrace:
    [1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
      @ LinearAlgebra ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:493
    [2] mul!
      @ ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:93 [inlined]
    [3] mul!
      @ ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:276 [inlined]
    [4] *(tA::Transpose{Float64, Matrix{Float64}}, x::Vector{Float64})
      @ LinearAlgebra ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:86
    [5] _j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64})
      @ FiniteDifferences ~/.julia/packages/FiniteDifferences/R6uao/src/grad.jl:80
    [6] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::ChainRulesTestUtils.var"#fnew#45"{ChainRulesTestUtils.var"#call#41"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{typeof(broadcast), typeof(sqrt), Matrix{Float64}}, Tuple{Bool, Bool, Bool}}, ȳ::InplaceableThunk{Thunk{ChainRules.var"#1316#1319"{Matrix{Float64}, Int64, Matrix{Float64}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, Matrix{CartesianIndex{2}}}}, ChainRules.var"#1317#1320"{Matrix{Float64}, Int64, Matrix{CartesianIndex{2}}}}, x::Matrix{Float64})
      @ FiniteDifferences ~/.julia/packages/FiniteDifferences/R6uao/src/grad.jl:73
    [7] _make_j′vp_call(fdm::Any, f::Any, ȳ::Any, xs::Any, ignores::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/fCvaU/src/finite_difference_calls.jl:51
    [8] f_pb
      @ ~/.julia/packages/ChainRulesTestUtils/fCvaU/src/rule_config.jl:40 [inlined]
    [9] (::ChainRules.var"#minormax_f_back2#2098"{ChainRules.var"#maximum_pullback#1326"{ChainRules.var"#findmax_pullback#1318"{Int64
    ```

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved by the to_vec PR, as you said.

Can this thing give less cryptic errors than these "DimensionMismatch" when it goes wrong?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree with you in general: JuliaDiff/ChainRulesTestUtils.jl#244

Here though this is coming from rrule_via_ad using the make_v'jp_call rather than the usual place 😂

Solving JuliaDiff/ChainRulesTestUtils.jl#213 would be a big QoL improvement indeed. It's on my list

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JuliaDiff/FiniteDifferences.jl#203 is now merged, so I think we can update the tests

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

This one is weird locally, but on 1.6 it seems to work (or will once changed to ≈ [10 0 0; 0 -20 0]):

julia> y2, bk2 = rrule(CFG, minimum, abs, [1 2 3; -5 -4 -4], dims = 2);

julia> @test y2 == hcat([1, 4])
Test Passed
  Expression: y2 == hcat([1, 4])
   Evaluated: [1; 4;;] == [1; 4;;]

julia> bk2(hcat([10, 20]))
(NoTangent(), NoTangent(), NoTangent())

save less stuff in sum(f, xs) rule

probably destroyed in the rebase

re-organise

change to use BitArray

add a few tests

Revert "save less stuff in sum(f, xs) rule"

This reverts commit c8034da.

tidy, add cumsum trick

tests for multiple maxima

tweaks
fixup

update, tidy

Apply 3 suggestions

Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com>

add an error

remove error, as closing over `y` breaks inference

simplify, update

solve Core.Box

tests

approx
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants