# Self Attention in Transformers

## Generate Data

In [1]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8 # L -> seq length 
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [2]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[ 0.29768014  1.46816034 -1.32397561 -0.76432103  0.38979776 -0.49550435
  -1.33092564 -1.72066063]
 [-0.05143646  1.38880016  0.79178977  1.58027136  0.00966491 -0.43574485
  -0.87423673 -0.0353768 ]
 [ 1.07928344 -0.83696223 -0.25376834  1.12296124 -0.55213534 -0.32317013
  -0.88158989  0.36128264]
 [ 0.85476158  0.24667341 -1.08318915 -0.92442848 -2.22248362  0.06401839
  -0.44924721 -2.22433039]]
K
 [[-0.09587018 -0.03241947 -0.8745195  -1.16439671  0.58552362 -0.90144685
   0.31727494 -0.81312424]
 [ 0.380725    0.51069887 -0.54695683  1.68399384 -1.15499012 -0.17677576
  -0.84175614 -0.71653484]
 [ 0.33919297  0.04711329  0.53569413 -2.11925621  0.69489994  0.17131378
   0.31011484  0.98110667]
 [ 1.0598199  -0.22156219 -0.31718686  0.77735678 -0.8814225  -0.89105165
  -1.21505333  0.91815065]]
V
 [[-0.13140865  0.87150421  0.79974773  0.58881187  0.18071253  0.53201066
  -1.55340947 -0.79907812]
 [-0.0704488  -1.49683304  0.31476469  0.06035425 -0.45877017 -0.54012935
   0.5

## Self Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$ 

In [3]:
np.matmul(q, k.T)

array([[ 3.62342792,  2.29077639, -0.83422069, -0.04874055],
       [-2.42273914,  3.74487703, -3.25061387,  2.02458568],
       [-1.76742587,  3.19138627, -2.54711109,  4.46023516],
       [ 2.24082459,  4.01472731, -2.17466666,  0.88169475]])

In [4]:
# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(0.9860641577812217, 0.6872929635556801, 6.793918584341099)

In [6]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(0.9860641577812217, 0.6872929635556801, 0.8492398230426372)

Notice the reduction in variance of the product

In [7]:
scaled

array([[ 1.28107523,  0.80991176, -0.29494155, -0.01723239],
       [-0.85656764,  1.32401397, -1.14926555,  0.71579913],
       [-0.62487941,  1.12832544, -0.90053976,  1.57693126],
       [ 0.79225113,  1.41942045, -0.76886077,  0.31172617]])

## Masking

- This is to ensure words don't get context from words generated in the future. 
- Not required in the encoders, but required in the decoders

In [8]:
mask = np.tril(np.ones((L, L)))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [9]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [10]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [12]:
tmp = scaled + mask
tmp

array([[ 1.28107523,        -inf,        -inf,        -inf],
       [-0.85656764,  1.32401397,        -inf,        -inf],
       [-0.62487941,  1.12832544, -0.90053976,        -inf],
       [ 0.79225113,  1.41942045, -0.76886077,  0.31172617]])

## Softmax

$$
\text{softmax} = \frac{e^{x_i}}{\sum_j e^x_j}
$$

In [20]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T 

t = np.random.randn(2, 3)
print(t)
print(t.T)
print(np.sum(t, -1))

[[ 0.57408516 -1.25974421  1.34331652]
 [-0.05632317  0.09484248  0.24030093]]
[[ 0.57408516 -0.05632317]
 [-1.25974421  0.09484248]
 [ 1.34331652  0.24030093]]
[0.65765747 0.27882023]


In [29]:
#attention = softmax(scaled)
attention = softmax(scaled + mask)

In [30]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.10150787, 0.89849213, 0.        , 0.        ],
       [0.1327643 , 0.76645822, 0.10077749, 0.        ],
       [0.27022175, 0.50593704, 0.05672023, 0.16712098]])

In [32]:
new_v = np.matmul(attention, v)
new_v

array([[-0.13140865,  0.87150421,  0.79974773,  0.58881187,  0.18071253,
         0.53201066, -1.55340947, -0.79907812],
       [-0.0766367 , -1.25642817,  0.36399429,  0.11399686, -0.39385764,
        -0.4312987 ,  0.32117278, -0.93224027],
       [-0.013042  , -1.12468397,  0.35790449,  0.11441382, -0.3179529 ,
        -0.3297745 ,  0.13445699, -0.85383407],
       [ 0.10301021, -0.82641986,  0.28847815,  0.04570925, -0.00910417,
        -0.38647786, -0.15790951, -0.70784992]])

In [33]:
v

array([[-0.13140865,  0.87150421,  0.79974773,  0.58881187,  0.18071253,
         0.53201066, -1.55340947, -0.79907812],
       [-0.0704488 , -1.49683304,  0.31476469,  0.06035425, -0.45877017,
        -0.54012935,  0.53295522, -0.94728437],
       [ 0.57949884, -0.9241015 ,  0.1039177 , -0.09941098,  0.0960839 ,
         0.13475285, -0.67270589, -0.21523799],
       [ 0.84545362, -1.50909078, -0.55514538, -0.82752848,  1.00958179,
        -1.58334694,  0.18172324, -0.00267494]])

# Function

In [34]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T
 
def scaled_dot_product_attention(q, k, v, mask=None):
    d_k = q.shape[-1]

    scaled = np.matmul(q, k.T) / math.sqrt(d_k)
    
    if mask is not None:
        scaled = scaled + mask
    
    attention = softmax(scaled)

    out = np.matmul(attention, v)
    
    return out, attention

In [35]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 0.29768014  1.46816034 -1.32397561 -0.76432103  0.38979776 -0.49550435
  -1.33092564 -1.72066063]
 [-0.05143646  1.38880016  0.79178977  1.58027136  0.00966491 -0.43574485
  -0.87423673 -0.0353768 ]
 [ 1.07928344 -0.83696223 -0.25376834  1.12296124 -0.55213534 -0.32317013
  -0.88158989  0.36128264]
 [ 0.85476158  0.24667341 -1.08318915 -0.92442848 -2.22248362  0.06401839
  -0.44924721 -2.22433039]]
K
 [[-0.09587018 -0.03241947 -0.8745195  -1.16439671  0.58552362 -0.90144685
   0.31727494 -0.81312424]
 [ 0.380725    0.51069887 -0.54695683  1.68399384 -1.15499012 -0.17677576
  -0.84175614 -0.71653484]
 [ 0.33919297  0.04711329  0.53569413 -2.11925621  0.69489994  0.17131378
   0.31011484  0.98110667]
 [ 1.0598199  -0.22156219 -0.31718686  0.77735678 -0.8814225  -0.89105165
  -1.21505333  0.91815065]]
V
 [[-0.13140865  0.87150421  0.79974773  0.58881187  0.18071253  0.53201066
  -1.55340947 -0.79907812]
 [-0.0704488  -1.49683304  0.31476469  0.06035425 -0.45877017 -0.54012935
   0.5