Skip to content

Commit

Permalink
Merge 5d391d0 into a2e6451
Browse files Browse the repository at this point in the history
  • Loading branch information
ararslan committed Jun 6, 2019
2 parents a2e6451 + 5d391d0 commit 1b4b6ad
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 7 deletions.
43 changes: 41 additions & 2 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,32 @@ See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref)
"""
store!(Δ, rule::AbstractRule, args...) = materialize!(Δ, broadcastable(rule(args...)))

# Special purpose updating for operations which can be done in-place. This function is
# just internal and free-form; it is not a method of `accumulate!` directly as it does
# not adhere to the expected method signature form, i.e. `accumulate!(value, rule, args)`.
# Instead it's `_update!(old, new, extrastuff...)` and is not specific to any particular
# rule.

_update!(x, y) = x + y
_update!(x::Array{T,N}, y::AbstractArray{T,N}) where {T,N} = x .+= y

_update!(x, ::Zero) = x
_update!(::Zero, y) = y
_update!(::Zero, ::Zero) = Zero()

function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}) where Ns
return NamedTuple{Ns}(map(p->_update!(getproperty(x, p), getproperty(y, p)), Ns))
end

function _update!(x::NamedTuple, y, p::Symbol)
new = NamedTuple{(p,)}((_update!(getproperty(x, p), y),))
return merge(x, new)
end

function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}, p::Symbol) where Ns
return _update!(x, getproperty(y, p), p)
end

#####
##### `Rule`
#####
Expand All @@ -123,14 +149,18 @@ Cassette.overdub(::RuleContext, ::typeof(add), a, b) = add(a, b)
Cassette.overdub(::RuleContext, ::typeof(mul), a, b) = mul(a, b)

"""
Rule(propation_function)
Rule(propation_function[, updating_function])
Return a `Rule` that wraps the given `propation_function`. It is assumed that
`propation_function` is a callable object whose arguments are differential
values, and whose output is a single differential value calculated by applying
internally stored/computed partial derivatives to the input differential
values.
If an updating function is provided, it is assumed to have the signature `u(Δ, xs...)`
and to store the result of the propagation function applied to the arguments `xs` into
`Δ` in-place, returning `Δ`.
For example:
```
Expand All @@ -141,12 +171,21 @@ rrule(::typeof(*), x, y) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * Δ
See also: [`frule`](@ref), [`rrule`](@ref), [`accumulate`](@ref), [`accumulate!`](@ref), [`store!`](@ref)
"""
struct Rule{F} <: AbstractRule
struct Rule{F,U<:Union{Function,Nothing}} <: AbstractRule
f::F
u::U
end

# NOTE: Using `Core.Typeof` instead of `typeof` here so that if we define a rule for some
# constructor based on a `UnionAll`, we get `Rule{Type{Thing}}` instead of `Rule{UnionAll}`
Rule(f) = Rule{Core.Typeof(f),Nothing}(f, nothing)

(rule::Rule{F})(args...) where {F} = Cassette.overdub(RULE_CONTEXT, rule.f, args...)

# Specialized accumulation
# TODO: Does this need to be overdubbed in the rule context?
accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...)

#####
##### `DNERule`
#####
Expand Down
11 changes: 7 additions & 4 deletions src/rules/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@ end

function rrule(::typeof(getproperty), F::SVD, x::Symbol)
if x === :U
return F.U, (Rule(->(U=Ȳ, S=zero(F.S), V=zero(F.V))), DNERule())
rule = ->(U=Ȳ, S=zero(F.S), V=zero(F.V))
elseif x === :S
return F.S, (Rule(->(U=zero(F.U), S=Ȳ, V=zero(F.V))), DNERule())
rule = ->(U=zero(F.U), S=Ȳ, V=zero(F.V))
elseif x === :V
return F.V, (Rule(->(U=zero(F.U), S=zero(F.S), V=)), DNERule())
rule = ->(U=zero(F.U), S=zero(F.S), V=Ȳ)
elseif x === :Vt
return F.Vt, (Rule(Ȳ->(U=zero(F.U), S=zero(F.S), V=')), DNERule())
# TODO: This could be made to work, but it'd be a pain
throw(ArgumentError("Vt is unsupported; use V and transpose the result"))
end
update = (X̄::NamedTuple{(:U,:S,:V)}, Ȳ)->_update!(X̄, rule(Ȳ), x)
return getproperty(F, x), (Rule(rule, update), DNERule())
end

function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix)
Expand Down
23 changes: 23 additions & 0 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,27 @@ cool(x) = x + 1
end
@test i == 1 # rules only iterate once, yielding themselves
end
@testset "helper functions" begin
# Hits fallback, since we can't update `Diagonal`s in place
X = Diagonal([1, 1])
Y = copy(X)
@test ChainRules._update!(X, [1 2; 3 4]) == [2 2; 3 5]
@test X == Y # no change to X

X = [1 2; 3 4]
Y = copy(X)
@test ChainRules._update!(X, Diagonal([1, 1])) == [2 2; 3 5]
@test X != Y # X has been updated

# Reusing above X
@test ChainRules._update!(X, Zero()) === X
@test ChainRules._update!(Zero(), X) === X
@test ChainRules._update!(Zero(), Zero()) === Zero()

X = (A=[1 0; 0 1], B=[2 2; 2 2])
Y = deepcopy(X)
@test ChainRules._update!(X, Y) == (A=[2 0; 0 2], B=[4 4; 4 4])
@test X.A != Y.A
@test X.B != Y.B
end
end
16 changes: 15 additions & 1 deletion test/rules/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,28 @@
for n in [4, 6, 10], m in [3, 5, 10]
X = randn(rng, n, m)
F, dX = rrule(svd, X)
for p in [:U, :S, :V, :Vt]
for p in [:U, :S, :V]
Y, (dF, dp) = rrule(getproperty, F, p)
@test dp isa ChainRules.DNERule
= randn(rng, size(Y)...)
X̄_ad = dX(dF(Ȳ))
X̄_fd = j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X)
@test X̄_ad X̄_fd rtol=1e-6 atol=1e-6
end
@test_throws ArgumentError rrule(getproperty, F, :Vt)
end
@testset "accumulate!" begin
X = [1.0 2.0; 3.0 4.0; 5.0 6.0]
F, dX = rrule(svd, X)
= (U=zeros(3, 2), S=zeros(2), V=zeros(2, 2))
for p in [:U, :S, :V]
Y, (dF, _) = rrule(getproperty, F, p)
= ones(size(Y)...)
ChainRules.accumulate!(X̄, dF, Ȳ)
end
@test.U ones(3, 2) atol=1e-6
@test.S ones(2) atol=1e-6
@test.V ones(2, 2) atol=1e-6
end
@testset "Helper functions" begin
X = randn(rng, 10, 10)
Expand Down

0 comments on commit 1b4b6ad

Please sign in to comment.