Skip to content

Commit

Permalink
Add an rrule for the Cholesky decomposition
Browse files Browse the repository at this point in the history
This is a direct port of the code from Nabla.
  • Loading branch information
ararslan committed Jun 5, 2019
1 parent a2e6451 commit 96cd630
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 0 deletions.
192 changes: 192 additions & 0 deletions src/rules/linalg/factorization.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using LinearAlgebra: checksquare
using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger!

#####
##### `svd`
#####
Expand Down Expand Up @@ -79,3 +82,192 @@ function _add!(X::AbstractMatrix{T}, Y::AbstractMatrix{T}) where T<:Real
end
X
end

#####
##### `cholesky`
#####

function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real})
F = cholesky(X)
∂X = Rule(Ȳ->chol_blocked_rev(Matrix(Ȳ), Matrix(F.U), 25, true))
return F, ∂X
end

function rrule(::typeof(getproperty), F::Cholesky, x::Symbol)
if x === :U
if F.uplo === 'U'
∂F =->UpperTriangular(Ȳ)
else
∂F =->LowerTriangular(Ȳ')
end
elseif x === :L
if F.uplo === 'L'
∂F =->LowerTriangular(Ȳ)
else
∂F =->UpperTriangular(Ȳ')
end
end
return getproperty(F, x), (∂F, DNERule())
end

# TODO: The comment below about row- versus column-major is incorrect; Julia's arrays are
# column-major. This comment along with the below implementation was copied from Nabla and
# works, but the implementation should perhaps should be revisited to ensure we're actually
# being cache friendly.
#
# See [1] for implementation details: pages 5-9 in particular. The derivations presented in
# [1] assume column-major layout, whereas Julia primarily uses row-major. We therefore
# implement both the derivations in [1] and their transpose, which is more appropriate to
# Julia.
#
# [1] - "Differentiation of the Cholesky decomposition", Murray 2016

"""
level2partition(A::AbstractMatrix, j::Integer, upper::Bool)
Returns views to various bits of the lower triangle of `A` according to the
`level2partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then
the transposed views are returned from the upper triangle of `A`.
"""
function level2partition(A::AbstractMatrix, j::Integer, upper::Bool)
n = checksquare(A)
@boundscheck checkbounds(1:n, j)
if upper
r = view(A, 1:j-1, j)
d = view(A, j, j)
B = view(A, 1:j-1, j+1:n)
c = view(A, j, j+1:n)
else
r = view(A, j, 1:j-1)
d = view(A, j, j)
B = view(A, j+1:n, 1:j-1)
c = view(A, j+1:n, j)
end
return r, d, B, c
end

"""
level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool)
Returns views to various bits of the lower triangle of `A` according to the
`level3partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then
the transposed views are returned from the upper triangle of `A`.
"""
function level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool)
n = checksquare(A)
@boundscheck checkbounds(1:n, j)
if upper
R = view(A, 1:j-1, j:k)
D = view(A, j:k, j:k)
B = view(A, 1:j-1, k+1:n)
C = view(A, j:k, k+1:n)
else
R = view(A, j:k, 1:j-1)
D = view(A, j:k, j:k)
B = view(A, k+1:n, 1:j-1)
C = view(A, k+1:n, j:k)
end
return R, D, B, C
end

"""
chol_unblocked_rev!(Ā::AbstractMatrix, L::AbstractMatrix, upper::Bool)
Compute the reverse-mode sensitivities of the Cholesky factorization in an unblocked manner.
If `upper` is `false`, then the sensitivites are computed from and stored in the lower triangle
of `Ā` and `L` respectively. If `upper` is `true` then they are computed and stored in the
upper triangles. If at input `upper` is `false` and `tril(Ā) = L̄`, at output
`tril(Ā) = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and
`triu(Ā) = triu(Ū)`, at output `triu(Ā) = triu(Σ̄)` where `Σ = UᵀU`.
"""
function chol_unblocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, upper::Bool) where T<:Real
n = checksquare(Σ̄)
j = n
@inbounds for _ in 1:n
r, d, B, c = level2partition(L, j, upper)
r̄, d̄, B̄, c̄ = level2partition(Σ̄, j, upper)

# d̄ <- d̄ - c'c̄ / d.
d̄[1] -= dot(c, c̄) / d[1]

# [d̄ c̄'] <- [d̄ c̄'] / d.
./= d
./= d

# r̄ <- r̄ - [d̄ c̄'] [r' B']'.
= axpy!(-Σ̄[j,j], r, r̄)
= gemv!(upper ? 'n' : 'T', -one(T), B, c̄, one(T), r̄)

# B̄ <- B̄ - c̄ r.
= upper ? ger!(-one(T), r, c̄, B̄) : ger!(-one(T), c̄, r, B̄)
./= 2
j -= 1
end
return (upper ? triu! : tril!)(Σ̄)
end

function chol_unblocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, upper::Bool)
return chol_unblocked_rev!(copy(Σ̄), L, upper)
end

"""
chol_blocked_rev!(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool)
Compute the sensitivities of the Cholesky factorization using a blocked, cache-friendly
procedure. `Σ̄` are the sensitivities of `L`, and will be transformed into the sensitivities
of `Σ`, where `Σ = LLᵀ`. `nb` is the block size to use. If the upper triangle has been used
to represent the factorization, that is `Σ = UᵀU` where `U := Lᵀ`, then this should be
indicated by passing `upper = true`.
"""
function chol_blocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, nb::Integer, upper::Bool) where T<:Real
n = checksquare(Σ̄)
tmp = Matrix{T}(undef, nb, nb)
k = n
if upper
@inbounds for _ in 1:nb:n
j = max(1, k - nb + 1)
R, D, B, C = level3partition(L, j, k, true)
R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, true)

