## Self Attention in Transformers


Generate Data


In [11]:
import numpy as np
import math

L, d_k, d_v = 4, 10, 10
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [12]:
print("Q\n", q)
print("K\n", k)
print("V\n", v) 

Q
 [[ 2.1242154  -0.05097662 -0.41177697  1.2434784   0.21638206  0.40670121
   1.37134177  0.13827728 -0.24159657  1.31941207]
 [-1.65945225  2.5133612  -0.52408406  0.52430126  1.29746741 -0.24841566
  -1.35109741 -1.50804762 -0.69676327  0.17439543]
 [ 0.74734326  0.01635423 -1.81193073  0.48258882  0.02638898 -1.00850241
  -2.42754604  0.64763816 -1.17444421 -0.41234463]
 [-1.02032181 -0.58109091 -0.97416148 -1.13679723 -1.2481006   0.52204578
   0.00593693 -0.27470854  0.70069493  1.94716725]]
K
 [[-0.06408409 -1.26976691  0.5749145  -0.10629726  2.16337022 -0.38491943
  -0.41933464 -1.04573974  1.02006385 -0.14139777]
 [ 1.63077553  0.10346948  1.91457149 -0.64366476 -0.18445495 -0.66989645
  -1.27784937  0.11644065  1.12501153  0.43821596]
 [-1.8321942   0.42343224 -0.24660732 -0.76167373 -1.19830539  0.53104853
   2.13644932 -0.84111169  0.3407909  -1.37863616]
 [-0.20837821 -1.51194795 -1.24852003 -1.36934492 -0.40535539  1.16545723
   0.76523327  1.04657534  1.68085271  1.149

## Scaled Dot-Product Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$ 

In [9]:
np.matmul(q, k.T)

array([[-3.55199318,  0.97288475, -0.2141719 , -0.74254354],
       [-1.35526289,  2.62442942,  0.35446297, -2.05105865],
       [ 1.63129485,  3.14958207,  0.46746191, -1.3425128 ],
       [-1.60039187,  1.77116277, -1.17169757,  1.20568295]])

In [10]:
# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(0.9001736505696285, 1.3136031103635024, 3.146410416897136)

In [13]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(1.21872256513619, 1.1430141103758429, 1.3286904104194095)



Notice the reduction in variance of the product


In [14]:
scaled

array([[-0.40521635,  0.04042821, -1.23021428,  0.35967032],
       [ 0.27468612, -0.95388936,  0.01641138, -2.50330607],
       [-0.4792087 , -0.06736921, -2.34423138, -1.07302556],
       [-0.57326447, -0.43479068,  0.72745546,  2.56537949]])

Masking

    - This is to ensure words don't get context from words generated in the future.
    - Not required in the encoders, but required int he decoders



In [15]:
mask = np.tril(np.ones( (L, L) ))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [16]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [17]:
scaled + mask

array([[-0.40521635,        -inf,        -inf,        -inf],
       [ 0.27468612, -0.95388936,        -inf,        -inf],
       [-0.4792087 , -0.06736921, -2.34423138,        -inf],
       [-0.57326447, -0.43479068,  0.72745546,  2.56537949]])

## Softmax

$$
\text{softmax} = \frac{e^{x_i}}{\sum_j e^x_j}
$$

In [18]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [20]:
attention = softmax(scaled)
attention

array([[0.19423197, 0.30329251, 0.08511942, 0.4173561 ],
       [0.47008929, 0.13759948, 0.36308898, 0.02922225],
       [0.31087758, 0.46929833, 0.04815267, 0.17167142],
       [0.03461044, 0.03975077, 0.12708743, 0.79855137]])

In [21]:
attention = softmax(scaled + mask)
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.77356915, 0.22643085, 0.        , 0.        ],
       [0.37530708, 0.5665606 , 0.05813233, 0.        ],
       [0.03461044, 0.03975077, 0.12708743, 0.79855137]])

In [22]:
new_v = np.matmul(attention, v)
new_v

array([[-0.87140596, -0.04071352,  2.41630292, -0.4193706 ,  2.35233482,
        -0.22222797, -2.14203221, -0.19885411, -0.18682697,  1.00402161],
       [-0.57026872,  0.23399815,  1.88201651, -0.05627454,  2.02505647,
        -0.47075369, -1.80229218, -0.08534686, -0.18927814,  0.99904136],
       [-0.07051071,  0.70531439,  0.89880557,  0.51359861,  1.30353882,
        -0.79452002, -1.15085628,  0.06711615, -0.21707625,  1.03246021],
       [-0.96067747,  2.29573774, -0.52686951, -0.75006147, -1.43108213,
        -1.19656409, -1.52237979,  0.78141992, -0.16604796,  0.07718431]])

