Skip to content

Commit

Permalink
Add support for eigvals, svdvals and diag, diagm. (#130)
Browse files Browse the repository at this point in the history
* Add eigvals/eigvals!/eigvecs

* Add svdvals/svdvals!

* Add rrules eigvals/svdvals

* Simplify non-inplace implementations

* Remove eigvecs

The easiest implementation would just be a wrapper around `eigen`, and
this should await some more experiments with the eigenvector gauge
fixing.

* Add tests

* Add support for diag/diagm

* Add tests AD rules

* Switch eigvals to return complex values

* Re-enable tests

* Remove specializations for TrivialTensorMap
  • Loading branch information
lkdvos committed Jun 26, 2024
1 parent 9787d05 commit 53929cd
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 1 deletion.
29 changes: 29 additions & 0 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,20 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
return (U′, Σ′, V′, ϵ), tsvd!_pullback
end

function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap)
U, S, V = tsvd(t)
s = diag(S)
project_t = ProjectTo(t)

function svdvals_pullback(Δs′)
Δs = unthunk(Δs′)
ΔS = diagm(codomain(S), domain(S), Δs)
return NoTangent(), project_t(U * ΔS * V)
end

return s, svdvals_pullback
end

function ChainRulesCore.rrule(::typeof(TensorKit.eig!), t::AbstractTensorMap; kwargs...)
D, V = eig(t; kwargs...)

Expand Down Expand Up @@ -266,6 +280,21 @@ function ChainRulesCore.rrule(::typeof(TensorKit.eigh!), t::AbstractTensorMap; k
return (D, V), eigh!_pullback
end

function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTensorMap;
sortby=nothing, kwargs...)
@assert sortby === nothing "only `sortby=nothing` is supported"
(D, _), eig_pullback = rrule(TensorKit.eig!, t; kwargs...)
d = diag(D)
project_t = ProjectTo(t)
function eigvals_pullback(Δd′)
Δd = unthunk(Δd′)
ΔD = diagm(codomain(D), domain(D), Δd)
return NoTangent(), project_t(eig_pullback((ΔD, ZeroTangent()))[2])
end

return d, eigvals_pullback
end

function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos())
alg isa TensorKit.QR || alg isa TensorKit.QRpos ||
error("only `alg=QR()` and `alg=QRpos()` are supported")
Expand Down
13 changes: 13 additions & 0 deletions src/tensors/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ function tsvd(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
return tsvd!(permute(t, (p₁, p₂); copy=true); kwargs...)
end

LinearAlgebra.svdvals(t::AbstractTensorMap) = LinearAlgebra.svdvals!(copy(t))
function LinearAlgebra.svdvals!(t::AbstractTensorMap)
return SectorDict(c => LinearAlgebra.svdvals!(b) for (c, b) in blocks(t))
end

"""
leftorth(t::AbstractTensorMap, (leftind, rightind)::Index2Tuple;
alg::OrthogonalFactorizationAlgorithm = QRpos()) -> Q, R
Expand Down Expand Up @@ -168,6 +173,14 @@ function LinearAlgebra.eigen(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple;
return eigen!(permute(t, (p₁, p₂); copy=true); kwargs...)
end

function LinearAlgebra.eigvals(t::AbstractTensorMap; kwargs...)
return LinearAlgebra.eigvals!(copy(t); kwargs...)
end
function LinearAlgebra.eigvals!(t::AbstractTensorMap; kwargs...)
return SectorDict(c => complex(LinearAlgebra.eigvals!(b; kwargs...))
for (c, b) in blocks(t))
end

"""
eig(t::AbstractTensor, (leftind, rightind)::Index2Tuple; kwargs...) -> D, V
Expand Down
13 changes: 13 additions & 0 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,19 @@ function isometry(::Type{A},
return t
end

# Diagonal tensors
# ----------------
# TODO: consider adding a specialised DiagonalTensorMap type
function LinearAlgebra.diag(t::AbstractTensorMap)
return SectorDict(c => LinearAlgebra.diag(b) for (c, b) in blocks(t))
end
function LinearAlgebra.diagm(codom::VectorSpace, dom::VectorSpace, v::SectorDict)
return TensorMap(SectorDict(c => LinearAlgebra.diagm(blockdim(codom, c),
blockdim(dom, c), b)
for (c, b) in v), codom dom)
end
LinearAlgebra.isdiag(t::AbstractTensorMap) = all(LinearAlgebra.isdiag, values(blocks(t)))

# In-place methods
#------------------
# Wrapping the blocks in a StridedView enables multithreading if JULIA_NUM_THREADS > 1
Expand Down
38 changes: 38 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,27 @@ function ChainRulesTestUtils.test_approx(actual::AbstractTensorMap,
end
end

# make sure that norms are computed correctly:
function FiniteDifferences.to_vec(t::TensorKit.SectorDict)
T = scalartype(valtype(t))
vec = mapreduce(vcat, t; init=T[]) do (c, b)
return reshape(b, :) .* sqrt(dim(c))
end
vec_real = T <: Real ? vec : collect(reinterpret(real(T), vec))

function from_vec(x_real)
x = T <: Real ? x_real : reinterpret(T, x_real)
ctr = 0
return TensorKit.SectorDict(c => (n = length(b);
b′ = reshape(view(x, ctr .+ (1:n)), size(b)) ./
sqrt(dim(c));
ctr += n;
b′)
for (c, b) in t)
end
return vec_real, from_vec
end

function _randomize!(a::TensorMap)
for b in values(blocks(a))
copyto!(b, randn(size(b)))
Expand All @@ -43,12 +64,18 @@ end
function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), args...; kwargs...)
return ChainRulesCore.rrule(tsvd!, args...; kwargs...)
end
function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals), args...; kwargs...)
return ChainRulesCore.rrule(svdvals!, args...; kwargs...)
end
function ChainRulesCore.rrule(::typeof(TensorKit.eig), args...; kwargs...)
return ChainRulesCore.rrule(eig!, args...; kwargs...)
end
function ChainRulesCore.rrule(::typeof(TensorKit.eigh), args...; kwargs...)
return ChainRulesCore.rrule(eigh!, args...; kwargs...)
end
function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals), args...; kwargs...)
return ChainRulesCore.rrule(eigvals!, args...; kwargs...)
end
function ChainRulesCore.rrule(::typeof(TensorKit.leftorth), args...; kwargs...)
return ChainRulesCore.rrule(leftorth!, args...; kwargs...)
end
Expand Down Expand Up @@ -330,5 +357,16 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0),
fkwargs=(; trunc=truncdim(2 * dim(c))))
end

