# In this tutorial we will learn some basics of Tensor Networks


# T1.1:  Diagrammaatic notations

In [46]:
##### Lets initialize some tensors in Julia
using LinearAlgebra
# tensor with randomly generated entries, order 3, dims: 2-by-3-by-4
A = rand(2,3,4)

# identity matrix, order 2, dims: 5-by-5 (New syntax in Julia 0.7+)
B = Matrix{Float64}(I,5,5)

# tensor of 1's, order 4, dims: 2-by-4-by-2-by-4
C = ones(2,4,2,4)

# matrix of 0's, order 2, dims: 3-by-5
D = zeros(3,5)

# initialize complex random tensor
E = rand(2,3,4) + im*rand(2,3,4) 

2×3×4 Array{Complex{Float64},3}:
[:, :, 1] =
 0.927357+0.463365im   0.79503+0.0410167im   0.98649+0.0481558im
 0.401424+0.823595im  0.518311+0.936971im   0.675189+0.914591im

[:, :, 2] =
 0.266294+0.938711im  0.0131703+0.855307im  0.0238696+0.412979im
 0.190188+0.159857im   0.523797+0.946069im   0.672592+0.390884im

[:, :, 3] =
 0.0142286+0.377461im  0.356804+0.269358im  0.591547+0.161821im
  0.331023+0.528359im  0.429539+0.434299im  0.555872+0.968045im

[:, :, 4] =
 0.380565+0.246555im  0.722401+0.59244im     0.171468+0.492581im
 0.606162+0.590639im  0.034199+0.0763232im  0.0628821+0.0630447im

# Permute and reshape functions

In [10]:
##### Ex.1.2(a):Permute
A = rand(4,4)
Atilda = permutedims(A,[2, 1])

4×4 Array{Float64,2}:
 0.881664  0.733946  0.324903  0.251353
 0.930197  0.970817  0.590985  0.416322
 0.20118   0.081186  0.217943  0.989526
 0.808224  0.876775  0.862441  0.21242

In [11]:
A

4×4 Array{Float64,2}:
 0.881664  0.930197  0.20118   0.808224
 0.733946  0.970817  0.081186  0.876775
 0.324903  0.590985  0.217943  0.862441
 0.251353  0.416322  0.989526  0.21242

The permute function reorders the storage of the elements of a tensor in computer memory, thus incurs some (often non-negligible) computational cost. In contrast, the reshape function leaves the elements of a tensor unchanged in memory, instead only changing the metadata for how the tensor is to be interpreted (and thus incurs negligible cost).

# T1.3: Binary Tensor Contractions

In [107]:
##### Ex.1.3(a): Binary Tensor Contraction
d = 10;
A = rand(d,d,d,d);  B = rand(d,d,d,d);

Ap  = permutedims(A,[1,3,2,4]);  Bp  = permutedims(B,[1,4,2,3]);
App = reshape(Ap,d^2,d^2);       Bpp = reshape(Bp,d^2,d^2);
Cpp = App*Bpp;
C   = reshape(Cpp,d,d,d,d)

10×10×10×10 Array{Float64,4}:
[:, :, 1, 1] =
 21.1753  24.688   21.7039  19.5421  …  19.5244  21.5617  21.8122  21.2162
 21.9283  18.7537  23.0226  22.0904     20.6288  22.9037  20.0999  23.7029
 21.6995  20.8393  22.1285  18.0818     23.191   21.0316  19.8575  20.4696
 19.868   21.5143  21.5291  22.5236     24.1452  20.096   22.6564  20.2768
 20.6781  21.0757  23.645   20.7892     22.0696  21.3847  18.9267  24.1184
 22.1595  19.7935  20.3354  22.5613  …  21.7464  23.0787  20.7267  21.3788
 21.1928  21.6383  21.1821  22.84       23.3314  20.7213  22.0296  23.2483
 18.4367  20.1891  21.2093  21.6924     21.2524  20.8109  20.0076  21.9238
 20.8562  20.2091  21.8254  21.2681     21.1659  17.0782  19.494   22.6893
 22.8383  18.5147  20.7855  19.054      21.5631  20.7245  20.2301  20.3345

[:, :, 2, 1] =
 23.3683  26.3756  24.1008  22.1318  …  21.9361  23.7766  21.0517  23.2808
 23.2038  21.3758  23.2444  23.8385     21.3442  23.9409  21.3138  23.7952
 23.4181  20.6789  22.7816  20.3803    

# ncon function

"ncon" function can be used to do the contractions in tensor networks. It reduces the computational effort to a great extent. One can also use it to calculate the partial trace and to combine disjoint tensor.