In [24]:
v

array([[-8.71405964e-01, -4.07135166e-02,  2.41630292e+00,
        -4.19370604e-01,  2.35233482e+00, -2.22227973e-01,
        -2.14203221e+00, -1.98854107e-01, -1.86826970e-01,
         1.00402161e+00],
       [ 4.58524346e-01,  1.17251196e+00,  5.67021406e-02,
         1.18419210e+00,  9.06955998e-01, -1.31980687e+00,
        -6.41618132e-01,  3.02434683e-01, -1.97652208e-01,
         9.82027039e-01],
       [-5.58676086e-02,  9.68400703e-01, -6.91099233e-01,
         1.28634942e-03, -1.60245819e+00,  6.30187748e-01,
         2.85196222e-01, -5.09184335e-01, -6.01675745e-01,
         1.70758649e+00],
       [-1.17919071e+00,  2.66415804e+00, -6.57343739e-01,
        -9.80253645e-01, -1.68417141e+00, -1.52338139e+00,
        -1.82703734e+00,  1.05314611e+00, -9.42450930e-02,
        -2.67502422e-01]])

## Function

In [25]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
  d_k = q.shape[-1]
  scaled = np.matmul(q, k.T) / math.sqrt(d_k)
  if mask is not None:
    scaled = scaled + mask
  attention = softmax(scaled)
  out = np.matmul(attention, v)
  return out, attention

In [27]:
values, attention = scaled_dot_product_attention(q, k, v)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 2.1242154  -0.05097662 -0.41177697  1.2434784   0.21638206  0.40670121
   1.37134177  0.13827728 -0.24159657  1.31941207]
 [-1.65945225  2.5133612  -0.52408406  0.52430126  1.29746741 -0.24841566
  -1.35109741 -1.50804762 -0.69676327  0.17439543]
 [ 0.74734326  0.01635423 -1.81193073  0.48258882  0.02638898 -1.00850241
  -2.42754604  0.64763816 -1.17444421 -0.41234463]
 [-1.02032181 -0.58109091 -0.97416148 -1.13679723 -1.2481006   0.52204578
   0.00593693 -0.27470854  0.70069493  1.94716725]]
K
 [[-0.06408409 -1.26976691  0.5749145  -0.10629726  2.16337022 -0.38491943
  -0.41933464 -1.04573974  1.02006385 -0.14139777]
 [ 1.63077553  0.10346948  1.91457149 -0.64366476 -0.18445495 -0.66989645
  -1.27784937  0.11644065  1.12501153  0.43821596]
 [-1.8321942   0.42343224 -0.24660732 -0.76167373 -1.19830539  0.53104853
   2.13644932 -0.84111169  0.3407909  -1.37863616]
 [-0.20837821 -1.51194795 -1.24852003 -1.36934492 -0.40535539  1.16545723
   0.76523327  1.04657534  1.68085271  1.149

In [26]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 2.1242154  -0.05097662 -0.41177697  1.2434784   0.21638206  0.40670121
   1.37134177  0.13827728 -0.24159657  1.31941207]
 [-1.65945225  2.5133612  -0.52408406  0.52430126  1.29746741 -0.24841566
  -1.35109741 -1.50804762 -0.69676327  0.17439543]
 [ 0.74734326  0.01635423 -1.81193073  0.48258882  0.02638898 -1.00850241
  -2.42754604  0.64763816 -1.17444421 -0.41234463]
 [-1.02032181 -0.58109091 -0.97416148 -1.13679723 -1.2481006   0.52204578
   0.00593693 -0.27470854  0.70069493  1.94716725]]
K
 [[-0.06408409 -1.26976691  0.5749145  -0.10629726  2.16337022 -0.38491943
  -0.41933464 -1.04573974  1.02006385 -0.14139777]
 [ 1.63077553  0.10346948  1.91457149 -0.64366476 -0.18445495 -0.66989645
  -1.27784937  0.11644065  1.12501153  0.43821596]
 [-1.8321942   0.42343224 -0.24660732 -0.76167373 -1.19830539  0.53104853
   2.13644932 -0.84111169  0.3407909  -1.37863616]
 [-0.20837821 -1.51194795 -1.24852003 -1.36934492 -0.40535539  1.16545723
   0.76523327  1.04657534  1.68085271  1.149