let D = LinearAlgebra.eigvals(C)
ΔD = diag(TensorMap(randn, complex(scalartype(C)), space(C)))
test_rrule(LinearAlgebra.eigvals, C; atol, output_tangent=ΔD,
fkwargs=(; sortby=nothing))
end

let S = LinearAlgebra.svdvals(C)
ΔS = diag(TensorMap(randn, real(scalartype(C)), space(C)))
test_rrule(LinearAlgebra.svdvals, C; atol, output_tangent=ΔS)
end
end
end
23 changes: 22 additions & 1 deletion test/tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ for V in spacelist
@test Base.promote_typeof(t, tc) == typeof(tc)
@test Base.promote_typeof(tc, t) == typeof(tc + t)
end
@timedtestset "diag/diagm" begin
W = V1 V2 V3 V4 V5
t = TensorMap(randn, ComplexF64, W)
d = LinearAlgebra.diag(t)
D = LinearAlgebra.diagm(codomain(t), domain(t), d)
@test LinearAlgebra.isdiag(D)
@test LinearAlgebra.diag(D) == d
end
@timedtestset "Permutations: test via inner product invariance" begin
W = V1 V2 V3 V4 V5
t = Tensor(rand, ComplexF64, W)
Expand Down Expand Up @@ -408,7 +416,14 @@ for V in spacelist
@test UdU one(UdU)
VVd = V * V'
@test VVd one(VVd)
@test U * S * V permute(t, ((3, 4, 2), (1, 5)))
t2 = permute(t, ((3, 4, 2), (1, 5)))
@test U * S * V t2

s = LinearAlgebra.svdvals(t2)
s′ = LinearAlgebra.diag(S)
for (c, b) in s
@test b s′[c]
end
end
end
@testset "empty tensor" begin
Expand Down Expand Up @@ -458,6 +473,12 @@ for V in spacelist
t2 = permute(t, ((1, 3), (2, 4)))
@test t2 * V V * D

d = LinearAlgebra.eigvals(t2; sortby=nothing)
d′ = LinearAlgebra.diag(D)
for (c, b) in d
@test b d′[c]
end

# Somehow moving these test before the previous one gives rise to errors
# with T=Float32 on x86 platforms. Is this an OpenBLAS issue?
VdV = V' * V
Expand Down

0 comments on commit 53929cd

Please sign in to comment.