<a href="https://colab.research.google.com/github/Erickrus/llm/blob/main/einsum_is_all_you_need.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**EINSUM IS ALL YOU NEED**

EINSTEIN SUMMATION IN DEEP LEARNING

pytorch: https://pytorch.org/docs/stable/generated/torch.einsum.html

tutorial url: https://rockt.github.io/2018/04/30/einsum

– Tim Rocktäschel, 30/04/2018 – *updated 02/05/2018*

When talking to colleagues I realized that not everyone knows about *einsum*, my favorite function for developing deep learning models.
This post is trying to change that once and for all! :)
Einstein summation (einsum) is implemented in [numpy](https://docs.scipy.org/doc/numpy/reference/generated/numpy.einsum.html), as well as deep learning libraries such as [TensorFlow](https://www.tensorflow.org/api_docs/python/tf/einsum) and, thanks to [Thomas Viehmann](https://github.com/pytorch/pytorch/pull/6307), recently also [PyTorch](http://pytorch.org/docs/master/torch.html?highlight=torch%20max#torch.einsum).
For background reading on einsum, I recommend the excellent blog posts by [Olexa Bilaniuk](https://obilaniu6266h16.wordpress.com/2016/02/04/einstein-summation-in-numpy/) and [Alex Riley](http://ajcr.net/Basic-guide-to-einsum/).
While their posts discuss einsum in the context of numpy, I am going to illustrate how einsum is extremely useful for writing elegant PyTorch/TensorFlow models.[^1]




## 1 Einsum Notation[^2]

If you are anything like me, you find it difficult to remember the names and signatures of all the different functions in PyTorch/TensorFlow for calculating dot products, outer products, transposes and matrix-vector or matrix-matrix multiplications.
Einsum notation is an elegant way to express all of these, as well as complex operations on tensors, using essentially a domain-specific language.  
This has benefits beyond not having to memorize or regularly looking up specific library functions.
Once you understand and make use of einsum, you will be able to write more concise and efficient code more quickly.
When not using einsum it is easy to introduce unnecessary reshaping and transposing of tensors, as well as intermediate tensors that could be omitted.
Furthermore, domain-specific languages like einsum can sometimes be compiled to high-performing code, and an einsum-like domain-specific language is in fact the basis for the recently introduced [Tensor Comprehensions](http://pytorch.org/2018/03/05/tensor-comprehensions.html)[^3] in PyTorch which automatically generate GPU code and auto-tune that code for specific input sizes.
In addition, projects like [opt einsum](https://github.com/dgasmith/opt_einsum) and [tf einsum opt](https://github.com/Bihaqo/tf_einsum_opt) can be used to optimize tensor contraction order of einsum expressions.[^4]

Let's say we want to multiply two matrices **A** in $ℝ^{I × K}$ and **B** in $ℝ^{K × J}$ followed by calculating the sum of each column resulting in a vector **c** in $ℝ^J$. Using Einstein summation notation, we can write this as

$ c_j = \sum_i\sum_k A_{ik}B_{kj} = A_{ik}B_{kj} $

which specifies how all individual elements $c_i$ in **c** are calculated from multiplying values in the column vectors **$A_{i:}$** and row vectors **$B_{:j}$** and summing them up. Note that for Einstein notation, the summation Sigmas can be dropped as we implicitly sum over repeated indices (k in this example) and indices not mentioned in the output specification (i in this example).

So far so good, but we can also express more basic operations using einsum. For instance, calculating the dot product of two vectors **a**, **b** in $ℝ^I$ can be written as

$ c = \sum_i a_i b_i = a_i b_i. $

A problem that I encounter often in deep learning is applying a transformation to vectors in a higher-order tensor. For example, I might have a tensor that contains T-long sequences of K-dimensional word vectors for N training examples in a batch and I want to project the word vectors to a different dimension Q. Let **T** in $ℝ^{N × T × K}$ be an order-3 tensor where the first dimension corresponds to the batch, the second dimension to the sequence length, and the last dimension to the word vectors. In addition, let **W** in $ℝ^{K × Q}$ be a projection matrix. The desired computation can be expressed using einsum

$ C_{ntq} = \sum_k T_{ntk}W_{kq} = T_{ntk}W_{kq}. $

As a final example, say you are given an order-4 tensor **T** in $ℝ^{N × T × K × M}$ and you are supposed to project vectors in the 3rd dimension to Q using the projection matrix from before. However, let's say I also ask you to sum over the 2nd dimension and transpose the first and last dimension in the result, yielding a tensor **C** in $ℝ^{M × Q × N}$.[^5] Einsum to the rescue!

$ C_{mqn} = \sum_t\sum_k T_{ntkm}W_{kq} = T_{ntkm}W_{kq}. $

Note that transposing the result of the tensor contraction is achieved by swapping n with m $C_{mqn}$ instead of $C_{nqm}$.


## 2 All you Need: Einsum in numpy, PyTorch, and TensorFlow

Einsum is implemented in [numpy](https://docs.scipy.org/doc/numpy/reference/generated/numpy.einsum.html) via `np.einsum`, in [PyTorch](http://pytorch.org/docs/master/torch.html?highlight=torch%20max#torch.einsum) via `torch.einsum`, and in [TensorFlow](https://www.tensorflow.org/api_docs/python/tf/einsum) via `tf.einsum`.[^6] All three einsum functions share the same signature `einsum(equation, operands)` where `equation` is a string representing the Einstein summation and `operands` is a sequence of tensors.[^7] The examples above can all be written using an equation string. For instance, our first example $c_j = \sum_i\sum_k A_{ik}B_{kj}$ can be written as the equation string "`ik,kj -> j`". Note that the naming of the indices (i, j, k) is arbitrary but it needs to be used consistently.


What's great about having einsum not only in numpy but also in PyTorch and TensorFlow is that it can be used in arbitrary computation graphs for neural network architectures and that we can backpropagate through it.
A typical call to einsum has the following form
$
result = einsum("{\color{red}\square\square},{\color{purple}\square\square\square},{\color{blue}\square\square}\,\text{->}\,{\color{green}\square\square}", {\color{red}\text{arg1}}, {\color{purple}\text{arg2}}, {\color{blue}\text{arg3}})
$
where $ \square $ is a placeholder for a character identifying a tensor dimension.

From this equation string, we can infer that $ \color{red}\text{arg1} $ and {\color{blue}\text{arg3}} are matrices, {\color{purple}\text{arg2}} is an order-3 tensor, and that the $\color{green}\textbf{result}$ of this einsum operation is a matrix.

Note that einsum works with a variable number of inputs.
In the example above, einsum specifies an operation on three arguments, but it can also be used for operations involving one, two, or more than three arguments.

Einsum is best learned by studying examples, so let's go through some examples for einsum in PyTorch that correspond to library functions which are used in many deep learning models.


FOOTNOTES:

1
My examples use PyTorch, but translating them to TensorFlow is trivial.

2
The first version of this post was incorrectly using a summation Sigma which is not Einstein notation but classical notation. Thanks to Christian Wolf and reddit/ML for pointing this out.

3
Vasilache, Zinenko, Theodoridis, Goyal, DeVito, Moses, Verdoolaege, Adams and Cohen. Tensor Comprehensions: Framework-Agnostic High-Performance Machine Learning Abstractions. arXiv preprint arXiv:1802.04730. 2018
4
Thanks to Stephan Hoyer and Alexander Novikov for the pointers.

5
Thanks to Blammar for pointing out a previous error.

6
Thanks to Martin Trapp for pointing out that there is also a Julia implementation.

7
In numpy and TensorFlow, operands can be a variable-length argument list whereas in PyTorch it needs to be a list.

8
Farquhar, Rocktäschel, Igl and Whiteson. TreeQN and ATreeC: Differentiable Tree-Structured Models for Deep Reinforcement Learning. in: International Conference on Learning Representations (ICLR). 2018

9
He, Zhang, Ren and Sun. Deep Residual Learning for Image Recognition. in: 2016 IEEE Conference on Computer Vision and Pattern Recognition, CVPR. 2016

10
Rocktäschel, Grefenstette, Hermann, Kocisky and Blunsom. Reasoning about Entailment with Neural Attention. in: International Conference on Learning Representations (ICLR). 2016

1. **Matrix Scalar Multiplication:**

   If you have a matrix $A$ and you want to multiply it by a scalar $k$, the result $B$ is obtained by multiplying each element of $A$ by $k$:

   $
   B_{ij} = k \cdot A_{ij}
   $


In [None]:
#@title numpy example
import numpy as np

A = np.round(np.random.uniform(size=[4, 3]) * 5, 0)
k = np.round(np.random.uniform(size=1) * 3, 0)
B = np.einsum('ij,k->ij', A, k)

print("Matrix Scalar Multiplication")
print("A{0}: \n{1}".format(A.shape, A))
print("k{0}: {1}".format(k.shape, k))
print(("einsum: %s,%s->%s" % (str(A.shape), str(k.shape), str(B.shape))).replace(" ",""))
print("B{0}: \n{1}".format(B.shape, B))

Matrix Scalar Multiplication
A(4, 3): 
[[4. 2. 4.]
 [1. 1. 2.]
 [1. 4. 4.]
 [3. 0. 4.]]
k(1,): [1.]
einsum:(4,3),(1,)->(4,3)
B(4, 3): 
[[4. 2. 4.]
 [1. 1. 2.]
 [1. 4. 4.]
 [3. 0. 4.]]


2. **Matrix Multiplication (Dot Product):**

   Matrix multiplication involves taking the dot product of rows and columns. If you have matrices $A$ and $B$ with dimensions $m \times n$ and $n \times p$, respectively, the resulting matrix $C$ will have dimensions $m \times p$. The element at position $(i, j)$ in $C$ is obtained by summing the products of corresponding elements in the $i$-th row of $A$ and the $j$-th column of $B$:

   $
   C_{ij} = \sum_{k=1}^{n} A_{ik} \cdot B_{kj}
   $


In [None]:
#@title numpy example
import numpy as np

A = np.round(np.random.uniform(size=[4, 3]) * 5, 0)
B = np.round(np.random.uniform(size=[3, 4]) * 5, 0)

C = np.einsum('ik,kj->ij', A, B)

print("Matrix Multiplication (Dot Product)")
print("A{0}: \n{1}".format(A.shape, A))
print("B{0}: \n{1}".format(B.shape, B))
print(("einsum: %s,%s->%s" % (str(A.shape), str(B.shape), str(C.shape))).replace(" ",""))
print("C{0}: \n{1}".format(C.shape, C))


Matrix Multiplication (Dot Product)
A(4, 3): 
[[3. 2. 2.]
 [0. 3. 5.]
 [1. 2. 3.]
 [3. 2. 4.]]
B(3, 4): 
[[2. 1. 5. 4.]
 [1. 2. 5. 4.]
 [2. 4. 2. 3.]]
einsum:(4,3),(3,4)->(4,4)
C(4, 4): 
[[12. 15. 29. 26.]
 [13. 26. 25. 27.]
 [10. 17. 21. 21.]
 [16. 23. 33. 32.]]


3. **Hadamard Product (Element-wise Multiplication):**

   If you have two matrices $A$ and $B$ of the same dimensions, the Hadamard product $C$ is obtained by multiplying corresponding elements:

   $
   C_{ij} = A_{ij} \cdot B_{ij}
   $


In [None]:
#@title numpy example

import numpy as np

A = np.round(np.random.uniform(size=[4, 3]) * 5, 0)
B = np.round(np.random.uniform(size=[4, 3]) * 5, 0)

C = np.einsum('ij,ij->ij', A, B)

print("Hadamard Product (Element-wise Multiplication)")
print("A{0}: \n{1}".format(A.shape, A))
print("B{0}: \n{1}".format(B.shape, B))
print(("einsum: %s,%s->%s" % (str(A.shape), str(B.shape), str(C.shape))).replace(" ",""))
print("C{0}: \n{1}".format(C.shape, C))



Hadamard Product (Element-wise Multiplication)
A(4, 3): 
[[4. 2. 4.]
 [0. 2. 5.]
 [5. 5. 4.]
 [3. 2. 1.]]
B(4, 3): 
[[3. 4. 2.]
 [2. 1. 3.]
 [4. 1. 4.]
 [3. 2. 4.]]
einsum:(4,3),(4,3)->(4,3)
C(4, 3): 
[[12.  8.  8.]
 [ 0.  2. 15.]
 [20.  5. 16.]
 [ 9.  4.  4.]]


4. **Transpose:**

   If you have a matrix $A$ with dimensions $m \times n$, the transpose $B$ is obtained by swapping its rows and columns. The element at position $(i, j)$ in $B$ is the element at position $(j, i)$ in $A$:

   $
   B_{ij} = A_{ji}
   $


In [None]:
#@title numpy example
import numpy as np

A = np.round(np.random.uniform(size=[4, 3]) * 5, 0)
B = np.einsum('ji->ij', A)

print("Transpose")
print("A{0}: \n{1}".format(A.shape, A))

print(("einsum: %s->%s" % (str(A.shape), str(B.shape))).replace(" ",""))
print("B{0}: \n{1}".format(B.shape, B))

Transpose
A(4, 3): 
[[4. 4. 4.]
 [4. 3. 1.]
 [4. 2. 1.]
 [2. 4. 1.]]
einsum:(4,3)->(3,4)
B(3, 4): 
[[4. 4. 4. 2.]
 [4. 3. 2. 4.]
 [4. 1. 1. 1.]]


5. **Trace:**

   The trace of a square matrix $A$ is the sum of its diagonal elements:

   $
   \text{tr}(A) = \sum_{i=1}^{n} A_{ii}
   $


In [None]:
#@title numpy example
import numpy as np

A = np.round(np.random.uniform(size=[3, 3]) * 5, 0)
trace_A = np.einsum('ii', A)

print("Trace")
print("A{0}: \n{1}".format(A.shape, A))

print(("einsum: %s->%s" % (str(A.shape), str("(1)"))).replace(" ",""))
print("trace_A{0}: \n{1}".format("(1)", trace_A))


Trace
A(3, 3): 
[[2. 1. 1.]
 [1. 1. 3.]
 [4. 1. 4.]]
einsum:(3,3)->(1)
trace_A(1): 
7.0


**6. Vector-Matrix Multiplication:**

**Row Vector by Matrix:**

a row vector $ \mathbf{v} $ and a matrix $ \mathbf{M} $, the result $ \mathbf{r} $ is obtained by multiplying the row vector by each column of the matrix:

$ \mathbf{r}_i = \sum_j \mathbf{v}_j \cdot \mathbf{M}_{ji} $




In [None]:
#@title numpy example
import numpy as np

# Row vector by matrix
v = np.round(np.random.uniform(size=[3]) * 5, 0)
M = np.round(np.random.uniform(size=[3, 4]) * 5, 0)

r = np.einsum('j,ji->i', v, M)

print("Row vector by matrix")
print("M{0}: \n{1}".format(M.shape, M))
print("v{0}: \n{1}".format(v.shape, v))
print(("einsum: %s,%s->%s" % (str(v.shape), str(M.shape), str(r.shape))).replace(" ",""))
print("r{0}: \n{1}".format(r.shape, r))

Row vector by matrix
M(3, 4): 
[[1. 2. 1. 1.]
 [5. 1. 4. 2.]
 [1. 3. 3. 2.]]
v(3,): 
[1. 4. 1.]
einsum:(3,),(3,4)->(4,)
r(4,): 
[22.  9. 20. 11.]


**Matrix by Column Vector:**

a matrix $ \mathbf{M} $ and a column vector $ \mathbf{v} $, the result $ \mathbf{r} $ is obtained by multiplying each row of the matrix by the column vector:

$ \mathbf{r}_i = \sum_j \mathbf{M}_{ij} \cdot \mathbf{v}_j $

In [None]:
#@title numpy example

import numpy  as np

# Col vector by matrix
v = np.round(np.random.uniform(size=[4]) * 5, 0)
M = np.round(np.random.uniform(size=[3, 4]) * 5, 0)

r = np.einsum('ij,j->i', M, v)

print("Col vector by matrix")
print("M{0}: \n{1}".format(M.shape, M))
print("v{0}: \n{1}".format(v.shape, v))
print(("einsum: %s,%s->%s" % (str(M.shape), str(v.shape), str(r.shape))).replace(" ",""))
print("r{0}: \n{1}".format(r.shape, r))

Col vector by matrix
M(3, 4): 
[[2. 3. 2. 3.]
 [4. 1. 2. 1.]
 [4. 5. 1. 0.]]
v(4,): 
[2. 2. 0. 4.]
einsum:(3,4),(4,)->(3,)
r(3,): 
[22. 14. 18.]