= trsm!('L', 'U', 'N', 'N', one(T), D, C̄)
gemm!('N', 'N', -one(T), R, C̄, one(T), B̄)
gemm!('N', 'T', -one(T), C, C̄, one(T), D̄)
chol_unblocked_rev!(D̄, D, true)
gemm!('N', 'T', -one(T), B, C̄, one(T), R̄)
if size(D̄, 1) == nb
tmp = axpy!(one(T), D̄, transpose!(tmp, D̄))
gemm!('N', 'N', -one(T), R, tmp, one(T), R̄)
else
gemm!('N', 'N', -one(T), R, D̄ +', one(T), R̄)
end

k -= nb
end
return triu!(Σ̄)
else
@inbounds for _ in 1:nb:n
j = max(1, k - nb + 1)
R, D, B, C = level3partition(L, j, k, false)
R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, false)

= trsm!('R', 'L', 'N', 'N', one(T), D, C̄)
gemm!('N', 'N', -one(T), C̄, R, one(T), B̄)
gemm!('T', 'N', -one(T), C̄, C, one(T), D̄)
chol_unblocked_rev!(D̄, D, false)
gemm!('T', 'N', -one(T), C̄, B, one(T), R̄)
if size(D̄, 1) == nb
tmp = axpy!(one(T), D̄, transpose!(tmp, D̄))
gemm!('N', 'N', -one(T), tmp, R, one(T), R̄)
else
gemm!('N', 'N', -one(T), D̄ +', R, one(T), R̄)
end

k -= nb
end
return tril!(Σ̄)
end
end

function chol_blocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool)
return chol_blocked_rev!(copy(Σ̄), L, nb, upper)
end
63 changes: 63 additions & 0 deletions test/rules/linalg/factorization.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblocked_rev

@testset "Factorizations" begin
@testset "svd" begin
rng = MersenneTwister(2)
Expand All @@ -21,4 +23,65 @@
@test ChainRules._add!(copy(X), Y) X + Y
end
end
@testset "cholesky" begin
rng = MersenneTwister(4)
@testset "the thing" begin
X = generate_well_conditioned_matrix(rng, 10)
V = generate_well_conditioned_matrix(rng, 10)
F, dX = rrule(cholesky, X)
for p in [:U, :L]
Y, (dF, dp) = rrule(getproperty, F, p)
@test dp isa ChainRules.DNERule
= (p === :U ? UpperTriangular : LowerTriangular)(randn(rng, size(Y)))
# NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp`
# machinery from FDM because that isn't set up to respect necessary special
# properties of the input. In the case of the Cholesky factorization, we
# need the input to be Hermitian.
X̄_ad = dot(dX(dF(Ȳ)), V)
X̄_fd = central_fdm(5, 1)() do ε
dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p))
end
@test X̄_ad X̄_fd rtol=1e-6 atol=1e-6
end
end
@testset "helper functions" begin
A = randn(rng, 5, 5)
r, d, B2, c = level2partition(A, 4, false)
R, D, B3, C = level3partition(A, 4, 4, false)
@test all(r .== R')
@test all(d .== D)
@test B2[1] == B3[1]
@test all(c .== C)

# Check that level 2 partition with `upper == true` is consistent with `false`
rᵀ, dᵀ, B2ᵀ, cᵀ = level2partition(transpose(A), 4, true)
@test r == rᵀ
@test d == dᵀ
@test B2' == B2ᵀ
@test c == cᵀ

# Check that level 3 partition with `upper == true` is consistent with `false`
R, D, B3, C = level3partition(A, 2, 4, false)
Rᵀ, Dᵀ, B3ᵀ, Cᵀ = level3partition(transpose(A), 2, 4, true)
@test transpose(R) == Rᵀ
@test transpose(D) == Dᵀ
@test transpose(B3) == B3ᵀ
@test transpose(C) == Cᵀ

A = Matrix(LowerTriangular(randn(rng, 10, 10)))
= Matrix(LowerTriangular(randn(rng, 10, 10)))
# NOTE: BLAS gets angry if we don't materialize the Transpose objects first
B = Matrix(transpose(A))
= Matrix(transpose(Ā))
@test chol_unblocked_rev(Ā, A, false) chol_blocked_rev(Ā, A, 1, false)
@test chol_unblocked_rev(Ā, A, false) chol_blocked_rev(Ā, A, 3, false)
@test chol_unblocked_rev(Ā, A, false) chol_blocked_rev(Ā, A, 5, false)
@test chol_unblocked_rev(Ā, A, false) chol_blocked_rev(Ā, A, 10, false)
@test chol_unblocked_rev(Ā, A, false) transpose(chol_unblocked_rev(B̄, B, true))

@test chol_unblocked_rev(B̄, B, true) chol_blocked_rev(B̄, B, 1, true)
@test chol_unblocked_rev(B̄, B, true) chol_blocked_rev(B̄, B, 5, true)
@test chol_unblocked_rev(B̄, B, true) chol_blocked_rev(B̄, B, 10, true)
end
end
end

0 comments on commit 96cd630

Please sign in to comment.