# Self Attention in Transformers

## Generate Data

In [1]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [2]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[-8.40458589e-02  3.99050485e-01  3.14351957e-02  1.40126345e+00
   8.03466156e-01  1.42979362e+00 -1.17161170e+00 -1.59378880e+00]
 [ 5.41274377e-01 -8.47902052e-02  2.76591907e-01  2.65523872e-01
  -2.88359451e-01 -2.78486756e-01 -4.10098863e-01 -2.39111812e-01]
 [-1.17558885e+00  1.48631684e+00  2.35151985e-01 -2.12834985e-01
   3.57923144e-04  1.97128215e+00  2.80536086e-02  1.24816311e-01]
 [ 6.96524172e-02 -1.25102375e-01 -1.87162323e+00 -4.72408516e-01
   7.77379372e-01 -2.70110097e-01  2.05871084e+00 -1.23797616e+00]]
K
 [[ 0.05598404 -2.12251061 -1.60331461  0.94812262  0.65457539 -0.44988676
  -0.59631323  1.93199314]
 [-0.51362603  0.01482244  0.98385343  1.59871025 -1.64343588 -0.46582907
   0.10012008  0.22925616]
 [ 0.15454578  0.74031909 -0.84109077  0.58904157 -0.52782392 -1.65597203
  -0.57589727 -0.00436591]
 [-0.11334296 -0.85600336 -1.19641934  0.99629672  0.34737519  1.78293352
   0.00761595 -0.25721772]]
V
 [[-1.46622247 -2.27459364 -0.59591388 -1.06136793  0.

## Self Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$

In [3]:
np.matmul(q, k.T)

array([[-2.07138256, -0.1489474 , -1.02870224,  4.25576092],
       [-0.26232432,  0.92510273,  0.79523362, -0.59346007],
       [-4.46155767, -0.37050762, -2.68576434,  1.85046339],
       [-0.16669537, -3.86372028,  0.07085985,  1.99034302]])

In [None]:
# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(0.8672192297664698, 0.9229851723027697, 5.1446872979260165)

In [None]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(0.8672192297664698, 0.9229851723027697, 0.643085912240752)

Notice the reduction in variance of the product

In [None]:
scaled

array([[ 0.68537216,  1.92208565, -0.13566043,  0.43920453],
       [ 0.47796088,  0.42358302, -0.60457577, -0.13480942],
       [ 0.37611945, -0.30709922, -0.65849946, -0.24225621],
       [ 0.78209275, -0.99700418,  1.88206279,  0.79213542]])

## Masking

- This is to ensure words don't get context from words generated in the future.
- Not required in the encoders, but required int he decoders

In [None]:
mask = np.tril(np.ones( (L, L) ))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [None]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [None]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [None]:
scaled + mask

array([[ 0.68537216,        -inf,        -inf,        -inf],
       [ 0.47796088,  0.42358302,        -inf,        -inf],
       [ 0.37611945, -0.30709922, -0.65849946,        -inf],
       [ 0.78209275, -0.99700418,  1.88206279,  0.79213542]])

## Softmax

$$
\text{softmax} = \frac{e^{x_i}}{\sum_j e^x_j}
$$

In [None]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [None]:
attention = softmax(scaled + mask)

In [None]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.51359112, 0.48640888, 0.        , 0.        ],
       [0.53753304, 0.27144826, 0.1910187 , 0.        ],
       [0.19293995, 0.03256643, 0.57960627, 0.19488734]])

In [None]:
new_v = np.matmul(attention, v)
new_v

array([[-0.00368231,  1.43739233, -0.59614565, -1.23171219,  1.12030717,
        -0.98620738, -0.15461465, -1.03106383],
       [ 0.41440401, -0.13671232,  0.02128364, -0.60532081,  0.49977893,
        -1.1936286 , -0.27463831, -1.10169151],
       [ 0.32673907,  0.72121642, -0.00947672, -0.59897862,  0.90155754,
        -0.88535361, -0.21384855, -0.7053796 ],
       [ 0.18700384,  1.67754576,  0.33105314, -0.41795742,  1.4258469 ,
        -0.18788199, -0.10285145,  0.54683565]])

In [None]:
v

array([[-0.00368231,  1.43739233, -0.59614565, -1.23171219,  1.12030717,
        -0.98620738, -0.15461465, -1.03106383],
       [ 0.85585446, -1.79878344,  0.67321704,  0.05607552, -0.15542661,
        -1.41264124, -0.40136933, -1.17626611],
       [ 0.50465335,  2.28693419,  0.67128338,  0.2506863 ,  1.78802234,
         0.14775751, -0.11405725,  0.88026286],
       [-0.68069105,  0.68385101,  0.17994557, -1.68013201,  0.91543969,
        -0.19108312,  0.03160471,  1.40527326]])