In [1]:
import numpy as np
import math

In [3]:
L, d_k, d_v = 4, 8, 8

q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[-1.81050378 -0.24897168 -2.26884808  1.10849275 -1.57123851 -1.58251502
  -0.8258621  -1.18132231]
 [ 1.44839782  0.17246017  0.87394158 -0.99372012 -0.41441144 -0.59054794
   1.58063832 -0.88785146]
 [-1.2315141  -1.5809589   0.42536996 -0.79258902  2.54660601 -0.56873598
  -0.21982826 -0.09892881]
 [-2.89024226  0.20176127 -0.1206807  -0.76937137  0.08701138  0.57072396
   1.28525246  0.85153431]]
K
 [[-0.23323083 -0.75191959  1.47204636 -0.49110547 -0.72249873 -0.11749525
   1.27161585  1.74945649]
 [-0.92549353  1.88537757 -0.00339091 -1.27223034 -0.51699712  0.19816817
   0.58770827  0.88543809]
 [-1.18828683 -1.32452554  0.83091507 -0.84525652  1.44333182  0.27031521
  -0.20778694  1.29604248]
 [ 0.31188761 -1.49117412  2.13415465 -0.22833452 -0.67549589  1.84580266
   1.85052021 -0.21984828]]
V
 [[-1.77765873  1.40105675 -0.57268865  0.46924732 -2.38076693  1.466956
  -0.66044631 -2.44672177]
 [ 0.16348705 -0.58464001 -0.62697152 -1.3981173  -1.90586158  0.31570041
   0.107

In [4]:
np.matmul(q, k.T)

array([[-5.07045992, -1.22899281, -4.39605061, -8.41680152],
       [ 2.13252227,  0.48598556, -2.62031681,  4.59669034],
       [ 0.26568553, -2.48011628,  8.02012076, -0.0928719 ],
       [ 3.71672693,  5.61197262,  4.83366724,  1.9016761 ]])

In [5]:
q.var(), k.var(), np.matmul(q, k.T).var()

(1.4074323633697174, 1.155268970772072, 18.107191946568555)

In [12]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(1.4074323633697174, 1.155268970772072, 2.2633989933210694)

In [13]:
mask = np.tril(np.ones((L, L)))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [14]:
mask[mask==0] = -np.infty
mask[mask==1] = 0

mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [15]:
scaled + mask

array([[-1.7926783 ,        -inf,        -inf,        -inf],
       [ 0.75396048,  0.17182184,        -inf,        -inf],
       [ 0.09393402, -0.87685352,  2.83554089,        -inf],
       [ 1.31406141,  1.98413195,  1.70895944,  0.67234403]])

In [16]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [23]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.shape[-1]
    scaled = np.matmul(q, k.T) / math.sqrt(d_k)
    if mask is not None:
        scaled = scaled + mask
    attention = softmax(scaled)
    out = np.matmul(attention, v)
    return out, attention

In [25]:
values, attention = scaled_dot_product(q, k, v)

print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("Values\n", values)
print("Attention\n", attention)

Q
 [[-1.81050378 -0.24897168 -2.26884808  1.10849275 -1.57123851 -1.58251502
  -0.8258621  -1.18132231]
 [ 1.44839782  0.17246017  0.87394158 -0.99372012 -0.41441144 -0.59054794
   1.58063832 -0.88785146]
 [-1.2315141  -1.5809589   0.42536996 -0.79258902  2.54660601 -0.56873598
  -0.21982826 -0.09892881]
 [-2.89024226  0.20176127 -0.1206807  -0.76937137  0.08701138  0.57072396
   1.28525246  0.85153431]]
K
 [[-0.23323083 -0.75191959  1.47204636 -0.49110547 -0.72249873 -0.11749525
   1.27161585  1.74945649]
 [-0.92549353  1.88537757 -0.00339091 -1.27223034 -0.51699712  0.19816817
   0.58770827  0.88543809]
 [-1.18828683 -1.32452554  0.83091507 -0.84525652  1.44333182  0.27031521
  -0.20778694  1.29604248]
 [ 0.31188761 -1.49117412  2.13415465 -0.22833452 -0.67549589  1.84580266
   1.85052021 -0.21984828]]
V
 [[-1.77765873  1.40105675 -0.57268865  0.46924732 -2.38076693  1.466956
  -0.66044631 -2.44672177]
 [ 0.16348705 -0.58464001 -0.62697152 -1.3981173  -1.90586158  0.31570041
   0.107

In [28]:
values, attention = scaled_dot_product(q, k, v, mask=mask)

print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("Values\n", values)
print("Attention\n", attention)

Q
 [[-1.81050378 -0.24897168 -2.26884808  1.10849275 -1.57123851 -1.58251502
  -0.8258621  -1.18132231]
 [ 1.44839782  0.17246017  0.87394158 -0.99372012 -0.41441144 -0.59054794
   1.58063832 -0.88785146]
 [-1.2315141  -1.5809589   0.42536996 -0.79258902  2.54660601 -0.56873598
  -0.21982826 -0.09892881]
 [-2.89024226  0.20176127 -0.1206807  -0.76937137  0.08701138  0.57072396
   1.28525246  0.85153431]]
K
 [[-0.23323083 -0.75191959  1.47204636 -0.49110547 -0.72249873 -0.11749525
   1.27161585  1.74945649]
 [-0.92549353  1.88537757 -0.00339091 -1.27223034 -0.51699712  0.19816817
   0.58770827  0.88543809]
 [-1.18828683 -1.32452554  0.83091507 -0.84525652  1.44333182  0.27031521
  -0.20778694  1.29604248]
 [ 0.31188761 -1.49117412  2.13415465 -0.22833452 -0.67549589  1.84580266
   1.85052021 -0.21984828]]
V
 [[-1.77765873  1.40105675 -0.57268865  0.46924732 -2.38076693  1.466956
  -0.66044631 -2.44672177]
 [ 0.16348705 -0.58464001 -0.62697152 -1.3981173  -1.90586158  0.31570041
   0.107