In [2]:
import tensorflow as tf

In [3]:
tf.random.set_seed(10)

t1 = tf.random.uniform((3,3), minval=0, maxval=9, dtype=tf.int32, seed=5)
t2 = tf.random.uniform((3,3), minval=0, maxval=9, dtype=tf.int32, seed=5)

t1, t2

(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
 array([[6, 1, 6],
        [4, 5, 5],
        [3, 5, 5]], dtype=int32)>,
 <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
 array([[1, 7, 7],
        [2, 5, 5],
        [0, 4, 6]], dtype=int32)>)

In [4]:
# matmul
tf.einsum("ij, jk -> ik", t1, t2)

<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[ 8, 71, 83],
       [14, 73, 83],
       [13, 66, 76]], dtype=int32)>

In [5]:
# element wise multiplication (Hardamond)
tf.einsum("ij, ij -> ij", t1, t2)

<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[ 6,  7, 42],
       [ 8, 25, 25],
       [ 0, 20, 30]], dtype=int32)>

In [6]:
# sum all elements
# tf.reduce_sum(t1)
tf.einsum("ij -> ", t1)

<tf.Tensor: shape=(), dtype=int32, numpy=40>

In [7]:
# tf.reduce_sum(t1, axis=0)
tf.einsum("ij -> j", t1)

<tf.Tensor: shape=(3,), dtype=int32, numpy=array([13, 11, 16], dtype=int32)>

In [8]:
# tf.reduce_sum(t1, axis=1)
tf.einsum("ij -> i", t1)

<tf.Tensor: shape=(3,), dtype=int32, numpy=array([13, 14, 13], dtype=int32)>

In [9]:
# batch matmul
t3 = tf.random.uniform((2,3,5), minval=0, maxval=10, seed=5)
t4 = tf.random.uniform((2,5,2), minval=0, maxval=10, seed=5)

t3, t4

(<tf.Tensor: shape=(2, 3, 5), dtype=float32, numpy=
 array([[[9.894415  , 3.6855626 , 1.1616302 , 8.626854  , 4.758624  ],
         [4.570756  , 5.5284667 , 3.4956563 , 9.879353  , 3.0193317 ],
         [3.7229264 , 2.9353356 , 3.838569  , 8.439041  , 3.5416162 ]],
 
        [[0.08355021, 7.017478  , 6.5004754 , 5.003395  , 0.30138493],
         [1.469748  , 3.1486166 , 5.5211163 , 6.7279778 , 1.9601166 ],
         [1.105243  , 4.431342  , 0.65983534, 3.7309551 , 9.60889   ]]],
       dtype=float32)>,
 <tf.Tensor: shape=(2, 5, 2), dtype=float32, numpy=
 array([[[4.5276299e+00, 1.6004682e-01],
         [1.9785523e-01, 9.3377552e+00],
         [3.2783031e-01, 8.3442554e+00],
         [3.2186508e-04, 3.0256987e+00],
         [3.8291883e+00, 6.4413571e+00]],
 
        [[9.3078785e+00, 3.5569692e+00],
         [8.5994987e+00, 7.2223902e-01],
         [2.2673368e-01, 5.4778409e+00],
         [1.4468837e+00, 1.5051723e-01],
         [2.8195024e+00, 4.5141850e+00]]], dtype=float32)>)

In [10]:
# also sums all elements, but now with more dimensions
# tf.reduce_sum(t3)
tf.einsum("bij -> ", t3)

<tf.Tensor: shape=(), dtype=float32, numpy=134.36826>

In [11]:
# matmul
tf.einsum("bij, bjk -> bik", t3, t4)

<tf.Tensor: shape=(2, 3, 2), dtype=float32, numpy=
array([[[ 64.13272 , 102.44565 ],
        [ 34.49928 , 130.8642  ],
        [ 32.259438, 108.382095]],

       [[ 70.68743 ,  43.087658],
        [ 57.269737,  47.606705],
        [ 81.03494 ,  54.684155]]], dtype=float32)>

In [17]:
# simulating first step of Scaled Dot Product Attention (Q*KT)
q = tf.random.uniform((32, 64, 512))  # queries
k = tf.random.uniform((32, 128, 512)) # keys

# tf.einsum("bqm, bkm -> bqk", q, k)
# limiting the output
tf.einsum("bqm, bkm -> bqk", q, k)[:2]

