# Self Attention for Transformer Neural Networks

In [2]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8 # L is length of input seq
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [3]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[ 0.03150329  1.35143838  1.31164657  0.70608428 -1.99686372  1.57223659
   0.23142044 -0.78647482]
 [-0.65195725  0.00220972  2.0445276   1.84680798 -0.58622211 -0.30384119
   0.83235954  1.41259665]
 [-0.57780887 -0.19889084  0.55294506  0.18665258  1.73284592  1.5020829
   0.2913727  -0.65308217]
 [ 0.16997107  0.11478303 -0.5818222  -0.5968622   0.61287048  0.98165716
  -0.24918441  1.27734386]]
K
 [[-0.24171735 -0.26598006 -0.04842253 -1.22726851 -0.63681599  0.10354473
   1.94568937 -0.11057033]
 [-1.4858841   0.72175437  0.54311908 -1.75825835  0.75474015  0.12813026
   0.33796925 -0.52707357]
 [-1.09577162 -0.20531537 -1.68999833 -1.21844415 -1.07981043 -0.48485809
   2.40551271 -0.21745657]
 [-1.97456094  1.30160785  1.19910193 -0.23477499 -1.49570263  0.46499355
  -0.4010039  -0.31465564]]
V
 [[-0.92598431  2.1010246   0.41446724 -0.93862317 -0.76331392 -0.41016499
  -1.88031595 -0.6489421 ]
 [ 0.2738382   1.14945643 -2.24596163  0.62672257 -0.69421247  0.71430608
   0.04

### Self Attention

$$ \text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg) $$

$$ \text{new V} = \text{self attention}.V $$

In [4]:
np.matmul(q, k.T)

# Here, it focuses on the affinity of the word towards another words. 1st line is for 1st word and each columns represents affinity towards another word

array([[ 0.67452581, -0.41342103, -1.26736471,  6.97632708],
       [-0.40335257, -2.11102285, -2.51613643,  3.26548703],
       [-0.37211935,  2.63014624, -2.24444445, -0.30345187],
       [-0.22564354,  0.39459415, -0.5142258 , -1.50595883]])

In [5]:
# Why we need sqrt(d_k) in denominator (to stabilize the product)
q.var(), k.var(), np.matmul(q, k.T).var()

(0.8868725621213633, 1.0660734565357852, 5.448324327093019)

In [6]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(0.8868725621213633, 1.0660734565357852, 0.6810405408866272)

In [7]:
scaled

array([[ 0.23848089, -0.14616641, -0.44808109,  2.46650409],
       [-0.14260667, -0.74635929, -0.88958857,  1.15452401],
       [-0.13156406,  0.92989712, -0.79353095, -0.10728644],
       [-0.07977704,  0.1395101 , -0.18180627, -0.53243685]])

### Masking

This is to ensure words don't get context from words generated in the future

Not required in encoders, but required in the decoders

In [8]:
# creating a triangular matrix and will simulate the fact that the first word will 
# only look at itself, second will look at 1 and 2 and so on..
mask = np.tril(np.ones( (L, L) ))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [9]:
# Transforming 0 to inf and 1 to 0. 0 and inf coz of the softmax operation
mask[mask == 0] = -np.infty
mask[mask == 1] = 0
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [10]:
# masking to get no context from future words
scaled + mask

array([[ 0.23848089,        -inf,        -inf,        -inf],
       [-0.14260667, -0.74635929,        -inf,        -inf],
       [-0.13156406,  0.92989712, -0.79353095,        -inf],
       [-0.07977704,  0.1395101 , -0.18180627, -0.53243685]])

### Softmax

$$ \text{softmax} = \frac{e^{x_i}}{\sum_j e^x_j} $$

In [11]:
# convert a vector into a probability distribution
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [14]:
attention = softmax(scaled + mask)
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.64651438, 0.35348562, 0.        , 0.        ],
       [0.22694122, 0.65599438, 0.11706439, 0.        ],
       [0.26426179, 0.32905581, 0.23862923, 0.16805317]])

In [15]:
# Better encapsulate context of a word
new_v = np.matmul(attention, v)
new_v

array([[-0.92598431,  2.1010246 ,  0.41446724, -0.93862317, -0.76331392,
        -0.41016499, -1.88031595, -0.6489421 ],
       [-0.5018643 ,  1.76465893, -0.52595612, -0.38529596, -0.73888755,
        -0.01268063, -1.20097096, -0.76045016],
       [-0.1859505 ,  1.18446658, -1.28570613,  0.21612375, -0.54497209,
         0.42179768, -0.46763433, -0.87200564],
       [-0.69232082,  0.74700734, -0.02429212, -0.10406648, -0.22222414,
         0.42886921, -0.73350902, -0.5449603 ]])

In [16]:
v

array([[-0.92598431,  2.1010246 ,  0.41446724, -0.93862317, -0.76331392,
        -0.41016499, -1.88031595, -0.6489421 ],
       [ 0.2738382 ,  1.14945643, -2.24596163,  0.62672257, -0.69421247,
         0.71430608,  0.04153021, -0.96439499],
       [-1.32784019, -0.39618772,  0.79932404,  0.15384312,  0.71460485,
         0.39551092, -0.58221566, -0.78671425],
       [-1.3142574 , -0.54688957,  2.46639004, -0.58887559,  0.22254625,
         1.23671045, -0.66256083,  0.78310387]])

In [17]:
# Whole function
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
  d_k = q.shape[-1]
  scaled = np.matmul(q, k.T) / math.sqrt(d_k)
  if mask is not None:
    scaled = scaled + mask
  attention = softmax(scaled)
  out = np.matmul(attention, v)
  return out, attention

In [18]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 0.03150329  1.35143838  1.31164657  0.70608428 -1.99686372  1.57223659
   0.23142044 -0.78647482]
 [-0.65195725  0.00220972  2.0445276   1.84680798 -0.58622211 -0.30384119
   0.83235954  1.41259665]
 [-0.57780887 -0.19889084  0.55294506  0.18665258  1.73284592  1.5020829
   0.2913727  -0.65308217]
 [ 0.16997107  0.11478303 -0.5818222  -0.5968622   0.61287048  0.98165716
  -0.24918441  1.27734386]]
K
 [[-0.24171735 -0.26598006 -0.04842253 -1.22726851 -0.63681599  0.10354473
   1.94568937 -0.11057033]
 [-1.4858841   0.72175437  0.54311908 -1.75825835  0.75474015  0.12813026
   0.33796925 -0.52707357]
 [-1.09577162 -0.20531537 -1.68999833 -1.21844415 -1.07981043 -0.48485809
   2.40551271 -0.21745657]
 [-1.97456094  1.30160785  1.19910193 -0.23477499 -1.49570263  0.46499355
  -0.4010039  -0.31465564]]
V
 [[-0.92598431  2.1010246   0.41446724 -0.93862317 -0.76331392 -0.41016499
  -1.88031595 -0.6489421 ]
 [ 0.2738382   1.14945643 -2.24596163  0.62672257 -0.69421247  0.71430608
   0.04