Skip to content

Commit

Permalink
Merge 4ed0163 into b1daa7a
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 19, 2022
2 parents b1daa7a + 4ed0163 commit 874fa06
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 20 deletions.
3 changes: 3 additions & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ using Statistics
# to the normal rule of only overload via `ChainRulesCore.rrule`.
import ChainRulesCore: rrule, frule

# Experimental:
import ChainRulesCore: derivatives_given_output

# numbers that we know commute under multiplication
const CommutativeMulNumber = Union{Real,Complex}

Expand Down
2 changes: 2 additions & 0 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ function rrule(::typeof(identity), x)
return (x, identity_pullback)
end

derivatives_given_output(Ω, ::typeof(identity), x) = tuple(tuple(true))

# rouding related,
# we use `zero` rather than `ZeroTangent()` for scalar, and avoids issues with map etc
@scalar_rule round(x) zero(x)
Expand Down
5 changes: 5 additions & 0 deletions src/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ let
return Ω, abs_pullback
end

function derivatives_given_output(Ω, ::typeof(abs), x::Union{Real, Complex})
signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω)
return tuple(tuple(signx))
end

## abs2
function frule((_, Δz), ::typeof(abs2), z::Union{Real, Complex})
return abs2(z), 2 * realdot(z, Δz)
Expand Down
71 changes: 52 additions & 19 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,30 +72,59 @@ function rrule(
end

function rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f, xs::AbstractArray; dims=:
)
fx_and_pullbacks = map(x->rrule_via_ad(config, f, x), xs)
y = sum(first, fx_and_pullbacks; dims=dims)
config::RuleConfig{>:HasReverseMode},
::typeof(sum),
f::F,
xs::AbstractArray{T};
dims = :,
) where {F,T}
project = ProjectTo(xs)

pullbacks = last.(fx_and_pullbacks)
if _uses_input_only(f, T)
# Then we can compute the forward pass as usual, save nothing but `xs`:
function sum_pullback_f1(dy)
dxs = broadcast(unthunk(dy), xs) do dyₖ, xᵢ
∂yₖ∂xᵢ = only(only(derivatives_given_output(nothing, f, xᵢ)))
dyₖ * conj(∂yₖ∂xᵢ)
end
return (NoTangent(), NoTangent(), project(dxs))
end
return sum(f, xs; dims), sum_pullback_f1
end

project = ProjectTo(xs)
# In the general case, we need to save all the pullbacks:
fx_and_pullbacks = map(xᵢ -> rrule_via_ad(config, f, xᵢ), xs)
y = sum(first, fx_and_pullbacks; dims)

function sum_pullback_f2(dy)
# For arrays of arrays, we ought to protect the element against broadcasting:
broadcast_dy = dims isa Colon ? Ref(unthunk(dy)) : unthunk(dy)
if Base.issingletontype(F)
# Then at least `f` has no gradient. Note that broadcasting here
# gets the shape right with or without `dims` keyword.
dxs = broadcast(fx_and_pullbacks, broadcast_dy) do (_, pbᵢ), dyₖ
unthunk(last(pbᵢ(dyₖ)))
end
return (NoTangent(), NoTangent(), project(dxs))

function sum_pullback(ȳ)
call(f, x) = f(x)
# if dims is :, then need only left-handed only broadcast
broadcast_ȳ = dims isa Colon ? (ȳ,) :
f̄_and_x̄s = call.(pullbacks, broadcast_ȳ)
# no point thunking as most of work is in f̄_and_x̄s which we need to compute for both
= if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f
NoTangent()
else
sum(first, f̄_and_x̄s)
# Most general case. If `f` were stateful, we would need to reverse the order
# of iteration here, but since this function makes no guarantee, even the primal
# result is then ill-defined.
df_and_dxs = broadcast(fx_and_pullbacks, broadcast_dy) do (_, pbᵢ), dyₖ
map(unthunk, pbᵢ(dyₖ))
end
return (NoTangent(), sum(first, df_and_dxs), project(map(last, df_and_dxs)))
end
x̄s = map(unthunk last, f̄_and_x̄s) # project does not support receiving InplaceableThunks
return NoTangent(), f̄, project(x̄s)
end
return y, sum_pullback
return y, sum_pullback_f2
end

function _uses_input_only(f::F, ::Type{xT}) where {F,xT}
gT = Core.Compiler._return_type(derivatives_given_output, Tuple{Nothing, F, xT})
# Here we must check `<: Number`, to avoid this, the one rule which can return the `nothing`:
# ChainRules.derivatives_given_output("anything", exp, 1) == (("anything",),)
return isconcretetype(gT) && gT <: Tuple{Tuple{Number}}
end

