In [1]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [2]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[-0.53021291  0.17351938 -2.80372733  1.36853866 -0.62573526 -0.78760894
   0.87646577 -0.29423695]
 [-0.61079107 -0.28797028  0.19836336  0.16409689 -0.31869415  1.38278105
   0.25201184  1.22194168]
 [ 0.47424825  1.86374193 -0.18145084 -0.2654385  -0.01710023  1.4925224
   0.04318913 -1.14576111]
 [-0.67504673  1.13966909 -0.18227342 -0.89253254 -1.11796724 -0.47959207
  -1.61684476 -0.38093655]]
K
 [[-1.50791136e+00 -1.02548933e+00  1.32186843e+00  6.69433702e-01
  -7.11392191e-01 -6.09670864e-04  3.10016980e-01 -5.51820988e-01]
 [-6.25872602e-01  1.85842578e+00  6.16654571e-01 -1.36516740e-01
  -8.68447851e-01  3.19752629e-01  5.32315132e-01 -1.88929267e+00]
 [-1.74681454e-01 -6.22473374e-01 -2.06277244e+00 -1.45441534e-01
  -1.31098537e+00  1.86493034e-01 -3.30760078e-01  1.82804623e+00]
 [-1.94555521e-01 -6.24075896e-01 -1.43470780e+00 -7.47771927e-01
   6.39369278e-01 -2.50337344e+00  1.21469967e+00  3.35458618e-01]]
V
 [[-0.75019022  0.40213689  0.87469443 -0.08750707  0.3

In [None]:

#Self-attention
np.matmul(q, k.T)

array([[-1.2887322 ,  0.05259302,  5.41468384,  5.53158304],
       [ 1.21809987, -1.50851289,  2.67899554, -3.05810347],
       [-2.38702001,  5.77090827, -2.63810085, -4.87573555],
       [-0.48467569,  3.22650468,  1.12892056, -1.25695118]])

In [4]:

# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(np.float64(0.9364650972271772),
 np.float64(1.086444058695338),
 np.float64(10.097741326679849))

In [5]:

scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(np.float64(0.9364650972271772),
 np.float64(1.086444058695338),
 np.float64(1.262217665834981))

In [6]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(np.float64(0.9364650972271772),
 np.float64(1.086444058695338),
 np.float64(1.262217665834981))

In [None]:
scaled

array([[-0.45563564,  0.01859444,  1.91437983,  1.95570994],
       [ 0.43066334, -0.53333985,  0.94716796, -1.08120285],
       [-0.84393902,  2.04032419, -0.9327095 , -1.72383284],
       [-0.17135873,  1.14074167,  0.39913369, -0.44439935]])

In [8]:
#Masking

mask = np.tril(np.ones( (L, L) ))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [12]:
mask[mask == 0] = -np.inf
mask[mask == 1] = 0

In [13]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [None]:
scaled + mask

array([[-0.45563564,        -inf,        -inf,        -inf],
       [ 0.43066334, -0.53333985,        -inf,        -inf],
       [-0.84393902,  2.04032419, -0.9327095 ,        -inf],
       [-0.17135873,  1.14074167,  0.39913369, -0.44439935]])

In [15]:
#Now we use softmax to get attention weights

def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [None]:

#With mask we remove data leakage
attention = softmax(scaled + mask)

In [18]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.72392259, 0.27607741, 0.        , 0.        ],
       [0.05049119, 0.90330657, 0.04620224, 0.        ],
       [0.13804211, 0.51268376, 0.24421555, 0.10505859]])

In [19]:

new_v = np.matmul(attention, v)
new_v

array([[-0.75019022,  0.40213689,  0.87469443, -0.08750707,  0.30595976,
         0.57752085, -0.66289836, -1.41503872],
       [-0.60766979,  0.21784661,  0.44589257, -0.27115811, -0.04690101,
         0.73796781, -0.39412598, -1.13985976],
       [-0.25616189, -0.16222222, -0.50053149, -0.62164293, -0.77786182,
         1.01675176,  0.2689651 , -0.46597971],
       [-0.1764691 ,  0.08129623,  0.06401155,  0.10919793,  0.270461  ,
         0.24657159,  0.23857381, -0.65186896]])

In [20]:
v

array([[-0.75019022,  0.40213689,  0.87469443, -0.08750707,  0.30595976,
         0.57752085, -0.66289836, -1.41503872],
       [-0.23395663, -0.26539431, -0.6784999 , -0.7527228 , -0.97216284,
         1.15868743,  0.31064158, -0.41829304],
       [-0.1504113 ,  1.23816146,  1.47606625,  1.35739857,  1.8365123 ,
        -1.27824809,  0.47251054, -0.36114874],
       [ 0.79733874, -1.33763958, -0.66016079,  1.67229083,  2.64740769,
        -1.09484413,  0.52757604, -1.46474318]])

### Final function to summarize whole thing!

In [21]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
  d_k = q.shape[-1]
  scaled = np.matmul(q, k.T) / math.sqrt(d_k)
  if mask is not None:
    scaled = scaled + mask
  attention = softmax(scaled)
  out = np.matmul(attention, v)
  return out, attention

In [22]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[-0.53021291  0.17351938 -2.80372733  1.36853866 -0.62573526 -0.78760894
   0.87646577 -0.29423695]
 [-0.61079107 -0.28797028  0.19836336  0.16409689 -0.31869415  1.38278105
   0.25201184  1.22194168]
 [ 0.47424825  1.86374193 -0.18145084 -0.2654385  -0.01710023  1.4925224
   0.04318913 -1.14576111]
 [-0.67504673  1.13966909 -0.18227342 -0.89253254 -1.11796724 -0.47959207
  -1.61684476 -0.38093655]]
K
 [[-1.50791136e+00 -1.02548933e+00  1.32186843e+00  6.69433702e-01
  -7.11392191e-01 -6.09670864e-04  3.10016980e-01 -5.51820988e-01]
 [-6.25872602e-01  1.85842578e+00  6.16654571e-01 -1.36516740e-01
  -8.68447851e-01  3.19752629e-01  5.32315132e-01 -1.88929267e+00]
 [-1.74681454e-01 -6.22473374e-01 -2.06277244e+00 -1.45441534e-01
  -1.31098537e+00  1.86493034e-01 -3.30760078e-01  1.82804623e+00]
 [-1.94555521e-01 -6.24075896e-01 -1.43470780e+00 -7.47771927e-01
   6.39369278e-01 -2.50337344e+00  1.21469967e+00  3.35458618e-01]]
V
 [[-0.75019022  0.40213689  0.87469443 -0.08750707  0.3