From 5d391d0dbdf5ab5818d5c1e2b95e23ce9e4c1a1b Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Thu, 30 May 2019 01:20:46 -0700 Subject: [PATCH] Make Rules store a second function for updating Now `Rule`s store two things: one callable thing, which is used for evaluating the rule, and one other thing, which can be `nothing` or a function with the signature `u(value, args...)` that evaluates the rule for the given arguments and adds the result to `value`, doing so in place when possible. The advantage of this is that we can more easily define custom ways of accumulating the results of rules, and we can even share intermediate steps between the regular rule evaluation and custom updating. As a test/proof of concept, the `rrule` for `svd` now also has an updating function that handles `NamedTuple`s appropriately. --- src/rules.jl | 43 ++++++++++++++++++++++++++++-- src/rules/linalg/factorization.jl | 11 +++++--- test/rules.jl | 23 ++++++++++++++++ test/rules/linalg/factorization.jl | 16 ++++++++++- 4 files changed, 86 insertions(+), 7 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 25eb84bc1..c8e36d1ca 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -108,6 +108,32 @@ See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref) """ store!(Δ, rule::AbstractRule, args...) = materialize!(Δ, broadcastable(rule(args...))) +# Special purpose updating for operations which can be done in-place. This function is +# just internal and free-form; it is not a method of `accumulate!` directly as it does +# not adhere to the expected method signature form, i.e. `accumulate!(value, rule, args)`. +# Instead it's `_update!(old, new, extrastuff...)` and is not specific to any particular +# rule. + +_update!(x, y) = x + y +_update!(x::Array{T,N}, y::AbstractArray{T,N}) where {T,N} = x .+= y + +_update!(x, ::Zero) = x +_update!(::Zero, y) = y +_update!(::Zero, ::Zero) = Zero() + +function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}) where Ns + return NamedTuple{Ns}(map(p->_update!(getproperty(x, p), getproperty(y, p)), Ns)) +end + +function _update!(x::NamedTuple, y, p::Symbol) + new = NamedTuple{(p,)}((_update!(getproperty(x, p), y),)) + return merge(x, new) +end + +function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}, p::Symbol) where Ns + return _update!(x, getproperty(y, p), p) +end + ##### ##### `Rule` ##### @@ -123,7 +149,7 @@ Cassette.overdub(::RuleContext, ::typeof(add), a, b) = add(a, b) Cassette.overdub(::RuleContext, ::typeof(mul), a, b) = mul(a, b) """ - Rule(propation_function) + Rule(propation_function[, updating_function]) Return a `Rule` that wraps the given `propation_function`. It is assumed that `propation_function` is a callable object whose arguments are differential @@ -131,6 +157,10 @@ values, and whose output is a single differential value calculated by applying internally stored/computed partial derivatives to the input differential values. +If an updating function is provided, it is assumed to have the signature `u(Δ, xs...)` +and to store the result of the propagation function applied to the arguments `xs` into +`Δ` in-place, returning `Δ`. + For example: ``` @@ -141,12 +171,21 @@ rrule(::typeof(*), x, y) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * Δ See also: [`frule`](@ref), [`rrule`](@ref), [`accumulate`](@ref), [`accumulate!`](@ref), [`store!`](@ref) """ -struct Rule{F} <: AbstractRule +struct Rule{F,U<:Union{Function,Nothing}} <: AbstractRule f::F + u::U end +# NOTE: Using `Core.Typeof` instead of `typeof` here so that if we define a rule for some +# constructor based on a `UnionAll`, we get `Rule{Type{Thing}}` instead of `Rule{UnionAll}` +Rule(f) = Rule{Core.Typeof(f),Nothing}(f, nothing) + (rule::Rule{F})(args...) where {F} = Cassette.overdub(RULE_CONTEXT, rule.f, args...) +# Specialized accumulation +# TODO: Does this need to be overdubbed in the rule context? +accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...) + ##### ##### `DNERule` ##### diff --git a/src/rules/linalg/factorization.jl b/src/rules/linalg/factorization.jl index 92b816da8..46bb388ef 100644 --- a/src/rules/linalg/factorization.jl +++ b/src/rules/linalg/factorization.jl @@ -12,14 +12,17 @@ end function rrule(::typeof(getproperty), F::SVD, x::Symbol) if x === :U - return F.U, (Rule(Ȳ->(U=Ȳ, S=zero(F.S), V=zero(F.V))), DNERule()) + rule = Ȳ->(U=Ȳ, S=zero(F.S), V=zero(F.V)) elseif x === :S - return F.S, (Rule(Ȳ->(U=zero(F.U), S=Ȳ, V=zero(F.V))), DNERule()) + rule = Ȳ->(U=zero(F.U), S=Ȳ, V=zero(F.V)) elseif x === :V - return F.V, (Rule(Ȳ->(U=zero(F.U), S=zero(F.S), V=Ȳ)), DNERule()) + rule = Ȳ->(U=zero(F.U), S=zero(F.S), V=Ȳ) elseif x === :Vt - return F.Vt, (Rule(Ȳ->(U=zero(F.U), S=zero(F.S), V=Ȳ')), DNERule()) + # TODO: This could be made to work, but it'd be a pain + throw(ArgumentError("Vt is unsupported; use V and transpose the result")) end + update = (X̄::NamedTuple{(:U,:S,:V)}, Ȳ)->_update!(X̄, rule(Ȳ), x) + return getproperty(F, x), (Rule(rule, update), DNERule()) end function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix) diff --git a/test/rules.jl b/test/rules.jl index e7dcd0b03..9f52550a3 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -21,4 +21,27 @@ cool(x) = x + 1 end @test i == 1 # rules only iterate once, yielding themselves end + @testset "helper functions" begin + # Hits fallback, since we can't update `Diagonal`s in place + X = Diagonal([1, 1]) + Y = copy(X) + @test ChainRules._update!(X, [1 2; 3 4]) == [2 2; 3 5] + @test X == Y # no change to X + + X = [1 2; 3 4] + Y = copy(X) + @test ChainRules._update!(X, Diagonal([1, 1])) == [2 2; 3 5] + @test X != Y # X has been updated + + # Reusing above X + @test ChainRules._update!(X, Zero()) === X + @test ChainRules._update!(Zero(), X) === X + @test ChainRules._update!(Zero(), Zero()) === Zero() + + X = (A=[1 0; 0 1], B=[2 2; 2 2]) + Y = deepcopy(X) + @test ChainRules._update!(X, Y) == (A=[2 0; 0 2], B=[4 4; 4 4]) + @test X.A != Y.A + @test X.B != Y.B + end end diff --git a/test/rules/linalg/factorization.jl b/test/rules/linalg/factorization.jl index 56acd0962..782144bd8 100644 --- a/test/rules/linalg/factorization.jl +++ b/test/rules/linalg/factorization.jl @@ -4,7 +4,7 @@ for n in [4, 6, 10], m in [3, 5, 10] X = randn(rng, n, m) F, dX = rrule(svd, X) - for p in [:U, :S, :V, :Vt] + for p in [:U, :S, :V] Y, (dF, dp) = rrule(getproperty, F, p) @test dp isa ChainRules.DNERule Ȳ = randn(rng, size(Y)...) @@ -12,6 +12,20 @@ X̄_fd = j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X) @test X̄_ad ≈ X̄_fd rtol=1e-6 atol=1e-6 end + @test_throws ArgumentError rrule(getproperty, F, :Vt) + end + @testset "accumulate!" begin + X = [1.0 2.0; 3.0 4.0; 5.0 6.0] + F, dX = rrule(svd, X) + X̄ = (U=zeros(3, 2), S=zeros(2), V=zeros(2, 2)) + for p in [:U, :S, :V] + Y, (dF, _) = rrule(getproperty, F, p) + Ȳ = ones(size(Y)...) + ChainRules.accumulate!(X̄, dF, Ȳ) + end + @test X̄.U ≈ ones(3, 2) atol=1e-6 + @test X̄.S ≈ ones(2) atol=1e-6 + @test X̄.V ≈ ones(2, 2) atol=1e-6 end @testset "Helper functions" begin X = randn(rng, 10, 10)