# https://github.com/JuliaDiff/ChainRules.jl/issues/522
Expand Down Expand Up @@ -228,6 +257,7 @@ function ∇prod_dims(vald::Val{dims}, x, dy, y=prod(x; dims=dims)) where {dims}
∇prod_dims!(dx, vald, x, dy, y)
return dx
end
∇prod_dims(::Val{dims}, x, dy::AbstractZero, y=prod(x; dims=dims)) where {dims} = dy

function ∇prod_dims!(dx, ::Val{dims}, x, dy, y) where {dims}
iters = ntuple(d -> d in dims ? tuple(:) : axes(x,d), ndims(x)) # Without Val(dims) this is a serious type instability
Expand All @@ -244,6 +274,7 @@ function ∇prod(x, dy::Number=1, y::Number=prod(x))
∇prod!(dx, x, dy, y)
return dx
end
∇prod(x, dy::AbstractZero, y::Number=prod(x)) = dy

function ∇prod!(dx, x, dy::Number=1, y::Number=prod(x))
numzero = iszero(y) ? count(iszero, x) : 0
Expand Down Expand Up @@ -326,7 +357,8 @@ function ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy=fill!(zero(x),1), y
dx = fill!(similar(x, T, axes(x)), zero(T))
∇cumprod_dim!(dx, vald, x, dy, y)
return dx
end
end
∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy::AbstractZero, y=cumprod(x; dims=dim)) where {dim} = dy

@inline function ∇cumprod_dim!(dx::AbstractArray, ::Val{dim}, x::AbstractArray, dy, y) where {dim}
iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x))
Expand All @@ -342,6 +374,7 @@ function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x))
∇cumprod!(dx, x, dy, y)
return dx
end
∇cumprod(x::AbstractVector, dy::AbstractZero, y=cumprod(x)) = dy

@inline function ∇cumprod!(dx::AbstractVector, x::AbstractVector, dy, y)
lo, hi = firstindex(x), lastindex(x)
Expand Down
18 changes: 17 additions & 1 deletion test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
@testset "sum(f, xs)" begin
# This calls back into AD
test_rrule(sum, abs, [-4.0, 2.0, 2.0])
test_rrule(sum, log, rand(3, 4) .+ 1)
test_rrule(sum, cbrt, randn(5))
test_rrule(sum, Multiplier(2.0), [2.0, 4.0, 8.0])

# Complex numbers
test_rrule(sum, log, rand(ComplexF64, 5))
test_rrule(sum, sqrt, rand(ComplexF64, 5))
test_rrule(sum, abs, rand(ComplexF64, 3, 4)) # complex -> real

Expand All @@ -82,6 +84,12 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()

test_rrule(sum, abs, @SVector[1.0, -3.0])

# Make sure the above test both `derivatives_given_output` path and general case:
@test ChainRules._uses_input_only(abs, Float32)
@test !ChainRules._uses_input_only(cbrt, Float64)
@test ChainRules._uses_input_only(log, ComplexF64)
@test !ChainRules._uses_input_only(abs, ComplexF64)

# covectors
x = [-4.0 2.0; 2.0 -1.0]
test_rrule(sum, inv, x[1, :]')
Expand All @@ -102,14 +110,22 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
# ... and Bool produced by function
@test_skip test_rrule(sum, iszero, randn(5)) # DimensionMismatch("second dimension of A, 1, does not match length of x, 0")


# Functions that return a Vector
# see https://github.com/FluxML/Zygote.jl/issues/1074
test_rrule(sum, make_two_vec, [1.0, 3.0, 5.0, 7.0])
test_rrule(sum, make_two_vec, [1.0 2.0; 3.0 4.0])
test_rrule(sum, make_two_vec, [1.0 2.0; 3.0 4.0]; fkwargs=(;dims=2))
test_rrule(sum, make_two_vec, [1.0 2.0; 3.0 4.0]; fkwargs=(;dims=1))
test_rrule(sum, make_two_vec, [1.0 2.0; 3.0 4.0]; fkwargs=(;dims=(3, 4)))

# arrays of arrays, functions which return a scalar:
test_rrule(sum, sum, [[1,2], [3,4], [5,6]]; check_inferred=false)
x2345 = [rand(2,3) for _ in 1:4, _ in 1:5]
test_rrule(sum, prod, x2345; check_inferred=false)
test_rrule(sum, sum, x2345; fkwargs=(;dims=1), check_inferred=false)
test_rrule(sum, sum, x2345; fkwargs=(;dims=(1,2)), check_inferred=false)

test_rrule(sum, cumprod, [[1,2], [3,4], [5,6]]; check_inferred=false)
end

# https://github.com/JuliaDiff/ChainRules.jl/issues/522
Expand Down

0 comments on commit 874fa06

Please sign in to comment.