In [None]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [None]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[-0.84543245 -0.52081905  0.55021128  0.79945822  0.92674455 -0.69706346
  -0.47514018  1.96004714]
 [ 0.44075183  0.98787987 -1.70070881  1.65625301  1.50906187  1.03956577
  -0.42002314 -0.95275528]
 [ 0.13741238 -0.06368303  0.12997533 -0.03770637 -1.28842835 -1.0042427
   1.55279742  0.64740795]
 [ 0.04753675  1.23275493 -0.49823731 -0.77458834  1.06250167 -0.79027712
  -0.28100786 -1.17684878]]
K
 [[ 0.16553263 -0.62029741  0.16454589  1.92687308 -0.25380693 -1.28127959
   0.37637922  0.20199905]
 [ 0.99902253 -0.41723682 -0.40106573 -1.68420824 -0.97821926 -0.25346888
   0.48163566  0.07444992]
 [-0.61401083  1.57437555  0.99657773 -0.63440352  1.47530764 -0.31123916
   0.17916518  1.14488319]
 [ 1.86128341  0.99221747  1.84987683  1.24174935 -0.24867515 -0.30516681
   1.21464081  1.57477463]]
V
 [[ 0.31643031  1.95333079 -1.35441478  1.80368225 -0.94033831 -1.48521246
  -0.33849278 -0.15601087]
 [ 0.4298108   0.67032049 -0.04052544  0.52608237  2.17850863  1.04441503
  -0.72

In [None]:
# Self attention
np.matmul(q, k.T)

array([[ 2.68911934, -3.00722076,  3.48337229,  2.41196482],
       [ 0.30619578, -4.09215943, -0.72422385, -1.99195425],
       [ 2.33992436,  2.48621746, -0.60003695,  3.91866686],
       [-1.93191433, -0.02447515,  2.32228132, -2.7895285 ]])

In [None]:
# We need square root of d_k on denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(np.float64(0.9189056896328487),
 np.float64(0.9193535801425889),
 np.float64(6.150014617446438))

In [None]:
# division
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(np.float64(0.9189056896328487),
 np.float64(0.9193535801425889),
 np.float64(0.7687518271808046))

In [None]:
# Masking
mask = np.tril(np.ones((L,L)))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [None]:
mask[mask == 0] = -np.inf
mask[mask == 1] = 0

In [None]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [None]:
scaled + mask

array([[ 0.95074726,        -inf,        -inf,        -inf],
       [ 0.10825656, -1.44679684,        -inf,        -inf],
       [ 0.82728819,  0.87901061, -0.2121451 ,        -inf],
       [-0.68303486, -0.00865327,  0.82105044, -0.98624726]])

In [None]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [None]:
attention = softmax(scaled + mask)

In [None]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.8256424 , 0.1743576 , 0.        , 0.        ],
       [0.41550007, 0.43755623, 0.1469437 , 0.        ],
       [0.12193194, 0.2393303 , 0.54869806, 0.09003969]])

In [None]:
new_v = np.matmul(attention, v)
new_v

array([[ 0.31643031,  1.95333079, -1.35441478,  1.80368225, -0.94033831,
        -1.48521246, -0.33849278, -0.15601087],
       [ 0.33619906,  1.7296282 , -1.12532819,  1.580923  , -0.39654365,
        -1.04415268, -0.40666199, -0.45659907],
       [ 0.3982562 ,  1.11145059, -0.74629685,  1.11701268,  0.44570594,
        -0.18585055, -0.48356288, -0.74721685],
       [ 0.58480509,  0.3808359 , -0.73644504,  0.96730777,  0.06856579,
        -0.0633084 , -0.31988114,  0.10078896]])