In [1]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [3]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[-1.4553675   0.52822167 -0.08908464 -0.41361999  0.54665672  0.21945883
  -1.09894967  0.80929701]
 [-0.51415209  0.90760674 -1.4335523  -2.04511947 -0.61882516 -0.63503193
   1.87869091 -0.52581161]
 [ 1.37506346  1.03859363  0.80758575 -2.25798031  0.45322613  1.61573788
   0.21857565  0.55213646]
 [ 0.46949285  0.11316391 -0.85712567 -1.18369103  0.02316273  0.5767991
  -2.1741154  -0.35041046]]
K
 [[-0.28445272 -1.07196609 -1.93513569 -0.23985278  1.15788102 -0.1993426
   1.12183335  0.57522366]
 [ 2.57037699 -0.04824941 -1.10011083  3.09217572 -0.51949531 -2.15972381
   0.97396318  3.81101708]
 [ 1.37036296 -0.40076792  1.11391891  0.58940167  0.7881421   0.52416013
  -1.52193945  1.67225601]
 [-0.82578705  0.60693118 -0.78561119  1.27544268 -1.98608909  1.2081519
   1.1276339   0.94256832]]
V
 [[-0.22641115 -0.76829269  0.33259227 -0.34406533  0.01781909  0.38991619
  -1.07700178 -0.36599018]
 [-0.18291884  0.27608871  0.8284444   0.18075344  0.08854557 -1.38488221
   0.1890

In [4]:
np.matmul(q, k.T)

array([[-0.05874937, -3.69136011,  1.02266399, -0.23210939],
       [ 3.65315597, -4.59328945, -8.4296985 ,  1.57790612],
       [-1.76017889, -5.79409611,  2.83159528, -2.20071861],
       [-1.04101081, -6.22662217,  1.9890718 , -3.28641529]])

In [5]:
# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(1.1174121613390589, 1.9434908929369135, 11.489498438636566)

In [8]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(1.1174121613390589, 1.9434908929369135, 1.4361873048295708)

In [9]:
scaled

array([[-0.02077104, -1.30509288,  0.36156632, -0.08206306],
       [ 1.29158568, -1.62397306, -2.98034848,  0.55787406],
       [-0.62231722, -2.04852233,  1.00112011, -0.77807153],
       [-0.3680529 , -2.20144338,  0.70324308, -1.16192327]])

In [14]:
#Masking
# * This is to ensure words don't get context from words generated in the future.
# * Not required in the encoders but in the decoders

In [15]:
mask = np.tril(np.ones ( (L, L) ) )
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [16]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [17]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [18]:
scaled + mask

array([[-0.02077104,        -inf,        -inf,        -inf],
       [ 1.29158568, -1.62397306,        -inf,        -inf],
       [-0.62231722, -2.04852233,  1.00112011,        -inf],
       [-0.3680529 , -2.20144338,  0.70324308, -1.16192327]])

In [21]:
# Softmax
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [22]:
attention = softmax(scaled + mask)

In [23]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.94861022, 0.05138978, 0.        , 0.        ],
       [0.15846082, 0.03806527, 0.80347391, 0.        ],
       [0.22069587, 0.03528278, 0.64424659, 0.09977476]])

In [24]:
new_v = np.matmul(attention, v)
new_v

array([[-0.22641115, -0.76829269,  0.33259227, -0.34406533,  0.01781909,
         0.38991619, -1.07700178, -0.36599018],
       [-0.22417609, -0.71462216,  0.358074  , -0.31709501,  0.02145371,
         0.2987097 , -1.01194147, -0.3045296 ],
       [-0.93111787, -1.10827526,  1.36948149, -0.91305965,  1.28067377,
         1.02966156,  0.24815927, -0.23288413],
       [-0.6801363 , -1.09638209,  1.10838589, -0.72851651,  1.07050575,
         0.87141249,  0.03371391, -0.30129853]])

In [25]:
v

array([[-0.22641115, -0.76829269,  0.33259227, -0.34406533,  0.01781909,
         0.38991619, -1.07700178, -0.36599018],
       [-0.18291884,  0.27608871,  0.8284444 ,  0.18075344,  0.08854557,
        -1.38488221,  0.18901486,  0.82997888],
       [-1.10554644, -1.24091193,  1.59960849, -1.07709684,  1.5862116 ,
         1.270223  ,  0.51230906, -0.25698706],
       [ 0.88729978, -1.37420489, -0.24843958,  0.35034689,  0.41631364,
         0.15921114, -0.65466042, -0.84437127]])

In [26]:
# Softmax
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
    d_k = q.shape[-1]
    scaled = np.matmul(q, k.T) / math.sqrt(d_k)
    if mask is not None:
        scaled = scaled + mask
    attention = softmax(scaled)
    out = np.matmul(attention, v)
    return out, attention

In [28]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[-1.4553675   0.52822167 -0.08908464 -0.41361999  0.54665672  0.21945883
  -1.09894967  0.80929701]
 [-0.51415209  0.90760674 -1.4335523  -2.04511947 -0.61882516 -0.63503193
   1.87869091 -0.52581161]
 [ 1.37506346  1.03859363  0.80758575 -2.25798031  0.45322613  1.61573788
   0.21857565  0.55213646]
 [ 0.46949285  0.11316391 -0.85712567 -1.18369103  0.02316273  0.5767991
  -2.1741154  -0.35041046]]
K
 [[-0.28445272 -1.07196609 -1.93513569 -0.23985278  1.15788102 -0.1993426
   1.12183335  0.57522366]
 [ 2.57037699 -0.04824941 -1.10011083  3.09217572 -0.51949531 -2.15972381
   0.97396318  3.81101708]
 [ 1.37036296 -0.40076792  1.11391891  0.58940167  0.7881421   0.52416013
  -1.52193945  1.67225601]
 [-0.82578705  0.60693118 -0.78561119  1.27544268 -1.98608909  1.2081519
   1.1276339   0.94256832]]
V
 [[-0.22641115 -0.76829269  0.33259227 -0.34406533  0.01781909  0.38991619
  -1.07700178 -0.36599018]
 [-0.18291884  0.27608871  0.8284444   0.18075344  0.08854557 -1.38488221
   0.1890