In [1]:
import numpy as np
import math 

L, d_k, d_v = 4,8,8
q = np.random.rand(L, d_k)
k = np.random.rand(L, d_k)
v = np.random.rand(L, d_v)

In [2]:
print("Query (q):", q)
print("Key (k):", k)
print("Value (v):", v)

Query (q): [[0.03804345 0.14047541 0.3922199  0.11841733 0.77818584 0.37189161
  0.8484695  0.57376775]
 [0.45076212 0.43299163 0.35443178 0.36221555 0.8420644  0.97925817
  0.32987794 0.32044776]
 [0.92737552 0.83521621 0.79393371 0.04742947 0.72977081 0.73897854
  0.25233071 0.29189709]
 [0.39761686 0.10225246 0.56884807 0.82885827 0.33576961 0.87406877
  0.86853091 0.78706834]]
Key (k): [[0.22632587 0.55259542 0.85119661 0.35770502 0.6182692  0.60176701
  0.16856799 0.8310279 ]
 [0.51847161 0.3884267  0.14162048 0.99902546 0.61664483 0.8844136
  0.38515788 0.60073248]
 [0.57231284 0.29113161 0.54804465 0.02444004 0.79062592 0.39249552
  0.52194399 0.14666225]
 [0.779745   0.16220611 0.12817453 0.70254974 0.25724824 0.85201203
  0.97423032 0.81656838]]
Value (v): [[0.08902852 0.38781269 0.49968116 0.74165462 0.13852491 0.60615516
  0.01318295 0.56203986]
 [0.79318981 0.79775557 0.54488489 0.62042129 0.99231183 0.89416733
  0.19494001 0.27932318]
 [0.34553113 0.39220165 0.61692035 0.5

Self Attention

                        self attention = softmax((Q.K^t / root d_k) + M) V

In [3]:
np.matmul(q, k.T)

array([[1.78721325, 1.72838306, 1.56874104, 1.9980852 ],
       [2.20436141, 2.51883259, 1.85641929, 2.35562089],
       [2.54518239, 2.34116996, 2.25171133, 2.29521026],
       [2.46124756, 2.94190963, 1.76663557, 3.30178924]])

In [4]:
# If we didn't do scaling
q.var(),k.var(),np.matmul(q, k.T).var()

(0.08209684324885873, 0.0750886873680431, 0.20091428192891453)

In [5]:
# With scaling
scaled = np.matmul(q, k.T) / math.sqrt(d_k)      # Dividing by root d_k for scaling
q.var(),k.var(),scaled.var()

(0.08209684324885873, 0.0750886873680431, 0.025114285241114313)

In [6]:
scaled

array([[0.6318753 , 0.61107569, 0.55463371, 0.7064298 ],
       [0.77935945, 0.8905418 , 0.65634333, 0.83283775],
       [0.89985786, 0.82772858, 0.79610018, 0.81147937],
       [0.87018242, 1.04012213, 0.62459999, 1.16735878]])

Masking -
1) This is to ensure words don't get context from words generated in the future
2) Not required in encoders but required in decoders

In [7]:
mask = np.tril(np.ones((L, L)))  # Lower triangular matrix
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [8]:
mask[mask == 0] = -np.inf
mask[mask == 1] = 0

In [9]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [10]:
scaled + mask

array([[0.6318753 ,       -inf,       -inf,       -inf],
       [0.77935945, 0.8905418 ,       -inf,       -inf],
       [0.89985786, 0.82772858, 0.79610018,       -inf],
       [0.87018242, 1.04012213, 0.62459999, 1.16735878]])

Softmax

In [11]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=1)).T

In [12]:
attention = softmax(scaled + mask)

In [13]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.47223301, 0.52776699, 0.        , 0.        ],
       [0.35312551, 0.32855172, 0.31832277, 0.        ],
       [0.23182843, 0.27477079, 0.18134751, 0.31205327]])

In [14]:
new_v = np.matmul(attention, v)
new_v

array([[0.08902852, 0.38781269, 0.49968116, 0.74165462, 0.13852491,
        0.60615516, 0.01318295, 0.56203986],
       [0.46066161, 0.60416701, 0.5235382 , 0.67767167, 0.58912546,
        0.75815847, 0.10910833, 0.41283133],
       [0.40203254, 0.52389723, 0.55185283, 0.62694851, 0.41379373,
        0.77817116, 0.34206303, 0.34036656],
       [0.30635951, 0.39353128, 0.47814102, 0.63679817, 0.33605283,
        0.84974548, 0.46673054, 0.26717886]])

In [None]:
v       # values get distinct compared to new_vas the attention weights are applied

array([[0.08902852, 0.38781269, 0.49968116, 0.74165462, 0.13852491,
        0.60615516, 0.01318295, 0.56203986],
       [0.79318981, 0.79775557, 0.54488489, 0.62042129, 0.99231183,
        0.89416733, 0.19494001, 0.27932318],
       [0.34553113, 0.39220165, 0.61692035, 0.50643838, 0.12205   ,
        0.84927036, 0.85875075, 0.1574623 ],
       [0.01638684, 0.04262268, 0.32271831, 0.64907688, 0.02931246,
        0.99187343, 0.81517565, 0.10119041]])