In [1]:
import numpy as np
import math

In [2]:
L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)
q, k, v

(array([[ 0.28510219, -0.21361618, -0.00468576, -1.86013234,  1.84238025,
         -0.40951539, -0.53836045, -0.52704631],
        [-0.89695075, -1.00895405, -0.83749199,  0.93796229, -1.21645196,
          0.61239255, -0.75620668, -0.18775102],
        [ 0.86570504, -1.02519589, -0.32764229, -0.62259984, -1.56114   ,
         -2.11567798, -0.00892709,  1.25706832],
        [-0.28797953, -0.41004683,  2.6873104 , -1.50403716, -0.26002903,
          1.22397   , -0.65568646,  2.24927254]]),
 array([[-5.18476288e-01,  2.38313015e-01,  3.36328785e-01,
          1.94731855e-02, -1.80993620e-01,  1.14522335e+00,
         -5.00267683e-01,  1.24965011e+00],
        [ 5.27354756e-01, -2.11144691e-01,  9.84897854e-01,
         -1.05407432e+00,  1.26387909e+00,  1.77559954e+00,
          5.64745381e-01,  1.41759127e-03],
        [-8.43719732e-01,  2.15499802e+00, -1.34713987e+00,
          1.80658846e-01, -1.48284572e+00, -4.73379789e-01,
          1.61858326e-01,  3.93551771e-01],
        [-2.79

## Self-Attention

In [15]:
np.matmul(q, k.T), np.einsum("ij,kj->ik", q, k)

(array([[-1.4282697 ,  3.4481835 , -3.86329326,  2.46585413],
        [ 1.02637224, -2.95091664,  1.19777933, -5.30003038],
        [-1.38049054, -4.72637792,  1.19891977, -0.68364712],
        [ 5.51372756,  5.644333  , -3.94731954,  3.98676444]]),
 array([[-1.4282697 ,  3.4481835 , -3.86329326,  2.46585413],
        [ 1.02637224, -2.95091664,  1.19777933, -5.30003038],
        [-1.38049054, -4.72637792,  1.19891977, -0.68364712],
        [ 5.51372756,  5.644333  , -3.94731954,  3.98676444]]))

## why we need sqrt(d_k) in denominator
To reduce it's variance

In [18]:
q.var(), k.var(), np.einsum("ij,kj->ik", q, k).var()

(1.2591072378909316, 0.8003583179557943, 12.131480935404243)

In [20]:
scaled = (np.einsum("ij,kj->ik", q, k)/ np.sqrt(d_k))
q.var(), k.var(), scaled.var()

(1.2591072378909316, 0.8003583179557943, 1.5164351169255301)

## Masking
- This is to ensure words don't get context from words generated in future
- Not required in encoders, but required in the decoders

In [32]:
mask = np.tril(np.ones((L, L)))
mask

## Consider the sentence -> My name is Arun
## Now see this matrix as for each word, i.,e in first row my can only see my
## similarly name can see "my" and "name", ... 

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [33]:
mask[mask == 0] = -np.infty 
mask[mask == 1] = 0
mask

## see below why do we do this
## -inf because we are going to apply softmax

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [34]:
scaled, scaled + mask

(array([[-0.5049696 ,  1.21911697, -1.36588043,  0.87181109],
        [ 0.36287738, -1.04330658,  0.42347894, -1.87384371],
        [-0.48807711, -1.67102694,  0.42388215, -0.24170576],
        [ 1.94939707,  1.99557307, -1.39558821,  1.40953409]]),
 array([[-0.5049696 ,        -inf,        -inf,        -inf],
        [ 0.36287738, -1.04330658,        -inf,        -inf],
        [-0.48807711, -1.67102694,  0.42388215,        -inf],
        [ 1.94939707,  1.99557307, -1.39558821,  1.40953409]]))

In [37]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis = -1)).T

attention = softmax(scaled + mask)
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.80316336, 0.19683664, 0.        , 0.        ],
       [0.26346515, 0.08071878, 0.65581607, 0.        ],
       [0.37518559, 0.39291638, 0.01322932, 0.21866871]])

In [48]:
new_v = np.matmul(attention, v)
new_v, v
attention.shape, v.shape

((4, 4), (4, 8))

## Putting all together

In [46]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis = -1)).T

def scaled_dot_product_attention(q, k, v, mask = None):
    d_k = q.shape[-1]
    scaled = np.matmul(q, k.T) / math.sqrt(d_k)
    if mask is not None:
        scaled = scaled + mask
    attention = softmax(scaled)
    out = np.matmul(attention, v)
    return out, attention

values, attention = scaled_dot_product_attention(q, k, v, mask=None)
print(f"Q {q}")
print(f"K {k}")
print(f"V {v}")
print(f"New V {values}")
print(f"Attention {attention}")

Q [[ 0.28510219 -0.21361618 -0.00468576 -1.86013234  1.84238025 -0.40951539
  -0.53836045 -0.52704631]
 [-0.89695075 -1.00895405 -0.83749199  0.93796229 -1.21645196  0.61239255
  -0.75620668 -0.18775102]
 [ 0.86570504 -1.02519589 -0.32764229 -0.62259984 -1.56114    -2.11567798
  -0.00892709  1.25706832]
 [-0.28797953 -0.41004683  2.6873104  -1.50403716 -0.26002903  1.22397
  -0.65568646  2.24927254]]
K [[-5.18476288e-01  2.38313015e-01  3.36328785e-01  1.94731855e-02
  -1.80993620e-01  1.14522335e+00 -5.00267683e-01  1.24965011e+00]
 [ 5.27354756e-01 -2.11144691e-01  9.84897854e-01 -1.05407432e+00
   1.26387909e+00  1.77559954e+00  5.64745381e-01  1.41759127e-03]
 [-8.43719732e-01  2.15499802e+00 -1.34713987e+00  1.80658846e-01
  -1.48284572e+00 -4.73379789e-01  1.61858326e-01  3.93551771e-01]
 [-2.79691239e-02  9.54441214e-01  1.69463428e+00 -4.37507769e-01
   1.14677320e+00 -8.75178076e-01  6.87536285e-01  4.34911065e-01]]
V [[-0.35554833 -0.81257688 -0.50922826 -0.09947729  1.683118