### Tensor Computation Workshop 2017, September 15, NYC

# &nbsp;

# TensorOperations.jl:
## Elementary tensor operations with Julia
### (and fun with metaprogramming)

# &nbsp;

### Jutho Haegeman
#### Department of Physics and Astronomy, UGent

## Overview


* **Motivation: Tensor Network Decompositions in Quantum Physics**
* **Intro to the Julia Language**
* **TensorOperations.jl**
* **Implementation of basic tensor operations (with metaprogramming)**
* **Optimization of tensor contraction order**
* **Outlook**

## Motivation: quantum many body physics
* quantum bit ( = qubit):

$$\vert\Psi\rangle = \alpha \vert 0\rangle + \beta \vert 1\rangle\quad\text{with}\quad\alpha,\beta\in\mathbb{C}$$

* intrinsically indeterministic:

    * $|\alpha|^2$: probability of measuring 0
    * $|\beta|^2$: probability of measuring 1

* for $N$ different qubits? 

$$\vert\Psi\rangle = \Psi_{00000}\vert 00000 \rangle + \Psi_{00001} \vert 00001\rangle + \ldots+ \Psi_{11111} \vert 11111\rangle$$

**$\Rightarrow$ storing a quantum state of $N$ qubits requires $2^N$ complex numbers: $\Psi_{i_1,i_2,\ldots,i_{N}}$**

## Motivation: quantum many body physics
* quantum state is a high-order tensor / multidimensional array:
  <img src="psi.png" style="width: 500px;"/>

* Curse of dimensionality: exponential scaling in $N$, the number of degrees of freedom (qubits, spins, atoms, ...)
  
* Realistic materials: $N$ is in the order of Avogadro's number, i.e. $O(10^{23})$
  <img src="graphene.jpg" style="width: 300px;"/>

## Motivation: tensor network decompositions
* graphical notation:
    * matrix - vector multiplication: <img src="matvec.png" style="width: 400px;"/>
    * matrix - matrix multiplication: <img src="matmat.png" style="width: 400px;"/>
* tensor network decompositions for efficient description of quantum states
  <img src="tn2.png" style="width: 600px;"/>

## Introduction to the Julia Language

  <img src="julia.png" style="width: 200px;"/>


* Selling point: dynamic high-level language with the speed of a statically-compiled language

* Key features:
    * Just-in-time compiled (using LLVM infrastructure)
    * Dynamic type system
    * Multiple dispatch:
        * define function behavior across many combinations of argument types
        * automatic generation of efficient, specialized code for different argument types
    * Good support for computational science: numerics, statistics, multidimensional arrays, ...
    * Homoiconic and powerful metaprogramming facilities

### Code generation

In [31]:
function myabs(x)
    if x < 0
        return -x
    end
    return x
end
function myabs2(x::Real)
    if x < 0
        return -x
    end
    return x
end
function myabs2(x::Unsigned)
    return x
end

myabs2 (generic function with 2 methods)

### Code generation

In [62]:
@code_native myabs(3) # LLVM code for 64-bit integer

	.section	__TEXT,__text,regular,pure_instructions
Filename: In[31]
	pushq	%rbp
	movq	%rsp, %rbp
Source line: 2
	testq	%rdi, %rdi
	js	L14
Source line: 5
	movq	%rdi, %rax
	popq	%rbp
	retq
Source line: 3
L14:
	negq	%rdi
	movq	%rdi, %rax
	popq	%rbp
	retq
	nopw	%cs:(%rax,%rax)


In [33]:
code_llvm(myabs,Tuple{UInt64}) # LLVM code for 64-bit unsigned integer


define i64 @julia_myabs_62677(i64) #0 !dbg !5 {
L4:
  ret i64 %0
}


In [34]:
code_llvm(myabs,Tuple{Float64}) # LLVM code for 64-bit floating point


define double @julia_myabs_62678(double) #0 !dbg !5 {
top:
  %1 = fcmp uge double %0, 0.000000e+00
  br i1 %1, label %L9, label %if

