# Einsum Tutorial 

In [1]:
import numpy as np

In [2]:
def prettyPrint(x):
    with np.printoptions(precision=4, suppress=True, formatter={'float': '{:0.4f}'.format}, linewidth=100):
        print(x)
        print("")

## Rules
* Free indices: indices specified in output. 
* Summation indices: all other indices. Thosee that appear in the input argument but NOT in output specification.

1. Repeating letters in different inputs means thoses values will be 
   mulitplied and those products will be the output.
2. Omitting a letter means axis will be summed.
3. We can return unsummed axis in any order.

Example: np.einsum('ij,jk -> ik', A,B). Left side ij,jk corresponds to the axis of A\~ij, and B\~jk. The right side correpsonds to the output with axis and dims of ik.


In [3]:
x = np.array([[1,2],[3,4]])
print('x shape: ',x.shape)
prettyPrint(x)
y = np.array([[5],[6]])
print('y shape: ',y.shape)
prettyPrint(y)

x shape:  (2, 2)
[[1 2]
 [3 4]]

y shape:  (2, 1)
[[5]
 [6]]



## Permute

In [4]:
z = np.einsum('ij -> ji', x)
print('z shape: ',z.shape)
prettyPrint(z)

z shape:  (2, 2)
[[1 3]
 [2 4]]



## Sum all 

In [5]:
z = np.einsum('ij -> ', x)
print('z shape: ',z.shape)
prettyPrint(z)

z shape:  ()
10



## Sum row

In [6]:
z = np.einsum('ij -> i ', x)
print('z shape: ',z.shape)
prettyPrint(z)

z shape:  (2,)
[3 7]



## Sum col 

In [7]:
z = np.einsum('ij -> j ', x)
print('z shape: ',z.shape)
prettyPrint(z)

z shape:  (2,)
[4 6]



## Matrix Vect Mult

In [8]:
z = np.einsum('ij,ij -> i ', x, y)
print('z shape: ',z.shape)
prettyPrint(z)

z shape:  (2,)
[15 42]



## Dot product 

In [9]:
z = np.einsum('ij,ij ->  ', y, y)
print('z shape: ',z.shape)
prettyPrint(z)

z shape:  ()
61



## Hadamard

In [10]:
z = np.einsum('ij,ij -> ij', x, x)
print('z shape: ',z.shape)
prettyPrint(z)

z shape:  (2, 2)
[[ 1  4]
 [ 9 16]]



## Outter Product 

In [11]:
b = np.array([1,5])
print('b shape: ',b.shape)
prettyPrint(b)

z = np.einsum('i,j -> ij', b, b)
print('z shape: ',z.shape)
prettyPrint(z)

b shape:  (2,)
[1 5]

z shape:  (2, 2)
[[ 1  5]
 [ 5 25]]



## Batch Matrix Multiplication 

In [12]:
c = np.random.uniform(0,1,(2,3,2))
print('c shape: ',c.shape)
prettyPrint(c)

d = np.random.uniform(0,1,(2,2,3))
print('d shape: ',d.shape)
prettyPrint(d)

z = np.einsum('ijk,ikl -> ijl', c, d)
print('z shape: ',z.shape)
prettyPrint(z)

c shape:  (2, 3, 2)
[[[0.8832 0.3920]
  [0.4225 0.8603]
  [0.5957 0.9276]]

 [[0.6047 0.5411]
  [0.3407 0.0153]
  [0.6460 0.0249]]]

d shape:  (2, 2, 3)
[[[0.1002 0.1918 0.4342]
  [0.4898 0.6019 0.4387]]

 [[0.6598 0.0907 0.1521]
  [0.6551 0.3510 0.4693]]]

z shape:  (2, 3, 3)
[[[0.2806 0.4053 0.5554]
  [0.4637 0.5988 0.5608]
  [0.5141 0.6725 0.6656]]

 [[0.7534 0.2447 0.3459]
  [0.2348 0.0363 0.0590]
  [0.4426 0.0673 0.1099]]]



## Matrix diagonal 

In [13]:
z = np.einsum('ii ->i', x)
print('z shape: ',z.shape)
prettyPrint(z)

z shape:  (2,)
[1 4]



## Matrix trace 

In [14]:
z = np.einsum('ii ->', x)
print('z shape: ',z.shape)
prettyPrint(z)

z shape:  ()
5

