In [1]:
import numpy as np
import math

In [2]:
L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)
q, k, v

(array([[-0.68190211, -0.94305326,  0.07893161, -0.03624853,  0.45741098,
         -0.44829929, -0.24813051,  1.45937929],
        [-1.14309258, -0.57973174, -1.0412738 ,  0.94769575,  0.05873672,
          0.19985851,  1.36647004, -0.13683946],
        [-0.19955125, -0.42674957,  0.22373456, -1.32353963, -1.44472564,
          1.46642934,  0.44756885,  0.12039613],
        [ 0.1985702 ,  0.97842152, -0.03877199, -1.0715728 , -0.50543109,
          0.75281765,  0.55503435, -0.97773298]]),
 array([[ 0.0760034 ,  1.21254768,  0.11962905,  0.40280899,  2.08565063,
         -0.46898407, -0.52284828, -0.24912036],
        [-1.07594132, -1.38524548,  1.55240629, -1.27517719,  0.07886014,
          1.11728429, -0.49441666, -0.83698795],
        [ 0.55275853,  0.43031192, -2.00058078,  0.94076683, -1.73413799,
         -0.33796419, -1.03680266,  0.49883279],
        [ 1.11621141, -0.09282793,  1.0912765 , -1.30005241, -1.01872229,
          2.19376868, -0.81599013, -0.51761354]]),
 array([[ 1.

## Self-Attention

In [3]:
np.matmul(q, k.T), np.einsum("ij,kj->ik", q, k)

(array([[-0.27006439,  0.64519485, -0.63120076, -2.54270617],
        [-1.18425073, -1.12513051,  0.43897104, -4.25607863],
        [-5.00391795,  4.04336443, -0.38091035,  6.0429506 ],
        [-0.68864038,  1.0824369 , -0.84087035,  3.70119829]]),
 array([[-0.27006439,  0.64519485, -0.63120076, -2.54270617],
        [-1.18425073, -1.12513051,  0.43897104, -4.25607863],
        [-5.00391795,  4.04336443, -0.38091035,  6.0429506 ],
        [-0.68864038,  1.0824369 , -0.84087035,  3.70119829]]))

## why we need sqrt(d_k) in denominator
To reduce it's variance

In [4]:
q.var(), k.var(), np.einsum("ij,kj->ik", q, k).var()

(0.6284323494612694, 1.1413571687505781, 7.648227968818941)

In [5]:
scaled = (np.einsum("ij,kj->ik", q, k)/ np.sqrt(d_k))
q.var(), k.var(), scaled.var()

(0.6284323494612694, 1.1413571687505781, 0.9560284961023675)

## Masking
- This is to ensure words don't get context from words generated in future
- Not required in encoders, but required in the decoders

In [6]:
mask = np.tril(np.ones((L, L)))
mask

## Consider the sentence -> My name is Arun
## Now see this matrix as for each word, i.,e in first row my can only see my
## similarly name can see "my" and "name", ... 

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [7]:
mask[mask == 0] = -np.infty 
mask[mask == 1] = 0
mask

## see below why do we do this
## -inf because we are going to apply softmax

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [8]:
scaled, scaled + mask

(array([[-0.09548218,  0.22811083, -0.22316317, -0.89898239],
        [-0.41869586, -0.39779371,  0.1551997 , -1.50475103],
        [-1.76915216,  1.4295452 , -0.13467215,  2.13650567],
        [-0.24347114,  0.38269924, -0.29729256,  1.3085712 ]]),
 array([[-0.09548218,        -inf,        -inf,        -inf],
        [-0.41869586, -0.39779371,        -inf,        -inf],
        [-1.76915216,  1.4295452 , -0.13467215,        -inf],
        [-0.24347114,  0.38269924, -0.29729256,  1.3085712 ]]))

In [9]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis = -1)).T

attention = softmax(scaled + mask)
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.49477465, 0.50522535, 0.        , 0.        ],
       [0.03265052, 0.79995709, 0.1673924 , 0.        ],
       [0.11710785, 0.21904247, 0.11097155, 0.55287814]])

In [10]:
new_v = np.matmul(attention, v)
new_v, v
attention.shape, v.shape

((4, 4), (4, 8))

## Putting all together

In [11]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis = -1)).T

def scaled_dot_product_attention(q, k, v, mask = None):
    d_k = q.shape[-1]
    scaled = np.matmul(q, k.T) / math.sqrt(d_k)
    if mask is not None:
        scaled = scaled + mask
    attention = softmax(scaled)
    out = np.matmul(attention, v)
    return out, attention

values, attention = scaled_dot_product_attention(q, k, v, mask=None)
print(f"Q {q}")
print(f"K {k}")
print(f"V {v}")
print(f"New V {values}")
print(f"Attention {attention}")

Q [[-0.68190211 -0.94305326  0.07893161 -0.03624853  0.45741098 -0.44829929
  -0.24813051  1.45937929]
 [-1.14309258 -0.57973174 -1.0412738   0.94769575  0.05873672  0.19985851
   1.36647004 -0.13683946]
 [-0.19955125 -0.42674957  0.22373456 -1.32353963 -1.44472564  1.46642934
   0.44756885  0.12039613]
 [ 0.1985702   0.97842152 -0.03877199 -1.0715728  -0.50543109  0.75281765
   0.55503435 -0.97773298]]
K [[ 0.0760034   1.21254768  0.11962905  0.40280899  2.08565063 -0.46898407
  -0.52284828 -0.24912036]
 [-1.07594132 -1.38524548  1.55240629 -1.27517719  0.07886014  1.11728429
  -0.49441666 -0.83698795]
 [ 0.55275853  0.43031192 -2.00058078  0.94076683 -1.73413799 -0.33796419
  -1.03680266  0.49883279]
 [ 1.11621141 -0.09282793  1.0912765  -1.30005241 -1.01872229  2.19376868
  -0.81599013 -0.51761354]]
V [[ 1.23745848  1.69623919 -0.68700635  0.06021433 -0.39171069 -1.04289411
   0.46420511 -0.97087047]
 [-0.86684658 -0.99046743  0.7318422  -1.4798993   0.94394208  0.65871688
   0.7061