<tf.Tensor: shape=(2, 64, 128), dtype=float32, numpy=
array([[[127.25105 , 121.83553 , 128.3844  , ..., 124.94056 ,
         119.92784 , 128.60434 ],
        [138.84723 , 136.54141 , 135.33191 , ..., 129.22368 ,
         129.08049 , 135.8829  ],
        [134.38132 , 129.30405 , 129.6083  , ..., 130.16269 ,
         120.84741 , 131.50003 ],
        ...,
        [128.86937 , 129.02179 , 124.95419 , ..., 123.75957 ,
         123.69379 , 129.42894 ],
        [125.92457 , 127.344696, 125.03459 , ..., 121.406235,
         121.61434 , 128.03055 ],
        [129.02173 , 124.84944 , 127.394485, ..., 123.89872 ,
         123.41255 , 129.51161 ]],

       [[132.25992 , 130.0847  , 126.11891 , ..., 132.88153 ,
         126.51277 , 134.9429  ],
        [129.32639 , 127.70118 , 125.5228  , ..., 131.28888 ,
         127.06064 , 135.20764 ],
        [126.740036, 131.41684 , 128.4274  , ..., 125.79219 ,
         129.82951 , 134.39114 ],
        ...,
        [128.09706 , 124.842636, 123.862144, ..., 131.

In [13]:
# (1, 4, 4)
# B = tf.constant([
#     [[1, 1, 1, 1],
#      [1, 1, 1, 1],
#      [1, 1, 1, 1],
#      [1, 1, 1, 1]]
# ])

# WARNING: if you want to create a column, put the brackets in the same row when creating manually. Believe me, it WILL make a difference.
# Or, just use tf.einsum("bij -> bji", B) to get the tranpose.
# (1, 4, 4, 1)
B = tf.constant([
    [[[  2], [ -3], [ 2], [ 5]],
     [[  3], [  9], [ 9], [ 6]],
     [[ 12], [  2], [ 3], [ 1]],
     [[  6], [ 1 ], [ 0], [ 1]]]
])
# tf.transpose(B, perm=[0,1,3,2]) -> (1, 4, 1, 4)
# 4 + -9 + 2 + 15 = 12
# 18 + -18 + 6 + 0 = 6

# (1, 4, 4, 2)
A = tf.constant([
    [[[2, 9], [3,  6], [1, 3], [3, 0]],
     [[0, 3], [8, -2], [5, 0], [2, 0]],
     [[0, 5], [2,  1], [1, 3], [5, 0]],
     [[0, 9], [8,  2], [1, 3], [2, 5]]]
])

# b -> batch size
# c -> n of chunks

# error: (4, 1) and (4, 2) but should be (1, 4) and (4, 2)
# tf.einsum("bcij, bcjk -> bcik", B, A)
# ok
tf.einsum("bcij, bcik -> bcjk", B, A)
# A, B

<tf.Tensor: shape=(1, 4, 1, 2), dtype=int32, numpy=
array([[[[ 12,   6]],

        [[129,  -9]],

        [[ 12,  71]],

        [[ 10,  61]]]], dtype=int32)>

In [14]:
tf.random.set_seed(10)
A = tf.random.uniform((1,4,4,2), minval=0, maxval=5, dtype=tf.int32, seed=5)
B = tf.random.uniform((1,4,4,1), minval=0, maxval=5, dtype=tf.int32, seed=5)

A, B

(<tf.Tensor: shape=(1, 4, 4, 2), dtype=int32, numpy=
 array([[[[1, 1],
          [4, 1],
          [2, 3],
          [2, 3]],
 
         [[2, 3],
          [0, 1],
          [3, 2],
          [2, 0]],
 
         [[1, 2],
          [4, 2],
          [0, 2],
          [1, 4]],
 
         [[2, 1],
          [3, 3],
          [1, 3],
          [4, 2]]]], dtype=int32)>,
 <tf.Tensor: shape=(1, 4, 4, 1), dtype=int32, numpy=
 array([[[[0],
          [3],
          [4],
          [0]],
 
         [[1],
          [0],
          [4],
          [3]],
 
         [[0],
          [4],
          [3],
          [3]],
 
         [[4],
          [4],
          [4],
          [1]]]], dtype=int32)>)

In [15]:
tf.einsum("bcij, bcik -> bcjk", B, A)

<tf.Tensor: shape=(1, 4, 1, 2), dtype=int32, numpy=
array([[[[20, 15]],

        [[20, 11]],

        [[19, 26]],

        [[28, 30]]]], dtype=int32)>