diff --git a/src/tensors/planar.jl b/src/tensors/planar.jl index 346e75513..8c8c43c58 100644 --- a/src/tensors/planar.jl +++ b/src/tensors/planar.jl @@ -219,7 +219,7 @@ function _add_modules(ex::Expr) (ex.args[i] for i in argind)...) elseif ex.head == :call && ex.args[1] == :trace! @assert ex.args[4] == :(:N) - argind = [2,4,5,6,7,8,9,10] + argind = [2,3,5,6,7,8,9,10] return Expr(ex.head, GlobalRef(TensorKit, Symbol(:planar_trace!)), (ex.args[i] for i in argind)...) elseif ex.head == :call && ex.args[1] == :contract! @@ -322,3 +322,47 @@ function planar_contract!(α, A::AbstractTensorMap{S}, B::AbstractTensorMap{S}, # end return C end + +function planar_trace!(α, tsrc::AbstractTensorMap{S}, + β, tdst::AbstractTensorMap{S, N₁, N₂}, + p1::IndexTuple{N₁}, p2::IndexTuple{N₂}, + q1::IndexTuple{N₃}, q2::IndexTuple{N₃}) where {S, N₁, N₂, N₃} + if BraidingStyle(sectortype(S)) == Bosonic() + return trace!(α, tsrc, β, tdst, p1, p2, q1, q2) + end + + @boundscheck begin + all(i->space(tsrc, p1[i]) == space(tdst, i), 1:N₁) || + throw(SpaceMismatch("trace: tsrc = $(codomain(tsrc))←$(domain(tsrc)), + tdst = $(codomain(tdst))←$(domain(tdst)), p1 = $(p1), p2 = $(p2)")) + all(i->space(tsrc, p2[i]) == space(tdst, N₁+i), 1:N₂) || + throw(SpaceMismatch("trace: tsrc = $(codomain(tsrc))←$(domain(tsrc)), + tdst = $(codomain(tdst))←$(domain(tdst)), p1 = $(p1), p2 = $(p2)")) + all(i->space(tsrc, q1[i]) == dual(space(tsrc, q2[i])), 1:N₃) || + throw(SpaceMismatch("trace: tsrc = $(codomain(tsrc))←$(domain(tsrc)), + q1 = $(q1), q2 = $(q2)")) + end + + cod = codomain(tsrc) + dom = domain(tsrc) + n = length(cod) + pdata = (p1..., p2...) + if iszero(β) + fill!(tdst, β) + elseif β != 1 + mul!(tdst, β, tdst) + end + r1 = (p1..., q1...) + r2 = (p2..., q2...) + for (f1, f2) in fusiontrees(tsrc) + for ((f1′, f2′), coeff) in transpose(f1, f2, r1, r2) + f1′′, g1 = split(f1′, N₁) + f2′′, g2 = split(f2′, N₂) + if g1 == g2 + coeff *= dim(g1.coupled)/dim(g1.uncoupled[1]) + TO._trace!(α*coeff, tsrc[f1, f2], true, tdst[f1′′, f2′′], pdata, q1, q2) + end + end + end + return tdst +end diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index ab891680f..3c97d3d36 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -177,7 +177,7 @@ function trace!(α, tsrc::AbstractTensorMap{S}, β, tdst::AbstractTensorMap{S, N tdst = $(codomain(tdst))←$(domain(tdst)), p1 = $(p1), p2 = $(p2)")) all(i->space(tsrc, q1[i]) == dual(space(tsrc, q2[i])), 1:N₃) || throw(SpaceMismatch("trace: tsrc = $(codomain(tsrc))←$(domain(tsrc)), - q1 = $(p1), q2 = $(q2)")) + q1 = $(q1), q2 = $(q2)")) end I = sectortype(S)