Skip to content

Commit

Permalink
Update AD rules, clean up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Sep 25, 2023
1 parent 1c1d870 commit a917787
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 226 deletions.
316 changes: 138 additions & 178 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ function _repartition(p::Union{IndexTuple,Index2Tuple},
return _repartition(p, N₁)
end

TensorKit.block(t::ZeroTangent, c::Sector) = t

# Constructors
# ------------

Expand All @@ -36,20 +38,19 @@ end

function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...)
function TensorMap_pullback(Δt)
∂d = @thunk(convert(Array, Δt))
∂d = convert(Array, Δt)
return NoTangent(), ∂d, fill(NoTangent(), length(args))...
end
return TensorMap(d, args...), TensorMap_pullback
end

function ChainRulesCore.rrule(::typeof(convert), ::Type{<:Array}, t::AbstractTensorMap)
function convert_pullback(Δt)
spacetype(t) <: ComplexSpace ||
error("currently only implemented or ComplexSpace spacetypes")
∂d = TensorMap(Δt, codomain(t), domain(t))
return NoTangent(), NoTangent(), ∂d
function ChainRulesCore.rrule(::typeof(convert), T::Type{<:Array}, t::AbstractTensorMap)
A = convert(T, t)
function convert_pullback(ΔA)
∂t = TensorMap(ΔA, codomain(t), domain(t))
return NoTangent(), NoTangent(), ∂t
end
return convert(Array, t), convert_pullback
return A, convert_pullback
end

# Base Linear Algebra
Expand Down Expand Up @@ -122,53 +123,14 @@ end
# --------------

function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; kwargs...)
T = eltype(t)

U, S, V, ϵ = tsvd(t; kwargs...)

F = similar(S)
for (k, dst) in blocks(F)
src = blocks(S)[k]
@inbounds for i in axes(dst, 1), j in axes(dst, 2)
if i == j
dst[i, j] = zero(eltype(S))
else
sᵢ, sⱼ = src[i, i], src[j, j]
dst[i, j] = 1 / (abs(sⱼ - sᵢ) < 1e-12 ? 1e-12 : sⱼ^2 - sᵢ^2)
end
end
end

function tsvd!_pullback(ΔUSV)
dU, dS, dV = ΔUSV

