In [2]:
import numpy as np
import tensorflow as tf

In [3]:
A = np.array([
    [1, 2, 3],
    [4, 5, 6],
    [0, -1, -2]
])

B = np.array([
    [0, 0, 1],
    [2, 3, 4], 
    [-1, 0, 2]
])

In [4]:
print('Matmul:')
print(np.matmul(A, B))
print('Einsum:')
print(np.einsum('ij,jk -> ik', A, B))

Matmul:
[[ 1  6 15]
 [ 4 15 36]
 [ 0 -3 -8]]
Einsum:
[[ 1  6 15]
 [ 4 15 36]
 [ 0 -3 -8]]


In [5]:
print('Normal:')
print(A*B)
print('Einsum:')
print(np.einsum('ij,ij -> ij', A, B))

Normal:
[[ 0  0  3]
 [ 8 15 24]
 [ 0  0 -4]]
Einsum:
[[ 0  0  3]
 [ 8 15 24]
 [ 0  0 -4]]


In [6]:
print('Normal:')
print(A.T)
print('Einsum:')
print(np.einsum('ij -> ji', A))

Normal:
[[ 1  4  0]
 [ 2  5 -1]
 [ 3  6 -2]]
Einsum:
[[ 1  4  0]
 [ 2  5 -1]
 [ 3  6 -2]]


In [7]:
np.einsum('ij,jk -> ki', A, B)    # first matmul then transpose 

array([[ 1,  4,  0],
       [ 6, 15, -3],
       [15, 36, -8]])

In [8]:
A_3d = np.array([
    [
        [2, 5, 5, 2],
        [2, -2, 2, 3],
        [1, 5, 3, 0]
    ],
    [
        [1, 3, 1, 22],
        [0, 2, 2, 0],
        [1, 5, 4, 1]
    ]
])

B_3d = np.array([
    [
        [2, 5, 5, 2, 0],
        [2, -2, 2, 3, 1],
        [1, 5, 3, 0, 4],
        [-3, -4, 0, 7, 2]
    ],
    [
        [1, 3, 1, 22, 0],
        [0, 2, 2, 0, 1],
        [1, 5, 4, 1, 4],
        [-3, -4, 2, 1, 0]
    ]
])

In [9]:
print(A_3d.shape, B_3d.shape)

(2, 3, 4) (2, 4, 5)


(2, 3, 4)x(2, 4, 5) -> (2, 3, 5)

In [10]:
print('Normal:')
print(A_3d @ B_3d)
print('Einsum:')
print(np.einsum('bij,bjk -> bik', A_3d, B_3d))

Normal:
[[[ 13  17  35  33  29]
  [ -7  12  12  19  12]
  [ 15  10  24  17  17]]

 [[-64 -74  55  45   7]
  [  2  14  12   2  10]
  [  2  29  29  27  21]]]
Einsum:
[[[ 13  17  35  33  29]
  [ -7  12  12  19  12]
  [ 15  10  24  17  17]]

 [[-64 -74  55  45   7]
  [  2  14  12   2  10]
  [  2  29  29  27  21]]]


In [11]:
print('Normal:')
print(np.sum(A))
print('Einsum:')
print(np.einsum('ij -> ', A))
print(np.einsum('bij -> ', A_3d))

Normal:
18
Einsum:
18
70


In [12]:
print('Normal:')
print(np.sum(A, axis= 0))
print('Einsum:')
print(np.einsum('ij -> j', A))

Normal:
[5 6 7]
Einsum:
[5 6 7]


In [13]:
print('Normal:')
print(np.sum(A, axis= 1))
print('Einsum:')
print(np.einsum('ij -> i', A))

Normal:
[ 6 15 -3]
Einsum:
[ 6 15 -3]


In [14]:
Q = np.random.randn(32, 64, 512)
K = np.random.randn(32, 128, 512)

In [15]:
np.einsum('bij,bkj -> bik', Q, K).shape    # bij x bkj.T

(32, 64, 128)

In [16]:
A = np.random.randn(2, 4, 4, 2)
B = np.random.randn(2, 4, 4, 1)

In [17]:
np.einsum('bcij,bcik -> bcjk', A, B).shape

(2, 4, 2, 1)

In [18]:
np.einsum('bcij,bcik -> bckj', A, B).shape

(2, 4, 1, 2)

In [19]:
print('Normal:')
print(np.matmul(np.transpose(B, (0, 1, 3, 2)), A))
print('Einsum:')
print(np.einsum('bcij,bcik -> bcjk', A, B))

Normal:
[[[[-0.72771207 -1.93534818]]

  [[-1.318677   -2.02757867]]

  [[-0.51570906 -0.16685797]]

  [[-1.03833446 -0.54490717]]]


 [[[-5.19482103  1.5033355 ]]

  [[ 1.00128279  2.49137695]]

  [[-0.81537162  5.75194177]]

  [[-1.29919953 -3.80853147]]]]
Einsum:
[[[[-0.72771207]
   [-1.93534818]]

  [[-1.318677  ]
   [-2.02757867]]

  [[-0.51570906]
   [-0.16685797]]

  [[-1.03833446]
   [-0.54490717]]]


 [[[-5.19482103]
   [ 1.5033355 ]]

  [[ 1.00128279]
   [ 2.49137695]]

  [[-0.81537162]
   [ 5.75194177]]

  [[-1.29919953]
   [-3.80853147]]]]


In [21]:
tf.einsum('ijkl -> ', A)

<tf.Tensor: shape=(), dtype=float64, numpy=-0.870390895468884>