In [31]:
"""
    ncon(tensor_list, connect_list_in; con_order=[], check_network=true)
------------------------
    by Glen Evenbly (c) for www.tensors.net, (v1.2) - last modified 6/2020
------------------------
Network CONtractor. Input is an array of tensors 'tensor_list' and an array of
vectors 'connect_list_in',  with each vector labelling the indices of the
corresponding tensor. Labels should be  positive integers for contracted indices
and negative integers for free indices. Optional input 'con_order'  can be used
to specify order of index contractions (otherwise defaults to ascending order of
the positive indices). Checking of the consistancy of the input network can be
disabled for slightly faster operation.

Further information can be found at: https://arxiv.org/abs/1402.0939
"""
function ncon(
  tensor_list,
  connect_list_in;
  con_order = [],
  check_network = true,
)

  # copy original list to avoid destroying
  connect_list = deepcopy(connect_list_in)

  # put inputs into an array if necessary
  if (tensor_list[1] isa Real) | (tensor_list[1] isa Complex)
    tensor_list = Any[tensor_list]
  end
  if !(connect_list[1] isa Array)
    connect_list = Any[connect_list]
  end

  # generate contraction order if necessary
  flat_connect = vcat(connect_list...)
  if isempty(con_order)
    con_order = sort(unique(flat_connect[flat_connect.>0]))
  end

  # check inputs if enabled
  if check_network
    dims_list = Array{Any,1}(undef, length(tensor_list))
    for ik = 1:length(tensor_list)
      dims_list[ik] = [size(tensor_list[ik])...]
    end
    ncon_check_inputs(connect_list, flat_connect, dims_list, con_order)
  end

  # do all partial traces
  for ip = 1:length(connect_list)
    num_cont = length(connect_list[ip]) - length(unique(connect_list[ip]))
    if num_cont > 0
      tensor_list[ip], connect_list[ip], cont_label =
        ncon_partial_trace(tensor_list[ip], connect_list[ip])
      con_order = setdiff(con_order, cont_label)
    end
  end

  # do all binary contractions
  while !isempty(con_order)
    # identify tensors to be contracted
    cont_ind = con_order[1];
    locs = [
      ele
      for
      ele in collect(1:length(connect_list)) if
      sum(connect_list[ele] .== cont_ind) > 0
    ]

    # do a binary contraction
    cont_many = intersect(connect_list[locs[1]], connect_list[locs[2]])
    A_cont = [findfirst(connect_list[locs[1]] .== x) for x in cont_many]
    B_cont = [findfirst(connect_list[locs[2]] .== x) for x in cont_many]
    A_free = deleteat!(collect(1:length(connect_list[locs[1]])), sort(A_cont))
    B_free = deleteat!(collect(1:length(connect_list[locs[2]])), sort(B_cont))
    push!(
      tensor_list,
      ncon_tensordot(
        tensor_list[locs[1]],
        tensor_list[locs[2]],
        A_cont,
        B_cont,
      ),
    )
    push!(
      connect_list,
      vcat(connect_list[locs[1]][A_free], connect_list[locs[2]][B_free]),
    )

    # remove contracted tensors from list and update con_order
    deleteat!(connect_list, locs)
    deleteat!(tensor_list, locs)
    con_order = setdiff(con_order, cont_many)
  end

  # do all outer products
  while length(tensor_list) > 1
    s1 = size(tensor_list[end-1])
    s2 = size(tensor_list[end])
    tensor_list[end-1] = reshape(
      reshape(tensor_list[end-1], prod(s1)) *
      reshape(tensor_list[end], 1, prod(s2)),
      (s1..., s2...),
    )
    connect_list[end-1] = vcat(connect_list[end-1], connect_list[end])
    deleteat!(connect_list, length(connect_list))
    deleteat!(tensor_list, length(tensor_list))
  end

  # do final permutation
  if length(connect_list[1]) > 0
    return permutedims(tensor_list[1], sortperm(connect_list[1], by = abs))
  else
    return tensor_list[1][1]
  end
end

"""
ncon_tensordot: contracts a pair of tensors via matrix multiplication,
similar to the Numpy function of the same name
"""
function ncon_tensordot(A, B, A_cont, B_cont)

  A_free = deleteat!(collect(1:ndims(A)), sort(A_cont))
  B_free = deleteat!(collect(1:ndims(B)), sort(B_cont))
  A_perm = vcat(A_free, A_cont)
  B_perm = vcat(B_cont, B_free)

  return reshape(
    reshape(
      permutedims(A, A_perm),
      prod(size(A)[A_free]),
      prod(size(A)[A_cont]),
    ) * reshape(
      permutedims(B, B_perm),
      prod(size(B)[B_cont]),
      prod(size(B)[B_free]),
    ),
    (size(A)[A_free]..., size(B)[B_free]...),
  )
end