if:                                               ; preds = %top
  %2 = fsub double -0.000000e+00, %0
  ret double %2

L9:                                               ; preds = %top
  ret double %0
}


### Type inference & type stability

In [35]:
mysqrt(x) = x < zero(x) ? sqrt(complex(x)) : sqrt(x)
code_warntype(mysqrt,Tuple{Float64})

Variables:
  #self#::#mysqrt
  x::Float64

Body:
  begin 
      unless (Base.lt_float)(x::Float64, (Base.sitofp)(Float64, 0)::Float64)::Bool goto 3
      return $(Expr(:invoke, MethodInstance for sqrt(::Complex{Float64}), :(Main.sqrt), :($(Expr(:new, Complex{Float64}, :(x), :((Base.sitofp)(Float64, 0)::Float64))))))
      3: 
      return (Base.Math.sqrt_llvm)(x::Float64)::Float64
  end[1m[91m::Union{Complex{Float64}, Float64}[39m[22m


### Type inference & type stability

In [36]:
function summyabs(v::Vector)
    s = myabs(v[1])
    for i = 2:length(v)
        s += abs(v[i])
    end
    return s
end
code_warntype(summyabs, Tuple{Vector{Float64}})

Variables:
  #self#::#summyabs
  v::Array{Float64,1}
  i::Int64
  #temp#@_4::Int64
  s::Float64
  fy::Float64
  #temp#@_7::Float64

Body:
  begin 
      SSAValue(2) = (Base.arrayref)(v::Array{Float64,1}, 1)::Float64
      $(Expr(:inbounds, false))
      # meta: location In[31] myabs 2
      # meta: location float.jl < 491
      fy::Float64 = (Base.sitofp)(Float64, 0)::Float64
      # meta: pop location
      unless (Base.or_int)((Base.lt_float)(SSAValue(2), fy::Float64)::Bool, (Base.and_int)((Base.and_int)((Base.eq_float)(SSAValue(2), fy::Float64)::Bool, (Base.lt_float)(fy::Float64, 9.223372036854776e18)::Bool)::Bool, (Base.slt_int)((Base.fptosi)(Int64, fy::Float64)::Int64, 0)::Bool)::Bool)::Bool goto 11 # line 3:
      #temp#@_7::Float64 = (Base.neg_float)(SSAValue(2))::Float64
      goto 14
      11:  # line 5:
      #temp#@_7::Float64 = SSAValue(2)
      14: 
      # meta: pop location
      $(Expr(:inbounds, :pop))
      s::Float64 = #temp#@_7::Float64 # line 3:
      SSAValue(3) =

### Type inference & type stability

In [37]:
function summysqrt(v::Vector)
    s = mysqrt(v[1])
    for i = 2:length(v)
        s += mysqrt(v[i])
    end
    return s
end
code_warntype(summysqrt,Tuple{Vector{Int64}})

Variables:
  #self#::#summysqrt
  v::Array{Int64,1}
  i::Int64
  #temp#@_4::Int64
  s[1m[91m::Union{Complex{Float64}, Float64}[39m[22m
  #temp#@_6[1m[91m::Union{Complex{Float64}, Float64}[39m[22m
  #temp#@_7[1m[91m::Union{Complex{Float64}, Float64}[39m[22m
  #temp#@_8::Core.MethodInstance
  #temp#@_9[1m[91m::Union{Complex{Float64}, Float64}[39m[22m

Body:
  begin 
      SSAValue(2) = (Base.arrayref)(v::Array{Int64,1}, 1)::Int64
      $(Expr(:inbounds, false))
      # meta: location In[35] mysqrt 1
      unless (Base.slt_int)(SSAValue(2), 0)::Bool goto 7
      #temp#@_6[1m[91m::Union{Complex{Float64}, Float64}[39m[22m = $(Expr(:invoke, MethodInstance for sqrt(::Complex{Float64}), :(Base.sqrt), :($(Expr(:new, Complex{Float64}, :((Base.sitofp)(Float64, SSAValue(2))::Float64), :((Base.sitofp)(Float64, 0)::Float64))))))
      goto 9
      7: 
      #temp#@_6[1m[91m::Union{Complex{Float64}, Float64}[39m[22m = (Base.Math.sqrt_llvm)((Base.sitofp)(Float64, SSAValue(2))::

### Homoiconicity

In [38]:
ex=:(function summysqrt(v::Vector)
        s = mysqrt(v[1])
        for i = 2:length(v)
            s += mysqrt(v[i])
        end
        return x
    end);

In [39]:
typeof(ex)

Expr

In [40]:
println(ex.head),println(ex.args[1]),println(ex.args[2]);

function
summysqrt(v::Vector)
begin  # In[38], line 2:
    s = mysqrt(v[1]) # In[38], line 3:
    for i = 2:length(v) # In[38], line 4:
        s += mysqrt(v[i])
    end # In[38], line 6:
    return x
end


In [41]:
Meta.show_sexpr(ex)

(:function, (:call, :summysqrt, (:(::), :v, :Vector)), (:block,
    (:line, 2, Symbol("In[38]")),
    (:(=), :s, (:call, :mysqrt, (:ref, :v, 1))),
    (:line, 3, Symbol("In[38]")),
    (:for, (:(=), :i, (:(:), 2, (:call, :length, :v))), (:block,
        (:line, 4, Symbol("In[38]")),
        (:+=, :s, (:call, :mysqrt, (:ref, :v, :i)))
      )),
    (:line, 6, Symbol("In[38]")),
    (:return, :x)
  ))

### Metaprogramming

In [42]:
macro twice(ex)
    Expr(:block, esc(ex), esc(ex))
end

@twice (macro with 1 method)

In [43]:
x=3;
@twice x+=1
x

5

In [44]:
macroexpand(:(@twice x+=1))

quote 
    x += 1
    x += 1
end

## TensorOperations.jl
* general tensor operations include permutations, partial traces, contractions
    * graphical:
      <img src="tensorcontraction.png" style="width: 500px;"/>
      
    * index notation with Einstein summation convention:
      $$D_{a,b,c} = A_{a,d,e,c}\cdot B_{f,e,b,d,f}+C_{c,b,a}$$

In [64]:
n=3;
A=randn(n,n,n,n);
B=randn(n,n,n,n,n);
C=randn(n,n,n);

D=zeros(n,n,n);
for a=1:n, b=1:n, c=1:n
    D[a,b,c] += C[c,b,a]
    for d=1:n, e=1:n, f=1:n
        D[a,b,c] += A[a,d,e,c]*B[f,e,b,d,f]
    end
end

using TensorOperations
@tensor D2[a,b,c] := A[a,d,e,c]*B[f,e,b,d,f] + C[c,b,a];

@tensor D3[α,β,3] := A[α,d',-7,3]*B[f′′,-7,β,d',f′′] + C[3,β,α];

vecnorm(D-D2)

macroexpand(:(@tensor D2[a,b,c] := A[a,d,e,c]*B[f,e,b,d,f] + C[c,b,a];))

quote 
    D2 = (TensorOperations.deindexify)((TensorOperations.indexify)(A, (TensorOperations.Indices){(:a, :d, :e, :c)}()) * (TensorOperations.indexify)(B, (TensorOperations.Indices){(:f, :e, :b, :d, :f)}()) + (TensorOperations.indexify)(C, (TensorOperations.Indices){(:c, :b, :a)}()), (TensorOperations.Indices){(:a, :b, :c)}())
end

In [46]:
function f1!(D,n,A,B,C)
    for a=1:n, b=1:n, c=1:n
        D[a,b,c] += C[c,b,a]
        for d=1:n, e=1:n, f=1:n
            D[a,b,c] += A[a,d,e,c]*B[f,e,b,d,f]
        end
    end
    return D
end
function f2!(D,n,A,B,C)
    @tensor D[a,b,c] = A[a,d,e,c]*B[f,e,b,d,f] + C[c,b,a] 
    return D
end

n=30;
A=randn(n,n,n,n);
B=randn(n,n,n,n,n);
C=randn(n,n,n);
D=zeros(n,n,n);

In [47]:
@time f1!(D,n,A,B,C);
@time f2!(D,n,A,B,C);

  6.015464 seconds (7.77 k allocations: 372.718 KiB)
  0.027367 seconds (1.63 k allocations: 6.682 MiB)


In [48]:
@time f1!(D,n,A,B,C);
@time f2!(D,n,A,B,C);

  5.773314 seconds (4 allocations: 160 bytes)
  0.016246 seconds (117 allocations: 6.598 MiB)


### What is going on underneath?
* **Basic tensor operations** (inspired by BLAS)
    * permutations and addition: `C = β*C + α*permutation(op(A))`
    * partial trace: `C = β*C + α*partialtrace(op(A))`
    * contraction: `C = β*C + α*contract(op(A),op(B))`
    
  `op` can be idenity (doing nothing) or `conj`
    
  also available via function based syntax

## Implementation of basic tensor operations (with metaprogramming)
### 1. Permutations

In [49]:
A=randn(10,10,10,10,10,10,10,10);
B=zeros(10,10,10,10,10,10,10,10);

In [50]:
myreverse!(B,A) = (@tensor B[8,7,6,5,4,3,2,1] = A[1,2,3,4,5,6,7,8])

myreverse! (generic function with 1 method)

In [51]:
@time copy!(B,A);
@time permutedims!(B,A,[8,7,6,5,4,3,2,1]);
@time myreverse!(B,A);

  0.216277 seconds (8 allocations: 224 bytes)
 19.048340 seconds (450.00 M allocations: 6.706 GiB, 4.22% gc time)
  0.284500 seconds (260.51 k allocations: 3.995 MiB, 0.53% gc time)


In [52]:
@time copy!(B,A);
@time permutedims!(B,A,[8,7,6,5,4,3,2,1]);
@time myreverse!(B,A);

  0.128558 seconds (8 allocations: 224 bytes)
 20.239143 seconds (450.00 M allocations: 6.706 GiB, 2.12% gc time)
  0.309911 seconds (260.01 k allocations: 3.968 MiB)


#### Timing results:
  <img src="permutationtimings.pdf" style="width: 1600px;"/>



### 1. Permutations
* How to optimize permutations? Why is it slower than normal copy?
* Even for matrix transposition?
  <img src="transpose.png" style="width: 600px;"/>
  Memory is linear $\Rightarrow$ `transpose` requires unfavorable memory access!
* Cache-oblivious approach (divide and conquer) already gives you a decent efficiency!

### 1. Permutations
* How to generalize cache-oblivious approach to multidimensional permutations?
    1. What is the best blocking (divide and conquer) strategy?
    2. How to write nested loops depending on the dimensionality of the array?


* Solution to 1: divide dimensions along which the minimum of the memory jumps of the two arrays is maximal.
* Solution to 2: generated functions!

parse -> expressions -> macro expansion -> new expression -> type inference -> generated functions -> compile -> run

[TensorOperations.jl kernels](https://github.com/Jutho/TensorOperations.jl/blob/master/src/implementation/kernels.jl)


### 2. Partial trace
* very similar, but somewhat more carefull

### 3. Tensor contraction: very similar to matrix multiplication

* Native divide and conquer implementation
* TTGT approach using BLAS

## Optimization of tensor contraction order

### Contraction order matters!

* matrix - matrix - vector multiplication: `A*B*v`: 
  
  `A*(B*v)` is much faster than `(A*B)*v`


* Optimal contraction order in more complicated tensor networks?
  <img src="mera.png" style="width: 200px;"/>
  
* Pairwise contraction is always sufficient, but in which sequence?

### What is optimal contraction sequence?

* Manual determination can become laborious task
* Contraction of two-dimensional multiscale entanglement renormalization ansatz:
  <img src="2dmerac.png" style="width: 1200px;"/>

### Algorithmic determination of optimal contraction sequence

"Faster identification of optimal contraction sequences for tensor networks"

Robert N. C. Pfeifer, JH, and Frank Verstraete, Phys Rev E 90, 033315 (2014)

* Breadth-first constructive approach:
  <img src="algorithm.png" style="width: 800px;"/>
* Add tricks to make it efficient   

  <img src="mera2.png" style="width: 1200px;"/>
  

In [53]:
ex=:(result[-1,-2,-3,-4,-5,-6] := 
        W1[-1,1,2]*W2[-2,3,4]*W3[-3,5,6]*
        U1[2,3,7,8]*U2[4,5,9,10]*
        h[7,8,9,11,12,13]*
        conj(U1)[14,15,11,12]*conj(U2)[16,17,13,10]*
        conj(W1)[-4,1,14]*conj(W2)[-5,15,16]*conj(W3)[-6,17,6])
Meta.show_sexpr(ex)


(:(:=), (:ref, :result, -1, -2, -3, -4, -5, -6), (:call, :*, (:ref, :W1, -1, 1, 2), (:ref, :W2, -2, 3, 4), (:ref, :W3, -3, 5, 6), (:ref, :U1, 2, 3, 7, 8), (:ref, :U2, 4, 5, 9, 10), (:ref, :h, 7, 8, 9, 11, 12, 13), (:ref, (:call, :conj, :U1), 14, 15, 11, 12), (:ref, (:call, :conj, :U2), 16, 17, 13, 10), (:ref, (:call, :conj, :W1), -4, 1, 14), (:ref, (:call, :conj, :W2), -5, 15, 16), (:ref, (:call, :conj, :W3), -6, 17, 6)))

In [66]:
@optimalcontractiontree !(1,5) W1[-1,1,2]*W2[-2,3,4]*W3[-3,5,6]*
        U1[2,3,7,8]*U2[4,5,9,10]*
        h[7,8,9,11,12,13]*
        conj(U1)[14,15,11,12]*conj(U2)[16,17,13,10]*
        conj(W1)[-4,1,14]*conj(W2)[-5,15,16]*conj(W3)[-6,17,6]


LoadError: [91mMethodError: Cannot `convert` an object of type TensorOperations.Power{:χ,Int64} to an object of type TensorOperations.Power{:chi,Int64}
This may have arisen from a call to the constructor TensorOperations.Power{:chi,Int64}(...),
since type constructors fall back to convert methods.[39m

In [55]:
tic()
@optimalcontractiontree W1[-1,1,2]*W2[-2,3,4]*W3[-3,5,6]*
        U1[2,3,7,8]*U2[4,5,9,10]*
        h[7,8,9,11,12,13]*
        conj(U1)[14,15,11,12]*conj(U2)[16,17,13,10]*
        conj(W1)[-4,1,14]*conj(W2)[-5,15,16]*conj(W3)[-6,17,6]
toc()


elapsed time: 0.004234351 seconds


0.004234351

In [65]:
using TensorOperations
function f1(W1, W2, W3, U1, U2, h)
    @tensor result[:] := 
        W1[-1,1,2]*W2[-2,3,4]*W3[-3,5,6]*
        U1[2,3,7,8]*U2[4,5,9,10]*
        h[7,8,9,11,12,13]*
        conj(U1)[14,15,11,12]*conj(U2)[16,17,13,10]*
        conj(W1)[-4,1,14]*conj(W2)[-5,15,16]*conj(W3)[-6,17,6]
    return result
end
function f2(W1, W2, W3, U1, U2, h)
    @tensoropt result[-1,-2,-3,-4,-5,-6] := 
        W1[-1,1,2]*W2[-2,3,4]*W3[-3,5,6]*
        U1[2,3,7,8]*U2[4,5,9,10]*
        h[7,8,9,11,12,13]*
        conj(U1)[14,15,11,12]*conj(U2)[16,17,13,10]*
        conj(W1)[-4,1,14]*conj(W2)[-5,15,16]*conj(W3)[-6,17,6]
    return result
end
Meta.show_sexpr(macroexpand(:(@tensoropt result[-1,-2,-3,-4,-5,-6] := 
        W1[-1,1,2]*W2[-2,3,4]*W3[-3,5,6]*
        U1[2,3,7,8]*U2[4,5,9,10]*
        h[7,8,9,11,12,13]*
        conj(U1)[14,15,11,12]*conj(U2)[16,17,13,10]*
        conj(W1)[-4,1,14]*conj(W2)[-5,15,16]*conj(W3)[-6,17,6])))

(:(=), :result, (:call, :(TensorOperations.deindexify), (:call, :*, (:call, :*, (:call, :(TensorOperations.indexify), (:call, :conj, :W1), (:call, (:curly, :(TensorOperations.Indices), (-4, 1, 14)))), (:call, :(TensorOperations.indexify), :W1, (:call, (:curly, :(TensorOperations.Indices), (-1, 1, 2))))), (:call, :*, (:call, :*, (:call, :(TensorOperations.indexify), :W3, (:call, (:curly, :(TensorOperations.Indices), (-3, 5, 6)))), (:call, :(TensorOperations.indexify), (:call, :conj, :W3), (:call, (:curly, :(TensorOperations.Indices), (-6, 17, 6))))), (:call, :*, (:call, :*, (:call, :(TensorOperations.indexify), :W2, (:call, (:curly, :(TensorOperations.Indices), (-2, 3, 4)))), (:call, :(TensorOperations.indexify), :U2, (:call, (:curly, :(TensorOperations.Indices), (4, 5, 9, 10))))), (:call, :*, (:call, :*, (:call, :(TensorOperations.indexify), (:call, :conj, :U2), (:call, (:curly, :(TensorOperations.Indices), (16, 17, 13, 10)))), (:call, :(TensorOperations.indexify), (:call, :conj, :W2),

In [57]:
x=5;
W1=W2=W3=randn(x,x,x);
U1=U2=randn(x,x,x,x);
h=randn(x,x,x,x,x,x);

In [58]:
@time result=f1(W1, W2, W3, U1, U2, h);
@time result2=f2(W1, W2, W3, U1, U2, h);

  2.078769 seconds (1.95 M allocations: 2.404 GiB, 12.24% gc time)
  0.033222 seconds (7.19 k allocations: 3.434 MiB)


In [59]:
@time result=f1(W1, W2, W3, U1, U2, h);
@time result2=f2(W1, W2, W3, U1, U2, h);

  2.449243 seconds (1.94 M allocations: 2.403 GiB, 19.23% gc time)
  0.001907 seconds (922 allocations: 3.090 MiB)


In [60]:
vecnorm(result-result2)/vecnorm(result)

8.32871813299754e-16

## Outlook
Possible features that might be added to `TensorOperations.jl` in the future:
* More flexible index notation
    * to allow for a mixed combination with manual loops in order to take slices
    * to automatically specify reshapes: e.g. 
    ```julia
    C[(a,b),(c,d)] = A[a,c,e]*B[b,d,e]
    Q,R = qr(C)
    E[a,f,b] = Q[(a,b),g]*D[g,f]
    ```
* Multi-threading support? Currently being implemented in Julia
* GPU support? Also being developed in Julia
* Linking to TBLIS, HPTT, ...

## Conclusions
* Julia is a promising language for developing tensor algorithms (and much more)
* I would like to see all of BLAS and LAPACK (with a modernized interface) being written in pure Julia, with maybe a  small amount LLVM / Assembly code injections (`Base.llvmcal("...")`) for the wholy quadruple of types, but also allowing completely generic user types (finite fields, ...) 