Skip to content

Commit

Permalink
Merge pull request #157 from JuliaDiff/npr/composite
Browse files Browse the repository at this point in the history
Change SVD to use `Composite`
  • Loading branch information
nickrobinson251 committed Jan 22, 2020
2 parents f17f612 + 9bb6649 commit 3a1422c
Show file tree
Hide file tree
Showing 13 changed files with 43 additions and 157 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.3.1"
version = "0.3.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -10,7 +10,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.5.1"
ChainRulesCore = "0.6"
FiniteDifferences = "^0.7"
Reexport = "0.2"
Requires = "0.5.2, 1"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ The most important `AbstractDifferential`s when getting started are the ones abo
### Other `AbstractDifferential`s:
- `Composite{P}`: this is the differential for tuples and structs. Use it like a `Tuple` or `NamedTuple`. The type parameter `P` is for the primal type.
- `DoesNotExist`: Zero-like, represents that the operation on this input is not differentiable. Its primal type is normally `Integer` or `Bool`.
- `InplaceableThunk`: it is like a Thunk but it can do `store!` and `accumulate!` in-place.
- `InplaceableThunk`: it is like a `Thunk` but it can do in-place `add!`.

-------------------------------

Expand Down
5 changes: 1 addition & 4 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ using Reexport
# Basically everything this package does is overloading these, so we make an exception
# to the normal rule of only overload via `ChainRulesCore.rrule`.
import ChainRulesCore: rrule, frule

# Deal with name clashes, by defining in this module which one we mean.
const accumulate = ChainRulesCore.accumulate
const accumulate! = ChainRulesCore.accumulate!
using ChainRulesCore: AbstractZero

using LinearAlgebra
using LinearAlgebra.BLAS
Expand Down
19 changes: 0 additions & 19 deletions src/helper_functions.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,3 @@
# Internal helpers for defining the `add!` field of an `InplaceableThunk`

_update!(x, y) = x + y
_update!(x::Array{T,N}, y::AbstractArray{T,N}) where {T,N} = x .+= y

_update!(x, ::Zero) = x
_update!(::Zero, y) = y
_update!(::Zero, ::Zero) = Zero()


function _update!(x::NamedTuple, y, p::Symbol)
y = extern(y)
yp = getproperty(y, p)
xp = getproperty(x, p)
new_xp = _update!(xp, yp)
new = NamedTuple{(p,)}((new_xp,))
return merge(x, new)
end

