Skip to content

Commit

Permalink
some AD tests amd updates
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Sep 18, 2023
1 parent 25c06e5 commit 1c1d870
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 16 deletions.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ WignerSymbols = "1,2"
julia = "1.6"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -41,4 +44,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b"

[targets]
test = ["Combinatorics", "HalfIntegers", "LinearAlgebra", "Random", "TensorOperations", "Test", "TestExtras", "WignerSymbols"]
test = ["Combinatorics", "HalfIntegers", "LinearAlgebra", "Random", "TensorOperations", "Test", "TestExtras", "WignerSymbols", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences"]
31 changes: 16 additions & 15 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ end
# Factorizations
# --------------

function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), t::AbstractTensorMap; kwargs...)
function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; kwargs...)
T = eltype(t)

U, S, V, ϵ = tsvd(t; kwargs...)
Expand All @@ -139,7 +139,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), t::AbstractTensorMap; kw
end
end

function tsvd_pullback(ΔUSV)
function tsvd!_pullback(ΔUSV)
dU, dS, dV = ΔUSV

∂t = zero(t)
Expand Down Expand Up @@ -174,7 +174,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), t::AbstractTensorMap; kw
return NoTangent(), ∂t
end

return (U, S, V, ϵ), tsvd_pullback
return (U, S, V, ϵ), tsvd!_pullback
end

function _elementwise_mult(a::AbstractTensorMap, b::AbstractTensorMap)
Expand All @@ -185,9 +185,9 @@ function _elementwise_mult(a::AbstractTensorMap, b::AbstractTensorMap)
return dst
end

function ChainRulesCore.rrule(::typeof(leftorth!), t; alg=QRpos())
function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos())
alg isa TensorKit.QR || alg isa TensorKit.QRpos || error("only QR and QRpos supported")
Q, R = leftorth(t; alg)
Q, R = leftorth(t; alg=alg)

function leftorth_pullback((ΔQ, ΔR))
∂t = similar(t)
Expand All @@ -210,9 +210,9 @@ function ChainRulesCore.rrule(::typeof(leftorth!), t; alg=QRpos())
return (Q, R), leftorth_pullback
end

function ChainRulesCore.rrule(::typeof(rightorth!), tensor; alg=LQpos())
function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos())
alg isa TensorKit.LQ || alg isa TensorKit.LQpos || error("only LQ and LQpos supported")
L, Q = rightorth(tensor; alg)
L, Q = rightorth(t; alg)

function rightorth_pullback((ΔL, ΔQ))
∂t = similar(t)
Expand Down Expand Up @@ -282,7 +282,7 @@ end

lq_pullback(A, L, Q, ::Nothing, ::Nothing) = nothing
function lq_pullback(A, L, Q, ΔL, ΔQ)
M = lq_rqnk(L)
M = lq_rank(L)
N = size(L, 1)

l = view(L, :, 1:M)
Expand Down Expand Up @@ -310,16 +310,17 @@ end

