Skip to content

Commit

Permalink
finalize ncon implementation, add @ncon macro
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Nov 14, 2019
1 parent 8f1f254 commit c8ac71e
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/TensorOperations.jl
Expand Up @@ -8,13 +8,14 @@ using LinearAlgebra: mul!, BLAS.BlasFloat
using LRUCache

# export macro API
export @tensor, @tensoropt, @tensoropt_verbose, @optimalcontractiontree, @notensor
export @tensor, @tensoropt, @tensoropt_verbose, @optimalcontractiontree, @notensor, @ncon

export enable_blas, disable_blas, enable_cache, disable_cache, clear_cache, cachesize

# export function based API
export tensorcopy, tensoradd, tensortrace, tensorcontract, tensorproduct, scalar
export tensorcopy!, tensoradd!, tensortrace!, tensorcontract!, tensorproduct!
export ncon

# Convenient type alias
const IndexTuple{N} = NTuple{N,Int}
Expand Down
49 changes: 49 additions & 0 deletions src/indexnotation/tensormacros.jl
Expand Up @@ -127,3 +127,52 @@ macro optimalcontractiontree(expressions...)
tree, cost = optimaltree(network, optdict)
return tree, cost
end

macro ncon(args...)
if length(args) == 2
return _nconmacro(args[1], args[2])
else
return _nconmacro(args[2], args[3], args[1])
end
end
function _nconmacro(tensors, indices, kwargs = nothing)
if !(tensors isa Expr) # there is not much that we can do
if kwargs === nothing
ex = Expr(:call, :ncon, tensors, indices,
Expr(:call, :fill, false, Expr(:call, :length, tensors)),
QuoteNode(gensym()))
else
ex = Expr(:call, :ncon, kwargs, tensors, indices,
Expr(:call, :fill, false, Expr(:call, :length, tensors)),
QuoteNode(gensym()))
end
return esc(ex)
end
if tensors.head == :vect || tensors.head == :tuple
tensorargs = tensors.args
elseif tensors.head == :ref
tensorargs = tensors.args[2:end]
else
throw(ArgumentError("invalid @ncon syntax"))
end
conjlist = fill(false, length(tensorargs))
for i = 1:length(tensorargs)
if tensorargs[i] isa Expr
if tensorargs[i].head == :call && tensorargs[i].args[1] == :conj
tensorargs[i] = tensorargs[i].args[2]
conjlist[i] = true
end
end
end
if tensors.head == :ref
tensorex = Expr(:ref, tensors.args[1], tensorargs...)
else
tensorex = Expr(:ref, :Any, tensorargs...)
end
if kwargs === nothing
ex = Expr(:call, :ncon, tensorex, indices, conjlist, QuoteNode(gensym()))
else
ex = Expr(:call, :ncon, kwargs, tensorex, indices, conjlist, QuoteNode(gensym()))
end
return esc(ex)
end
6 changes: 6 additions & 0 deletions test/tensor.jl
Expand Up @@ -247,9 +247,15 @@ withcache = TensorOperations.use_cache() ? "with" : "without"
@tensor HrA12′′[:] := rhoL[-1, 1] * H[-2, -3, 4, 5] * A2[2, 5, 3] * rhoR[3, -4] * A1[1, 4, 2] # should be contracted in exactly same order
@tensor HrA12′′′[a, s1, s2, c] := H[s1, s2, t1, t2] * rhoL[a, a'] * rhoR[c', c] * A1[a', t1, b] * A2[b, t2, c'] order=(a',b,c',t1,t2)# should be contracted in exactly same order
@tensoropt HrA12′′′′[:] := rhoL[-1, 1] * H[-2, -3, 4, 5] * A2[2, 5, 3] * rhoR[3, -4] * A1[1, 4, 2]

@test HrA12′ == HrA12′′ == HrA12′′′ # should be exactly equal
@test HrA12 HrA12′
@test HrA12 HrA12′′′′
@test HrA12′′ == ncon([rhoL, H, A2, rhoR, A1],
[[-1,1],[-2,-3,4,5],[2,5,3],[3,-4],[1,4,2]])
@test HrA12′′ == @ncon([rhoL, H, A2, rhoR, A1],
[[-1,1],[-2,-3,4,5],[2,5,3],[3,-4],[1,4,2]];
order = [1,2,3,4,5], output=[-1,-2,-3,-4])
@test E @tensor scalar(rhoL[a', a] * A1[a, s, b] * A2[b, s', c] * rhoR[c, c'] * H[t, t', s, s'] * conj(A1[a', t, b']) * conj(A2[b', t', c']))
end
println("tensor network examples: $(time()-t0) seconds")
Expand Down

0 comments on commit c8ac71e

Please sign in to comment.