Skip to content

Commit

Permalink
Merge 9bb6649 into f17f612
Browse files Browse the repository at this point in the history
  • Loading branch information
nickrobinson251 committed Jan 22, 2020
2 parents f17f612 + 9bb6649 commit 5a20d99
Show file tree
Hide file tree
Showing 13 changed files with 43 additions and 157 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.3.1"
version = "0.3.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -10,7 +10,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.5.1"
ChainRulesCore = "0.6"
FiniteDifferences = "^0.7"
Reexport = "0.2"
Requires = "0.5.2, 1"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ The most important `AbstractDifferential`s when getting started are the ones abo
### Other `AbstractDifferential`s:
- `Composite{P}`: this is the differential for tuples and structs. Use it like a `Tuple` or `NamedTuple`. The type parameter `P` is for the primal type.
- `DoesNotExist`: Zero-like, represents that the operation on this input is not differentiable. Its primal type is normally `Integer` or `Bool`.
- `InplaceableThunk`: it is like a Thunk but it can do `store!` and `accumulate!` in-place.
- `InplaceableThunk`: it is like a `Thunk` but it can do in-place `add!`.

-------------------------------

Expand Down
5 changes: 1 addition & 4 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ using Reexport
# Basically everything this package does is overloading these, so we make an exception
# to the normal rule of only overload via `ChainRulesCore.rrule`.
import ChainRulesCore: rrule, frule

# Deal with name clashes, by defining in this module which one we mean.
const accumulate = ChainRulesCore.accumulate
const accumulate! = ChainRulesCore.accumulate!
using ChainRulesCore: AbstractZero

using LinearAlgebra
using LinearAlgebra.BLAS
Expand Down
19 changes: 0 additions & 19 deletions src/helper_functions.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,3 @@
# Internal helpers for defining the `add!` field of an `InplaceableThunk`

_update!(x, y) = x + y
_update!(x::Array{T,N}, y::AbstractArray{T,N}) where {T,N} = x .+= y

_update!(x, ::Zero) = x
_update!(::Zero, y) = y
_update!(::Zero, ::Zero) = Zero()


function _update!(x::NamedTuple, y, p::Symbol)
y = extern(y)
yp = getproperty(y, p)
xp = getproperty(x, p)
new_xp = _update!(xp, yp)
new = NamedTuple{(p,)}((new_xp,))
return merge(x, new)
end

