Skip to content

Commit

Permalink
nconstyle and optimized contraction order
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Sep 15, 2017
1 parent 1a92478 commit 1dcaa41
Show file tree
Hide file tree
Showing 6 changed files with 679 additions and 34 deletions.
7 changes: 6 additions & 1 deletion src/TensorOperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module TensorOperations
export tensorcopy, tensoradd, tensortrace, tensorcontract, tensorproduct, scalar
export tensorcopy!, tensoradd!, tensortrace!, tensorcontract!, tensorproduct!

export @tensor
export @tensor, @tensoropt, @optimalcontractiontree

# Auxiliary functions
#---------------------
Expand All @@ -25,7 +25,12 @@ include("implementation/strides.jl")

# Index notation
#----------------
import Base.Iterators.flatten

include("indexnotation/tensormacro.jl")
include("indexnotation/nconstyle.jl")
include("indexnotation/poly.jl")
include("indexnotation/optimize.jl")
include("indexnotation/indexedobject.jl")
include("indexnotation/sum.jl")
include("indexnotation/product.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/indexnotation/indexedobject.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Base.eltype(::Type{IndexedObject{I,C, A, T}}) where {I,C, A, T} = promote_type(e
Expr(:block, meta, :($J))
end

indexify(object, ::Indices{I}) where {I} = IndexedObject{I,:N}(object)
@inline indexify(object, ::Indices{I}) where {I} = IndexedObject{I,:N}(object)

deindexify(A::IndexedObject{I,:N}, ::Indices{I}) where {I} = A.α == 1 ? A.object : A.α*A.object
deindexify(A::IndexedObject{I,:C}, ::Indices{I}) where {I} = A.α == 1 ? conj(A.object) : A.α*conj(A.object)
Expand Down
58 changes: 58 additions & 0 deletions src/indexnotation/nconstyle.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# check if a list of indices specifies a tensor contraction in ncon style
function isnconstyle(network::Vector)
allindices = Vector{Int}()
for ind in network
all(i->isa(i, Integer), ind) || return false
append!(allindices, ind)
end
while length(allindices) > 0
i = pop!(allindices)
if i > 0 # positive labels represent contractions or traces and should appear twice
k = findnext(allindices, i, 1)
l = findnext(allindices, i, k+1)
if k == 0 || l != 0
return false
end
deleteat!(allindices, k)
elseif i < 0 # negative labels represent open indices and should appear once
findnext(allindices, i, 1) == 0 || return false
else # i == 0
return false
end
end
return true
end

function ncontree(network::Vector)
contractionindices = Vector{Vector{Int}}(length(network))
for k = 1:length(network)
indices = network[k]
# trace indices have already been removed, remove open indices by filtering on positive values
contractionindices[k] = filter(i->i>0, indices)
end
partialtrees = collect(Any, 1:length(network))
_ncontree!(partialtrees, contractionindices)
end

function _ncontree!(partialtrees, contractionindices)
if length(partialtrees) == 1
return partialtrees[1]
end
if all(isempty, contractionindices) # disconnected network
partialtrees[end-1] = (partialtrees[end-1], partialtrees[end])
pop!(partialtrees)
pop!(contractionindices)
else
let firstind = minimum(flatten(contractionindices))
i1 = findfirst(x->in(firstind,x), contractionindices)
i2 = findnext(x->in(firstind,x), contractionindices, i1+1)
newindices = unique2(vcat(contractionindices[i1], contractionindices[i2]))
newtree = (partialtrees[i1], partialtrees[i2])
partialtrees[i1] = newtree
deleteat!(partialtrees, i2)
contractionindices[i1] = newindices
deleteat!(contractionindices, i2)
end
end
_ncontree!(partialtrees, contractionindices)
end
259 changes: 259 additions & 0 deletions src/indexnotation/optimize.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
function optimaltree(network, optdata::Dict)
numtensors = length(network)
allindices = unique(flatten(network))
numindices = length(allindices)
costtype = valtype(optdata)
allcosts = [get(optdata, i, one(costtype)) for i in allindices]
maxcost = prod(allcosts)*maximum(allcosts) + zero(costtype) # add zero for type stability: Power -> Poly
tensorcosts = Vector{costtype}(numtensors)
for k = 1:numtensors
tensorcosts[k] = prod(get(optdata, i, one(costtype)) for i in network[k])
end
initialcost = min(maxcost, maximum(tensorcosts)^2 + zero(costtype)) # just some arbitrary guess

if numindices <= 32
return _optimaltree(UInt32, network, allindices, allcosts, initialcost, maxcost)
elseif numindices <= 64
return _optimaltree(UInt64, network, allindices, allcosts, initialcost, maxcost)
elseif numindices <= 128
return _optimaltree(UInt128, network, allindices, allcosts, initialcost, maxcost)
else
return _optimaltree(BitVector, network, allindices, allcosts, initialcost, maxcost)
end
end

storeset(::Type{IntSet}, ints, maxint) = sizehint!(IntSet(ints), maxint)
function storeset(::Type{BitVector}, ints, maxint)
set = falses(maxint)
set[ints] = true
return set
end
function storeset(::Type{T}, ints, maxint) where {T<:Unsigned}
set = zero(T)
u = one(T)
for i in ints
set |= (u<<(i-1))
end
return set
end
_intersect(s1::T, s2::T) where {T<:Unsigned} = s1 & s2
_intersect(s1::BitVector, s2::BitVector) = s1 .& s2
_intersect(s1::IntSet, s2::IntSet) = intersect(s1, s2)
_union(s1::T, s2::T) where {T<:Unsigned} = s1 | s2
_union(s1::BitVector, s2::BitVector) = s1 .| s2
_union(s1::IntSet, s2::IntSet) = union(s1, s2)
_setdiff(s1::T, s2::T) where {T<:Unsigned} = s1 & (~s2)
_setdiff(s1::BitVector, s2::BitVector) = s1 .& (.~s2)
_setdiff(s1::IntSet, s2::IntSet) = setdiff(s1, s2)
_isemptyset(s::Unsigned) = iszero(s)
_isemptyset(s::BitVector) = !any(s)
_isemptyset(s::IntSet) = isempty(s)

function computecost(allcosts, ind1::T, ind2::T) where {T<:Unsigned}
cost = one(eltype(allcosts))
ind = _union(ind1, ind2)
n = 1
while !iszero(ind)
if isodd(ind)
cost *= allcosts[n]
end
ind = ind>>1
n += 1
end
return cost
end
function computecost(allcosts, ind1::BitVector, ind2::BitVector)
cost = one(eltype(allcosts))
ind = _union(ind1, ind2)
for n in find(ind)
cost *= allcosts[n]
end
return cost
end
function computecost(allcosts, ind1::IntSet, ind2::IntSet)
cost = one(eltype(allcosts))
ind = _union(ind1, ind2)
for n in ind
cost *= allcosts[n]
end
return cost
end

function _optimaltree(::Type{T}, network, allindices, allcosts::Vector{S}, initialcost::C, maxcost::C) where {T,S,C}
numindices = length(allindices)
numtensors = length(network)
indexsets = Array{T}(numtensors)

tabletensor = zeros(Int, (numindices,2))
tableindex = zeros(Int, (numindices,2))

adjacencymatrix=falses(numtensors,numtensors)
costfac = maximum(allcosts)

@inbounds for n = 1:numtensors
indn = findin(allindices, network[n])
indexsets[n] = storeset(T, indn, numindices)
for i in indn
if tabletensor[i,1] == 0
tabletensor[i,1] = n
tableindex[i,1] = findfirst(network[n], allindices[i])
elseif tabletensor[i,2] == 0
tabletensor[i,2] = n
tableindex[i,2] = findfirst(network[n], allindices[i])
n1 = tabletensor[i,1]
adjacencymatrix[n1,n] = true
adjacencymatrix[n,n1] = true
else
error("no index should appear more than two times")
end
end
end
componentlist = connectedcomponents(adjacencymatrix)
numcomponent = length(componentlist)

# generate output structures
costlist = Vector{C}(numcomponent)
treelist = Vector{Any}(numcomponent)
indexlist = Vector{T}(numcomponent)

# run over components
for c=1:numcomponent
# find optimal contraction for every component
component = componentlist[c]
componentsize = length(component)
costdict = Array{Dict{T, C}}(componentsize)
treedict = Array{Dict{T, Any}}(componentsize)
indexdict = Array{Dict{T, T}}(componentsize)

for k=1:componentsize
costdict[k] = Dict{T, C}()
treedict[k] = Dict{T, Any}()
indexdict[k] = Dict{T, T}()
end

for i in component
s = storeset(T, [i], numtensors)
costdict[1][s] = zero(C)
treedict[1][s] = i
indexdict[1][s] = indexsets[i]
end

# run over currentcost
currentcost = initialcost
previouscost = zero(initialcost)
while currentcost <= maxcost
nextcost = maxcost
# construct all subsets of n tensors that can be constructed with cost <= currentcost
for n=2:componentsize
# construct subsets by combining two smaller subsets
for k = 1:div(n-1,2)
for s1 in keys(costdict[k]), s2 in keys(costdict[n-k])
if _isemptyset(_intersect(s1, s2)) && get(costdict[n], _union(s1, s2), currentcost) > previouscost
ind1 = indexdict[k][s1]
ind2 = indexdict[n-k][s2]
cind = _intersect(ind1, ind2)
if !_isemptyset(cind)
s = _union(s1, s2)
cost = costdict[k][s1] + costdict[n-k][s2] + computecost(allcosts, ind1, ind2)
if cost <= get(costdict[n], s, currentcost)
costdict[n][s] = cost
indexdict[n][s] = _setdiff(_union(ind1,ind2), cind)
treedict[n][s] = (treedict[k][s1], treedict[n-k][s2])
elseif currentcost < cost < nextcost
nextcost = cost
end
end
end
end
end
if iseven(n) # treat the case k = n/2 special
k = div(n,2)
it = keys(costdict[k])
state1 = start(it)
while !done(it, state1)
s1, nextstate1 = next(it, state1)
state2 = nextstate1
while !done(it, state2)
s2, nextstate2 = next(it, state2)
if _isemptyset(_intersect(s1, s2)) && get(costdict[n], _union(s1, s2), currentcost) > previouscost
ind1 = indexdict[k][s1]
ind2 = indexdict[k][s2]
cind = _intersect(ind1, ind2)
if !_isemptyset(cind)
s = _union(s1, s2)
cost = costdict[k][s1] + costdict[k][s2] + computecost(allcosts, ind1, ind2)
if cost <= get(costdict[n], s, currentcost)
costdict[n][s] = cost
indexdict[n][s] = _setdiff(_union(ind1,ind2), cind)
treedict[n][s] = (treedict[k][s1], treedict[k][s2])
elseif currentcost < cost < nextcost
nextcost = cost
end
end
end
state2 = nextstate2
end
state1 = nextstate1
end
end
end
if !isempty(costdict[componentsize])
break
end
previouscost = currentcost
currentcost = min(maxcost, nextcost*costfac)
end
if isempty(costdict[componentsize])
error("Maxcost $maxcost reached without finding solution") # should be impossible
end
s = storeset(T, component, numtensors)
costlist[c] = costdict[componentsize][s]
treelist[c] = treedict[componentsize][s]
indexlist[c] = indexdict[componentsize][s]
end
tree = treelist[1]
cost = costlist[1]
ind = indexlist[1]
for c = 2:numcomponent
tree = (tree, treelist[c])
cost = cost + costlist[c] + computecost(allcosts, ind, indexlist[c])
ind = _union(ind, indexlist[c])
end
return tree, cost
end

function connectedcomponents(A::AbstractMatrix{Bool})
# For a given adjacency matrix of size n x n, connectedcomponents returns
# a list componentlist that contains integer vectors, where every integer
# vector groups the indices of the vertices of a connected component of the
# graph encoded by A. The number of connected components is given by
# length(componentlist).
#
# Used as auxiliary function to analyze contraction graph in contract.

n=size(A,1)
assert(size(A,2)==n)

componentlist=Vector{Vector{Int}}()
assignedlist=falses((n,))

for i=1:n
if !assignedlist[i]
assignedlist[i]=true
checklist=[i]
currentcomponent=[i]
while !isempty(checklist)
j=pop!(checklist)
for k=find(A[j,:])
if !assignedlist[k]
push!(currentcomponent,k)
push!(checklist,k)
assignedlist[k]=true;
end
end
end
push!(componentlist,currentcomponent)
end
end
return componentlist
end

0 comments on commit 1dcaa41

Please sign in to comment.