In [2]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [3]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[ 1.23767444  0.18502693 -1.09854368  0.16503678  0.31023597 -0.37130178
   0.6124875   0.25187821]
 [-1.37761389 -0.7662307  -1.4039775   1.49568658  0.74314376 -2.19738003
   0.71133727  2.01005643]
 [ 1.20876575 -0.27331337 -0.65803996  0.80016005  0.66394185 -2.26971528
  -0.49590017  0.02441049]
 [-1.51969392  0.43542701  0.23870059  0.23507748 -0.63689157  0.89787709
  -0.04342489  0.97911984]]
K
 [[ 0.08344936  1.30839146 -0.1798441  -0.0232427  -0.10818522  0.8712742
  -0.74525497  0.22306002]
 [-0.35561789 -0.29492574 -0.04721697 -0.46935789 -0.3424444   0.33506727
   0.16858585  0.62867649]
 [-0.43523471 -0.35258237 -0.41215324 -0.234798    0.34846414 -1.61257461
  -0.7165962  -0.44720911]
 [-0.4904113   0.73225942 -0.21706811  0.88411808 -0.65757577  0.95325741
  -0.20484205  0.0970765 ]]
V
 [[ 3.85965083e-02 -3.25931476e-01  4.19487217e-01  5.29292944e-01
  -5.92452147e-01 -2.96825284e-01 -2.23785032e-01  8.50729990e-01]
 [-1.83609861e-01  7.60972745e-01 -6.70625975e-01

Self Attention

In [5]:
np.matmul(q, k.T)

array([[-0.2182425 , -0.4893428 , -0.03458858, -0.74607259],
       [-2.9764395 ,  0.47300563,  3.49095555, -0.79228462],
       [-1.83133939, -1.74986828,  3.88949474, -2.4389173 ],
       [ 1.49646326,  1.4175835 , -1.72225815,  2.59879978]])

In [6]:
q.var(), k.var(), np.matmul(q, k.T).var()

(1.0413895219605402, 0.3467872290805614, 4.0121944848665745)

In [7]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(1.0413895219605402, 0.3467872290805614, 0.5015243106083216)

In [8]:
scaled

array([[-0.07716037, -0.17300881, -0.01222891, -0.26377649],
       [-1.05233028,  0.16723275,  1.23423917, -0.28011491],
       [-0.64747625, -0.61867186,  1.37514405, -0.86228748],
       [ 0.52907966,  0.50119145, -0.60891021,  0.91881447]])

#### Masking

- This is to ensure words don't get context from words generated in the future.
- Not required in the encoders, but required in the decoders

In [9]:
mask = np.tril(np.ones((L, L)))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [11]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [12]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [13]:
scaled + mask

array([[-0.07716037,        -inf,        -inf,        -inf],
       [-1.05233028,  0.16723275,        -inf,        -inf],
       [-0.64747625, -0.61867186,  1.37514405,        -inf],
       [ 0.52907966,  0.50119145, -0.60891021,  0.91881447]])

#### Softmax

In [14]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [15]:
attention = softmax(scaled + mask)

In [16]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.22801336, 0.77198664, 0.        , 0.        ],
       [0.10430436, 0.10735247, 0.78834316, 0.        ],
       [0.26528374, 0.25798766, 0.0850135 , 0.3917151 ]])

In [17]:
new_v = np.matmul(attention, v)
new_v

array([[ 0.03859651, -0.32593148,  0.41948722,  0.52929294, -0.59245215,
        -0.29682528, -0.22378503,  0.85072999],
       [-0.13294384,  0.51314406, -0.4220656 ,  1.00735511,  1.25318309,
        -0.17738227,  1.44321142,  0.6386532 ],
       [-0.31212626, -0.42379857, -0.99943955,  0.56102371, -0.65796138,
         0.4561752 , -0.17627354,  0.87586935],
       [ 0.73538019,  0.30081276, -0.80842003,  0.47791425,  0.3219984 ,
        -0.16800258,  0.24965289,  0.66426098]])

In [18]:
v

