diff --git a/src/fusiontrees/manipulations.jl b/src/fusiontrees/manipulations.jl index 9f67c1306..1b0e26ab0 100644 --- a/src/fusiontrees/manipulations.jl +++ b/src/fusiontrees/manipulations.jl @@ -522,9 +522,11 @@ const TransposeKey{I<:Sector, N₁, N₂} = Tuple{<:FusionTree{I}, <:FusionTree{ function _transpose((f1, f2, p1, p2)::TransposeKey{I,N₁,N₂}) where {I<:Sector, N₁, N₂} N = N₁ + N₂ p = linearizepermutation(p1, p2, length(f1), length(f2)) + newtrees = repartition(f1, f2, N₁) + length(p) == 0 && return newtrees i1 = findfirst(==(1), p) @assert i1 !== nothing - newtrees = repartition(f1, f2, N₁) + i1 == 1 && return newtrees Nhalf = N >> 1 while 1 < i1 <= Nhalf local newtrees′ @@ -766,7 +768,38 @@ function artin_braid(f::FusionTree{I, N}, i; inv::Bool = false) where {I<:Sector inner = f.innerlines vertices = f.vertices u = one(I) - oneT = one(eltype(Rsymbol(u,u,u))) * one(eltype(Fsymbol(u,u,u,u,u,u))) + + if BraidingStyle(I) isa NoBraiding + oneT = one(eltype(Fsymbol(u,u,u,u,u,u))) + else + oneT = one(eltype(Rsymbol(u,u,u))) * one(eltype(Fsymbol(u,u,u,u,u,u))) + end + + if u in (uncoupled[i],uncoupled[i+1]) # the braid simplifies drastically, we are braiding a trivial sector + a, b = uncoupled[i], uncoupled[i+1] + uncoupled′ = TupleTools.setindex(uncoupled,b,i); + uncoupled′ = TupleTools.setindex(uncoupled′,a,i+1); + vertices′ = vertices; + + if i > 1 #we also need to alter innerlines and vertices + incharges = (uncoupled[1],inner...,coupled′); + + vertices′ = TupleTools.setindex(vertices′,vertices[i],i-1) + vertices′ = TupleTools.setindex(vertices′,vertices[i-1],i) + + if a == u + inner = TupleTools.setindex(inner,incharges[i+1],i-1); + else + inner = TupleTools.setindex(inner,incharges[i-1],i-1); + end + end + + f′ = FusionTree{I}(uncoupled′, coupled′, isdual′, inner, vertices′) + return fusiontreedict(I)(f′ => oneT) + end + + BraidingStyle(I) isa NoBraiding && throw(SectorMismatch("cannot braid sector "*type_repr(I))) + if i == 1 a, b = uncoupled[1], uncoupled[2] c = N > 2 ? inner[1] : coupled′ diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index ac75b14a5..4422f4672 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -65,21 +65,26 @@ function fusiontrees(b::BraidingTensor) end end dim2 = offset2 - push!(data, c=>f((dim1, dim2))) push!(rowr, c=>rowrc) push!(colr, c=>colrc) end return TensorKeyIterator(rowr, colr) end +function Base.getindex(b::BraidingTensor{S}) where S + sectortype(S) == Trivial || throw(SectorMismatch()) + (V1,V2) = domain(b); + return reshape(storagetype(b)(LinearAlgebra.I, dim(V1)*dim(V2), dim(V1)*dim(V2)),(dim(V2),dim(V1),dim(V1),dim(V2))); +end + @inline function Base.getindex(b::BraidingTensor, f1::FusionTree{I,2}, f2::FusionTree{I,2}) where {I<:Sector} I == sectortype(b) || throw(SectorMismatch()) c = f1.coupled V1, V2 = domain(b) @boundscheck begin c == f2.coupled || throw(SectorMismatch()) - ((f1.uncoupled[1] ∈ V2) && (f2.uncoupled[1] ∈ V1)) || throw(SectorMismatch()) - ((f1.uncoupled[2] ∈ V1) && (f2.uncoupled[2] ∈ V2)) || throw(SectorMismatch()) + ((f1.uncoupled[1] ∈ sectors(V2)) && (f2.uncoupled[1] ∈ sectors(V1))) || throw(SectorMismatch()) + ((f1.uncoupled[2] ∈ sectors(V1)) && (f2.uncoupled[2] ∈ sectors(V2))) || throw(SectorMismatch()) end @inbounds begin d = (dims(V2 ⊗ V1, f1.uncoupled)..., dims(V1 ⊗ V2, f2.uncoupled)...) @@ -89,10 +94,11 @@ end data = fill!(storagetype(b)(undef, (n1, n2)), zero(eltype(b))) if f1.uncoupled == (a2, a1) braiddict = artin_braid(f2, 1; inv = b.adjoint) - r = get(dict, f1, zero(valtype(braiddict))) + r = get(braiddict, f1, zero(valtype(braiddict))) data[1:(n1+1):end] .= r # set diagonal end - return permutedims(sreshape(StridedView(data), d), (1,2,4,3)) + + return sreshape(StridedView(data), d) end end @@ -101,6 +107,11 @@ function Base.copy!(t::TensorMap, b::BraidingTensor) space(t) == space(b) || throw(SectorMismatch()) fill!(t, zero(eltype(t))) for (f1, f2) in fusiontrees(t) + if f1 == nothing || f2 == nothing + _one!(t.data) + return t + end + a1, a2 = f2.uncoupled c = f2.coupled f1.uncoupled == (a2, a1) || continue @@ -128,3 +139,300 @@ function block(b::BraidingTensor, s::Sector) # a1, a2 = f2.uncoupled # end + +function planar_contract!(α, A::BraidingTensor, B::AbstractTensorMap{S}, + β, C::AbstractTensorMap{S}, + oindA::IndexTuple{2}, cindA::IndexTuple{2}, + oindB::IndexTuple{N₂}, cindB::IndexTuple, + p1::IndexTuple, p2::IndexTuple, + syms::Union{Nothing, NTuple{3, Symbol}}) where {S, N₂} + codA = codomainind(A) + domA = domainind(A) + codB = codomainind(B) + domB = domainind(B) + oindA, cindA, oindB, cindB = reorder_indices(codA, domA, codB, domB, oindA, cindA, oindB, cindB, p1, p2) + + braidingtensor_levels = A.adjoint ? (1,2,2,1) : (2,1,1,2) + + if iszero(β) + fill!(C, β) + elseif β != 1 + rmul!(C, β) + end + + for (f1, f2) in fusiontrees(B) + if f1 == nothing && f2 == nothing + return TO.add!(α, B,:N,true,C, reverse(cindB),oindB); + end + + fmap = Dict{Tuple{typeof(f1),typeof(f2)},eltype(f1)}(); + + + braid_above = braidingtensor_levels[cindA[1]] > braidingtensor_levels[cindA[2]]; + for ((f1′,f2′),coeff) in transpose(f1,f2,cindB,oindB), + (f1′′,coeff′) in artin_braid(f1′,1,inv = braid_above) + + nk = (f1′′,f2′); + nv = coeff′*coeff; + + fmap[nk] = get(fmap,nk,zero(nv)) + nv; + end + + for ((f1′,f2′),c) in fmap + TO._add!(c*α, B[f1, f2], true,C[f1′,f2′], (reverse(cindB)...,oindB...)); + end + end + C +end +function planar_contract!(α, A::AbstractTensorMap{S}, B::BraidingTensor, + β, C::AbstractTensorMap{S}, + oindA::IndexTuple{N₁}, cindA::IndexTuple{N₁}, + oindB::IndexTuple{2}, cindB::IndexTuple{2}, + p1::IndexTuple, p2::IndexTuple, + syms::Union{Nothing, NTuple{3, Symbol}}) where {S, N₁} + codA = codomainind(A) + domA = domainind(A) + codB = codomainind(B) + domB = domainind(B) + oindA, cindA, oindB, cindB = reorder_indices(codA, domA, codB, domB, oindA, cindA, oindB, cindB, p1, p2) + + braidingtensor_levels = B.adjoint ? (1,2,2,1) : (2,1,1,2); + + if iszero(β) + fill!(C, β) + elseif β != 1 + rmul!(C, β) + end + + for (f1, f2) in fusiontrees(A) + if f1 == nothing && f2 == nothing + return TO.add!(α, A,:N,true,C, oindA,reverse(cindA)); + end + + fmap = Dict{Tuple{typeof(f1),typeof(f2)},eltype(f1)}(); + + braid_above = braidingtensor_levels[cindB[1]] > braidingtensor_levels[cindB[2]]; + + for ((f1′,f2′),coeff) in transpose(f1,f2,oindA,cindA), + (f2′′,coeff′) in artin_braid(f2′,1,inv = braid_above) + + nk = (f1′,f2′′); + nv = coeff′*coeff; + + fmap[nk] = get(fmap,nk,zero(nv)) + nv; + end + + for ((f1′,f2′),c) in fmap + TO._add!(c*α, A[f1, f2], true,C[f1′,f2′], (oindA...,reverse(cindA)...)); + end + end + C +end + +function planar_contract!(α, A::BraidingTensor, B::AbstractTensorMap{S}, + β, C::AbstractTensorMap{S}, + oindA::IndexTuple{0}, cindA::IndexTuple{4}, + oindB::IndexTuple{N₂}, cindB::IndexTuple, + p1::IndexTuple, p2::IndexTuple, + syms::Union{Nothing, NTuple{3, Symbol}}) where {S, N₂} + + braidingtensor_levels = A.adjoint ? (1,2,2,1) : (2,1,1,2); + codA = codomainind(A) + domA = domainind(A) + codB = codomainind(B) + domB = domainind(B) + + oindA, cindA, oindB, cindB = reorder_indices(codA, domA, codB, domB, oindA, cindA, oindB, cindB, p1, p2) + + if iszero(β) + fill!(C, β) + elseif β != 1 + rmul!(C, β) + end + + for (f1, f2) in fusiontrees(B) + if f1 == nothing && f2 == nothing + TO.trace!(α, B, :N,true,C, oindB, (),(cindB[1],cindB[2]), (cindB[3],cindB[4])) + break; + end + + local fmap; + + braid_above = braidingtensor_levels[cindA[2]] > braidingtensor_levels[cindA[3]]; + + for ((f1′,f2′),coeff) in transpose(f1,f2,cindB,oindB), + (f1′′,coeff′) in artin_braid(f1′,2,inv = braid_above), + (f1_tr1,c_tr1) in elementary_trace(f1′′, 1), + (f1_tr2,c_tr2) in elementary_trace(f1_tr1,1) + nk = (f1_tr2,f2′); + nv = coeff*coeff′*c_tr1*c_tr2*α; + if @isdefined fmap + fmap[nk] = get(fmap,nk,zero(nv)) + nv; + else + fmap = Dict(nk=>nv); + end + + end + + for ((f1′,f2′),c) in fmap + TO._trace!(c, B[f1, f2], true,C[f1′,f2′], oindB, (cindB[1],cindB[2]), (cindB[3],cindB[4])) + end + end + + C +end + +function planar_contract!(α, A::AbstractTensorMap{S}, B::BraidingTensor, + β, C::AbstractTensorMap{S}, + oindA::IndexTuple{N₁}, cindA::IndexTuple, + oindB::IndexTuple{0}, cindB::IndexTuple{4}, + p1::IndexTuple, p2::IndexTuple, + syms::Union{Nothing, NTuple{3, Symbol}}) where {S, N₁} + + braidingtensor_levels = B.adjoint ? (1,2,2,1) : (2,1,1,2); + + codA = codomainind(A) + domA = domainind(A) + codB = codomainind(B) + domB = domainind(B) + + oindA, cindA, oindB, cindB = reorder_indices(codA, domA, codB, domB, oindA, cindA, oindB, cindB, p1, p2) + + + if iszero(β) + fill!(C, β) + elseif β != 1 + rmul!(C, β) + end + + braid_above = braidingtensor_levels[cindB[2]] > braidingtensor_levels[cindB[3]]; + + for (f1, f2) in fusiontrees(A) + if f1 == nothing && f2 == nothing + TO.trace!(α, A, :N,true,C, oindA , (),(cindA[1],cindA[2]), (cindA[3],cindA[4])) + break; + end + + local fmap; + for ((f1′,f2′),coeff) in transpose(f1,f2,oindA,cindA), + (f2′′,coeff′) in artin_braid(f2′,2,inv = braid_above), + (f2_tr1,c_tr1) in elementary_trace(f2′′, 1), + (f2_tr2,c_tr2) in elementary_trace(f2_tr1,1) + + nk = (f1′,f2_tr2); + nv = coeff*coeff′*c_tr1*c_tr2*α; + if @isdefined fmap + fmap[nk] = get(fmap,nk,zero(nv)) + nv; + else + fmap = Dict(nk=>nv); + end + end + + for ((f1′,f2′),c) in fmap + TO._trace!(c, A[f1, f2], true,C[f1′,f2′], oindA , (cindA[1],cindA[2]), (cindA[3],cindA[4])) + end + + end + C +end + +function planar_contract!(α, A::BraidingTensor, B::AbstractTensorMap{S}, + β, C::AbstractTensorMap{S}, + oindA::IndexTuple{1}, cindA::IndexTuple{3}, + oindB::IndexTuple{N₂}, cindB::IndexTuple, + p1::IndexTuple, p2::IndexTuple, + syms::Union{Nothing, NTuple{3, Symbol}}) where {S, N₁, N₂} + + braidingtensor_levels = A.adjoint ? (1,2,2,1) : (2,1,1,2); + + codA = codomainind(A) + domA = domainind(A) + codB = codomainind(B) + domB = domainind(B) + + oindA, cindA, oindB, cindB = reorder_indices(codA, domA, codB, domB, oindA, cindA, oindB, cindB, p1, p2) + + for (f1, f2) in fusiontrees(B) + if f1 == nothing && f2 == nothing + TO.trace!(α, B,:N, true,C, (cindB[2],), oindB, (cindB[1],), (cindB[3],)) + break; + end + + local fmap + braid_above = braidingtensor_levels[cindA[1]] > braidingtensor_levels[cindA[2]]; + + for ((f1′,f2′),coeff) in transpose(f1,f2,cindB,oindB), + (f1′′,coeff′) in artin_braid(f1′,1,inv = braid_above), + (f1′′′,coeff′′) in elementary_trace(f1′′, 2) + nk = (f1′′′,f2′); + nv = coeff*coeff′*coeff′′*α + if @isdefined fmap + fmap[nk] = get(fmap,nk,zero(nv)) + nv + else + fmap = Dict(nk => nv) + end + end + + for ((f1′,f2′),c) in fmap + TO._trace!(c, B[f1, f2], true,C[f1′,f2′], (cindB[2],oindB...) , (cindB[1],), (cindB[3],)) + end + end + + + C +end + +function planar_contract!(α, A::AbstractTensorMap{S}, B::BraidingTensor, + β, C::AbstractTensorMap{S}, + oindA::IndexTuple{N₁}, cindA::IndexTuple, + oindB::IndexTuple{1}, cindB::IndexTuple{3}, + p1::IndexTuple, p2::IndexTuple, + syms::Union{Nothing, NTuple{3, Symbol}}) where {S, N₁} + braidingtensor_levels = B.adjoint ? (1,2,2,1) : (2,1,1,2); + + codA = codomainind(A) + domA = domainind(A) + codB = codomainind(B) + domB = domainind(B) + + oindA, cindA, oindB, cindB = reorder_indices(codA, domA, codB, domB, oindA, cindA, oindB, cindB, p1, p2) + + if iszero(β) + fill!(C, β) + elseif β != 1 + rmul!(C, β) + end + + for (f1, f2) in fusiontrees(A) + if f1 == nothing && f2 == nothing + + TO.trace!(α, A,:N, true,C, oindA, (cindA[2],) , (cindA[1],), (cindA[3],)) + break; + end + + braid_above = braidingtensor_levels[cindB[1]] > braidingtensor_levels[cindB[2]]; + + local fmap + + for ((f1′,f2′),coeff) in transpose(f1,f2,oindA,cindA), + (f2′′,coeff′) in artin_braid(f2′,1,inv = braid_above), + (f2′′′,coeff′′) in elementary_trace(f2′′, 2) + + nk = (f1′,f2′′′); + nv = coeff*coeff′*coeff′′*α; + + if @isdefined fmap + fmap[nk] = get(fmap,nk,zero(nv)) + nv + else + fmap = Dict(nk => nv); + end + + end + + for ((f1′,f2′),c) in fmap + TO._trace!(c, A[f1, f2], true,C[f1′,f2′], (oindA...,cindA[2]) , (cindA[1],), (cindA[3],)) + end + end + + C +end diff --git a/src/tensors/planar.jl b/src/tensors/planar.jl index 30e865558..7ffbaa91b 100644 --- a/src/tensors/planar.jl +++ b/src/tensors/planar.jl @@ -67,8 +67,10 @@ end _conj_to_adjoint(ex) = ex function get_possible_planar_indices(ex::Expr) - @assert TO.istensorexpr(ex) - if TO.isgeneraltensor(ex) + #@assert TO.istensorexpr(ex) + if !TO.istensorexpr(ex) + return [[]] + elseif TO.isgeneraltensor(ex) _,leftind,rightind = TO.decomposegeneraltensor(ex) ind = planar_unique2(vcat(leftind, reverse(rightind))) return length(ind) == length(unique(ind)) ? Any[ind] : Any[] @@ -270,6 +272,17 @@ function _extract_contraction_pairs(rhs, lhs, pre, temporaries) a1 = _extract_contraction_pairs(rhs.args[2], (oind1, reverse(cind1)), pre, temporaries) a2 = _extract_contraction_pairs(rhs.args[3], (cind2, reverse(oind2)), pre, temporaries) end + + if TO.isscalarexpr(a1) || TO.isscalarexpr(a2) + rhs = Expr(:call, :*, a1, a2) + s = gensym() + newlhs = Expr(:typed_vcat, s, Expr(:tuple, oind1...), + Expr(:tuple, reverse(oind2)...)) + push!(temporaries, s) + push!(pre, Expr(:(:=), newlhs, rhs)) + return newlhs + end + # note that index order in _extract... is only a suggestion, now we have actual index order _, l1, r1, = TO.decomposegeneraltensor(a1) _, l2, r2, = TO.decomposegeneraltensor(a2) @@ -307,6 +320,9 @@ function _extract_contraction_pairs(rhs, lhs, pre, temporaries) args = [_extract_contraction_pairs(a, lhs, pre, temporaries) for a in rhs.args[2:end]] return Expr(rhs.head, rhs.args[1], args...) + elseif TO.isscalarexpr(rhs) + #do nothing? + return rhs else throw(ArgumentError("unknown tensor expression")) end @@ -409,7 +425,7 @@ function _construct_braidingtensors(ex::Expr) lhs, rhs = TO.getlhs(ex), TO.getrhs(ex) if TO.istensorexpr(rhs) list = TO.gettensors(rhs) - if TO.isassignment(ex) && istensor(lhs) + if TO.isassignment(ex) && TO.istensor(lhs) obj, l, r = TO.decomposetensor(lhs) lhs_as_rhs = Expr(:typed_vcat, Expr(TO.prime, obj), Expr(:tuple, r...), Expr(:tuple, l...)) push!(list, lhs_as_rhs) @@ -434,8 +450,10 @@ function _construct_braidingtensors(ex::Expr) i += 1 end end - pre = Expr(:block) - for (t, construct_expr) in translatebraidings + + unresolved = Any[]; # list of indices that we couldn't yet figure out + indexmaps = Dict{Any,Any}(); # contains the expression to resolve the space at index indexmaps[i] + for (t,construct_expr) in translatebraidings obj, leftind, rightind = TO.decomposetensor(t) length(leftind) == length(rightind) == 2 || throw(ArgumentError("The name τ is reserved for the braiding, and should have two input and two output indices.")) @@ -446,25 +464,71 @@ function _construct_braidingtensors(ex::Expr) i2b, i1b, = leftind i1a, i2a, = rightind end - obj_and_pos = _findindex(i1a, list) - if !isnothing(obj_and_pos) - push!(construct_expr.args, Expr(:call, :space, obj_and_pos...)) + + obj_and_pos1a = _findindex(i1a, list) + obj_and_pos2a = _findindex(i2a, list) + obj_and_pos1b = _findindex(i1b, list) + obj_and_pos2b = _findindex(i2b, list) + + if !isnothing(obj_and_pos1a) + indexmaps[i1a] = Expr(:call, :space, obj_and_pos1a...); + indexmaps[i1b] = Expr(TO.prime, Expr(:call, :space, obj_and_pos1a...)); + elseif !isnothing(obj_and_pos1b) + indexmaps[i1a] = Expr(TO.prime, Expr(:call, :space, obj_and_pos1b...)); + indexmaps[i1b] = Expr(:call, :space, obj_and_pos1b...); else - obj_and_pos = _findindex(i1b, list) - isnothing(obj_and_pos) && - throw(ArgumentError("cannot determine space of index $i1a and $i1b of braiding tensor")) - push!(construct_expr.args, Expr(TO.prime, Expr(:call, :space, obj_and_pos...))) + push!(unresolved,(i1a,i1b)); end - obj_and_pos = _findindex(i2a, list) - if !isnothing(obj_and_pos) - push!(construct_expr.args, Expr(:call, :space, obj_and_pos...)) + if !isnothing(obj_and_pos2a) + indexmaps[i2a] = Expr(:call, :space, obj_and_pos2a...); + indexmaps[i2b] = Expr(TO.prime, Expr(:call, :space, obj_and_pos2a...)) + elseif !isnothing(obj_and_pos2b) + indexmaps[i2a] = Expr(TO.prime, Expr(:call, :space, obj_and_pos2b...)) + indexmaps[i2b] = Expr(:call, :space, obj_and_pos2b...); else - obj_and_pos = _findindex(i2b, list) - isnothing(obj_and_pos) && - throw(ArgumentError("cannot determine space of index $i2a and $i2b of braiding tensor")) - push!(construct_expr.args, Expr(TO.prime, Expr(:call, :space, obj_and_pos...))) + push!(unresolved,(i2a,i2b)); + end + end + + # very simple loop that tries to resolve as many indices as possible + changed = true; + while changed == true + changed = false; + + i = 1; + while i<=length(unresolved) + (i1,i2) = unresolved[i]; + if i1 in keys(indexmaps) + changed = true; + indexmaps[i2] = Expr(TO.prime,indexmaps[i1]); + deleteat!(unresolved,i); + elseif i2 in keys(indexmaps) + changed = true; + indexmaps[i1] = Expr(TO.prime,indexmaps[i2]); + deleteat!(unresolved,i); + else + i+=1 + end end + end + + !isempty(unresolved) && throw(ArgumentError("cannot determine the spaces of indices $(tuple(unresolved...)) for the braiding tensors in $(ex)")); + + pre = Expr(:block) + for (t, construct_expr) in translatebraidings + obj, leftind, rightind = TO.decomposetensor(t) + if _is_adjoint(obj) + i1b, i2b, = leftind + i2a, i1a, = rightind + else + i2b, i1b, = leftind + i1a, i2a, = rightind + end + + push!(construct_expr.args, indexmaps[i1a]); + push!(construct_expr.args, indexmaps[i2a]); + s = gensym() push!(pre.args, Expr(:(=), s, construct_expr)) ex = TO.replacetensorobjects(ex) do o, l, r @@ -500,11 +564,15 @@ function _update_temporaries(ex, temporaries) i = findfirst(==(lhs), temporaries) if i !== nothing rhs = ex.args[2] - if !(rhs isa Expr && rhs.head == :call && rhs.args[1] == :contract!) + if rhs isa Expr && rhs.head == :call && rhs.args[1] == :add! + newname = rhs.args[6] + temporaries[i] = newname + elseif rhs isa Expr && rhs.head == :call && rhs.args[1] == :contract! + newname = rhs.args[8] + temporaries[i] = newname + else @error "lhs = $lhs , rhs = $rhs" end - newname = rhs.args[8] - temporaries[i] = newname end elseif ex isa Expr for a in ex.args