"""
_checked_rrule
Expand Down
28 changes: 14 additions & 14 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,33 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger!

function rrule(::typeof(svd), X::AbstractMatrix{<:Real})
F = svd(X)
function svd_pullback(Ȳ::NamedTuple{(:U,:S,:V)})
function svd_pullback(Ȳ::Composite{<:SVD})
∂X = @thunk(svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V))
return (NO_FIELDS, ∂X)
end
return F, svd_pullback
end

function rrule(::typeof(getproperty), F::SVD, x::Symbol)
function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD
function getproperty_svd_pullback(Ȳ)
if x === :U
= @thunk((; U=Ȳ, S=(zero(F.S)), V=(zero(F.V))))
C = Composite{T}
∂F = if x === :U
C(U=Ȳ,)
elseif x === :S
= @thunk((; U=(zero(F.U)), S=Ȳ, V=(zero(F.V))))
C(S=Ȳ,)
elseif x === :V
= @thunk((; U=(zero(F.U)), S=(zero(F.S)), V=))
C(V=,)
elseif x === :Vt
# TODO: This could be made to work, but it'd be a pain
# TODO: https://github.com/JuliaDiff/ChainRules.jl/issues/106
throw(ArgumentError("Vt is unsupported; use V and transpose the result"))
end

update = (X̄::NamedTuple{(:U,:S,:V)}) -> _update!(X̄, ∂, x)
∂F = InplaceableThunk(∂, update)
return NO_FIELDS, ∂F, DoesNotExist()
end
return getproperty(F, x), getproperty_svd_pullback
end

function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix)
# When not `Zero`s expect `Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix`
function svd_rev(USV::SVD, Ū, s̄, V̄)
# Note: assuming a thin factorization, i.e. svd(A, full=false), which is the default
U = USV.U
s = USV.S
Expand All @@ -56,11 +55,12 @@ function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::Abstra
ImVVᵀ = _eyesubx!(V*Vt) # I - VVᵀ

S = Diagonal(s)
= Diagonal(s̄)
= isa AbstractZero ?: Diagonal(s̄)

# TODO: consider using MuladdMacro here
= _add!(U * FUᵀŪ * S, ImUUᵀ * (Ū / S)) * Vt
_add!(Ā, U ** Vt)
_add!(Ā, U * _add!(S * FVᵀV̄ * Vt, (S \') * ImVVᵀ))
= _add!(Ā, U ** Vt)
= _add!(Ā, U * _add!(S * FVᵀV̄ * Vt, (S \') * ImVVᵀ))

return
end
Expand Down
14 changes: 9 additions & 5 deletions src/rulesets/LinearAlgebra/utils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Some utility functions for optimizing linear algebra operations that aren't specific
# to any particular rule definition

# F .* (X - X'), overwrites X
# F .* (X - X'), overwrites X if possible
function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
k = size(X, 1)
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle
Expand All @@ -11,22 +11,26 @@ function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j])
end
end
X
return X
end
_mulsubtrans!(X::AbstractZero, F::AbstractZero) = X
_mulsubtrans!(X::AbstractZero, F::AbstractMatrix{<:Real}) = X
_mulsubtrans!(X::AbstractMatrix{<:Real}, F::AbstractZero) = F

# I - X, overwrites X
function _eyesubx!(X::AbstractMatrix)
n, m = size(X)
@inbounds for j = 1:m, i = 1:n
X[i,j] = (i == j) - X[i,j]
end
X
return X
end

# X + Y, overwrites X
# X + Y, overwrites X if possible
function _add!(X::AbstractVecOrMat{T}, Y::AbstractVecOrMat{T}) where T<:Real
@inbounds for i = eachindex(X, Y)
X[i] += Y[i]
end
X
return X
end
_add!(X, Y) = X + Y # handles all `AbstractZero` overloads
26 changes: 0 additions & 26 deletions test/helper_functions.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,4 @@
@testset "helper functions" begin
@testset "_update! Array" begin
# Hits fallback, since we can't update `Diagonal`s in place
X = Diagonal([1, 1])
Y = copy(X)
@test ChainRules._update!(X, [1 2; 3 4]) == [2 2; 3 5]
@test X == Y # no change to X

X = [1 2; 3 4]
Y = copy(X)
@test ChainRules._update!(X, Diagonal([1, 1])) == [2 2; 3 5]
@test X != Y # X has been updated
end
@testset "_update! Zero" begin
X = [1 2; 3 4]
@test ChainRules._update!(X, Zero()) === X
@test ChainRules._update!(Zero(), X) === X
@test ChainRules._update!(Zero(), Zero()) === Zero()
end
@testset "_update! NamedTuple" begin
X = (A=[1 0; 0 1], B=[2 2; 2 2])
old_X = deepcopy(X)
Y = deepcopy(X)
@test ChainRules._update!(X, Y, :A) == (A=[2 0; 0 2], B=[2 2; 2 2])
@test X.A != old_X.A
@test X.B == old_X.B
end
@testset "_checked_rrule" begin
try
@eval cool(x,y) = x + y
Expand Down
7 changes: 2 additions & 5 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,8 @@

@test ds === NO_FIELDS

@test extern(dx) == extern(accumulate(zeros(3, 2), dx))
@test extern(dy) == extern(accumulate(zeros(2, 5), dy))

test_accumulation(rand(3, 2), dx)
test_accumulation(rand(2, 5), dy)
@test extern(dx) == extern(zeros(3, 2) .+ dx)
@test extern(dy) == extern(zeros(2, 5) .+ dy)
end

@testset "binary function ($f)" for f in (hypot, atan, mod, rem, ^)
Expand Down
7 changes: 2 additions & 5 deletions test/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@

x̄, ȳ = rand(), rand()
∂x = pullback(ȳ)[3]
@test isequal(
extern(ChainRules.accumulate(x̄, ∂x)),
.+.* cos.(x)
)
@test isequal(extern(x̄ .+ ∂x), x̄ .+.* cos.(x))

x̄, ȳ = Zero(), rand(3, 3)
∂x = pullback(ȳ)[3]
@test extern(extern(accumulate(x̄, ∂x))) ==.* cos.(x)
@test extern(extern(.+ ∂x)) ==.* cos.(x)
end
@testset "frule" begin
x = rand(3, 3)
Expand Down
5 changes: 0 additions & 5 deletions test/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
function generate_well_conditioned_matrix(rng, N)
A = randn(rng, N, N)
return A * A' + I
end

@testset "linalg" begin
@testset "dot" begin
@testset "Vector" begin
Expand Down
14 changes: 7 additions & 7 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo
@test dself1 === NO_FIELDS
@test dp === DoesNotExist()

ΔF = extern(dF)
ΔF = unthunk(dF)
dself2, dX = dX_pullback(ΔF)
@test dself2 === NO_FIELDS
X̄_ad = extern(dX)
X̄_ad = unthunk(dX)
X̄_fd = j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X)
@test X̄_ad X̄_fd rtol=1e-6 atol=1e-6
@test all(isapprox.(X̄_ad, X̄_fd; rtol=1e-6, atol=1e-6))
end
@testset "Vt" begin
Y, dF_pullback = rrule(getproperty, F, :Vt)
Expand All @@ -28,17 +28,17 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo
end
end

@testset "accumulate!" begin
@testset "+" begin
X = [1.0 2.0; 3.0 4.0; 5.0 6.0]
F, dX_pullback = rrule(svd, X)
= (U=zeros(3, 2), S=zeros(2), V=zeros(2, 2))
= Composite{typeof(F)}(U=zeros(3, 2), S=zeros(2), V=zeros(2, 2))
for p in [:U, :S, :V]
Y, dF_pullback = rrule(getproperty, F, p)
= ones(size(Y)...)
(dself, dF, dp) = dF_pullback(Ȳ)
dself, dF, dp = dF_pullback(Ȳ)
@test dself === NO_FIELDS
@test dp === DoesNotExist()
ChainRules.accumulate!(X̄, dF)
+= dF
end
@test.U ones(3, 2) atol=1e-6
@test.S ones(2) atol=1e-6
Expand Down
4 changes: 0 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ using Random
using Statistics
using Test

# For testing purposes we use a lot of
using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
Zero, One, DoesNotExist, Thunk, AbstractDifferential

Random.seed!(1) # Set seed that all testsets should reset to.

include("test_util.jl")
Expand Down
65 changes: 5 additions & 60 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ using ChainRulesCore: AbstractDifferential

const _fdm = central_fdm(5, 1)

# Useful for LinearAlgebra tests
function generate_well_conditioned_matrix(rng, N)
A = randn(rng, N, N)
return A * A' + I
end

"""
test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
Expand Down Expand Up @@ -115,10 +120,6 @@ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm
# Correctness testing via finite differencing.
x̄_fd = j′vp(fdm, f, ȳ, x)
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)

