In [1]:
import numpy as np
import math

In [2]:
L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [3]:
print('Q\n', q)
print('K\n', k)
print('V\n', v)

Q
 [[-1.80431154 -0.37368754  0.18010639 -0.41175358 -0.30394883  0.44828695
   1.53403912 -1.24055926]
 [ 0.03852399 -0.12737531  0.73924372  0.12165477  1.73989877  0.89172102
   1.55223778  1.55887451]
 [ 0.54602632 -1.64355645 -0.78197596  0.90004299 -0.49636728  0.22885404
  -0.13821963  0.59972321]
 [-0.22494579  0.09870965  0.54050982  0.01470106  0.90192531 -0.14246449
   0.83499709  0.11098957]]
K
 [[-0.48213113  1.47288856 -1.8496946  -2.25106863 -1.35438539  0.93500355
  -0.11675162  0.97428812]
 [ 0.72348294 -0.19864244 -1.58457345  1.51892988  0.3553269  -0.81143822
  -0.61858356  0.53543886]
 [ 0.99838551 -0.54932221  2.27421914  1.85088532  0.28980389  1.20969071
  -0.89098249 -1.06620449]
 [-0.61397051 -0.33762844 -0.41239007  0.32558912  0.39238482 -0.3161828
  -0.88655283  0.22469547]]
V
 [[ 0.10021491 -0.32656936 -0.29499322  1.69677639  0.18486615  0.60787452
   0.13903236 -0.80996387]
 [-0.21982156  0.18985089  3.26693063 -0.2989055  -0.84915124 -2.368887
   0.0858

## Self Attention

In [4]:
np.matmul(q, k.T)

array([[ 0.35630847, -4.2269084 , -1.53854002, -0.87413524],
       [-2.03257612, -1.16427776,  0.5526435 , -0.87100011],
       [-1.77698109,  3.37226327,  0.9521956 ,  0.82535806],
       [-2.12314023, -1.03750789,  0.20437988, -0.42971504]])

In [5]:
q.var(), k.var(), np.matmul(q, k.T).var()

(0.7166690029479453, 1.144581532816517, 2.724213784644609)

In [6]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(0.7166690029479453, 1.144581532816517, 0.34052672308057613)

In [7]:
scaled

array([[ 0.12597407, -1.4944378 , -0.54395604, -0.30905348],
       [-0.71862418, -0.41163435,  0.19538898, -0.30794504],
       [-0.62825769,  1.19227511,  0.33665198,  0.29180814],
       [-0.75064343, -0.36681443,  0.0722592 , -0.15192721]])

## Masking

In [8]:
mask = np.tril(np.ones((L, L)))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [9]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [10]:
scaled + mask

array([[ 0.12597407,        -inf,        -inf,        -inf],
       [-0.71862418, -0.41163435,        -inf,        -inf],
       [-0.62825769,  1.19227511,  0.33665198,        -inf],
       [-0.75064343, -0.36681443,  0.0722592 , -0.15192721]])

## Softmax

In [11]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [12]:
attention = softmax(scaled + mask)

In [13]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.42384966, 0.57615034, 0.        , 0.        ],
       [0.10204396, 0.63013651, 0.26781953, 0.        ],
       [0.15232814, 0.22360162, 0.34686638, 0.27720387]])

In [14]:
new_v = np.matmul(attention, v)
new_v

array([[ 0.10021491, -0.32656936, -0.29499322,  1.69677639,  0.18486615,
         0.60787452,  0.13903236, -0.80996387],
       [-0.08417421, -0.02903365,  1.75721043,  0.54696358, -0.41088332,
        -1.10718765,  0.10838188, -0.90574662],
       [ 0.17906353, -0.49147281,  2.05953944,  0.06241751, -0.51053434,
        -1.7528393 ,  0.13041618, -0.60688879],
       [-0.146911  , -0.42730744,  0.41758077,  0.28600241,  0.24678896,
        -0.66475584,  0.10564082, -0.2294829 ]])

In [15]:
v

array([[ 0.10021491, -0.32656936, -0.29499322,  1.69677639,  0.18486615,
         0.60787452,  0.13903236, -0.80996387],
       [-0.21982156,  0.18985089,  3.26693063, -0.2989055 , -0.84915124,
        -2.368887  ,  0.08583361, -0.97620998],
       [ 1.14761907, -2.15734962,  0.11585952,  0.28983326,  0.02121723,
        -1.20285114,  0.23202891,  0.33944009],
       [-1.84374977,  1.18432453, -1.11168081, -0.02222925,  1.44709633,
         0.68384095, -0.0548812 , -0.02006045]])

## Function

In [16]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
    d_k = q.shape[-1]
    scaled = np.matmul(q, k.T) / math.sqrt(d_k)
    if mask is not None:
        scaled = scaled + mask
    attention = softmax(scaled)
    out = np.matmul(attention, v)
    return out, attention

In [17]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print('Q\n', q)
print('K\n', k)
print('V\n', v)
print('New V\n', values)
print('Attention\n', attention)

Q
 [[-1.80431154 -0.37368754  0.18010639 -0.41175358 -0.30394883  0.44828695
   1.53403912 -1.24055926]
 [ 0.03852399 -0.12737531  0.73924372  0.12165477  1.73989877  0.89172102
   1.55223778  1.55887451]
 [ 0.54602632 -1.64355645 -0.78197596  0.90004299 -0.49636728  0.22885404
  -0.13821963  0.59972321]
 [-0.22494579  0.09870965  0.54050982  0.01470106  0.90192531 -0.14246449
   0.83499709  0.11098957]]
K
 [[-0.48213113  1.47288856 -1.8496946  -2.25106863 -1.35438539  0.93500355
  -0.11675162  0.97428812]
 [ 0.72348294 -0.19864244 -1.58457345  1.51892988  0.3553269  -0.81143822
  -0.61858356  0.53543886]
 [ 0.99838551 -0.54932221  2.27421914  1.85088532  0.28980389  1.20969071
  -0.89098249 -1.06620449]
 [-0.61397051 -0.33762844 -0.41239007  0.32558912  0.39238482 -0.3161828
  -0.88655283  0.22469547]]
V
 [[ 0.10021491 -0.32656936 -0.29499322  1.69677639  0.18486615  0.60787452
   0.13903236 -0.80996387]
 [-0.21982156  0.18985089  3.26693063 -0.2989055  -0.84915124 -2.368887
   0.0858