Skip to content

Commit

Permalink
restructure tensoroperations
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Jul 27, 2023
1 parent 8b9bd72 commit 8ae2a65
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 187 deletions.
4 changes: 3 additions & 1 deletion src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ using TupleTools: StaticLength
using Strided

using VectorInterface
using TensorOperations: TensorOperations, @tensor, @tensoropt, @ncon, ncon, IndexTuple, Index2Tuple, linearize

using TensorOperations: TensorOperations, @tensor, @tensoropt, @ncon, ncon
using TensorOperations: IndexTuple, Index2Tuple, linearize, Backend
const TO = TensorOperations

using LRUCache
Expand Down
3 changes: 1 addition & 2 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,7 @@ function ⊗(t1::AbstractTensorMap{S}, t2::AbstractTensorMap{S}) where S
end

# deligne product of tensors
function (t1::AbstractTensorMap{<:ElementarySpace{ℂ}},
t2::AbstractTensorMap{<:ElementarySpace{ℂ}})
function (t1::AbstractTensorMap, t2::AbstractTensorMap)
S1 = spacetype(t1)
I1 = sectortype(S1)
S2 = spacetype(t2)
Expand Down
326 changes: 142 additions & 184 deletions src/tensors/tensoroperations.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,143 @@
# Implement full TensorOperations.jl interface
#----------------------------------------------
TO.tensorstructure(t::AbstractTensorMap) = space(t)
function TO.tensorstructure(t::AbstractTensorMap, iA::Int, conjA::Symbol)
return conjA == :N ? space(t, iA) : conj(space(t, iA))
end

function TO.tensoralloc(ttype::Type{<:AbstractTensorMap}, structure, istemp=false,
backend::Backend...)
M = storagetype(ttype)
return TensorMap(structure) do d
return TO.tensoralloc(M, d, istemp, backend...)
end
end

function TO.tensorfree!(t::AbstractTensorMap, backend::Backend...)
for (c, b) in blocks(t)
TO.tensorfree!(b, backend...)
end
return nothing
end

TO.tensorscalar(t::AbstractTensorMap) = scalar(t)

_canonicalize(p::Index2Tuple{N₁,N₂}, ::AbstractTensorMap{<:IndexSpace,N₁,N₂}) where {N₁,N₂} = p
function _canonicalize(p::Index2Tuple, ::AbstractTensorMap)
p′ = linearize(p)
p₁ = TupleTools.getindices(p′, codomainind(t))
p₂ = TupleTools.getindices(p′, domainind(t))
return (p₁, p₂)
end

# tensoradd!
function TO.tensoradd!(C::AbstractTensorMap{S},
A::AbstractTensorMap{S}, pC::Index2Tuple, conjA::Symbol,
α::Number, β::Number, backend::Backend...) where {S}
if conjA == :N
A′ = A
pC′ = _canonicalize(pC, C)
elseif conjA == :C
A′ = adjoint(A)
pC′ = adjointtensorindices(A, _canonicalize(pA, C))
else
throw(ArgumentError("unknown conjugation flag $conjA"))
end
# TODO: novel syntax for tensoradd!?
# tensoradd!(C, A′, pC′, α, β, backend...)
add!(α, A′, β, C, pC′[1], pC′[2])
return C
end

function TO.tensoradd_type(TC, ::Index2Tuple{N₁,N₂}, A::AbstractTensorMap{S},
::Symbol) where {S,N₁,N₂}
M = similarstoragetype(A, TC)
return tensormaptype(S, N₁, N₂, M)
end

function TO.tensoradd_structure(pC::Index2Tuple{N₁,N₂},
A::AbstractTensorMap{S}, conjA::Symbol) where {S,N₁,N₂}
if conjA == :N
cod = ProductSpace{S,N₁}(space.(Ref(A), pC[1]))
dom = ProductSpace{S,N₂}(dual.(space.(Ref(A), pC[2])))
return dom cod
else
return TO.tensoradd_structure(adjoint(A), adjointtensorindices(A, pC), :N)
end
end