∂t = zero(t)
#A_s bar term
if dS != ChainRulesCore.ZeroTangent()
∂t += U * _elementwise_mult(dS, one(dS)) * V
end
#A_uo bar term
if dU != ChainRulesCore.ZeroTangent()
J = _elementwise_mult((U' * dU), F)
∂t += U * (J + J') * S * V
end
#A_vo bar term
if dV != ChainRulesCore.ZeroTangent()
VpdV = V * dV'
K = _elementwise_mult(VpdV, F)
∂t += U * S * (K + K') * V
end
#A_d bar term, only relevant if matrix is complex
if dV != ChainRulesCore.ZeroTangent() && T <: Complex
L = _elementwise_mult(VpdV, one(F))
∂t += 1 / 2 * U * pinv(S) * (L' - L) * V
end

if codomain(t) != domain(t)
pru = U * U'
prv = V' * V
∂t += (one(pru) - pru) * dU * pinv(S) * V
∂t += U * pinv(S) * dV * (one(prv) - prv)
function tsvd!_pullback((ΔU, ΔS, ΔV, Δϵ))
∂t = similar(t)
for (c, b) in blocks(∂t)
copyto!(b,
svd_rev(block(U, c), block(S, c), block(V, c),
block(ΔU, c), block(ΔS, c), block(ΔV, c)))
end

return NoTangent(), ∂t
Expand All @@ -177,62 +139,83 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; k
return (U, S, V, ϵ), tsvd!_pullback
end

function _elementwise_mult(a::AbstractTensorMap, b::AbstractTensorMap)
dst = similar(a)
for (k, block) in blocks(dst)
copyto!(block, blocks(a)[k] .* blocks(b)[k])
end
return dst
end
"""
svd_rev(U, S, V, ΔU, ΔS, ΔV; tol=eps(real(scalartype(Σ)))^(4 / 5))
function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos())
alg isa TensorKit.QR || alg isa TensorKit.QRpos || error("only QR and QRpos supported")
Q, R = leftorth(t; alg=alg)
Implements the following back propagation formula for the SVD:
function leftorth_pullback((ΔQ, ΔR))
∂t = similar(t)
ΔR = ΔR isa ZeroTangent ? zero(R) : ΔR
ΔQ = ΔQ isa ZeroTangent ? zero(Q) : ΔQ
```math
ΔA = UΔSV' + U(J + J')SV' + US(K + K')V' + \\frac{1}{2}US^{-1}(L' - L)V'\\
J = F ∘ (U'ΔU), \\qquad K = F ∘ (V'ΔV), \\qquad L = I ∘ (V'ΔV)\\
F_{i ≠ j} = \\frac{1}{s_j^2 - s_i^2}\\
F_{ii} = 0
```
if sectortype(t) === Trivial
copyto!(∂t.data, qr_pullback(t.data, Q.data, R.data, ΔQ.data, ΔR.data))
else
for (c, b) in blocks(∂t)
copyto!(b,
qr_pullback(block(t, c), block(Q, c), block(R, c),
block(ΔQ, c), block(ΔR, c)))
end
end
# References
return NoTangent(), ∂t
Wan, Zhou-Quan, and Shi-Xin Zhang. 2019. “Automatic Differentiation for Complex Valued SVD.” https://doi.org/10.48550/ARXIV.1909.02659.
"""
function svd_rev(U::AbstractMatrix, S::AbstractMatrix, V::AbstractMatrix, ΔU, ΔS, ΔV;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(scalartype(S))^(3 / 4))
# project out gauge invariance dependence?
# ΔU * U + ΔV * V' = 0

tol = atol > 0 ? atol : rtol * S[1, 1]
F = _invert_S²(S, tol)
S⁻¹ = pinv(S; atol=tol)

term = Diagonal(diag(ΔS))

J = F .* (U' * ΔU)
term += (J + J') * S
VΔV = (V * ΔV')
K = F .* VΔV
term += S * (K + K')

if scalartype(U) <: Complex && !(ΔV isa ZeroTangent) && !(ΔU isa ZeroTangent)
L = LinearAlgebra.Diagonal(diag(VΔV))
term += 0.5 * S⁻¹ * (L' - L)
end

return (Q, R), leftorth_pullback
end
ΔA = U * term * V

function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos())
alg isa TensorKit.LQ || alg isa TensorKit.LQpos || error("only LQ and LQpos supported")
L, Q = rightorth(t; alg)
if size(U, 1) != size(V, 2)
UUd = U * U'
VdV = V' * V
ΔA += (one(UUd) - UUd) * ΔU * S⁻¹ * V + U * S⁻¹ * ΔV * (one(VdV) - VdV)
end

function rightorth_pullback((ΔL, ΔQ))
∂t = similar(t)
ΔL = ΔL isa ZeroTangent ? zero(L) : ΔL
ΔQ = ΔQ isa ZeroTangent ? zero(Q) : ΔQ
return ΔA
end

if sectortype(t) === Trivial
copyto!(∂t.data, lq_pullback(t.data, Q.data, L.data, ΔQ.data, ΔL.data))
function _invert_S²(S::AbstractMatrix{T}, tol::Real) where {T<:Real}
F = similar(S)
@inbounds for i in axes(F, 1), j in axes(F, 2)
F[i, j] = if i == j
zero(T)
else
for (c, b) in blocks(∂t)
copyto!(b,
lq_pullback(block(t, c), block(Q, c), block(L, c),
block(ΔQ, c), block(ΔL, c)))
end
sᵢ, sⱼ = S[i, i], S[j, j]
1 / (abs(sⱼ - sᵢ) < tol ? tol : sⱼ^2 - sᵢ^2)
end

return NoTangent(), ∂t
end
return F
end

function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos())
alg isa TensorKit.QR || alg isa TensorKit.QRpos || error("only QR and QRpos supported")
Q, R = leftorth(t; alg)
leftorth!_pullback((ΔQ, ΔR)) = NoTangent(), qr_pullback!(similar(t), t, Q, R, ΔQ, ΔR)
leftorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = ZeroTangent()
return (Q, R), leftorth!_pullback
end

return (L, Q), rightorth_pullback
function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos())
alg isa TensorKit.LQ || alg isa TensorKit.LQpos || error("only LQ and LQpos supported")
L, Q = rightorth(t; alg)
rightorth!_pullback((ΔL, ΔQ)) = NoTangent(), lq_pullback!(similar(t), t, L, Q, ΔL, ΔQ)
rightorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = ZeroTangent()
return (L, Q), rightorth!_pullback
end

"""
Expand All @@ -251,108 +234,85 @@ function copyltu!(A::AbstractMatrix)
return A
end

qr_pullback(A, Q, R, ::Nothing, ::Nothing) = nothing
function qr_pullback(A, Q, R, ΔQ, ΔR)
M = qr_rank(R)
N = size(R, 2)

q = view(Q, :, 1:M)
Δq = isnothing(ΔQ) ? nothing : view(ΔQ, :, 1:M)

r = view(R, 1:M, :)
Δr = isnothing(ΔR) ? nothing : view(ΔR, 1:M, :)

N == M && return qr_pullback_fullrank(q, r, Δq, Δr)
function qr_pullback!(ΔA::AbstractTensorMap{S}, t::AbstractTensorMap{S},
Q::AbstractTensorMap{S}, R::AbstractTensorMap{S}, ΔQ, ΔR) where {S}
for (c, b) in blocks(ΔA)
qr_pullback!(b, block(t, c), block(Q, c), block(R, c), block(ΔQ, c), block(ΔR, c))
end
return ΔA
end

B = view(A, :, (M + 1):N)
U = view(r, :, 1:M)
function qr_pullback!(ΔA, A, Q::M, R::M, ΔQ, ΔR) where {M<:AbstractMatrix}
m = qr_rank(R)
n = size(R, 2)

if !isnothing(ΔR)
ΔD = view(Δr, :, (M + 1):N)
ΔA = qr_pullback_fullrank(q, U, !isnothing(Δq) ? Δq + B * ΔD' : B * ΔD',
view(Δr, :, 1:M))
ΔB = q * ΔD
if n == m # full rank
q = view(Q, :, 1:m)
Δq = view(ΔQ, :, 1:m)
r = view(R, 1:m, :)
Δr = view(ΔR, 1:m, :)
ΔA = qr_pullback_fullrank!(ΔA, q, r, Δq, Δr)
else
ΔA = qr_pullback_fullrank(q, U, Δq, nothing)
ΔB = zero(B)
q = view(Q, :, 1:m)
Δq = view(ΔQ, :, 1:m) + view(A, :, (m + 1):n) * view(ΔR, :, (m + 1):n)'
r = view(R, 1:m, 1:m)
Δr = view(ΔR, 1:m, 1:m)

qr_pullback_fullrank!(view(ΔA, :, 1:m), q, r, Δq, Δr)
ΔA[:, (m + 1):n] = q * view(ΔR, :, (m + 1):n)
end

return hcat(ΔA, ΔB)
return ΔA
end

lq_pullback(A, L, Q, ::Nothing, ::Nothing) = nothing
function lq_pullback(A, L, Q, ΔL, ΔQ)
M = lq_rank(L)
N = size(L, 1)

l = view(L, :, 1:M)
Δl = isnothing(ΔL) ? nothing : view(ΔL, :, 1:M)
q = view(Q, 1:M, :)
Δq = isnothing(ΔQ) ? nothing : view(ΔQ, 1:M, :)
function qr_pullback_fullrank!(ΔA, Q, R, ΔQ, ΔR)
b = ΔQ + Q * copyltu!(R * ΔR' - ΔQ' * Q)
return adjoint!(ΔA, LinearAlgebra.LAPACK.trtrs!('U', 'N', 'N', R, copy(adjoint(b))))
end

N == M && return lq_pullback_fullrank(l, q, Δl, Δq)
function lq_pullback!(ΔA::AbstractTensorMap{S}, t::AbstractTensorMap{S},
L::AbstractTensorMap{S}, Q::AbstractTensorMap{S}, ΔL, ΔQ) where {S}
for (c, b) in blocks(ΔA)
lq_pullback!(b, block(t, c), block(L, c), block(Q, c), block(ΔL, c), block(ΔQ, c))
end
return ΔA
end

B = view(A, (M + 1):N, :)
U = view(l, 1:M, :)
function lq_pullback!(ΔA, A, L::M, Q::M, ΔL, ΔQ) where {M<:AbstractMatrix}
m = qr_rank(L)
n = size(L, 1)

if !isnothing(ΔL)
ΔD = view(Δl, (M + 1):N, :)
ΔA = lq_pullback_fullrank(U, q, view(Δl, 1:M, :),
!isnothing(Δq) ? Δq + ΔD' * B : ΔD' * B)
ΔB = ΔD * q
if n == m # full rank
l = view(L, :, 1:m)
Δl = view(ΔL, :, 1:m)
q = view(Q, 1:m, :)
Δq = view(ΔQ, 1:m, :)
ΔA = lq_pullback_fullrank!(ΔA, l, q, Δl, Δq)
else
ΔA = lq_pullback_fullrank(U, q, nothing, Δq)
ΔB = zero(B)
end
l = view(L, 1:m, 1:m)
Δl = view(ΔL, 1:m, 1:m)
q = view(Q, 1:m, :)
Δq = view(ΔQ, 1:m, :) + view(ΔL, (m + 1):n, 1:m)' * view(A, (m + 1):n, :)

return vcat(ΔA, ΔB)
end
lq_pullback_fullrank!(view(ΔA, 1:m, :), l, q, Δl, Δq)
ΔA[(m + 1):n, :] = view(ΔL, (m + 1):n, :) * q
end

qr_pullback_fullrank(Q, R, ::Nothing, ::Nothing) = nothing
function qr_pullback_fullrank(Q, R, ΔQ, ::Nothing)
b = ΔQ + Q * copyltu!(-Q' * ΔQ)
return LinearAlgebra.LAPACK.trtrs!('U', 'N', 'N', R, copy(adjoint(b)))'
end
function qr_pullback_fullrank(Q, R, ::Nothing, ΔR)
b = Q * copyltu!(R * ΔR')
return LinearAlgebra.LAPACK.trtrs!('U', 'N', 'N', R, copy(adjoint(b)))'
end
function qr_pullback_fullrank(Q, R, ΔQ, ΔR)
b = ΔQ + Q * copyltu!(R * ΔR' - ΔQ' * Q)
return b / R'
return LinearAlgebra.LAPACK.trtrs!('U', 'N', 'N', R, copy(adjoint(b)))'
return ΔA
end

lq_pullback_fullrank(L, Q, ::Nothing, ::Nothing) = nothing
function lq_pullback_fullrank(L, Q, ΔL, ::Nothing)
b = copyltu!(L' * ΔL) * Q
return LinearAlgebra.LAPACK.trtrs!('L', 'N', 'N', L, b)
end
function lq_pullback_fullrank(L, Q, ::Nothing, ΔQ)
b = copyltu!(-ΔQ * Q') + ΔQ
return LinearAlgebra.LAPACK.trtrs!('L', 'N', 'N', L, b)
end
function lq_pullback_fullrank(L, Q, ΔL, ΔQ)
b = copyltu!(L' * ΔL - ΔQ * Q') + ΔQ
return LinearAlgebra.LAPACK.trtrs!('L', 'N', 'N', L, b)
function lq_pullback_fullrank!(ΔA, L, Q, ΔL, ΔQ)
mul!(ΔA, copyltu!(L' * ΔL - ΔQ * Q'), Q)
axpy!(true, ΔQ, ΔA)
return LinearAlgebra.LAPACK.trtrs!('L', 'C', 'N', L, ΔA)
end

function qr_rank(r::AbstractMatrix)
Base.require_one_based_indexing(r)
m, n = size(r)
r₀ = r[1, 1]
for i in axes(r, 1)
abs(r[i, i] / r₀) < 1e-12 && return i - 1
end
return size(r, 1)
end

function lq_rank(l::AbstractMatrix)
Base.require_one_based_indexing(l)
l₀ = l[1, 1]
for i in axes(l, 2)
abs(l[i, i] / l₀) < 1e-12 && return i - 1
end
return size(l, 2)
i = findfirst(x -> abs(x / r₀) < 1e-12, diag(r))
return isnothing(i) ? min(m, n) : i - 1
end

function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap)
Expand Down
Loading

0 comments on commit a917787

Please sign in to comment.