In [1]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [2]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[ 0.852226   -0.71171872  0.53089096  0.40084639  0.49603539 -0.35678672
   0.22365277  0.49172783]
 [ 0.33404883  0.4606918  -0.94216657 -1.05591888  1.123324    0.28033847
   0.37775621  0.99682997]
 [-0.66011769  1.04671902  1.63264508  1.90964659 -1.6512248   1.18788751
   0.02165373  0.99070253]
 [-0.57703345 -1.43415514  1.05907658 -1.07064769  0.20220469  1.57024659
  -0.19029823  0.0401484 ]]
K
 [[ 0.10526603  0.20051261 -0.5214993   0.19327802  1.05996943 -0.38725877
   1.0175988  -1.39186109]
 [-0.29628162  0.91106117 -0.18968601  0.65798033  0.10061668  1.52273672
  -0.96709365  1.08485234]
 [ 1.81358826  0.23180914 -0.92950598 -0.0116655   1.72159405  0.29452381
  -1.08987465 -0.15328583]
 [ 0.44674617 -1.24497462  0.27120307  1.05538752  0.95548548  0.68625328
  -0.42239781 -1.51177171]]
V
 [[ 0.72285241  0.10539189 -0.33320704 -0.01506525  2.232198    0.71390796
   0.26893955  0.36988868]
 [ 0.59585095  2.09261223  0.01945972 -1.10371443 -0.90067813 -0.55410663
  -0.6

In [3]:
np.matmul(q, k.T)

array([[-0.04525951, -0.91409563,  1.31222278,  1.22508588],
       [ 0.49387279,  1.06068235,  3.05265432, -2.19507641],
       [-3.90909131,  4.7925497 , -5.16271116, -1.40924029],
       [-1.75083779,  0.59801281, -1.33904329,  1.97545272]])

In [4]:
# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(0.7973117972700422, 0.7762168562728577, 5.9542115400914355)

In [5]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(0.7973117972700422, 0.7762168562728577, 0.7442764425114294)

In [6]:
scaled

array([[-0.01600165, -0.32318161,  0.46394081,  0.43313327],
       [ 0.1746104 ,  0.37500784,  1.07927629, -0.77607671],
       [-1.38207249,  1.6944222 , -1.82529404, -0.49824168],
       [-0.61901464,  0.21142946, -0.4734233 ,  0.69842801]])

In [7]:
#Masking
# * This is to ensure words don't get context from words generated in the future.
# * Not required in the encoders but in the decoders

In [8]:
mask = np.tril(np.ones ( (L, L) ) )
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [9]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [10]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [11]:
scaled + mask

array([[-0.01600165,        -inf,        -inf,        -inf],
       [ 0.1746104 ,  0.37500784,        -inf,        -inf],
       [-1.38207249,  1.6944222 , -1.82529404,        -inf],
       [-0.61901464,  0.21142946, -0.4734233 ,  0.69842801]])

In [12]:
# Softmax
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [13]:
attention = softmax(scaled + mask)

In [14]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.45006763, 0.54993237, 0.        , 0.        ],
       [0.04287387, 0.92960261, 0.02752352, 0.        ],
       [0.12217589, 0.28031271, 0.14132371, 0.4561877 ]])

In [15]:
new_v = np.matmul(attention, v)
new_v

array([[ 0.72285241,  0.10539189, -0.33320704, -0.01506525,  2.232198  ,
         0.71390796,  0.26893955,  0.36988868],
       [ 0.6530102 ,  1.19822868, -0.13926417, -0.61374867,  0.50932801,
         0.01658569, -0.23501451,  0.68311345],
       [ 0.60667419,  2.0031956 ,  0.012607  , -1.00689075, -0.74611758,
        -0.47128119, -0.57227953,  0.89215656],
       [-0.17213533,  0.46402792,  1.23059846, -0.37714974,  0.33318289,
         0.03872264,  0.21205131,  0.33494046]])

In [16]:
v

array([[ 0.72285241,  0.10539189, -0.33320704, -0.01506525,  2.232198  ,
         0.71390796,  0.26893955,  0.36988868],
       [ 0.59585095,  2.09261223,  0.01945972, -1.10371443, -0.90067813,
        -0.55410663, -0.64745324,  0.93945829],
       [ 0.79125464,  1.93940499,  0.31983798,  0.71832996, -0.16523325,
         0.4799455 ,  0.65631903,  0.10809378],
       [-1.18218466, -0.89769655,  2.67576884, -0.3670441 ,  0.73716373,
         0.08548183,  0.58732254,  0.02439932]])

In [17]:
# Softmax
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
    d_k = q.shape[-1]
    scaled = np.matmul(q, k.T) / math.sqrt(d_k)
    if mask is not None:
        scaled = scaled + mask
    attention = softmax(scaled)
    out = np.matmul(attention, v)
    return out, attention

In [18]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 0.852226   -0.71171872  0.53089096  0.40084639  0.49603539 -0.35678672
   0.22365277  0.49172783]
 [ 0.33404883  0.4606918  -0.94216657 -1.05591888  1.123324    0.28033847
   0.37775621  0.99682997]
 [-0.66011769  1.04671902  1.63264508  1.90964659 -1.6512248   1.18788751
   0.02165373  0.99070253]
 [-0.57703345 -1.43415514  1.05907658 -1.07064769  0.20220469  1.57024659
  -0.19029823  0.0401484 ]]
K
 [[ 0.10526603  0.20051261 -0.5214993   0.19327802  1.05996943 -0.38725877
   1.0175988  -1.39186109]
 [-0.29628162  0.91106117 -0.18968601  0.65798033  0.10061668  1.52273672
  -0.96709365  1.08485234]
 [ 1.81358826  0.23180914 -0.92950598 -0.0116655   1.72159405  0.29452381
  -1.08987465 -0.15328583]
 [ 0.44674617 -1.24497462  0.27120307  1.05538752  0.95548548  0.68625328
  -0.42239781 -1.51177171]]
V
 [[ 0.72285241  0.10539189 -0.33320704 -0.01506525  2.232198    0.71390796
   0.26893955  0.36988868]
 [ 0.59585095  2.09261223  0.01945972 -1.10371443 -0.90067813 -0.55410663
  -0.6