# tensortrace!
function TO.tensortrace!(C::AbstractTensorMap{S}, pC::Index2Tuple,
A::AbstractTensorMap{S}, qA::Index2Tuple, conjA::Symbol,
α::Number, β::Number, backend::Backend...) where {S}
if conjA == :N
A′ = A
pC′ = _canonicalize(pC, C)
qA′ = qA
elseif conjA == :C
A′ = adjoint(A)
pC′ = adjointtensorindices(A, _canonicalize(pC, C))
qA′ = adjointtensorindices(A, qA)
else
throw(ArgumentError("unknown conjugation flag $conjA"))
end
# TODO: novel syntax for tensortrace?
# tensortrace!(C, pC′, A′, qA′, α, β, backend...)
trace!(α, A′, β, C, pC′[1], pC′[2], qA′[1], qA′[2])
return C
end

# tensorcontract!
function TO.tensorcontract!(C::AbstractTensorMap{S,N₁,N₂}, pC::Index2Tuple,
A::AbstractTensorMap{S}, pA::Index2Tuple, conjA::Symbol,
B::AbstractTensorMap{S}, pB::Index2Tuple, conjB::Symbol,
α::Number, β::Number, backend::Backend...) where {S,N₁,N₂}
pC′ = _canonicalize(pC, C)
if conjA == :N
A′ = A
pA′ = pA
elseif conjA == :C
A′ = A'
pA′ = adjointtensorindices(A, pA)
else
throw(ArgumentError("unknown conjugation flag $conjA"))
end
if conjB == :N
B′ = B
pB′ = pB
elseif conjB == :C
B′ = B'
pB′ = adjointtensorindices(B, pB)
else
throw(ArgumentError("unknown conjugation flag $conjB"))
end
# TODO: novel syntax for tensorcontract?
# tensorcontract!(C, pC′, A′, pA′, B′, pB′, α, β, backend...)
contract!(α, A′, B′, β, C, pA′[1], pA′[2], pB′[2], pB′[1], pC′[1], pC′[2])
return C
end

function TO.tensorcontract_type(TC, ::Index2Tuple{N₁,N₂},
A::AbstractTensorMap{S}, pA, conjA,
B::AbstractTensorMap{S}, pB, conjB) where {S,N₁,N₂}
M = similarstoragetype(A, TC)
M == similarstoragetype(B, TC) || throw(ArgumentError("incompatible storage types"))
return tensormaptype(S, N₁, N₂, M)
end

function TO.tensorcontract_structure(pC::Index2Tuple{N₁,N₂},
A::AbstractTensorMap{S}, pA::Index2Tuple, conjA,
B::AbstractTensorMap{S}, pB::Index2Tuple, conjB) where {S,N₁,N₂}

spaces1 = TO.flag2op(conjA).(space.(Ref(A), pA[1]))
spaces2 = TO.flag2op(conjB).(space.(Ref(B), pB[2]))
spaces = (spaces1..., spaces2...)
cod = ProductSpace{S,N₁}(getindex.(Ref(spaces), pC[1]))
dom = ProductSpace{S,N₂}(dual.(getindex.(Ref(spaces), pC[2])))
return dom cod
end

