In [1]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [2]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[-0.02342616 -0.18716334 -1.73346516  0.52789542  1.42154486 -1.46894137
   0.13851992 -1.56341084]
 [-0.35804576  0.80050323 -0.63886135  0.53806547  0.82521337  0.02598343
  -0.55472615  0.33875249]
 [ 0.52783449  0.3844467   1.58022069 -1.23535155  0.11972285  0.27601495
   0.06323297 -0.39270406]
 [-0.74044956  0.26270041 -0.12779019 -2.56851282  0.28333497  0.50577462
  -0.01517249 -1.19995946]]
K
 [[-1.2721121   1.09135684  2.40890221 -0.80410101 -0.97490223 -0.06513947
   1.71550349  3.06753451]
 [-0.81158362  0.3878751   0.78358842  0.84077022  0.78522692  0.31115562
   0.07651727  1.02937312]
 [ 2.60162308  0.07117347 -0.51445639  0.41881586 -0.03531282 -1.11458765
  -1.49240581  0.84990205]
 [ 0.33031148  0.15732307 -0.38543283  2.07617959  0.38670256  1.49890119
   0.33853783  0.87330501]]
V
 [[-0.89252585  0.74905082 -0.49263931  0.03493668 -0.25496666 -1.47050847
  -0.54195602  0.49376536]
 [ 0.91337644  1.22186658  0.2680119  -0.60423314 -0.17160254  0.55892619
   1.8

In [3]:
np.matmul(q, k.T)

array([[-10.62305706,  -1.90763618,   1.09020725,  -1.24356618],
       [ -1.36119908,   1.51518564,   0.73717266,   1.83712686],
       [  3.31719401,  -0.29917607,  -0.66975048,  -2.80057588],
       [ -1.02997086,  -2.41334243,  -4.48860481,  -5.67208404]])

In [4]:
q.var(), k.var(), np.matmul(q, k.T).var()

(0.8027879987605494, 1.2205949433628949, 10.615283839224102)

In [5]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()


(0.8027879987605494, 1.2205949433628949, 1.3269104799030127)

In [6]:
scaled


array([[-3.75581784, -0.67445124,  0.38544647, -0.43966704],
       [-0.48125655,  0.53569902,  0.26062989,  0.64952243],
       [ 1.17280519, -0.10577471, -0.23679255, -0.9901531 ],
       [-0.36414969, -0.8532454 , -1.58696145, -2.00538455]])

In [7]:
mask = np.tril(np.ones( (L, L) ))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [8]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0


In [9]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [10]:
scaled + mask

array([[-3.75581784,        -inf,        -inf,        -inf],
       [-0.48125655,  0.53569902,        -inf,        -inf],
       [ 1.17280519, -0.10577471, -0.23679255,        -inf],
       [-0.36414969, -0.8532454 , -1.58696145, -2.00538455]])

In [11]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [12]:
attention = softmax(scaled + mask)

In [13]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.26562084, 0.73437916, 0.        , 0.        ],
       [0.65673942, 0.18285755, 0.16040303, 0.        ],
       [0.47589078, 0.29180701, 0.14010282, 0.09219939]])

In [14]:
new_v = np.matmul(attention, v)
new_v

array([[-0.89252585,  0.74905082, -0.49263931,  0.03493668, -0.25496666,
        -1.47050847, -0.54195602,  0.49376536],
       [ 0.43369115,  1.09627686,  0.06596708, -0.43445632, -0.19374579,
         0.01986605,  1.20445656, -0.61156691],
       [-0.35825661,  0.65492252, -0.12618337, -0.10502515, -0.29978786,
        -0.55392562, -0.17513489,  0.30951719],
       [-0.11326728,  0.63647892, -0.02901007, -0.19347618, -0.22179851,
        -0.31727115,  0.25691451,  0.11716336]])

In [15]:
v

array([[-0.89252585,  0.74905082, -0.49263931,  0.03493668, -0.25496666,
        -1.47050847, -0.54195602,  0.49376536],
       [ 0.91337644,  1.22186658,  0.2680119 , -0.60423314, -0.17160254,
         0.55892619,  1.83612425, -1.01135943],
       [ 0.37955973, -0.3767772 ,  0.92482225, -0.10898084, -0.62942943,
         1.93020916, -0.96606767,  1.06093207],
       [-0.08925758, -0.25758453, -0.02543965, -0.20080647,  0.40995321,
        -0.55310637,  1.24058596,  0.31092586]])

In [16]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
  d_k = q.shape[-1]
  scaled = np.matmul(q, k.T) / math.sqrt(d_k)
  if mask is not None:
    scaled = scaled + mask
  attention = softmax(scaled)
  out = np.matmul(attention, v)
  return out, attention

In [17]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[-0.02342616 -0.18716334 -1.73346516  0.52789542  1.42154486 -1.46894137
   0.13851992 -1.56341084]
 [-0.35804576  0.80050323 -0.63886135  0.53806547  0.82521337  0.02598343
  -0.55472615  0.33875249]
 [ 0.52783449  0.3844467   1.58022069 -1.23535155  0.11972285  0.27601495
   0.06323297 -0.39270406]
 [-0.74044956  0.26270041 -0.12779019 -2.56851282  0.28333497  0.50577462
  -0.01517249 -1.19995946]]
K
 [[-1.2721121   1.09135684  2.40890221 -0.80410101 -0.97490223 -0.06513947
   1.71550349  3.06753451]
 [-0.81158362  0.3878751   0.78358842  0.84077022  0.78522692  0.31115562
   0.07651727  1.02937312]
 [ 2.60162308  0.07117347 -0.51445639  0.41881586 -0.03531282 -1.11458765
  -1.49240581  0.84990205]
 [ 0.33031148  0.15732307 -0.38543283  2.07617959  0.38670256  1.49890119
   0.33853783  0.87330501]]
V
 [[-0.89252585  0.74905082 -0.49263931  0.03493668 -0.25496666 -1.47050847
  -0.54195602  0.49376536]
 [ 0.91337644  1.22186658  0.2680119  -0.60423314 -0.17160254  0.55892619
   1.8