qr_pullback_fullrank(Q, R, ::Nothing, ::Nothing) = nothing
function qr_pullback_fullrank(Q, R, ΔQ, ::Nothing)
b = ΔQ + q * copyltu!(-ΔQ' * Q)
return LinearAlgebra.LAPACK.trtrs!('U', 'N', 'N', r, copy(adjoint(b)))
b = ΔQ + Q * copyltu!(-Q' * ΔQ)
return LinearAlgebra.LAPACK.trtrs!('U', 'N', 'N', R, copy(adjoint(b)))'
end
function qr_pullback_fullrank(Q, R, ::Nothing, ΔR)
b = q * copyltu!(R * ΔR)
return LinearAlgebra.LAPACK.trtrs!('U', 'N', 'N', r, copy(adjoint(b)))
b = Q * copyltu!(R * ΔR')
return LinearAlgebra.LAPACK.trtrs!('U', 'N', 'N', R, copy(adjoint(b)))'
end
function qr_pullback_fullrank(Q, R, ΔQ, ΔR)
b = ΔQ + q * copyltu(R * ΔR' - ΔQ' * Q)
return LinearAlgebra.LAPACK.trtrs!('U', 'N', 'N', r, copy(adjoint(b)))
b = ΔQ + Q * copyltu!(R * ΔR' - ΔQ' * Q)
return b / R'
return LinearAlgebra.LAPACK.trtrs!('U', 'N', 'N', R, copy(adjoint(b)))'
end

lq_pullback_fullrank(L, Q, ::Nothing, ::Nothing) = nothing
Expand All @@ -346,7 +347,7 @@ function qr_rank(r::AbstractMatrix)
end

function lq_rank(l::AbstractMatrix)
Base.require_one_based_indexing(r)
Base.require_one_based_indexing(l)
l₀ = l[1, 1]
for i in axes(l, 2)
abs(l[i, i] / l₀) < 1e-12 && return i - 1
Expand Down
88 changes: 88 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using TestEnv: TestEnv;
TestEnv.activate("TensorKit");
using TensorKit
using TensorOperations
using ChainRulesCore
using ChainRulesTestUtils
using Random
using FiniteDifferences
using Test

## Test utility
# -------------
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::AbstractTensorMap)
return TensorMap(randn, scalartype(x), space(x))
end
function ChainRulesTestUtils.test_approx(actual::AbstractTensorMap, expected::AbstractTensorMap, msg=""; kwargs...)
ChainRulesTestUtils.@test_msg msg isapprox(actual, expected; kwargs...)
end
# function ChainRulesTestUtils.test_approx(actual::NTuple{N}, expected::NTuple{N}, msg="";
# kwargs...) where {N}
# @test all(isapprox.(actual, expected; Ref(kwargs)...))
# end
function FiniteDifferences.to_vec(t::T) where {T<:TensorKit.TrivialTensorMap}
vec, from_vec = to_vec(t.data)
return vec, x -> T(from_vec(x), codomain(t), domain(t))
end
function FiniteDifferences.to_vec(t::AbstractTensorMap)
vec, from_vec′ = to_vec(blocks(t))
function from_vec(x)
blocks′ = from_vec′(x)
t′ = similar(t)
for (c, b) in blocks(t′)
b .= blocks′[c]
end
return t′
end

return vec, from_vec
end
FiniteDifferences.to_vec(t::TensorKit.AdjointTensorMap) = to_vec(copy(t))

ChainRulesCore.rrule(::typeof(TensorKit.tsvd), args...) = ChainRulesCore.rrule(tsvd!, args...)
function ChainRulesCore.rrule(::typeof(TensorKit.leftorth), args...; kwargs...)
return ChainRulesCore.rrule(leftorth!, args...; kwargs...)
end
function ChainRulesCore.rrule(::typeof(TensorKit.rightorth), args...; kwargs...)
return ChainRulesCore.rrule(rightorth!, args...; kwargs...)
end
##

ChainRulesTestUtils.test_method_tables()

Vtr = (ℂ^3, (ℂ^4)', ℂ^5, ℂ^6, (ℂ^7)')
T = Float64

A = TensorMap(randn, T, Vtr[1] Vtr[2] Vtr[3] Vtr[4] Vtr[5])
B = TensorMap(randn, T, space(A))
test_rrule(+, A, B)
test_rrule(-, A, B)
C = TensorMap(randn, T, domain(A), codomain(A))
test_rrule(*, A, C)
α = randn(T)
test_rrule(*, α, A)
test_rrule(*, A, α)

test_rrule(permute, A, ((1, 3, 2), (5, 4)))

D = Tensor(randn, T, ProductSpace{ComplexSpace,0}())
test_rrule(TensorKit.scalar, D)

# LinearAlgebra
# -------------
using LinearAlgebra
for i in 1:3
E = TensorMap(randn, T, (Vtr[1:i]...) (Vtr[1:i]...))
test_rrule(tr, E)
end

test_rrule(adjoint, A)
test_rrule(norm, A, 2)



test_rrule(tsvd, A; atol=1e-6)
test_rrule(leftorth, A; fkwargs=(;alg = TensorKit.QR()))
test_rrule(leftorth, A; fkwargs=(;alg = TensorKit.QRpos()))
test_rrule(rightorth, A; fkwargs=(;alg = TensorKit.LQ()))
test_rrule(rightorth, A; fkwargs=(;alg = TensorKit.LQpos()))

0 comments on commit 1c1d870

Please sign in to comment.