Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove constraints on primal from Composites for SVD and Cholesky #207

Merged
merged 3 commits into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.6.3"
version = "0.6.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
52 changes: 26 additions & 26 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger!

function rrule(::typeof(svd), X::AbstractMatrix{<:Real})
F = svd(X)
function svd_pullback(::Composite{<:SVD})
∂X = @thunk(svd_rev(F, .U, .S, .V))
function svd_pullback(Ȳ::Composite)
∂X = @thunk(svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V))
return (NO_FIELDS, ∂X)
end
return F, svd_pullback
end

function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD
function getproperty_svd_pullback()
function getproperty_svd_pullback(Ȳ)
C = Composite{T}
∂F = if x === :U
C(U=,)
C(U=Ȳ,)
elseif x === :S
C(S=,)
C(S=Ȳ,)
elseif x === :V
C(V=,)
C(V=Ȳ,)
elseif x === :Vt
# TODO: https://github.com/JuliaDiff/ChainRules.jl/issues/106
throw(ArgumentError("Vt is unsupported; use V and transpose the result"))
Expand All @@ -32,8 +32,8 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD
return getproperty(F, x), getproperty_svd_pullback
end

# When not `Zero`s expect `::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix`
function svd_rev(USV::SVD, , s̄, V̄)
# When not `Zero`s expect `Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix`
function svd_rev(USV::SVD, Ū, s̄, V̄)
# Note: assuming a thin factorization, i.e. svd(A, full=false), which is the default
U = USV.U
s = USV.S
Expand All @@ -49,7 +49,7 @@ function svd_rev(USV::SVD, Ū, s̄, V̄)
# place functions here are significantly faster than their out-of-place, naively
# implemented counterparts, and allocate no additional memory.
Ut = U'
FUᵀŪ = _mulsubtrans!(Ut*, F) # F .* (UᵀŪ - ŪᵀU)
FUᵀŪ = _mulsubtrans!(Ut*Ū, F) # F .* (UᵀŪ - ŪᵀU)
FVᵀV̄ = _mulsubtrans!(Vt*V̄, F) # F .* (VᵀV̄ - V̄ᵀV)
ImUUᵀ = _eyesubx!(U*Ut) # I - UUᵀ
ImVVᵀ = _eyesubx!(V*Vt) # I - VVᵀ
Expand All @@ -58,11 +58,11 @@ function svd_rev(USV::SVD, Ū, s̄, V̄)
S̄ = s̄ isa AbstractZero ? s̄ : Diagonal(s̄)

# TODO: consider using MuladdMacro here
= _add!(U * FUᵀŪ * S, ImUUᵀ * ( / S)) * Vt
= _add!(, U * S̄ * Vt)
= _add!(, U * _add!(S * FVᵀV̄ * Vt, (S \ V̄') * ImVVᵀ))
Ā = _add!(U * FUᵀŪ * S, ImUUᵀ * (Ū / S)) * Vt
Ā = _add!(Ā, U * S̄ * Vt)
Ā = _add!(Ā, U * _add!(S * FVᵀV̄ * Vt, (S \ V̄') * ImVVᵀ))

return
return Ā
end

#####
Expand All @@ -71,31 +71,31 @@ end

function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real})
F = cholesky(X)
function cholesky_pullback(::Composite{<:Cholesky})
function cholesky_pullback(Ȳ::Composite)
∂X = if F.uplo === 'U'
@thunk(chol_blocked_rev(.U, F.U, 25, true))
@thunk(chol_blocked_rev(Ȳ.U, F.U, 25, true))
else
@thunk(chol_blocked_rev(.L, F.L, 25, false))
@thunk(chol_blocked_rev(Ȳ.L, F.L, 25, false))
end
return (NO_FIELDS, ∂X)
end
return F, cholesky_pullback
end

function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky
function getproperty_cholesky_pullback()
function getproperty_cholesky_pullback(Ȳ)
C = Composite{T}
∂F = @thunk if x === :U
if F.uplo === 'U'
C(U=UpperTriangular(),)
C(U=UpperTriangular(Ȳ),)
else
C(L=LowerTriangular('),)
C(L=LowerTriangular(Ȳ'),)
end
elseif x === :L
if F.uplo === 'L'
C(L=LowerTriangular(),)
C(L=LowerTriangular(Ȳ),)
else
C(U=UpperTriangular('),)
C(U=UpperTriangular(Ȳ'),)
end
end
return NO_FIELDS, ∂F, DoesNotExist()
Expand Down Expand Up @@ -159,14 +159,14 @@ function level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool)
end

"""
chol_unblocked_rev!(::AbstractMatrix, L::AbstractMatrix, upper::Bool)
chol_unblocked_rev!(Ā::AbstractMatrix, L::AbstractMatrix, upper::Bool)

Compute the reverse-mode sensitivities of the Cholesky factorization in an unblocked manner.
If `upper` is `false`, then the sensitivites are computed from and stored in the lower triangle
of `` and `L` respectively. If `upper` is `true` then they are computed and stored in the
upper triangles. If at input `upper` is `false` and `tril() = L̄`, at output
`tril() = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and
`triu() = triu()`, at output `triu() = triu(Σ̄)` where `Σ = UᵀU`.
of `Ā` and `L` respectively. If `upper` is `true` then they are computed and stored in the
upper triangles. If at input `upper` is `false` and `tril(Ā) = L̄`, at output
`tril(Ā) = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and
`triu(Ā) = triu(Ū)`, at output `triu(Ā) = triu(Σ̄)` where `Σ = UᵀU`.
"""
function chol_unblocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, upper::Bool) where T<:Real
n = checksquare(Σ̄)
Expand Down