"""
_checked_rrule
Expand Down
28 changes: 14 additions & 14 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,33 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger!

function rrule(::typeof(svd), X::AbstractMatrix{<:Real})
F = svd(X)
function svd_pullback(Ȳ::NamedTuple{(:U,:S,:V)})
function svd_pullback(Ȳ::Composite{<:SVD})
∂X = @thunk(svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V))
return (NO_FIELDS, ∂X)
end
return F, svd_pullback
end

function rrule(::typeof(getproperty), F::SVD, x::Symbol)
function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD
function getproperty_svd_pullback(Ȳ)
if x === :U
= @thunk((; U=Ȳ, S=(zero(F.S)), V=(zero(F.V))))
C = Composite{T}
∂F = if x === :U
C(U=Ȳ,)
elseif x === :S
= @thunk((; U=(zero(F.U)), S=Ȳ, V=(zero(F.V))))
C(S=Ȳ,)
elseif x === :V
= @thunk((; U=(zero(F.U)), S=(zero(F.S)), V=))
C(V=,)
elseif x === :Vt
# TODO: This could be made to work, but it'd be a pain
# TODO: https://github.com/JuliaDiff/ChainRules.jl/issues/106
throw(ArgumentError("Vt is unsupported; use V and transpose the result"))
end

update = (X̄::NamedTuple{(:U,:S,:V)}) -> _update!(X̄, ∂, x)
∂F = InplaceableThunk(∂, update)
return NO_FIELDS, ∂F, DoesNotExist()
end
return getproperty(F, x), getproperty_svd_pullback
end

function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix)
# When not `Zero`s expect `Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix`
function svd_rev(USV::SVD, Ū, s̄, V̄)
# Note: assuming a thin factorization, i.e. svd(A, full=false), which is the default
U = USV.U
s = USV.S
Expand All @@ -56,11 +55,12 @@ function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::Abstra
ImVVᵀ = _eyesubx!(V*Vt) # I - VVᵀ

S = Diagonal(s)
= Diagonal(s̄)
= isa AbstractZero ?: Diagonal(s̄)

# TODO: consider using MuladdMacro here
= _add!(U * FUᵀŪ * S, ImUUᵀ * (Ū / S)) * Vt
_add!(Ā, U ** Vt)
_add!(Ā, U * _add!(S * FVᵀV̄ * Vt, (S \') * ImVVᵀ))
= _add!(Ā, U ** Vt)
= _add!(Ā, U * _add!(S * FVᵀV̄ * Vt, (S \') * ImVVᵀ))

return
end
Expand Down
14 changes: 9 additions & 5 deletions src/rulesets/LinearAlgebra/utils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Some utility functions for optimizing linear algebra operations that aren't specific
# to any particular rule definition

# F .* (X - X'), overwrites X
# F .* (X - X'), overwrites X if possible
function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
k = size(X, 1)
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle
Expand All @@ -11,22 +11,26 @@ function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j])
end
end
X
return X
end
_mulsubtrans!(X::AbstractZero, F::AbstractZero) = X
_mulsubtrans!(X::AbstractZero, F::AbstractMatrix{<:Real}) = X
_mulsubtrans!(X::AbstractMatrix{<:Real}, F::AbstractZero) = F

# I - X, overwrites X
function _eyesubx!(X::AbstractMatrix)
n, m = size(X)
@inbounds for j = 1:m, i = 1:n
X[i,j] = (i == j) - X[i,j]
end
X
return X
end

# X + Y, overwrites X
# X + Y, overwrites X if possible
function _add!(X::AbstractVecOrMat{T}, Y::AbstractVecOrMat{T}) where T<:Real
@inbounds for i = eachindex(X, Y)
X[i] += Y[i]
end
X
return X
end
_add!(X, Y) = X + Y # handles all `AbstractZero` overloads
26 changes: 0 additions & 26 deletions test/helper_functions.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,4 @@
@testset "helper functions" begin
@testset "_update! Array" begin
# Hits fallback, since we can't update `Diagonal`s in place
X = Diagonal([1, 1])
Y = copy(X)
@test ChainRules._update!(X, [1 2; 3 4]) == [2 2; 3 5]
@test X == Y # no change to X

X = [1 2; 3 4]
Y = copy(X)
@test ChainRules._update!(X, Diagonal([1, 1])) == [2 2; 3 5]
@test X != Y # X has been updated
end
@testset "_update! Zero" begin
X = [1 2; 3 4]
@test ChainRules._update!(X, Zero()) === X
@test ChainRules._update!(Zero(), X) === X
@test ChainRules._update!(Zero(), Zero()) === Zero()
end
@testset "_update! NamedTuple" begin
X = (A=[1 0; 0 1], B=[2 2; 2 2])
old_X = deepcopy(X)
Y = deepcopy(X)
@test ChainRules._update!(X, Y, :A) == (A=[2 0; 0 2], B=[2 2; 2 2])
@test X.A != old_X.A
@test X.B == old_X.B
end
@testset "_checked_rrule" begin
try
@eval cool(x,y) = x + y
Expand Down
7 changes: 2 additions & 5 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,8 @@

@test ds === NO_FIELDS

@test extern(dx) == extern(accumulate(zeros(3, 2), dx))
@test extern(dy) == extern(accumulate(zeros(2, 5), dy))

test_accumulation(rand(3, 2), dx)
test_accumulation(rand(2, 5), dy)
@test extern(dx) == extern(zeros(3, 2) .+ dx)
@test extern(dy) == extern(zeros(2, 5) .+ dy)
end

@testset "binary function ($f)" for f in (hypot, atan, mod, rem, ^)
Expand Down
7 changes: 2 additions & 5 deletions test/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@

x̄, ȳ = rand(), rand()
∂x = pullback(ȳ)[3]
@test isequal(
extern(ChainRules.accumulate(x̄, ∂x)),
.+.* cos.(x)
)
@test isequal(extern(x̄ .+ ∂x), x̄ .+.* cos.(x))

x̄, ȳ = Zero(), rand(3, 3)
∂x = pullback(ȳ)[3]
@test extern(extern(accumulate(x̄, ∂x))) ==.* cos.(x)
@test extern(extern(.+ ∂x)) ==.* cos.(x)
end
@testset "frule" begin
x = rand(3, 3)
Expand Down
5 changes: 0 additions & 5 deletions test/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
function generate_well_conditioned_matrix(rng, N)
A = randn(rng, N, N)
return A * A' + I
end

@testset "linalg" begin
@testset "dot" begin
@testset "Vector" begin
Expand Down
14 changes: 7 additions & 7 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo
@test dself1 === NO_FIELDS
@test dp === DoesNotExist()

ΔF = extern(dF)
ΔF = unthunk(dF)
dself2, dX = dX_pullback(ΔF)
@test dself2 === NO_FIELDS
X̄_ad = extern(dX)
X̄_ad = unthunk(dX)
X̄_fd = j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X)
@test X̄_ad X̄_fd rtol=1e-6 atol=1e-6
@test all(isapprox.(X̄_ad, X̄_fd; rtol=1e-6, atol=1e-6))
end
@testset "Vt" begin
Y, dF_pullback = rrule(getproperty, F, :Vt)
Expand All @@ -28,17 +28,17 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo
end
end

@testset "accumulate!" begin
@testset "+" begin
X = [1.0 2.0; 3.0 4.0; 5.0 6.0]
F, dX_pullback = rrule(svd, X)
= (U=zeros(3, 2), S=zeros(2), V=zeros(2, 2))
= Composite{typeof(F)}(U=zeros(3, 2), S=zeros(2), V=zeros(2, 2))
for p in [:U, :S, :V]
Y, dF_pullback = rrule(getproperty, F, p)
= ones(size(Y)...)
(dself, dF, dp) = dF_pullback(Ȳ)
dself, dF, dp = dF_pullback(Ȳ)
@test dself === NO_FIELDS
@test dp === DoesNotExist()
ChainRules.accumulate!(X̄, dF)
+= dF
end
@test.U ones(3, 2) atol=1e-6
@test.S ones(2) atol=1e-6
Expand Down
4 changes: 0 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ using Random
using Statistics
using Test

# For testing purposes we use a lot of
using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
Zero, One, DoesNotExist, Thunk, AbstractDifferential

Random.seed!(1) # Set seed that all testsets should reset to.

include("test_util.jl")
Expand Down
65 changes: 5 additions & 60 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ using ChainRulesCore: AbstractDifferential

const _fdm = central_fdm(5, 1)

# Useful for LinearAlgebra tests
function generate_well_conditioned_matrix(rng, N)
A = randn(rng, N, N)
return A * A' + I
end

"""
test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
Expand Down Expand Up @@ -115,10 +120,6 @@ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm
# Correctness testing via finite differencing.
x̄_fd = j′vp(fdm, f, ȳ, x)
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)

