# Self Attention in Transformers

## Generate Data

In [1]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [2]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[-3.35336618e-01 -4.42061127e-01 -1.17153026e+00 -8.51012074e-01
  -1.83310562e+00 -1.01997786e+00  6.96294774e-01  5.13188841e-04]
 [-7.87534906e-01  1.21066774e+00  3.39951764e-01  1.63836578e+00
   2.54813915e+00 -9.37917890e-01  1.90957862e+00 -2.78035979e-01]
 [-4.53429169e-01 -8.28140974e-01 -6.84221156e-01 -2.21878794e-01
   2.38625669e-01  1.45336504e+00 -1.31057753e+00 -1.53303384e-02]
 [-2.66411325e-01  1.31584574e+00  4.34539796e-01  1.49393629e+00
  -2.14807639e-01 -1.99427124e+00  9.81772364e-01 -1.05129162e+00]]
K
 [[-0.8205968   1.4327968  -1.11966656  0.53342803  0.12464472  0.77386094
   0.57159063 -0.59304417]
 [ 0.35956654  0.70422676  1.69415205  0.74546541  0.39063949 -0.10742524
   1.59620862 -0.40137663]
 [-0.91570749  1.1063129   0.75561256  0.68401174  0.70281035 -0.46485235
  -0.78802292  0.37576825]
 [ 0.58927671  0.07248648 -0.23737439 -0.06384246 -0.65548825  0.10876345
  -0.63519811 -1.13629993]]
V
 [[-0.4547183   1.08644825  0.09757016  0.00665423 -1.

## Self Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$ 

In [3]:
np.matmul(q, k.T)

array([[-0.12055479, -2.54632389, -3.0120028 ,  0.7505471 ],
       [ 3.72238603,  6.62253125,  4.05564529, -3.23092701],
       [ 0.24769156, -4.21952803, -0.65063544,  0.7009105 ],
       [ 2.02889466,  4.80011561,  2.65727714,  0.23472856]])

In [4]:
# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(1.2149267729294917, 0.5843325614322845, 9.0454926090218)

In [5]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(1.2149267729294917, 0.5843325614322845, 1.1306865761277245)

Notice the reduction in variance of the product

In [6]:
scaled

array([[-0.04262255, -0.90026145, -1.0649038 ,  0.26535847],
       [ 1.3160622 ,  2.34141838,  1.43388714, -1.1423052 ],
       [ 0.08757219, -1.49182844, -0.23003436,  0.24780928],
       [ 0.71732259,  1.69709715,  0.93948934,  0.08298908]])

## Masking

- This is to ensure words don't get context from words generated in the future. 
- Not required in the encoders, but required int he decoders

In [7]:
mask = np.tril(np.ones( (L, L) ))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [8]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [9]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [None]:
scaled + mask

array([[ 0.68537216,        -inf,        -inf,        -inf],
       [ 0.47796088,  0.42358302,        -inf,        -inf],
       [ 0.37611945, -0.30709922, -0.65849946,        -inf],
       [ 0.78209275, -0.99700418,  1.88206279,  0.79213542]])

## Softmax

$$
\text{softmax} = \frac{e^{x_i}}{\sum_j e^x_j}
$$

In [None]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [None]:
attention = softmax(scaled + mask)

In [None]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.51359112, 0.48640888, 0.        , 0.        ],
       [0.53753304, 0.27144826, 0.1910187 , 0.        ],
       [0.19293995, 0.03256643, 0.57960627, 0.19488734]])

In [None]:
new_v = np.matmul(attention, v)
new_v

array([[-0.00368231,  1.43739233, -0.59614565, -1.23171219,  1.12030717,
        -0.98620738, -0.15461465, -1.03106383],
       [ 0.41440401, -0.13671232,  0.02128364, -0.60532081,  0.49977893,
        -1.1936286 , -0.27463831, -1.10169151],
       [ 0.32673907,  0.72121642, -0.00947672, -0.59897862,  0.90155754,
        -0.88535361, -0.21384855, -0.7053796 ],
       [ 0.18700384,  1.67754576,  0.33105314, -0.41795742,  1.4258469 ,
        -0.18788199, -0.10285145,  0.54683565]])

In [None]:
v

array([[-0.00368231,  1.43739233, -0.59614565, -1.23171219,  1.12030717,
        -0.98620738, -0.15461465, -1.03106383],
       [ 0.85585446, -1.79878344,  0.67321704,  0.05607552, -0.15542661,
        -1.41264124, -0.40136933, -1.17626611],
       [ 0.50465335,  2.28693419,  0.67128338,  0.2506863 ,  1.78802234,
         0.14775751, -0.11405725,  0.88026286],
       [-0.68069105,  0.68385101,  0.17994557, -1.68013201,  0.91543969,
        -0.19108312,  0.03160471,  1.40527326]])

# Function

In [10]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
  d_k = q.shape[-1]
  scaled = np.matmul(q, k.T) / math.sqrt(d_k)
  if mask is not None:
    scaled = scaled + mask
  attention = softmax(scaled)
  out = np.matmul(attention, v)
  return out, attention

In [11]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[-3.35336618e-01 -4.42061127e-01 -1.17153026e+00 -8.51012074e-01
  -1.83310562e+00 -1.01997786e+00  6.96294774e-01  5.13188841e-04]
 [-7.87534906e-01  1.21066774e+00  3.39951764e-01  1.63836578e+00
   2.54813915e+00 -9.37917890e-01  1.90957862e+00 -2.78035979e-01]
 [-4.53429169e-01 -8.28140974e-01 -6.84221156e-01 -2.21878794e-01
   2.38625669e-01  1.45336504e+00 -1.31057753e+00 -1.53303384e-02]
 [-2.66411325e-01  1.31584574e+00  4.34539796e-01  1.49393629e+00
  -2.14807639e-01 -1.99427124e+00  9.81772364e-01 -1.05129162e+00]]
K
 [[-0.8205968   1.4327968  -1.11966656  0.53342803  0.12464472  0.77386094
   0.57159063 -0.59304417]
 [ 0.35956654  0.70422676  1.69415205  0.74546541  0.39063949 -0.10742524
   1.59620862 -0.40137663]
 [-0.91570749  1.1063129   0.75561256  0.68401174  0.70281035 -0.46485235
  -0.78802292  0.37576825]
 [ 0.58927671  0.07248648 -0.23737439 -0.06384246 -0.65548825  0.10876345
  -0.63519811 -1.13629993]]
V
 [[-0.4547183   1.08644825  0.09757016  0.00665423 -1.