diff --git a/src/TensorOperations.jl b/src/TensorOperations.jl index 538265e3..fe7f5f09 100644 --- a/src/TensorOperations.jl +++ b/src/TensorOperations.jl @@ -3,7 +3,7 @@ module TensorOperations export tensorcopy, tensoradd, tensortrace, tensorcontract, tensorproduct, scalar export tensorcopy!, tensoradd!, tensortrace!, tensorcontract!, tensorproduct! -export @tensor +export @tensor, @tensoropt, @optimalcontractiontree # Auxiliary functions #--------------------- @@ -25,7 +25,12 @@ include("implementation/strides.jl") # Index notation #---------------- +import Base.Iterators.flatten + include("indexnotation/tensormacro.jl") +include("indexnotation/nconstyle.jl") +include("indexnotation/poly.jl") +include("indexnotation/optimize.jl") include("indexnotation/indexedobject.jl") include("indexnotation/sum.jl") include("indexnotation/product.jl") diff --git a/src/indexnotation/indexedobject.jl b/src/indexnotation/indexedobject.jl index b2592923..ae1aab3e 100644 --- a/src/indexnotation/indexedobject.jl +++ b/src/indexnotation/indexedobject.jl @@ -39,7 +39,7 @@ Base.eltype(::Type{IndexedObject{I,C, A, T}}) where {I,C, A, T} = promote_type(e Expr(:block, meta, :($J)) end -indexify(object, ::Indices{I}) where {I} = IndexedObject{I,:N}(object) +@inline indexify(object, ::Indices{I}) where {I} = IndexedObject{I,:N}(object) deindexify(A::IndexedObject{I,:N}, ::Indices{I}) where {I} = A.α == 1 ? A.object : A.α*A.object deindexify(A::IndexedObject{I,:C}, ::Indices{I}) where {I} = A.α == 1 ? conj(A.object) : A.α*conj(A.object) diff --git a/src/indexnotation/nconstyle.jl b/src/indexnotation/nconstyle.jl new file mode 100644 index 00000000..bb2f6879 --- /dev/null +++ b/src/indexnotation/nconstyle.jl @@ -0,0 +1,58 @@ +# check if a list of indices specifies a tensor contraction in ncon style +function isnconstyle(network::Vector) + allindices = Vector{Int}() + for ind in network + all(i->isa(i, Integer), ind) || return false + append!(allindices, ind) + end + while length(allindices) > 0 + i = pop!(allindices) + if i > 0 # positive labels represent contractions or traces and should appear twice + k = findnext(allindices, i, 1) + l = findnext(allindices, i, k+1) + if k == 0 || l != 0 + return false + end + deleteat!(allindices, k) + elseif i < 0 # negative labels represent open indices and should appear once + findnext(allindices, i, 1) == 0 || return false + else # i == 0 + return false + end + end + return true +end + +function ncontree(network::Vector) + contractionindices = Vector{Vector{Int}}(length(network)) + for k = 1:length(network) + indices = network[k] + # trace indices have already been removed, remove open indices by filtering on positive values + contractionindices[k] = filter(i->i>0, indices) + end + partialtrees = collect(Any, 1:length(network)) + _ncontree!(partialtrees, contractionindices) +end + +function _ncontree!(partialtrees, contractionindices) + if length(partialtrees) == 1 + return partialtrees[1] + end + if all(isempty, contractionindices) # disconnected network + partialtrees[end-1] = (partialtrees[end-1], partialtrees[end]) + pop!(partialtrees) + pop!(contractionindices) + else + let firstind = minimum(flatten(contractionindices)) + i1 = findfirst(x->in(firstind,x), contractionindices) + i2 = findnext(x->in(firstind,x), contractionindices, i1+1) + newindices = unique2(vcat(contractionindices[i1], contractionindices[i2])) + newtree = (partialtrees[i1], partialtrees[i2]) + partialtrees[i1] = newtree + deleteat!(partialtrees, i2) + contractionindices[i1] = newindices + deleteat!(contractionindices, i2) + end + end + _ncontree!(partialtrees, contractionindices) +end diff --git a/src/indexnotation/optimize.jl b/src/indexnotation/optimize.jl new file mode 100644 index 00000000..5a98cfec --- /dev/null +++ b/src/indexnotation/optimize.jl @@ -0,0 +1,259 @@ +function optimaltree(network, optdata::Dict) + numtensors = length(network) + allindices = unique(flatten(network)) + numindices = length(allindices) + costtype = valtype(optdata) + allcosts = [get(optdata, i, one(costtype)) for i in allindices] + maxcost = prod(allcosts)*maximum(allcosts) + zero(costtype) # add zero for type stability: Power -> Poly + tensorcosts = Vector{costtype}(numtensors) + for k = 1:numtensors + tensorcosts[k] = prod(get(optdata, i, one(costtype)) for i in network[k]) + end + initialcost = min(maxcost, maximum(tensorcosts)^2 + zero(costtype)) # just some arbitrary guess + + if numindices <= 32 + return _optimaltree(UInt32, network, allindices, allcosts, initialcost, maxcost) + elseif numindices <= 64 + return _optimaltree(UInt64, network, allindices, allcosts, initialcost, maxcost) + elseif numindices <= 128 + return _optimaltree(UInt128, network, allindices, allcosts, initialcost, maxcost) + else + return _optimaltree(BitVector, network, allindices, allcosts, initialcost, maxcost) + end +end + +storeset(::Type{IntSet}, ints, maxint) = sizehint!(IntSet(ints), maxint) +function storeset(::Type{BitVector}, ints, maxint) + set = falses(maxint) + set[ints] = true + return set +end +function storeset(::Type{T}, ints, maxint) where {T<:Unsigned} + set = zero(T) + u = one(T) + for i in ints + set |= (u<<(i-1)) + end + return set +end +_intersect(s1::T, s2::T) where {T<:Unsigned} = s1 & s2 +_intersect(s1::BitVector, s2::BitVector) = s1 .& s2 +_intersect(s1::IntSet, s2::IntSet) = intersect(s1, s2) +_union(s1::T, s2::T) where {T<:Unsigned} = s1 | s2 +_union(s1::BitVector, s2::BitVector) = s1 .| s2 +_union(s1::IntSet, s2::IntSet) = union(s1, s2) +_setdiff(s1::T, s2::T) where {T<:Unsigned} = s1 & (~s2) +_setdiff(s1::BitVector, s2::BitVector) = s1 .& (.~s2) +_setdiff(s1::IntSet, s2::IntSet) = setdiff(s1, s2) +_isemptyset(s::Unsigned) = iszero(s) +_isemptyset(s::BitVector) = !any(s) +_isemptyset(s::IntSet) = isempty(s) + +function computecost(allcosts, ind1::T, ind2::T) where {T<:Unsigned} + cost = one(eltype(allcosts)) + ind = _union(ind1, ind2) + n = 1 + while !iszero(ind) + if isodd(ind) + cost *= allcosts[n] + end + ind = ind>>1 + n += 1 + end + return cost +end +function computecost(allcosts, ind1::BitVector, ind2::BitVector) + cost = one(eltype(allcosts)) + ind = _union(ind1, ind2) + for n in find(ind) + cost *= allcosts[n] + end + return cost +end +function computecost(allcosts, ind1::IntSet, ind2::IntSet) + cost = one(eltype(allcosts)) + ind = _union(ind1, ind2) + for n in ind + cost *= allcosts[n] + end + return cost +end + +function _optimaltree(::Type{T}, network, allindices, allcosts::Vector{S}, initialcost::C, maxcost::C) where {T,S,C} + numindices = length(allindices) + numtensors = length(network) + indexsets = Array{T}(numtensors) + + tabletensor = zeros(Int, (numindices,2)) + tableindex = zeros(Int, (numindices,2)) + + adjacencymatrix=falses(numtensors,numtensors) + costfac = maximum(allcosts) + + @inbounds for n = 1:numtensors + indn = findin(allindices, network[n]) + indexsets[n] = storeset(T, indn, numindices) + for i in indn + if tabletensor[i,1] == 0 + tabletensor[i,1] = n + tableindex[i,1] = findfirst(network[n], allindices[i]) + elseif tabletensor[i,2] == 0 + tabletensor[i,2] = n + tableindex[i,2] = findfirst(network[n], allindices[i]) + n1 = tabletensor[i,1] + adjacencymatrix[n1,n] = true + adjacencymatrix[n,n1] = true + else + error("no index should appear more than two times") + end + end + end + componentlist = connectedcomponents(adjacencymatrix) + numcomponent = length(componentlist) + + # generate output structures + costlist = Vector{C}(numcomponent) + treelist = Vector{Any}(numcomponent) + indexlist = Vector{T}(numcomponent) + + # run over components + for c=1:numcomponent + # find optimal contraction for every component + component = componentlist[c] + componentsize = length(component) + costdict = Array{Dict{T, C}}(componentsize) + treedict = Array{Dict{T, Any}}(componentsize) + indexdict = Array{Dict{T, T}}(componentsize) + + for k=1:componentsize + costdict[k] = Dict{T, C}() + treedict[k] = Dict{T, Any}() + indexdict[k] = Dict{T, T}() + end + + for i in component + s = storeset(T, [i], numtensors) + costdict[1][s] = zero(C) + treedict[1][s] = i + indexdict[1][s] = indexsets[i] + end + + # run over currentcost + currentcost = initialcost + previouscost = zero(initialcost) + while currentcost <= maxcost + nextcost = maxcost + # construct all subsets of n tensors that can be constructed with cost <= currentcost + for n=2:componentsize + # construct subsets by combining two smaller subsets + for k = 1:div(n-1,2) + for s1 in keys(costdict[k]), s2 in keys(costdict[n-k]) + if _isemptyset(_intersect(s1, s2)) && get(costdict[n], _union(s1, s2), currentcost) > previouscost + ind1 = indexdict[k][s1] + ind2 = indexdict[n-k][s2] + cind = _intersect(ind1, ind2) + if !_isemptyset(cind) + s = _union(s1, s2) + cost = costdict[k][s1] + costdict[n-k][s2] + computecost(allcosts, ind1, ind2) + if cost <= get(costdict[n], s, currentcost) + costdict[n][s] = cost + indexdict[n][s] = _setdiff(_union(ind1,ind2), cind) + treedict[n][s] = (treedict[k][s1], treedict[n-k][s2]) + elseif currentcost < cost < nextcost + nextcost = cost + end + end + end + end + end + if iseven(n) # treat the case k = n/2 special + k = div(n,2) + it = keys(costdict[k]) + state1 = start(it) + while !done(it, state1) + s1, nextstate1 = next(it, state1) + state2 = nextstate1 + while !done(it, state2) + s2, nextstate2 = next(it, state2) + if _isemptyset(_intersect(s1, s2)) && get(costdict[n], _union(s1, s2), currentcost) > previouscost + ind1 = indexdict[k][s1] + ind2 = indexdict[k][s2] + cind = _intersect(ind1, ind2) + if !_isemptyset(cind) + s = _union(s1, s2) + cost = costdict[k][s1] + costdict[k][s2] + computecost(allcosts, ind1, ind2) + if cost <= get(costdict[n], s, currentcost) + costdict[n][s] = cost + indexdict[n][s] = _setdiff(_union(ind1,ind2), cind) + treedict[n][s] = (treedict[k][s1], treedict[k][s2]) + elseif currentcost < cost < nextcost + nextcost = cost + end + end + end + state2 = nextstate2 + end + state1 = nextstate1 + end + end + end + if !isempty(costdict[componentsize]) + break + end + previouscost = currentcost + currentcost = min(maxcost, nextcost*costfac) + end + if isempty(costdict[componentsize]) + error("Maxcost $maxcost reached without finding solution") # should be impossible + end + s = storeset(T, component, numtensors) + costlist[c] = costdict[componentsize][s] + treelist[c] = treedict[componentsize][s] + indexlist[c] = indexdict[componentsize][s] + end + tree = treelist[1] + cost = costlist[1] + ind = indexlist[1] + for c = 2:numcomponent + tree = (tree, treelist[c]) + cost = cost + costlist[c] + computecost(allcosts, ind, indexlist[c]) + ind = _union(ind, indexlist[c]) + end + return tree, cost +end + +function connectedcomponents(A::AbstractMatrix{Bool}) + # For a given adjacency matrix of size n x n, connectedcomponents returns + # a list componentlist that contains integer vectors, where every integer + # vector groups the indices of the vertices of a connected component of the + # graph encoded by A. The number of connected components is given by + # length(componentlist). + # + # Used as auxiliary function to analyze contraction graph in contract. + + n=size(A,1) + assert(size(A,2)==n) + + componentlist=Vector{Vector{Int}}() + assignedlist=falses((n,)) + + for i=1:n + if !assignedlist[i] + assignedlist[i]=true + checklist=[i] + currentcomponent=[i] + while !isempty(checklist) + j=pop!(checklist) + for k=find(A[j,:]) + if !assignedlist[k] + push!(currentcomponent,k) + push!(checklist,k) + assignedlist[k]=true; + end + end + end + push!(componentlist,currentcomponent) + end + end + return componentlist +end diff --git a/src/indexnotation/poly.jl b/src/indexnotation/poly.jl new file mode 100644 index 00000000..50fc1c4a --- /dev/null +++ b/src/indexnotation/poly.jl @@ -0,0 +1,162 @@ +# lightweight poly for abstract cost estimation +abstract type AbstractPoly{D,T<:Number} end +Base.one(x::AbstractPoly) = one(typeof(x)) +Base.zero(x::AbstractPoly) = zero(typeof(x)) + +function Base.show(io::IO, p::AbstractPoly{D,T}) where {D,T<:Real} + N=degree(p) + for i=N:-1:0 + if i > 0 + print(io,"$(abs(p[i]))*") + print(io, "$D") + i>1 && print(io,"^$i") + print(io, p[i-1]<0 ? " - " : " + ") + else + print(io,"$(abs(p[i]))") + end + end +end +function Base.show(io::IO, p::AbstractPoly{D,T}) where {D,T<:Complex} + N=degree(p) + for i=N:-1:0 + if i>0 + print(io,"($(p[i]))*") + print(io,"$D") + i>1 && print(io,"^$i") + print(io, " + ") + else + print(io,"($(p[i]))") + end + end +end +struct Power{D,T} <: AbstractPoly{D,T} + coeff::T + N::Int +end +degree(p::Power)=p.N +Base.getindex(p::Power{D,T},i::Int) where {D,T} = (i==p.N ? p.coeff : zero(T)) +Power{D}(coeff::T, N::Int=0) where {D,T} = Power{D,T}(coeff, N) + +Base.one(::Type{Power{D,T}}) where {D,T} = Power{D,T}(one(T), 0) +Base.zero(::Type{Power{D,T}}) where {D,T} = Power{D,T}(zero(T), 0) + +Base.convert(::Type{Power{D}}, coeff::Number) where {D} = Power{D}(coeff, 0) +Base.convert(::Type{Power{D,T}}, coeff::Number) where {D,T} = Power{D,T}(coeff, 0) +Base.convert(::Type{Power{D,T}}, p::Power{D}) where {D,T} = Power{D,T}(p.coeff, p.N) + +function Base.show(io::IO,p::Power{D,T}) where {D,T} + if p.coeff==1 + elseif p.coeff==-1 + print(io,"-") + elseif isa(p.coeff,Complex) + print(io,"($(p.coeff))") + else + print(io,"$(p.coeff)") + end + p.coeff==1 || p.coeff==-1 || p.N==0 || print(io,"*") + p.N==0 && (p.coeff==1 || p.coeff==-1) && print(io,"1") + p.N>0 && print(io,"$D") + p.N>1 && print(io,"^$(p.N)") +end + +Base.:*(p1::Power{D}, p2::Power{D}) where {D} = Power{D}(p1.coeff*p2.coeff, degree(p1)+degree(p2)) +Base.:*(p::Power{D}, s::Number) where {D} = Power{D}(p.coeff*s, degree(p)) +Base.:*(s::Number, p::Power) = *(p,s) +Base.:^(p::Power{D}, n::Int) where {D} = Power{D}(p.coeff^n, n*degree(p)) + +struct Poly{D,T} <: AbstractPoly{D,T} + coeffs::Vector{T} +end +degree(p::Poly)=length(p.coeffs)-1 +Base.getindex(p::Poly{D,T}, i::Int) where {D,T} = (0<=i<=degree(p) ? p.coeffs[i+1] : zero(T)) +Poly{D}(coeffs::Vector{T}) where {D,T} = Poly{D,T}(coeffs) +Poly{D}(c0::T) where {D,T} = Poly{D,T}([c0]) +Poly{D}(p::Power{D,T}) where {D,T} = Poly{D,T}(vcat(zeros(T, p.N), p.coeff)) +Poly{D,T}(c0::Number) where {D,T} = Poly{D,T}([T(c0)]) +Poly{D,T1}(p::Power{D,T2}) where {D,T1,T2} = Poly{D,T1}(vcat(zeros(T1,p.N),T1(p.coeff))) + +Base.one(::Type{Poly{D,T}}) where {D,T} = Poly{D,T}([one(T)]) +Base.zero(::Type{Poly{D,T}}) where {D,T} = Poly{D,T}(T[]) + +Base.convert(::Type{Poly{D}}, x::Number) where {D} = Poly{D}([x]) +Base.convert(::Type{Poly{D,T}}, x::Number) where {D,T} = Poly{D,T}(T[x]) +Base.convert(::Type{Poly{D}}, p::Power{D}) where {D} = Poly{D}(vcat(fill(zero(p.coeff), p.N), p.coeff)) +Base.convert(::Type{Poly{D,T}}, p::Power{D}) where {D,T} = Poly{D,T}(vcat(fill(zero(T), p.N), convert(T, p.coeff))) +Base.convert(::Type{Poly{D,T}}, p::Poly{D}) where {D,T} = Poly{D,T}(convert(Vector{T}, p.coeffs)) + +Base.:+(p::Poly{D}, s::Number) where {D} = Poly{D}([p[i]+(i==0 ? s : zero(s)) for i=0:degree(p)]) +Base.:+(s::Number, p::Poly) = +(p,s) +function Base.:+(p1::Power{D,T1}, p2::Power{D,T2}) where {D,T1,T2} + T=promote_type(T1,T2) + coeffs=zeros(T,max(degree(p1),degree(p2))+1) + coeffs[p1.N+1]=p1.coeff + coeffs[p2.N+1]+=p2.coeff + return Poly{D,T}(coeffs) +end +function Base.:+(p::Power{D,T1}, s::T2) where {D,T1,T2} + T=promote_type(T1,T2) + coeffs=zeros(T,degree(p)+1) + coeffs[p.N+1] = p.coeff + coeffs[1] += s + return Poly{D,T}(coeffs) +end +Base.:+(s::Number, p::Power) = +(p,s) + +function Base.:+(p1::Power{D,T1}, p2::Poly{D,T2}) where {D,T1,T2} + T=promote_type(T1,T2) + coeffs=zeros(T,max(degree(p1),degree(p2))+1) + coeffs[(0:degree(p2))+1]=p2.coeffs + coeffs[p1.N+1]+=p1.coeff + return Poly{D,T}(coeffs) +end +Base.:+(p1::Poly{D},p2::Power{D}) where {D}=+(p2,p1) +function Base.:+(p1::Poly{D,T1}, p2::Poly{D,T2}) where {D,T1,T2} + T=promote_type(T1,T2) + coeffs=zeros(T, max(degree(p1),degree(p2))+1) + coeffs[(0:degree(p1))+1] = p1.coeffs + for j=0:degree(p2) + coeffs[j+1] += p2.coeffs[j+1] + end + return Poly{D,T}(coeffs) +end + +Base.:-(p::Poly{D}) where {D} = Poly{D}(-p.coeffs) +Base.:-(p::Poly{D},s::Number) where {D} = Poly{D}([p[i]+(i==0 ? s : zero(s)) for i=0:degree(p)]) +Base.:-(s::Number,p::Poly{D}) where {D} = Poly{D}([-p[i]+(i==0 ? s : zero(s)) for i=0:degree(p)]) +Base.:-(p1::Union{Power{D},Poly{D}}, p2::Union{Power{D},Poly{D}}) where {D} = Poly{D}([p1[i]-p2[i] for i=0:max(degree(p1),degree(p2))]) + +Base.:*(p1::Power{D}, p2::Poly{D}) where {D} = Poly{D}([p1.coeff*p2[n-degree(p1)] for n=0:degree(p1)+degree(p2)]) +Base.:*(p1::Poly{D}, p2::Power{D}) where {D} = *(p2,p1) +Base.:*(p::Poly{D}, s::Number) where {D} = Poly{D}(s*p.coeffs) +Base.:*(s::Number, p::Poly) = *(p,s) +function Base.:*(p1::Poly{D}, p2::Poly{D}) where {D} + N=degree(p1)+degree(p2) + s=p1[0]*p2[0] + coeffs=zeros(typeof(s), N+1) + for i=0:degree(p1) + for j=0:degree(p2) + coeffs[i+j+1]+=p1[i]*p2[j] + end + end + return Poly{D}(coeffs) +end + +Base.promote_rule(::Type{Power{D,T1}}, ::Type{Power{D,T2}}) where {D,T1<:Number,T2<:Number} = Power{D,promote_type(T1,T2)} +Base.promote_rule(::Type{Power{D,T1}}, ::Type{T2}) where {D,T1<:Number,T2<:Number} = Power{D,promote_type(T1,T2)} +Base.promote_rule(::Type{Poly{D,T1}}, ::Type{Poly{D,T2}}) where {D,T1<:Number,T2<:Number} = Poly{D,promote_type(T1,T2)} +Base.promote_rule(::Type{Poly{D,T1}}, ::Type{Power{D,T2}}) where {D,T1<:Number,T2<:Number} = Poly{D,promote_type(T1,T2)} +Base.promote_rule(::Type{Poly{D,T1}}, ::Type{T2}) where {D,T1<:Number,T2<:Number} = Poly{D,promote_type(T1,T2)} + +function Base.:(==)(p1::AbstractPoly{D}, p2::AbstractPoly{D}) where {D} + for i=max(degree(p1),degree(p2)):-1:0 + p1[i]==p2[i] || return false + end + return true +end +function Base.:<(p1::AbstractPoly{D}, p2::AbstractPoly{D}) where {D} + for i=max(degree(p1),degree(p2)):-1:0 + p1[i]p2[i] && return false + end + return false +end diff --git a/src/indexnotation/tensormacro.jl b/src/indexnotation/tensormacro.jl index 73ea5545..d059b5da 100644 --- a/src/indexnotation/tensormacro.jl +++ b/src/indexnotation/tensormacro.jl @@ -1,33 +1,155 @@ # indexnotation/tensormacro.jl # # Defines the @tensor macro which switches to an index-notation environment. -const prime = Symbol("'") +macro tensor(ex::Expr) + tensorify(ex) +end +macro tensoropt(ex::Expr) + tensorify(ex, optdata(ex)) +end +macro tensoropt(optex::Expr, ex::Expr) + tensorify(ex, optdata(optex, ex)) +end +macro optimalcontractiontree(ex::Expr) + if isassignment(ex) || isdefinition(ex) + _,ex = getlhsrhs(ex::Expr) + elseif !(ex.head == :call && ex.args[1] == :*) + error("cannot compute optimal contraction tree for this expression") + end + network = [getindices(ex.args[k]) for k = 2:length(ex.args)] + tree, cost = optimaltree(network, optdata(ex)) + return tree, cost +end +macro optimalcontractiontree(optex::Expr, ex::Expr) + if isassignment(ex) || isdefinition(ex) + _,ex = getlhsrhs(ex::Expr) + elseif !(ex.head == :call && ex.args[1] == :*) + error("cannot compute optimal contraction tree for this expression") + end + network = [getindices(ex.args[k]) for k = 2:length(ex.args)] + tree, cost = optimaltree(network, optdata(optex, ex)) + return tree, cost +end -macro tensor(arg) - tensorify(arg) +function optdata(ex::Expr) + allindices = unique(getallindices(ex)) + cost = Power{:χ}(1,1) + return Dict{Any, typeof(cost)}(i=>cost for i in allindices) end +function optdata(optex::Expr, ex::Expr) + optex.head == :tuple || error("invalid index cost specification") + + isempty(optex.args) && return tensorify(ex) -function tensorify(ex::Expr) - if ex.head == :(=) || ex.head == :(:=) || ex.head == :(+=) || ex.head == :(-=) - lhs = ex.args[1] - rhs = ex.args[2] + args = optex.args + if isa(args[1], Expr) && args[1].head == :call && args[1].args[1] == :(=>) + indices = Vector{Any}(length(args)) + costs = Vector{Any}(length(args)) + costtype = typeof(parsecost(args[1].args[3])) + for k = 1:length(args) + if isa(args[k], Expr) && args[k].head == :call && args[k].args[1] == :(=>) + indices[k] = args[k].args[2] + costs[k] = parsecost(args[k].args[3]) + costtype = promote_type(costtype, typeof(costs[k])) + else + error("invalid index cost specification") + end + end + costs = convert(Vector{costtype}, costs) + else + indices = args + costtype = Power{:chi,Int} + costs = fill(Power{:χ,Int}(1,1), length(args)) + end + makeindices!(indices) + return Dict{Any, costtype}(indices[k]=>costs[k] for k = 1:length(args)) +end + +function parsecost(ex::Expr) + if ex.head == :call && ex.args[1] == :* + return *(map(parsecost, ex.args[2:end])...) + elseif ex.head == :call && ex.args[1] == :+ + return +(map(parsecost, ex.args[2:end])...) + elseif ex.head == :call && ex.args[1] == :- + return -(map(parsecost, ex.args[2:end])...) + elseif ex.head == :call && ex.args[1] == :^ + return ^(map(parsecost, ex.args[2:end])...) + elseif ex.head == :call && ex.args[1] == :/ + return /(map(parsecost, ex.args[2:end])...) + else + error("invalid index cost specification: $ex") + end +end +parsecost(ex::Number) = ex +parsecost(ex::Symbol) = Power{ex}(1,1) + + +isassignment(ex::Expr) = ex.head == :(=) || ex.head == :(+=) || ex.head == :(-=) +isdefinition(ex::Expr) = ex.head == :(:=) || (ex.head == :call && ex.args[1] == :(≝)) + +function getlhsrhs(ex::Expr) + if ex.head == :(=) || ex.head == :(+=) || ex.head == :(-=) || ex.head == :(:=) + return ex.args[1], ex.args[2] + elseif ex.head == :call && ex.args[1] == :(≝) + return ex.args[2], ex.args[3] + else + error("invalid assignment or definition $ex") + end +end + +function tensorify(ex::Expr, optdata = nothing) + # assignment case + if isassignment(ex) || isdefinition(ex) + #TODO: remove when := is removed + # if ex.head == :(:=) + # warn(":= will likely be deprecated as assignment operator in Julia, use ≝ (\\eqdef + TAB) or go to http://github.com/Jutho/TensorOperations.jl to suggest ASCII alternatives", once=true, key=:warnaboutcoloneq) + # end + lhs, rhs = getlhsrhs(ex) + # process left hand side if isa(lhs, Expr) && lhs.head == :ref - dst = tensorify(lhs.args[1]) - src = ex.head == :(-=) ? tensorify(Expr(:call,:-,rhs)) : tensorify(rhs) - indices = makeindex_expr(lhs) - if ex.head == :(:=) - return :($dst = deindexify($src, $indices)) + dst = esc(lhs.args[1]) + if length(lhs.args) == 2 && lhs.args[2] == :(:) + indices = getindices(rhs) + if all(isa(i, Integer) && i < 0 for i in indices) + indices = makeindices!(sort(indices, rev=true)) + else + error("cannot automatically infer index order of left hand side") + end else + indices = makeindices!(lhs.args[2:end]) + end + src = ex.head == :(-=) ? tensorify(Expr(:call, :-, rhs), optdata) : tensorify(rhs, optdata) + if isassignment(ex) value = ex.head == :(=) ? 0 : +1 - return :(deindexify!($dst, $src, $indices, $value)) + return :(deindexify!($dst, $src, Indices{$(tuple(indices...))}(), $value)) + else + return :($dst = deindexify($src, Indices{$(tuple(indices...))}() )) end + elseif isdefinition(ex) + # if lhs is not an index expression, there is no difference between assignment and definition + ex = Expr(:(=), lhs, rhs) end end + # single tensor expression if ex.head == :ref - indices = makeindex_expr(ex) - t = tensorify(ex.args[1]) - return :(indexify($t,$indices)) + indices = makeindices!(ex.args[2:end]) + t = esc(ex.args[1]) + return :(indexify($t, Indices{$(tuple(indices...))}() )) end + # tensor contraction: structure contraction order + if ex.head == :call && ex.args[1] == :* && length(ex.args) > 3 + network = [getindices(ex.args[k]) for k = 2:length(ex.args)] + if optdata == nothing + if isnconstyle(network) + tree = ncontree(network) + ex = tree2expr(ex.args[2:end], tree) + end + else + tree, = optimaltree(network, optdata) + ex = tree2expr(ex.args[2:end], tree) + end + end + # scalar if ex.head == :call && ex.args[1] == :scalar if length(ex.args) != 2 error("scalar accepts only a single argument") @@ -36,32 +158,71 @@ function tensorify(ex::Expr) indices = :(Indices{()}()) return :(scalar(deindexify($src, $indices))) end - return Expr(ex.head,map(tensorify,ex.args)...) + return Expr(ex.head, map(tensorify, ex.args)...) end tensorify(ex::Symbol) = esc(ex) tensorify(ex) = ex -function makeindex_expr(ex::Expr) +# for any index expression, get the list of uncontracted indices from that expression +function getindices(ex::Expr) if ex.head == :ref - for i = 2:length(ex.args) - if isa(ex.args[i],Expr) && ex.args[i].head == prime - ex.args[i] = makesymbolprime(ex.args[i]) - end - isa(ex.args[i],Int) || isa(ex.args[i],Symbol) || isa(ex.args[i],Char) || error("cannot make indices from $ex") + indices = makeindices!(ex.args[2:end]) + return unique2(indices) + elseif ex.head == :call && (ex.args[1] == :+ || ex.args[1] == :-) + return getindices(ex.args[2]) # getindices on any of the args[2:end] should yield the same result + elseif ex.head == :call && ex.args[1] == :* + indices = getindices(ex.args[2]) + for k = 3:length(ex.args) + append!(indices, getindices(ex.args[k])) end + return unique2(indices) + elseif ex.head == :call && length(ex.args) == 2 + return getindices(ex.args[2]) + else + return Vector{Any}() + end +end +getindices(ex) = Vector{Any}() + +function getallindices(ex::Expr) + if ex.head == :ref + return makeindices!(ex.args[2:end]) + elseif !isempty(ex.args) + return unique(mapreduce(getallindices, vcat, ex.args)) else - error("cannot make indices from $ex") + return Vector{Any}() end - return :(Indices{$(tuple(ex.args[2:end]...))}()) end +getallindices(ex) = Vector{Any}() -function makesymbolprime(ex::Expr) - if isa(ex,Expr) && ex.head == prime && length(ex.args) == 1 - if isa(ex.args[1],Symbol) || isa(ex.args[1],Int) - return Symbol(ex.args[1],prime) - elseif isa(ex.args[1],Expr) && ex.args[1].head == prime - return Symbol(makesymbolprime(ex.args[1]),prime) +# make the arguments of a :ref expression into a proper list of indices of type Int, Char or Symbol +function makeindices!(list::Vector) + for i = 1:length(list) + if isa(list[i], Expr) + list[i] = makesymbol(list[i]) end + isa(list[i], Int) || isa(list[i], Symbol) || isa(list[i], Char) || error("cannot make index from $(list[i])") + end + return list +end +# make a symbol from an index that is itself an expression: currently only supports priming +const prime = Symbol("'") +function makesymbol(ex::Expr) + if ex.head == prime && length(ex.args) == 1 + if isa(ex.args[1], Symbol) || isa(ex.args[1], Int) + return Symbol(ex.args[1], "′") + elseif isa(ex.args[1], Expr) + return Symbol(makesymbol(ex.args[1]), "′") + end + # could be extended with other functionality + end + error("cannot make index from $ex") +end + +function tree2expr(args, tree) + if isa(tree, Int) + return args[tree] + else + return Expr(:call, :*, tree2expr(args, tree[1]), tree2expr(args, tree[2])) end - error("cannot make indices from $ex") end