array([[ 3.85965083e-02, -3.25931476e-01,  4.19487217e-01,
         5.29292944e-01, -5.92452147e-01, -2.96825284e-01,
        -2.23785032e-01,  8.50729990e-01],
       [-1.83609861e-01,  7.60972745e-01, -6.70625975e-01,
         1.14855517e+00,  1.79830844e+00, -1.42103673e-01,
         1.93557416e+00,  5.76014368e-01],
       [-3.76030492e-01, -5.98083204e-01, -1.23195150e+00,
         4.85214972e-01, -1.00111085e+00,  6.37273937e-01,
        -4.57567819e-01,  9.20028195e-01],
       [ 2.05373222e+00,  6.17287037e-01, -1.63883759e+00,
        -1.56868984e-04,  2.56137375e-01, -2.72584720e-01,
        -3.86595733e-01,  5.40588637e-01]])

In [25]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
    d_k = q.shape[-1]
    scaled = np.matmul(q, k.T) / math.sqrt(d_k)
    if mask is not None:
        scaled = scaled + mask
    attention = softmax(scaled)
    out = np.matmul(attention, v)
    return out, attention

In [26]:
values, attention = scaled_dot_product_attention(q, k, v, mask=None)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 1.23767444  0.18502693 -1.09854368  0.16503678  0.31023597 -0.37130178
   0.6124875   0.25187821]
 [-1.37761389 -0.7662307  -1.4039775   1.49568658  0.74314376 -2.19738003
   0.71133727  2.01005643]
 [ 1.20876575 -0.27331337 -0.65803996  0.80016005  0.66394185 -2.26971528
  -0.49590017  0.02441049]
 [-1.51969392  0.43542701  0.23870059  0.23507748 -0.63689157  0.89787709
  -0.04342489  0.97911984]]
K
 [[ 0.08344936  1.30839146 -0.1798441  -0.0232427  -0.10818522  0.8712742
  -0.74525497  0.22306002]
 [-0.35561789 -0.29492574 -0.04721697 -0.46935789 -0.3424444   0.33506727
   0.16858585  0.62867649]
 [-0.43523471 -0.35258237 -0.41215324 -0.234798    0.34846414 -1.61257461
  -0.7165962  -0.44720911]
 [-0.4904113   0.73225942 -0.21706811  0.88411808 -0.65757577  0.95325741
  -0.20484205  0.0970765 ]]
V
 [[ 3.85965083e-02 -3.25931476e-01  4.19487217e-01  5.29292944e-01
  -5.92452147e-01 -2.96825284e-01 -2.23785032e-01  8.50729990e-01]
 [-1.83609861e-01  7.60972745e-01 -6.70625975e-01

In [27]:
# For decoder pass mask

values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 1.23767444  0.18502693 -1.09854368  0.16503678  0.31023597 -0.37130178
   0.6124875   0.25187821]
 [-1.37761389 -0.7662307  -1.4039775   1.49568658  0.74314376 -2.19738003
   0.71133727  2.01005643]
 [ 1.20876575 -0.27331337 -0.65803996  0.80016005  0.66394185 -2.26971528
  -0.49590017  0.02441049]
 [-1.51969392  0.43542701  0.23870059  0.23507748 -0.63689157  0.89787709
  -0.04342489  0.97911984]]
K
 [[ 0.08344936  1.30839146 -0.1798441  -0.0232427  -0.10818522  0.8712742
  -0.74525497  0.22306002]
 [-0.35561789 -0.29492574 -0.04721697 -0.46935789 -0.3424444   0.33506727
   0.16858585  0.62867649]
 [-0.43523471 -0.35258237 -0.41215324 -0.234798    0.34846414 -1.61257461
  -0.7165962  -0.44720911]
 [-0.4904113   0.73225942 -0.21706811  0.88411808 -0.65757577  0.95325741
  -0.20484205  0.0970765 ]]
V
 [[ 3.85965083e-02 -3.25931476e-01  4.19487217e-01  5.29292944e-01
  -5.92452147e-01 -2.96825284e-01 -2.23785032e-01  8.50729990e-01]
 [-1.83609861e-01  7.60972745e-01 -6.70625975e-01