Skip to content

Commit

Permalink
Merge 1e0ee69 into b1daa7a
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 19, 2022
2 parents b1daa7a + 1e0ee69 commit 49ef7e2
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 1 deletion.
47 changes: 47 additions & 0 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,53 @@ function rrule(::typeof(cumsum), x::AbstractArray; dims::Integer)
end
rrule(::typeof(cumsum), x::AbstractVector) = rrule(cumsum, x; dims=1)

#####
##### `maximum`, `minimum`
#####

for mimum in (:minimum, :maximum)
findm = Symbol(:find, string(mimum)[1:3])

@eval function rrule(
config::RuleConfig{>:HasReverseMode},
::typeof($mimum),
f::F,
xs::AbstractArray{<:Number};
dims = :,
) where {F}
project = ProjectTo(xs)
if dims isa Colon && VERSION >= v"1.7"
# The easy case is when we can use `findmax` to get index, and write into it:
y, ind = $findm(f, xs)
function minormax_f_back1(dy)
# Notice this evaluates `f` one more time, but this shouldn't matter unless `f` is
# sateful, in which case both this and `maximum(f.(xs))` give uncertain results.
_, one_back = rrule_via_ad(config, f, xs[ind])
df, one_dx = one_back(unthunk(dy))
x_thunk = @thunk project(_zerolike_writeat(xs, unthunk(one_dx), dims, ind))
x_ithunk = InplaceableThunk(x_thunk) do dxs
view(dxs, ind) .+= unthunk(one_dx) # TODO make _zerolike_writeat handle thunks
dxs
end
return (NoTangent(), df, x_ithunk)
end
return y, minormax_f_back1

else
# Otherwise, the best path is to broadcast, `maximum(f.(xs); dims)`:
fxs, cast_back = rrule_via_ad(config, broadcast, f, xs)
y, mm_back = rrule($mimum, fxs; dims)
function minormax_f_back2(dy)
_, dmid = mm_back(dy)
_, df, dxs = cast_back(dmid)
return (NoTangent(), df, project(dxs))
end
return y, minormax_f_back2
end

end # @eval function rrule(...)
end

#####
##### `prod`
#####
Expand Down
21 changes: 20 additions & 1 deletion test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
# This calls back into AD
test_rrule(sum, abs, [-4.0, 2.0, 2.0])
test_rrule(sum, cbrt, randn(5))
test_rrule(sum, Multiplier(2.0), [2.0, 4.0, 8.0])
test_rrule(sum, Multiplier(2.0), [2.0, 4.0, 8.0]) # Multiplier defined in test_helpers.jl

# Complex numbers
test_rrule(sum, sqrt, rand(ComplexF64, 5))
Expand Down Expand Up @@ -120,6 +120,25 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
@test rrule(SumRuleConfig(), Base.sum, xs, weights) isa Nothing
end

@testset "maximum(f, xs)" begin
test_rrule(maximum, abs, [-4.0, 2.0, 2.0])
test_rrule(minimum, sqrt, Float64[1 2; 3 4])
test_rrule(maximum, Multiplier(2.0), [2.0, 4.0, 8.0]) # Multiplier defined in test_helpers.jl

# repeated -- can't use FiniteDifferences
y1, bk1 = rrule(CFG, maximum, abs, [-4.0, 2.0, 4.0, 2.0])
@test y1 === 4.0
@test unthunk(bk1(10.0)[3]) == [-10, 0, 0, 0]

# dims keyword -- these need to call `rrule_via_ad(broadcast, ...`, which needs AD
@test_skip test_rrule(maximum, sqrt, Float64[1 2; 3 4], fkwargs=(; dims = 1), check_inferred=false)
@test_skip test_rrule(minimum, abs, randn(3,3), fkwargs=(; dims = 2), check_inferred=false)

@test_skip y2, bk2 = rrule(CFG, minimum, abs, [1 2 3; -5 -4 -4], dims = 2)
# @test y2 == hcat([1, 4])
# @test unthunk(bk2(hcat([10, 20]))[3]) == [10 0 0; 0 -20 0]
end

@testset "prod" begin
@testset "Array{$T}" for T in [Float64, ComplexF64]
@testset "size = $sz, dims = $dims" for (sz, dims) in [
Expand Down

0 comments on commit 49ef7e2

Please sign in to comment.