# Assuming x̄_ad to be correct, check that other ChainRules mechanisms are correct.
test_accumulation(x̄, x̄_ad)
test_accumulation(Zero(), x̄_ad)
end

function _make_fdm_call(fdm, f, ȳ, xs, ignores)
Expand Down Expand Up @@ -172,13 +173,6 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)
end
end

# Assuming the above to be correct, check that other ChainRules mechanisms are correct.
for (x̄, x̄_ad) in zip(x̄s, x̄s_ad)
=== nothing && continue
test_accumulation(x̄, x̄_ad)
test_accumulation(Zero(), x̄_ad)
end
end

function Base.isapprox(d_ad::DoesNotExist, d_fd; kwargs...)
Expand All @@ -188,52 +182,3 @@ end
function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...)
return isapprox(extern(d_ad), d_fd; kwargs...)
end

function test_accumulation(x̄, ∂x)
@test all(extern(x̄ + ∂x) .≈ extern(x̄) .+ extern(∂x))
test_accumulate(x̄, ∂x)
test_accumulate!(x̄, ∂x)
test_store!(x̄, ∂x)
end

function test_accumulate(x̄::Zero, ∂x)
@test extern(accumulate(x̄, ∂x)) extern(∂x)
end

function test_accumulate(x̄::Number, ∂x)
@test extern(accumulate(x̄, ∂x)) extern(x̄) + extern(∂x)
end

function test_accumulate(x̄::AbstractArray, ∂x)
x̄_old = copy(x̄)
@test all(extern(accumulate(x̄, ∂x)) .≈ (extern(x̄) .+ extern(∂x)))
@test== x̄_old # make sure didn't mutate x̄
end

test_accumulate!(x̄::Zero, ∂x) = nothing

function test_accumulate!(x̄::Number, ∂x)
# This case won't have been inplace as `Number` is immutable
@test accumulate!(x̄, ∂x) accumulate(x̄, ∂x)
end

function test_accumulate!(x̄::AbstractArray, ∂x)
x̄_copy = copy(x̄)

accumulate!(x̄_copy, ∂x) # this should have actually been in-place
@test extern(x̄_copy) (extern(x̄) .+ extern(∂x))
end

test_store!(x̄::Zero, ∂x) = nothing
test_store!(x̄::Number, ∂x) = nothing

function test_store!(x̄::AbstractArray, ∂x)
x̄_store = copy(x̄)
store!(x̄_store, ∂x)
@test x̄_store extern(∂x)

# store! is the same as `accumulate!` to a zero array
x̄_acc = false.*
accumulate!(x̄_acc, ∂x)
@test x̄_acc x̄_store
end

2 comments on commit 3a1422c

@nickrobinson251
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/8290

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.2 -m "<description of version>" 3a1422c40e4915184561dce95f2e7a979b8dc5fa
git push origin v0.3.2

Please sign in to comment.