# Actual implementations
function cached_permute(sym::Symbol, t::TensorMap{S},
p1::IndexTuple{N₁}, p2::IndexTuple{N₂}=();
copy::Bool=false) where {S,N₁,N₂}
Expand Down Expand Up @@ -202,7 +342,8 @@ function trace!(α, tsrc::AbstractTensorMap{S}, β, tdst::AbstractTensorMap{S,N
coeff *= twist(g1.uncoupled[i])
end
end
TO.tensortrace!(tdst[f₁′′, f₂′′], (p1, p2), tsrc[f₁, f₂], (q1, q2), :N, α*coeff, true)
TO.tensortrace!(tdst[f₁′′, f₂′′], (p1, p2), tsrc[f₁, f₂], (q1, q2), :N,
α * coeff, true)
end
end
end
Expand Down Expand Up @@ -323,186 +464,3 @@ function scalar(t::AbstractTensorMap{S}) where {S<:IndexSpace}
return dim(codomain(t)) == dim(domain(t)) == 1 ?
first(blocks(t))[2][1, 1] : throw(DimensionMismatch())
end

TO.tensorscalar(t::AbstractTensorMap) = scalar(t)

function TO.tensoradd!(tdst::AbstractTensorMap{S},
tsrc::AbstractTensorMap{S}, pA::Index2Tuple,
conjA::Symbol, α::Number, β::Number) where {S}
if conjA == :N
p = linearize(pA)
pl = TupleTools.getindices(p, codomainind(tdst))
pr = TupleTools.getindices(p, domainind(tdst))
add!(α, tsrc, β, tdst, pl, pr)
else
p = adjointtensorindices(tsrc, linearize(pA))
pl = TupleTools.getindices(p, codomainind(tdst))
pr = TupleTools.getindices(p, domainind(tdst))
add!(α, adjoint(tsrc), β, tdst, pl, pr)
end
return tdst
end

function TO.tensortrace!(tdst::AbstractTensorMap{S},
pC::Index2Tuple, tsrc::AbstractTensorMap{S},
pA::Index2Tuple, conjA::Symbol, α::Number,
β::Number) where {S}
if conjA == :N
p = linearize(pC)
pl = TupleTools.getindices(p, codomainind(tdst))
pr = TupleTools.getindices(p, domainind(tdst))
trace!(α, tsrc, β, tdst, pl, pr, pA[1], pA[2])
else
p = adjointtensorindices(tsrc, linearize(pC))
pl = TupleTools.getindices(p, codomainind(tdst))
pr = TupleTools.getindices(p, domainind(tdst))
q1 = adjointtensorindices(tsrc, pA[1])
q2 = adjointtensorindices(tsrc, pA[2])
trace!(α, adjoint(tsrc), β, tdst, pl, pr, q1, q2)
end
return tdst
end

# # function TO.similarstructure_from_indices(T::Type, p1::IndexTuple, p2::IndexTuple,
# # A::AbstractTensorMap, CA::Symbol=:N)
# # if CA == :N
# # _similarstructure_from_indices(T, p1, p2, A)
# # else
# # p1 = adjointtensorindices(A, p1)
# # p2 = adjointtensorindices(A, p2)
# # _similarstructure_from_indices(T, p1, p2, adjoint(A))
# # end
# # end

# # function TO.similarstructure_from_indices(T::Type, poA::IndexTuple, poB::IndexTuple,
# # p1::IndexTuple, p2::IndexTuple,
# # A::AbstractTensorMap, B::AbstractTensorMap,
# # CA::Symbol=:N, CB::Symbol=:N)
# # if CA == :N && CB == :N
# # _similarstructure_from_indices(T, poA, poB, p1, p2, A, B)
# # elseif CA == :C && CB == :N
# # poA = adjointtensorindices(A, poA)
# # _similarstructure_from_indices(T, poA, poB, p1, p2, adjoint(A), B)
# # elseif CA == :N && CB == :C
# # poB = adjointtensorindices(B, poB)
# # _similarstructure_from_indices(T, poA, poB, p1, p2, A, adjoint(B))
# # else
# # poA = adjointtensorindices(A, poA)
# # poB = adjointtensorindices(B, poB)
# # _similarstructure_from_indices(T, poA, poB, p1, p2, adjoint(A), adjoint(B))
# # end
# # end

# function _similarstructure_from_indices(::Type{T}, p1::IndexTuple{N₁}, p2::IndexTuple{N₂},
# t::AbstractTensorMap{S}) where {T,S<:IndexSpace,N₁,
# N₂}
# cod = ProductSpace{S,N₁}(space.(Ref(t), p1))
# dom = ProductSpace{S,N₂}(dual.(space.(Ref(t), p2)))
# return dom → cod
# end
# function _similarstructure_from_indices(::Type{T}, oindA::IndexTuple, oindB::IndexTuple,
# p1::IndexTuple{N₁}, p2::IndexTuple{N₂},
# tA::AbstractTensorMap{S},
# tB::AbstractTensorMap{S}) where {T,S<:IndexSpace,N₁,
# N₂}
# spaces = (space.(Ref(tA), oindA)..., space.(Ref(tB), oindB)...)
# cod = ProductSpace{S,N₁}(getindex.(Ref(spaces), p1))
# dom = ProductSpace{S,N₂}(dual.(getindex.(Ref(spaces), p2)))
# return dom → cod
# end

function TO.tensorcontract!(C::AbstractTensorMap{S,N₁,N₂},
pC::Index2Tuple,
A::AbstractTensorMap{S}, pA::Index2Tuple,
conjA::Symbol,
B::AbstractTensorMap{S}, pB::Index2Tuple,
conjB::Symbol,
α::Number, β::Number) where {S,N₁,N₂}
p = linearize(pC)
pl = ntuple(n -> p[n], N₁)
pr = ntuple(n -> p[N₁ + n], N₂)

if conjA == :C
pA = adjointtensorindices(A, pA)
A = A'
elseif conjA != :N
throw(ArgumentError("unknown conjugation flag $conjA"))
end

if conjB == :C
pB = adjointtensorindices(B, pB)
B = B'
elseif conjB != :N
throw(ArgumentError("unknown conjugation flag $conjB"))
end

contract!(α, A, B, β, C, pA[1], pA[2], pB[2], pB[1], pl, pr)
return C

# if conjA == :N && conjB == :N
# contract!(α, tA, tB, β, tC, pA[1], pA[2], pB[2], pB[1], pl, pr)
# elseif conjA == :N && conjB == :C
# pB[2] = adjointtensorindices(tB, pB[2])
# pB[1] = adjointtensorindices(tB, pB[1])
# contract!(α, tA, tB', β, tC, pA[1], pA[2], pB[2], pB[1], pl, pr)
# elseif conjA == :C && conjB == :N
# pA[1] = adjointtensorindices(tA, pA[1])
# pA[2] = adjointtensorindices(tA, pA[2])
# contract!(α, tA', tB, β, tC, pA[1], pA[2], pB[2], pB[1], pl, pr)
# elseif conjA == :C && conjB == :C
# pA[1] = adjointtensorindices(tA, pA[1])
# pA[2] = adjointtensorindices(tA, pA[2])
# pB[2] = adjointtensorindices(tB, pB[2])
# pB[1] = adjointtensorindices(tB, pB[1])
# contract!(α, tA', tB', β, tC, pA[1], pA[2], pB[2], pB[1], pl, pr)
# else
# error("unknown conjugation flags: $conjA and $conjB")
# end
# return tC
end

function TO.tensoradd_type(TC, ::Index2Tuple{N₁,N₂}, ::AbstractTensorMap{S},
::Symbol) where {S,N₁,N₂}
return tensormaptype(S, N₁, N₂, TC)
end

function TO.tensoradd_structure(pC::Index2Tuple{N₁,N₂}, A::AbstractTensorMap{S},
conjA::Symbol) where {S,N₁,N₂}
if conjA == :N
cod = ProductSpace{S,N₁}(space.(Ref(A), pC[1]))
dom = ProductSpace{S,N₂}(dual.(space.(Ref(A), pC[2])))
return dom cod
else
return TO.tensoradd_structure(adjoint(A), adjointtensorindices(A, pC), :N)
end
end

function TO.tensorcontract_type(TC, ::Index2Tuple{N₁,N₂},
::AbstractTensorMap{S}, pA, conjA,
::AbstractTensorMap{S}, pB, conjB) where {S,N₁,N₂}
return tensormaptype(S, N₁, N₂, TC)
end

function TO.tensorcontract_structure(pC::Index2Tuple{N₁,N₂},
A::AbstractTensorMap{S}, pA::Index2Tuple,
conjA, B::AbstractTensorMap,
pB::Index2Tuple, conjB) where {S,N₁,N₂}
spaces1 = conjA == :N ? space.(Ref(A), pA[1]) :
space.(Ref(A'), adjointtensorindices(A, pA[1]))
spaces2 = conjB == :N ? space.(Ref(B), pB[2]) :
space.(Ref(B'), adjointtensorindices(B, pB[2]))
spaces = (spaces1..., spaces2...)

cod = ProductSpace{S,N₁}(getindex.(Ref(spaces), pC[1]))
dom = ProductSpace{S,N₂}(dual.(getindex.(Ref(spaces), pC[2])))
return dom cod
end

TO.tensorstructure(t::AbstractTensorMap) = space(t)
function TO.tensorstructure(::AbstractTensorMap, iA::Int, conjA::Symbol)
return conjA == :N ? space(A, iA) : space(A', iA)
end

function TO.tensoralloc(ttype::Type{<:AbstractTensorMap}, structure, istemp=false)
return TensorMap(undef, scalartype(ttype), structure)
end

0 comments on commit 8ae2a65

Please sign in to comment.