Skip to content

Commit

Permalink
Make Rules store a second function for updating
Browse files Browse the repository at this point in the history
Now `Rule`s store two things: one callable thing, which is used for
evaluating the rule, and one other thing, which can be `nothing` or a
function with the signature `u(value, args...)` that evaluates the rule
for the given arguments and adds the result to `value`, doing so in
place when possible.

The advantage of this is that we can more easily define custom ways of
accumulating the results of rules, and we can even share intermediate
steps between the regular rule evaluation and custom updating. As a
test/proof of concept, the `rrule` for `svd` now also has an updating
function that handles `NamedTuple`s appropriately.
  • Loading branch information
ararslan committed May 30, 2019
1 parent a2e6451 commit 0056ab9
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 7 deletions.
39 changes: 37 additions & 2 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,28 @@ See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref)
"""
store!(Δ, rule::AbstractRule, args...) = materialize!(Δ, broadcastable(rule(args...)))

# Special purpose updating for operations which can be done in-place

_update!(x, y) = x + y
_update!(x::Array{T,N}, y::AbstractArray{T,N}) where {T,N} = x .+= y

_update!(x, ::Zero) = x
_update!(::Zero, y) = y
_update!(::Zero, ::Zero) = Zero()

function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}) where Ns
return NamedTuple{Ns}(map(p->_update!(getproperty(x, p), getproperty(y, p)), Ns))
end

function _update!(x::NamedTuple, y, p::Symbol)
new = NamedTuple{(p,)}((_update!(getproperty(x, p), y),))
return merge(x, new)
end

function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}, p::Symbol) where Ns
return _update!(x, getproperty(y, p), p)
end

#####
##### `Rule`
#####
Expand All @@ -123,14 +145,18 @@ Cassette.overdub(::RuleContext, ::typeof(add), a, b) = add(a, b)
Cassette.overdub(::RuleContext, ::typeof(mul), a, b) = mul(a, b)

"""
Rule(propation_function)
Rule(propation_function[, updating_function])
Return a `Rule` that wraps the given `propation_function`. It is assumed that
`propation_function` is a callable object whose arguments are differential
values, and whose output is a single differential value calculated by applying
internally stored/computed partial derivatives to the input differential
values.
If an updating function is provided, it is assumed to have the signature `u(y, xs...)`
and to store the result of the propagation function applied to the arguments `xs` into
`y` in-place, returning `y`.
For example:
```
Expand All @@ -141,12 +167,21 @@ rrule(::typeof(*), x, y) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * Δ
See also: [`frule`](@ref), [`rrule`](@ref), [`accumulate`](@ref), [`accumulate!`](@ref), [`store!`](@ref)
"""
struct Rule{F} <: AbstractRule
struct Rule{F,U<:Union{Function,Nothing}} <: AbstractRule
f::F
u::U
end

# NOTE: Using `Core.Typeof` instead of `typeof` here so that if we define a rule for some
# constructor based on a `UnionAll`, we get `Rule{Type{Thing}}` instead of `Rule{UnionAll}`
Rule(f) = Rule{Core.Typeof(f),Nothing}(f, nothing)

(rule::Rule{F})(args...) where {F} = Cassette.overdub(RULE_CONTEXT, rule.f, args...)

# Specialized accumulation
# TODO: Does this need to be overdubbed in the rule context?
accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...)

#####
##### `DNERule`
#####
Expand Down
11 changes: 7 additions & 4 deletions src/rules/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@ end

function rrule(::typeof(getproperty), F::SVD, x::Symbol)
if x === :U
return F.U, (Rule(->(U=Ȳ, S=zero(F.S), V=zero(F.V))), DNERule())
rule = ->(U=Ȳ, S=zero(F.S), V=zero(F.V))
elseif x === :S
return F.S, (Rule(->(U=zero(F.U), S=Ȳ, V=zero(F.V))), DNERule())
rule = ->(U=zero(F.U), S=Ȳ, V=zero(F.V))
elseif x === :V
return F.V, (Rule(->(U=zero(F.U), S=zero(F.S), V=)), DNERule())
rule = ->(U=zero(F.U), S=zero(F.S), V=Ȳ)
elseif x === :Vt
return F.Vt, (Rule(Ȳ->(U=zero(F.U), S=zero(F.S), V=')), DNERule())
# TODO: This could be made to work, but it'd be a pain
throw(ArgumentError("Vt is unsupported; use V and transpose the result"))
end
update = (X̄::NamedTuple{(:U,:S,:V)}, Ȳ)->_update!(X̄, rule(Ȳ), x)
return getproperty(F, x), (Rule(rule, update), DNERule())
end

function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix)
Expand Down
23 changes: 23 additions & 0 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,27 @@ cool(x) = x + 1
end
@test i == 1 # rules only iterate once, yielding themselves
end
@testset "helper functions" begin
# Hits fallback, since we can't update `Diagonal`s in place
X = Diagonal([1, 1])
Y = copy(X)
@test ChainRules._update!(X, [1 2; 3 4]) == [2 2; 3 5]
@test X == Y # no change to X

X = [1 2; 3 4]
Y = copy(X)
@test ChainRules._update!(X, Diagonal([1, 1])) == [2 2; 3 5]
@test X != Y # X has been updated

# Reusing above X
@test ChainRules._update!(X, Zero()) === X
@test ChainRules._update!(Zero(), X) === X
@test ChainRules._update!(Zero(), Zero()) === Zero()

X = (A=[1 0; 0 1], B=[2 2; 2 2])
Y = deepcopy(X)
@test ChainRules._update!(X, Y) == (A=[2 0; 0 2], B=[4 4; 4 4])
@test X.A != Y.A
@test X.B != Y.B
end
end
16 changes: 15 additions & 1 deletion test/rules/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,28 @@
for n in [4, 6, 10], m in [3, 5, 10]
X = randn(rng, n, m)
F, dX = rrule(svd, X)
for p in [:U, :S, :V, :Vt]
for p in [:U, :S, :V]
Y, (dF, dp) = rrule(getproperty, F, p)
@test dp isa ChainRules.DNERule
= randn(rng, size(Y)...)
X̄_ad = dX(dF(Ȳ))
X̄_fd = j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X)
@test X̄_ad X̄_fd rtol=1e-6 atol=1e-6
end
@test_throws ArgumentError rrule(getproperty, F, :Vt)
end
@testset "accumulate!" begin
X = [1.0 2.0; 3.0 4.0; 5.0 6.0]
F, dX = rrule(svd, X)
= (U=zeros(3, 2), S=zeros(2), V=zeros(2, 2))
for p in [:U, :S, :V]
Y, (dF, _) = rrule(getproperty, F, p)
= ones(size(Y)...)
ChainRules.accumulate!(X̄, dF, Ȳ)
end
@test.U ones(3, 2) atol=1e-6
@test.S ones(2) atol=1e-6
@test.V ones(2, 2) atol=1e-6
end
@testset "Helper functions" begin
X = randn(rng, 10, 10)
Expand Down

0 comments on commit 0056ab9

Please sign in to comment.