"""
ncon_partial_trace: partial trace on tensor A over repeated labels in A_label
"""
function ncon_partial_trace(A, A_label)

  num_cont = length(A_label) - length(unique(A_label))
  if num_cont > 0
    dup_list = []
    for ele in unique(A_label)
      if sum(A_label .== ele) > 1
        dup_list = vcat(dup_list, findall(A_label .== ele))
      end
    end

    cont_ind =
      reshape(permutedims(reshape(dup_list, 2, num_cont), [2, 1]), 2 * num_cont)
    free_ind = deleteat!(collect(1:length(A_label)), sort(dup_list))
    cont_dim = prod(size(A)[cont_ind[1:num_cont]])
    free_dim = size(A)[free_ind]

    cont_label = unique(A_label[cont_ind])
    B = zeros(prod(free_dim))
    perm_tot = [free_ind; cont_ind]

    A_dims = size(A)
    A = reshape(
      permutedims(reshape(A[:], A_dims), vcat(free_ind, cont_ind)),
      prod(free_dim),
      cont_dim,
      cont_dim,
    )
    for ip = 1:cont_dim
      B = B + A[:, ip, ip]
    end

    return reshape(B, free_dim), deleteat!(A_label, sort(cont_ind)), cont_label
  else
    return A, A_label
  end
end

"""
ncon_check_inputs: check consistency of input tensor network
"""
function ncon_check_inputs(connect_list, flat_connect, dims_list, con_order)

  pos_ind = flat_connect[flat_connect.>0]
  neg_ind = flat_connect[flat_connect.<0]

  # check that lengths of lists match
  if length(dims_list) != length(connect_list)
    e_str0 = "NCON error: $(length(dims_list)) tensors given ";
    error(e_str0,"but $(length(connect_list)) index sublists given")
  end

  # check that tensors have the right number of indices
  for ik = 1:length(dims_list)
    if length(dims_list[ik]) != length(connect_list[ik])
      e_str0 = "number of indices does not match number of labels on tensor ";
      e_str1 = "$(ik): $(length(dims_list[ik]))-indices "
      error(e_str0,e_str1,"versus $(length(connect_list[ik]))-labels")
    end
  end

  # check that contraction order is valid
  if !(sort(con_order) == sort(unique(pos_ind)))
    error("NCON error: invalid contraction order")
  end

  # check that negative indices are valid
  for ind = -1:-1:-length(neg_ind)
    if sum(neg_ind .== ind) == 0
      error("NCON error: no index labelled $(ind)")
    elseif sum(neg_ind .== ind) > 1
      error("NCON error: more than one index labelled $(ind)")
    end
  end

  # check that positive indices are valid and contracted tensor dimensions match
  flat_dims = []
  for ele in dims_list
    flat_dims = vcat(flat_dims, ele)
  end
  for ind in unique(pos_ind)
    if sum(pos_ind .== ind) == 1
      error("NCON error: only one index labelled $(ind)")
    elseif sum(pos_ind .== ind) > 2
      error("NCON error: more than two indices labelled $(ind)")
    end
    cont_dims = flat_dims[flat_connect.==ind]
    if cont_dims[1] != cont_dims[2]
      e_str0 = "dimension mismatch on index labelled $(ind): "
      error(e_str0,"dim-$(cont_dims[1]) versus dim-$(cont_dims[2])")
    end
  end

  return true
end
                


ncon_check_inputs

In [106]:
##### Ex.1.5(b): Contraction using ncon
#include("ncon.jl")
d = 10;
A = rand(d,d,d); B = rand(d,d,d,d);
C = rand(d,d,d); D = rand(d,d);

TensorArray = Any[A,B,C,D];
IndexArray = Any[[1,-2,2],[-1,1,3,4],[5,3,2],[4,5]];

E = ncon(TensorArray,IndexArray)

10×10 Array{Float64,2}:
 5934.91  5744.34  5981.41  5046.53  …  5861.63  5343.32  6068.09  5529.35
 6150.68  5965.05  6189.11  5216.7      6077.87  5551.89  6283.38  5757.45
 6083.9   5902.08  6161.76  5181.91     6044.95  5474.58  6254.42  5706.09
 6173.51  5951.41  6217.09  5231.15     6074.31  5551.09  6291.74  5762.49
 6119.96  5896.99  6178.12  5176.71     6032.36  5459.42  6281.92  5753.39
 5960.38  5802.34  6007.01  5054.13  …  5912.43  5388.71  6133.35  5639.99
 6099.53  5905.27  6162.89  5192.88     6020.27  5506.13  6282.63  5732.81
 6104.26  5933.7   6152.08  5175.07     6022.16  5520.28  6287.15  5714.53
 5858.08  5662.22  5922.29  4972.88     5779.77  5244.46  6040.07  5493.24
 6157.98  5998.13  6234.38  5250.07     6113.95  5575.87  6344.27  5759.93