# Assuming x̄_ad to be correct, check that other ChainRules mechanisms are correct.
test_accumulation(x̄, x̄_ad)
test_accumulation(Zero(), x̄_ad)
end

function _make_fdm_call(fdm, f, ȳ, xs, ignores)
Expand Down Expand Up @@ -172,13 +173,6 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)
end
end

# Assuming the above to be correct, check that other ChainRules mechanisms are correct.
for (x̄, x̄_ad) in zip(x̄s, x̄s_ad)
=== nothing && continue
test_accumulation(x̄, x̄_ad)
test_accumulation(Zero(), x̄_ad)
end
end

function Base.isapprox(d_ad::DoesNotExist, d_fd; kwargs...)
Expand All @@ -188,52 +182,3 @@ end
function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...)
return isapprox(extern(d_ad), d_fd; kwargs...)
end

function test_accumulation(x̄, ∂x)
@test all(extern(x̄ + ∂x) .≈ extern(x̄) .+ extern(∂x))
test_accumulate(x̄, ∂x)
test_accumulate!(x̄, ∂x)
test_store!(x̄, ∂x)
end

function test_accumulate(x̄::Zero, ∂x)
@test extern(accumulate(x̄, ∂x)) extern(∂x)
end

function test_accumulate(x̄::Number, ∂x)
@test extern(accumulate(x̄, ∂x)) extern(x̄) + extern(∂x)
end

function test_accumulate(x̄::AbstractArray, ∂x)
x̄_old = copy(x̄)
@test all(extern(accumulate(x̄, ∂x)) .≈ (extern(x̄) .+ extern(∂x)))
@test== x̄_old # make sure didn't mutate x̄
end

test_accumulate!(x̄::Zero, ∂x) = nothing

function test_accumulate!(x̄::Number, ∂x)
# This case won't have been inplace as `Number` is immutable
@test accumulate!(x̄, ∂x) accumulate(x̄, ∂x)
end

function test_accumulate!(x̄::AbstractArray, ∂x)
x̄_copy = copy(x̄)

accumulate!(x̄_copy, ∂x) # this should have actually been in-place
@test extern(x̄_copy) (extern(x̄) .+ extern(∂x))
end

test_store!(x̄::Zero, ∂x) = nothing
test_store!(x̄::Number, ∂x) = nothing

function test_store!(x̄::AbstractArray, ∂x)
x̄_store = copy(x̄)
store!(x̄_store, ∂x)
@test x̄_store extern(∂x)

# store! is the same as `accumulate!` to a zero array
x̄_acc = false.*
accumulate!(x̄_acc, ∂x)
@test x̄_acc x̄_store
end

0 comments on commit 5a20d99

Please sign in to comment.