In [2]:
import numpy as np
import math

l, d_k, d_v = 4, 8, 8
q = np.random.randn(l, d_k)
k = np.random.randn(l, d_k)
v = np.random.randn(l, d_v)

In [3]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[ 0.94142301 -0.80811881  0.2406522  -0.46755052  0.10591556  1.80989183
  -1.64695504  1.04237871]
 [-1.02135628  0.55656575 -0.01198069 -0.75074019 -0.03859188  0.90127641
  -0.88331595 -0.09496225]
 [-1.42517688 -0.13237865  0.43719959  1.15830376  1.28994537 -0.08678479
  -0.91320545  0.30338664]
 [ 0.65252126 -0.36736193  1.28117239 -1.0100903   0.8692625   0.20136011
   0.73339642  0.37824704]]
K
 [[-1.28743798 -0.43084084 -0.20666856  1.05117486 -0.43817876 -0.83808158
   0.17144337 -2.11779588]
 [-1.80326155  0.01636247 -0.67958714 -1.4289214  -0.44625137  0.63996727
  -0.34637386  1.0625549 ]
 [-1.4160158  -0.39064908  0.36067154  1.30137457 -0.28661542  0.39955587
   0.51092093 -2.14149157]
 [ 0.29136595 -0.78595696 -0.9465062  -1.48521404  1.30378315  0.26926597
   1.21011638  0.78784062]]
V
 [[-0.74766703  0.36752208 -1.25601433  0.1143262   0.39920838 -1.11201019
  -0.39908666 -0.00720719]
 [-0.37161097  1.32498083 -0.31729656 -3.02009093  0.08730038  0.15698311
  -0.0

In [4]:
np.matmul(q, k.T)

array([[-5.45821756,  1.58274743, -3.91995407,  0.82973322],
       [-0.40030258,  3.73083398,  0.37074702, -0.5600381 ],
       [ 1.72751743,  0.62305652,  2.21418978, -1.65296892],
       [-3.23332836, -0.72117288, -2.23688894,  3.13945847]])

In [5]:
q.var(), k.var(), np.matmul(q, k.T).var()

(0.7178461432778144, 0.9592170200092118, 6.171352903968691)

In [7]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(0.7178461432778144, 0.9592170200092118, 0.7714191129960863)

In [8]:
mask = np.tril(np.ones((l , l)))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [9]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [10]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [11]:
scaled + mask

array([[-1.92977133,        -inf,        -inf,        -inf],
       [-0.14152834,  1.319049  ,        -inf,        -inf],
       [ 0.61076964,  0.22028374,  0.7828343 ,        -inf],
       [-1.1431542 , -0.25497312, -0.79085967,  1.10996619]])

In [12]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [13]:
attention = softmax(scaled + mask)

In [14]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.18837904, 0.81162096, 0.        , 0.        ],
       [0.34910318, 0.23624789, 0.41464892, 0.        ],
       [0.06958739, 0.16914638, 0.09897605, 0.66229018]])

In [15]:
new_v = np.matmul(attention, v)
new_v

array([[-0.74766703,  0.36752208, -1.25601433,  0.1143262 ,  0.39920838,
        -1.11201019, -0.39908666, -0.00720719],
       [-0.44245205,  1.14461567, -0.49413131, -2.42963244,  0.14605731,
        -0.08206863, -0.15012494, -1.52597841],
       [-0.64483242, -0.51064496, -0.83868634, -1.46603082,  0.1020694 ,
        -0.13347182,  0.13984064, -0.47987799],
       [ 0.43279494, -0.42422158,  1.067846  , -0.94509136, -0.06802195,
         0.03074479, -0.20770026,  0.14337322]])

In [16]:
v

array([[-0.74766703,  0.36752208, -1.25601433,  0.1143262 ,  0.39920838,
        -1.11201019, -0.39908666, -0.00720719],
       [-0.37161097,  1.32498083, -0.31729656, -3.02009093,  0.08730038,
         0.15698311, -0.09234037, -1.8784886 ],
       [-0.71392243, -2.29585068, -0.78439151, -1.9111405 , -0.13968454,
         0.52489599,  0.72586293, -0.08096722],
       [ 0.93364062, -0.67444448,  1.942584  , -0.38208695, -0.14607333,
         0.04472575, -0.3565701 ,  0.7090972 ]])

In [17]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
  d_k = q.shape[-1]
  scaled = np.matmul(q, k.T) / math.sqrt(d_k)
  if mask is not None:
    scaled = scaled + mask
  attention = softmax(scaled)
  out = np.matmul(attention, v)
  return out, attention

In [19]:
values, attention = scaled_dot_product_attention(q, k, v, mask=None)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 0.94142301 -0.80811881  0.2406522  -0.46755052  0.10591556  1.80989183
  -1.64695504  1.04237871]
 [-1.02135628  0.55656575 -0.01198069 -0.75074019 -0.03859188  0.90127641
  -0.88331595 -0.09496225]
 [-1.42517688 -0.13237865  0.43719959  1.15830376  1.28994537 -0.08678479
  -0.91320545  0.30338664]
 [ 0.65252126 -0.36736193  1.28117239 -1.0100903   0.8692625   0.20136011
   0.73339642  0.37824704]]
K
 [[-1.28743798 -0.43084084 -0.20666856  1.05117486 -0.43817876 -0.83808158
   0.17144337 -2.11779588]
 [-1.80326155  0.01636247 -0.67958714 -1.4289214  -0.44625137  0.63996727
  -0.34637386  1.0625549 ]
 [-1.4160158  -0.39064908  0.36067154  1.30137457 -0.28661542  0.39955587
   0.51092093 -2.14149157]
 [ 0.29136595 -0.78595696 -0.9465062  -1.48521404  1.30378315  0.26926597
   1.21011638  0.78784062]]
V
 [[-0.74766703  0.36752208 -1.25601433  0.1143262   0.39920838 -1.11201019
  -0.39908666 -0.00720719]
 [-0.37161097  1.32498083 -0.31729656 -3.02009093  0.08730038  0.15698311
  -0.0

In [20]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 0.94142301 -0.80811881  0.2406522  -0.46755052  0.10591556  1.80989183
  -1.64695504  1.04237871]
 [-1.02135628  0.55656575 -0.01198069 -0.75074019 -0.03859188  0.90127641
  -0.88331595 -0.09496225]
 [-1.42517688 -0.13237865  0.43719959  1.15830376  1.28994537 -0.08678479
  -0.91320545  0.30338664]
 [ 0.65252126 -0.36736193  1.28117239 -1.0100903   0.8692625   0.20136011
   0.73339642  0.37824704]]
K
 [[-1.28743798 -0.43084084 -0.20666856  1.05117486 -0.43817876 -0.83808158
   0.17144337 -2.11779588]
 [-1.80326155  0.01636247 -0.67958714 -1.4289214  -0.44625137  0.63996727
  -0.34637386  1.0625549 ]
 [-1.4160158  -0.39064908  0.36067154  1.30137457 -0.28661542  0.39955587
   0.51092093 -2.14149157]
 [ 0.29136595 -0.78595696 -0.9465062  -1.48521404  1.30378315  0.26926597
   1.21011638  0.78784062]]
V
 [[-0.74766703  0.36752208 -1.25601433  0.1143262   0.39920838 -1.11201019
  -0.39908666 -0.00720719]
 [-0.37161097  1.32498083 -0.31729656 -3.02009093  0.08730038  0.15698311
  -0.0