# SELF ATTENTION IN TRANSFORMERS

Generate Data

Every token must have a Query, Key and Value vector

In [1]:
import numpy as np
import math 

In [2]:
#Create random values for just understanding purposes.
L, d_k, d_v = 4, 8, 8  # L- Length of the input sequence, d_k& d_v is the dimension of the key and value vector 
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v) 

In [3]:
q

array([[ 1.18304853e+00, -3.13631763e-01,  9.24653097e-04,
        -1.19670113e+00, -4.15901609e-01, -7.93845646e-01,
         8.70198594e-01,  6.66830205e-01],
       [-2.28167078e-01, -2.67476597e-01, -1.05013566e+00,
         1.41707463e+00,  1.34952354e+00, -1.85019868e+00,
        -4.06035619e-01,  8.64095133e-01],
       [ 3.61913131e-01,  9.99244111e-01, -6.80026644e-01,
         2.45144538e+00,  1.52780840e+00,  1.60902785e+00,
         6.83520823e-01, -3.07668668e-01],
       [-9.46288710e-01, -3.79265717e-02,  2.27832956e-01,
        -5.71507154e-01, -1.17008646e+00, -8.38085754e-01,
        -2.26409738e-01, -2.72844812e-01]])

k

In [5]:
v

array([[ 1.2333225 , -0.62699353,  0.32310016, -1.1867326 , -0.18207297,
         2.04800005,  0.02230134,  0.18301048],
       [-0.85412563,  0.14094044,  0.5287847 , -0.54090079, -0.16981355,
        -1.91462681,  1.12083824, -0.19820833],
       [ 1.01747248,  0.16747363,  0.01942133,  0.06486241,  2.29862547,
         0.45047009,  0.10376919,  1.1047455 ],
       [ 0.2057432 ,  0.12865555, -0.51736786, -1.0502618 ,  2.48146085,
         0.20489614, -0.39753814,  0.8390537 ]])

![image.png](attachment:5036da96-0273-4a07-8745-6e7ef8ab9f3f.png)

In [6]:
np.matmul(q, k.T)

array([[ 2.79101243, -1.86582682,  0.38558697, -0.19154529],
       [-4.59673636, -1.14300148, -0.93929017, -4.34065046],
       [-1.4650403 ,  5.68861161,  2.59477458, -4.10608373],
       [ 0.05678689, -3.78451946, -1.74907036,  2.78762434]])

In [7]:
# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(0.9583443565030865, 0.9814021653379571, 8.173378769074978)

In [8]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(0.9583443565030865, 0.9814021653379571, 1.021672346134372)

Notice the reduction in variance of the product

In [9]:
scaled

array([[ 0.98677191, -0.6596694 ,  0.13632558, -0.06772149],
       [-1.62519173, -0.40411205, -0.33208922, -1.53465169],
       [-0.51796996,  2.01122792,  0.91739135, -1.45171982],
       [ 0.0200772 , -1.33802969, -0.61838975,  0.98557404]])

# MASKING 
- This is to ensure words dont get context from the wors generated in the future
- Nt required in the encoders, but required in the decoders

In [10]:
mask = np.tril(np.ones( (L, L) ))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [11]:

mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [12]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [13]:
scaled + mask

array([[ 0.98677191,        -inf,        -inf,        -inf],
       [-1.62519173, -0.40411205,        -inf,        -inf],
       [-0.51796996,  2.01122792,  0.91739135,        -inf],
       [ 0.0200772 , -1.33802969, -0.61838975,  0.98557404]])

![image.png](attachment:56fe4aff-e163-44b5-9cf9-d196e9e0740e.png)

In [14]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [15]:
attention = softmax(scaled + mask)

In [16]:

attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.2277465 , 0.7722535 , 0.        , 0.        ],
       [0.05635516, 0.70688764, 0.2367572 , 0.        ],
       [0.22668846, 0.05829229, 0.11971449, 0.59530476]])

In [17]:
new_v = np.matmul(attention, v)
new_v

array([[ 1.2333225 , -0.62699353,  0.32310016, -1.1867326 , -0.18207297,
         2.04800005,  0.02230134,  0.18301048],
       [-0.37871662, -0.03395384,  0.48194076, -0.68798673, -0.17260559,
        -1.0121524 ,  0.8706503 , -0.11138708],
       [-0.29337282,  0.10394532,  0.39659787, -0.43387795,  0.41391628,
        -1.13135861,  0.81813159,  0.13175902],
       [ 0.47407715, -0.03727817, -0.20159939, -0.91800981,  1.70123159,
         0.52855344, -0.15384199,  0.66167906]])

In [18]:
v

array([[ 1.2333225 , -0.62699353,  0.32310016, -1.1867326 , -0.18207297,
         2.04800005,  0.02230134,  0.18301048],
       [-0.85412563,  0.14094044,  0.5287847 , -0.54090079, -0.16981355,
        -1.91462681,  1.12083824, -0.19820833],
       [ 1.01747248,  0.16747363,  0.01942133,  0.06486241,  2.29862547,
         0.45047009,  0.10376919,  1.1047455 ],
       [ 0.2057432 ,  0.12865555, -0.51736786, -1.0502618 ,  2.48146085,
         0.20489614, -0.39753814,  0.8390537 ]])