In [1]:
import numpy as np
import math

In [2]:
# L - Length i.e Tarun has 5 Char
L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [11]:
print(q.shape,k.shape,v.shape)

(4, 8) (4, 8) (4, 8)


In [8]:
print("q",q)
print("\n")
print("k",k)
print("\n")
print("v",v)

q [[-1.46868161 -0.92527241  0.93325244 -0.24670097 -0.97912115  0.03440879
   0.00422575 -1.35696244]
 [-0.82643577  2.0792953  -1.50199646 -1.51169856  0.43840298 -1.27706233
  -0.06526329  1.74137974]
 [-0.32772483 -0.79495567  0.43778303 -0.61373461  2.68386832 -0.51070335
  -0.12103809 -0.20937305]
 [ 1.53227183  1.15381532  0.07229121 -0.03409058  1.39929093 -0.20088247
   1.4441639   0.41387737]]


k [[-1.07006165  0.41741129  0.58066157  2.81828782 -0.3954185   1.00287525
   1.29573473  0.01670122]
 [ 1.4829161  -0.49853378  1.37847177 -1.25058329 -0.12087623 -0.01708252
  -0.14912928  0.03904045]
 [ 0.09108514 -1.2712991   0.13225173 -1.78843173 -0.84319619  0.401934
   0.37202635 -0.87575333]
 [ 0.59664739  2.19420334 -0.14255928  0.76902345  0.26976998  0.47959014
   1.87090092 -1.65504414]]


v [[ 9.46521011e-01 -1.61713281e-01  5.50877055e-04 -6.79487547e-01
   7.44079389e-01  2.41903399e-01  2.61895363e+00 -5.13317778e-01]
 [-1.18682638e+00 -8.49863642e-01  2.45513597e+00

### Self Attention 
$ softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)V$

In [12]:
np.matmul(q, k.T)

array([[ 1.43647306, -0.05751172,  3.63651276, -1.22318017],
       [-4.8898623 , -2.39554876, -2.6460075 , -0.37746218],
       [-3.19036766,  0.97550615, -0.1936712 , -1.87504671],
       [-0.08871011,  1.47438053, -2.34255361,  5.70745688]])

In [13]:
# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(1.1832384721009135, 1.129963243856758, 6.638794087956057)

In [14]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(1.1832384721009135, 1.129963243856758, 0.829849260994507)

In [15]:
scaled

array([[ 0.50786992, -0.02033346,  1.28570142, -0.4324595 ],
       [-1.72882739, -0.84695439, -0.93550492, -0.13345303],
       [-1.1279653 ,  0.34489351, -0.06847311, -0.66292912],
       [-0.03136376,  0.52127224, -0.82821777,  2.01789073]])

### Masking
* This is to ensure words don't get context from words generated in the future.
* Not required in the encoders, but required int he decoders

In [16]:
mask = np.tril(np.ones( (L, L) ))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [17]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [18]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [19]:
scaled + mask

array([[ 0.50786992,        -inf,        -inf,        -inf],
       [-1.72882739, -0.84695439,        -inf,        -inf],
       [-1.1279653 ,  0.34489351, -0.06847311,        -inf],
       [-0.03136376,  0.52127224, -0.82821777,  2.01789073]])

### Softmax

$\frac{e^{x_i}}{\sum_j e^x_j}$

In [21]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [22]:
attention = softmax(scaled + mask)

In [23]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.2927898 , 0.7072102 , 0.        , 0.        ],
       [0.12126221, 0.52890775, 0.34983004, 0.        ],
       [0.09131849, 0.15869582, 0.04116133, 0.70882436]])

In [24]:
new_v = np.matmul(attention, v)
new_v

array([[ 9.46521011e-01, -1.61713281e-01,  5.50877055e-04,
        -6.79487547e-01,  7.44079389e-01,  2.41903399e-01,
         2.61895363e+00, -5.13317778e-01],
       [-5.62204034e-01, -6.48380238e-01,  1.73645850e+00,
         3.71230441e-01,  1.24072236e+00,  3.67671755e-01,
        -5.47825387e-02,  4.04519750e-01],
       [-9.83353575e-01,  1.04497228e-02,  2.07873171e+00,
         3.28437086e-01,  3.50507400e-01,  2.32789511e-02,
        -6.24755527e-01,  3.25687919e-01],
       [-3.34514196e-01, -5.71953185e-01,  8.68757358e-01,
        -1.02028760e+00,  8.19153481e-01,  4.54040984e-01,
        -7.02832800e-01,  1.60689997e-01]])

In [25]:
v

array([[ 9.46521011e-01, -1.61713281e-01,  5.50877055e-04,
        -6.79487547e-01,  7.44079389e-01,  2.41903399e-01,
         2.61895363e+00, -5.13317778e-01],
       [-1.18682638e+00, -8.49863642e-01,  2.45513597e+00,
         8.06234777e-01,  1.44633590e+00,  4.19740703e-01,
        -1.16172735e+00,  7.84510679e-01],
       [-1.34467907e+00,  1.37083399e+00,  2.23000994e+00,
        -4.45661483e-02, -1.44269941e+00, -6.51913443e-01,
        -9.37280950e-01, -7.71798049e-02],
       [-2.50070271e-01, -6.75401899e-01,  5.46393559e-01,
        -1.52978589e+00,  8.19753492e-01,  5.53272833e-01,
        -1.01442716e+00,  1.21671413e-01]])

### Function

In [27]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
    d_k = q.shape[-1]
    scaled = np.matmul(q, k.T) / math.sqrt(d_k)
    if mask is not None:
        scaled = scaled + mask
    attention = softmax(scaled)
    out = np.matmul(attention, v)
    return out, attention

In [28]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[-1.46868161 -0.92527241  0.93325244 -0.24670097 -0.97912115  0.03440879
   0.00422575 -1.35696244]
 [-0.82643577  2.0792953  -1.50199646 -1.51169856  0.43840298 -1.27706233
  -0.06526329  1.74137974]
 [-0.32772483 -0.79495567  0.43778303 -0.61373461  2.68386832 -0.51070335
  -0.12103809 -0.20937305]
 [ 1.53227183  1.15381532  0.07229121 -0.03409058  1.39929093 -0.20088247
   1.4441639   0.41387737]]
K
 [[-1.07006165  0.41741129  0.58066157  2.81828782 -0.3954185   1.00287525
   1.29573473  0.01670122]
 [ 1.4829161  -0.49853378  1.37847177 -1.25058329 -0.12087623 -0.01708252
  -0.14912928  0.03904045]
 [ 0.09108514 -1.2712991   0.13225173 -1.78843173 -0.84319619  0.401934
   0.37202635 -0.87575333]
 [ 0.59664739  2.19420334 -0.14255928  0.76902345  0.26976998  0.47959014
   1.87090092 -1.65504414]]
V
 [[ 9.46521011e-01 -1.61713281e-01  5.50877055e-04 -6.79487547e-01
   7.44079389e-01  2.41903399e-01  2.61895363e+00 -5.13317778e-01]
 [-1.18682638e+00 -8.49863642e-01  2.45513597e+00 