In [43]:
##### Ex.1.5(c): Partial trace
d = 10;
A = rand(d,d,d,d,d,d);

B = ncon(Any[A],Any[[-1,-2,1,-3,-4,1]]);

In [44]:
##### Ex.1.5(d): Disjoint networks
d = 10;
A = rand(d,d);
B = rand(d,d);

C = ncon(Any[A,B],Any[[-1,-2],[-3,-4]]);

# Problem Sheet 1

In [109]:
d = 4;
A = rand(d,d,d); B = rand(d,d,d); C = rand(d,d,d);

# b).

In [110]:
# Binary contraction

# we will start with contracting B and A first and then we will contract with C

Ap  = permutedims(A,[1,3,2]);  Bp  = permutedims(B,[1,2,3]); # doing the necessary permutaions
App = reshape(Ap,d^2,d);       Bpp = reshape(Bp,d,d^2); # reshaping

Ap_Bp = App*Bpp; # the first binary operation. The resulting matrix is a d^2 by d^2 matrix

# Now we need to reshape the array again 
Ap_Bp = reshape(Ap_Bp, d,d,d,d)

Ap_Bp = permutedims(Ap_Bp,[1,3,2,4]) ; 

Ap_Bp = reshape(Ap_Bp, d^2, d^2) # reshaping


# now we need to do the second binary operation
D = Ap_Bp*reshape(C, d^2, d)
D = reshape(D, d, d, d)

4×4×4 Array{Float64,3}:
[:, :, 1] =
  8.39431  10.3054  7.86183  6.09542
  9.88297  12.605   9.19239  7.18686
 10.9116   13.0974  9.5401   7.92622
  8.65083  10.5464  7.51779  6.32336

[:, :, 2] =
  8.13779  10.1176   7.83304  5.89421
  9.76389  12.1524   8.869    6.71718
 10.3122   12.4131   8.99222  7.28036
  8.34295   9.84292  6.8192   5.68353

[:, :, 3] =
  8.20018  10.8782   8.01295  6.16127
 11.4353   14.8206  10.7632   8.16382
 11.2298   14.1639  10.0993   8.40978
 10.394    13.1646   9.03182  7.55317

[:, :, 4] =
  7.86697  10.3938   8.10375  6.40823
  9.27385  12.7089   9.71256  7.49696
 10.4561   13.2812  10.307    8.54835
  8.05313  10.4099   7.8233   6.71518

In [111]:
# The above code can be written in a few lines

mid = reshape(permutedims(reshape(reshape(permutedims(A, [1,3,2]), d^2,d)*reshape(B,d,d^2),
             d,d,d,d), [1,3,2,4]),d^2,d^2)

D1 = reshape(mid*reshape(C, d^2, d), d, d, d)

4×4×4 Array{Float64,3}:
[:, :, 1] =
  8.39431  10.3054  7.86183  6.09542
  9.88297  12.605   9.19239  7.18686
 10.9116   13.0974  9.5401   7.92622
  8.65083  10.5464  7.51779  6.32336

[:, :, 2] =
  8.13779  10.1176   7.83304  5.89421
  9.76389  12.1524   8.869    6.71718
 10.3122   12.4131   8.99222  7.28036
  8.34295   9.84292  6.8192   5.68353

[:, :, 3] =
  8.20018  10.8782   8.01295  6.16127
 11.4353   14.8206  10.7632   8.16382
 11.2298   14.1639  10.0993   8.40978
 10.394    13.1646   9.03182  7.55317

[:, :, 4] =
  7.86697  10.3938   8.10375  6.40823
  9.27385  12.7089   9.71256  7.49696
 10.4561   13.2812  10.307    8.54835
  8.05313  10.4099   7.8233   6.71518

In [120]:
# by ncon function 
ncon(Any[A,B,C],Any[[1,-2,2],[-1,1,3],[3,2,-3]]);


In [121]:
D2

4×4×4 Array{Float64,3}:
[:, :, 1] =
 11.7302   11.475    13.1268   9.53629
  7.82193   8.00331   9.1225   6.71506
  8.65737   9.36316   9.81819  7.46549
  8.38096   8.72623   8.60349  7.18394

[:, :, 2] =
 11.252    12.2557   12.2543   7.83286
  7.23558   8.05684   8.06796  5.10447
  8.95743  10.3687    9.35729  6.80101
  7.72796   8.68686   7.83012  5.8647

[:, :, 3] =
 13.515    13.2784   15.1018   10.7639
  9.68529   9.87138  11.4891    7.87432
  9.28118   9.88462  10.1644    7.75076
  8.94471   8.60431   8.81901   6.99441

[:, :, 4] =
 11.3136   10.9328   13.1543   9.59904
  7.45716   7.65143   8.98773  6.79002
  8.60403   9.00791  10.0909   7.742
  8.76089   8.6799    9.32858  7.5501