# Einsum Operator

# 1. Import Dependencies

In [1]:
import numpy as np

In [2]:
A = np.array([
    [2, 6, 5, 2],
    [2, -2, 2, 3],
    [1, 5, 4, 0],
])

B = np.array([
    [2, 9, 0, 3, 0],
    [3, 6, 8, -2, 2],
    [1, 3, 5, 0, 1],
    [3, 0, 2, 0, 5]
])

A.shape, B.shape

((3, 4), (4, 5))

In [3]:
np.matmul(A, B)

array([[33, 69, 77, -6, 27],
       [ 9, 12,  0, 10, 13],
       [21, 51, 60, -7, 14]])

In [4]:
np.einsum('ij,jk -> ik', A,B)

array([[33, 69, 77, -6, 27],
       [ 9, 12,  0, 10, 13],
       [21, 51, 60, -7, 14]])

In [5]:
C = np.array([
    [2, 6, 5, 2],
    [2, -2, 2, 3],
    [1, 5, 4, 0],
])

D = np.array([
    [2, 9, 0, 3],
    [3, 6, 8, -2],
    [1, 3, 5, 0],
])

print("Hardamond C := \n")
print(C*D, "\n")

print("Einsum C =: \n")
print(np.einsum('ij, ij -> ij', C,D))

Hardamond C := 

[[  4  54   0   6]
 [  6 -12  16  -6]
 [  1  15  20   0]] 

Einsum C =: 

[[  4  54   0   6]
 [  6 -12  16  -6]
 [  1  15  20   0]]


In [6]:
print("Transpose A =: \n")
print(A.T, "\n")

print("Einsum Transpose A =: \n")
print(np.einsum("ij -> ji", A))

Transpose A =: 

[[ 2  2  1]
 [ 6 -2  5]
 [ 5  2  4]
 [ 2  3  0]] 

Einsum Transpose A =: 

[[ 2  2  1]
 [ 6 -2  5]
 [ 5  2  4]
 [ 2  3  0]]


In [7]:
E = np.array([
   [[2, 6, 5, 2],
    [2, -2, 2, 3],
    [1, 5, 4, 0]],
   [[1, 3, 1, 22],
    [0, 2, 2, 0],
    [1, 5, 4, 1]]
])

F = np.array([
   [[2, 9, 0, 3, 0],
    [3, 6, 8, -2, 2],
    [1, 3, 5, 0, 1],
    [3, 0, 2, 0, 5]],
   
   [[1, 0, 0, 3, 0],
    [3, 0, 4, -2, 2],
    [1, 0, 2, 0, 0],
    [3, 0, 1, 1, 0]]
])

E.shape, F.shape

((2, 3, 4), (2, 4, 5))

In [8]:
print("Batch Multiplication G := \n")
print(np.matmul(E,F), "\n")

Batch Multiplication G := 

[[[33 69 77 -6 27]
  [ 9 12  0 10 13]
  [21 51 60 -7 14]]

 [[77  0 36 19  6]
  [ 8  0 12 -4  4]
  [23  0 29 -6 10]]] 



In [9]:
print("Einsum G := \n")
print(np.einsum('bij, bjk -> bik', E, F))

Einsum G := 

[[[33 69 77 -6 27]
  [ 9 12  0 10 13]
  [21 51 60 -7 14]]

 [[77  0 36 19  6]
  [ 8  0 12 -4  4]
  [23  0 29 -6 10]]]


In [10]:
A = np.array([
    [2, 6, 5, 2],
    [2, -2, 2, 3],
    [1, 5, 4, 0],
])

print('Axix 0 Sum A =: \n')
print(np.sum(A, axis=0), "\n")

print('Einsum A =: \n')
print(np.einsum("ij -> j", A))
print("------------------------------")

print('Axix 1 Sum A =: \n')
print(np.sum(A, axis=1), "\n")

print('Einsum A =: \n')
print(np.einsum("ij -> i", A))


Axix 0 Sum A =: 

[ 5  9 11  5] 

Einsum A =: 

[ 5  9 11  5]
------------------------------
Axix 1 Sum A =: 

[15  5 10] 

Einsum A =: 

[15  5 10]


In [11]:
# Q = batchsize, sequence_query, modelsize
# K = batch_size, sequence_key, modelsize

QUERY = np.random.randn(32, 64, 512) # bqm
KEY = np.random.randn(32, 128, 512) # bkm

In [16]:
np.einsum("bqm, bkm -> bqk", QUERY, KEY)

array([[[ 2.62377540e+00,  4.14949557e+01,  3.44319776e+01, ...,
          1.74242969e+01, -1.63258487e+01,  2.60614442e+01],
        [-3.74288092e+01, -3.78417038e+00, -3.20201611e+01, ...,
         -8.67733708e+00,  7.57202900e+00, -4.77030874e+01],
        [ 3.14401276e+01,  4.49124892e+01, -2.31896720e+01, ...,
         -3.83776424e+00,  2.28355437e+01, -8.02941736e+00],
        ...,
        [-4.58666657e+01,  3.20503102e+01, -2.39478731e+01, ...,
         -2.32809714e+01, -1.58657308e+01, -3.82159257e+01],
        [ 1.77999102e+01, -4.77133041e+00, -3.97629362e+01, ...,
          6.80074719e+00,  2.18073947e+01, -3.69441161e+00],
        [-1.10681529e+01,  2.55859610e+01, -1.92556179e+01, ...,
         -1.34218796e+01, -5.40021288e-01,  7.16250748e+00]],

       [[ 4.40450021e-01,  1.46991142e+01,  1.85541796e+01, ...,
         -2.51797976e+01,  2.21282142e+01, -7.00480774e+00],
        [-4.07380436e+00,  7.84714692e+00, -1.29933347e+01, ...,
         -1.70436961e+01, -3.13036683e

In [17]:
A = np.random.randn(2, 4, 4, 2) # bcij
B = np.random.randn(2, 4, 4, 1) # bcik

In [24]:
np.einsum('bcij, bcik -> bckj', A, B)

array([[[[ 0.86435269,  1.71804362]],

        [[ 1.97032939, -2.62070444]],

        [[ 0.20283039,  1.59711511]],

        [[ 1.45017914, -0.6235231 ]]],


       [[[ 0.13255149,  0.46472008]],

        [[-2.40781342,  2.85164692]],

        [[-1.38631939,  1.97734122]],

        [[-1.00945959, -1.63059331]]]])