In [1]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [2]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[ 0.27425579  3.30726302  1.21639574 -1.30345721  0.58099169 -0.10331169
   1.20447982  1.82315813]
 [-1.14932628 -0.08398634  1.95051774 -0.11457736 -1.64998464 -0.20648265
   0.28634909 -0.1754566 ]
 [-0.35376869  0.0277062   0.7418196   0.58293707 -0.4950439   0.51551094
   0.83226284  0.37152595]
 [ 0.95910223  1.04487724 -1.206584    0.71541962  0.85301133  1.16689068
   0.84551017  2.17667541]]
K
 [[-0.48211478 -0.38406614 -0.17979803  0.78520215 -0.96527334 -0.6630798
   0.62012532 -2.20334705]
 [-0.78279528  1.06238304  0.08317115 -0.8090412  -0.20314522 -0.1864577
   0.55461076 -1.45649221]
 [ 0.65025037  0.79775616  0.02508248 -1.56677952 -1.0167757  -0.10593784
   0.06236182 -0.27459765]
 [-0.69599466  1.03763871  0.40960599  1.97186402  0.45891292  1.17874799
  -1.13045595 -0.83790218]]
V
 [[ 2.41934434 -2.94570438 -0.15651827  0.92741815  0.53501952  0.07621312
  -1.25600643  0.59494809]
 [-1.53663081 -0.70187285 -0.26943615 -0.96015249  0.85452404  0.023894
   0.07307

Self Attention

In [3]:
np.matmul(q, k.T)

array([[-6.40704703,  2.36845307,  3.88414966, -1.57552724],
       [ 2.43946264,  1.85343637,  1.17966667,  0.1085104 ],
       [ 0.31779627, -0.17865697, -0.70404484,  0.85663308],
       [-5.95378869, -3.41212043, -1.22987658,  0.3204361 ]])

In [4]:
q.var(), k.var(), np.matmul(q, k.T).var()

(1.0788823578891231, 0.806879409904545, 7.6709803912297705)

In [5]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(1.0788823578891231, 0.806879409904545, 0.9588725489037212)

In [6]:
scaled

array([[-2.2652332 ,  0.83737461,  1.37325428, -0.557033  ],
       [ 0.86248029,  0.65528871,  0.41707515,  0.03836422],
       [ 0.11235795, -0.06316478, -0.24891744,  0.30286553],
       [-2.10498218, -1.20636675, -0.43482703,  0.11329127]])

#### Masking

- This is to ensure words don't get context from words generated in the future.
- Not required in the encoders, but required in the decoders

In [7]:
mask = np.tril(np.ones((L, L)))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [8]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [9]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [10]:
scaled + mask

array([[-2.2652332 ,        -inf,        -inf,        -inf],
       [ 0.86248029,  0.65528871,        -inf,        -inf],
       [ 0.11235795, -0.06316478, -0.24891744,        -inf],
       [-2.10498218, -1.20636675, -0.43482703,  0.11329127]])

#### Softmax

In [11]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [12]:
attention = softmax(scaled + mask)

In [13]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.55161339, 0.44838661, 0.        , 0.        ],
       [0.39435203, 0.33086858, 0.2747794 , 0.        ],
       [0.0556773 , 0.13675459, 0.29581307, 0.51175503]])

In [14]:
new_v = np.matmul(attention, v)
new_v

array([[ 2.41934434, -2.94570438, -0.15651827,  0.92741815,  0.53501952,
         0.07621312, -1.25600643,  0.59494809],
       [ 0.64553804, -1.93960035, -0.20714914,  0.08105674,  0.67828107,
         0.05275392, -0.66006318,  0.79319493],
       [ 0.52619549, -1.39053142,  0.48949317,  0.12928276,  0.72021807,
        -0.46282972, -0.53121545,  0.75850133],
       [-0.95636001,  0.08486834,  0.4403292 , -0.71580263, -0.47047009,
        -0.82875674,  0.06609433,  0.25165542]])

In [15]:
v

array([[ 2.41934434, -2.94570438, -0.15651827,  0.92741815,  0.53501952,
         0.07621312, -1.25600643,  0.59494809],
       [-1.53663081, -0.70187285, -0.26943615, -0.96015249,  0.85452404,
         0.023894  ,  0.07307707,  1.03708182],
       [ 0.29312605,  0.01215789,  2.33046738,  0.29564742,  0.82428626,
        -1.8225176 , -0.21866879,  0.65778061],
       [-1.8908105 ,  0.66685244, -0.39763628, -1.41393801, -1.6823543 ,
        -0.58063555,  0.37267215, -0.23033578]])

In [16]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
    d_k = q.shape[-1]
    scaled = np.matmul(q, k.T) / math.sqrt(d_k)
    if mask is not None:
        scaled = scaled + mask
    attention = softmax(scaled)
    out = np.matmul(attention, v)
    return out, attention

In [17]:
values, attention = scaled_dot_product_attention(q, k, v, mask=None)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 0.27425579  3.30726302  1.21639574 -1.30345721  0.58099169 -0.10331169
   1.20447982  1.82315813]
 [-1.14932628 -0.08398634  1.95051774 -0.11457736 -1.64998464 -0.20648265
   0.28634909 -0.1754566 ]
 [-0.35376869  0.0277062   0.7418196   0.58293707 -0.4950439   0.51551094
   0.83226284  0.37152595]
 [ 0.95910223  1.04487724 -1.206584    0.71541962  0.85301133  1.16689068
   0.84551017  2.17667541]]
K
 [[-0.48211478 -0.38406614 -0.17979803  0.78520215 -0.96527334 -0.6630798
   0.62012532 -2.20334705]
 [-0.78279528  1.06238304  0.08317115 -0.8090412  -0.20314522 -0.1864577
   0.55461076 -1.45649221]
 [ 0.65025037  0.79775616  0.02508248 -1.56677952 -1.0167757  -0.10593784
   0.06236182 -0.27459765]
 [-0.69599466  1.03763871  0.40960599  1.97186402  0.45891292  1.17874799
  -1.13045595 -0.83790218]]
V
 [[ 2.41934434 -2.94570438 -0.15651827  0.92741815  0.53501952  0.07621312
  -1.25600643  0.59494809]
 [-1.53663081 -0.70187285 -0.26943615 -0.96015249  0.85452404  0.023894
   0.07307

In [18]:
# For decoder pass mask

values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 0.27425579  3.30726302  1.21639574 -1.30345721  0.58099169 -0.10331169
   1.20447982  1.82315813]
 [-1.14932628 -0.08398634  1.95051774 -0.11457736 -1.64998464 -0.20648265
   0.28634909 -0.1754566 ]
 [-0.35376869  0.0277062   0.7418196   0.58293707 -0.4950439   0.51551094
   0.83226284  0.37152595]
 [ 0.95910223  1.04487724 -1.206584    0.71541962  0.85301133  1.16689068
   0.84551017  2.17667541]]
K
 [[-0.48211478 -0.38406614 -0.17979803  0.78520215 -0.96527334 -0.6630798
   0.62012532 -2.20334705]
 [-0.78279528  1.06238304  0.08317115 -0.8090412  -0.20314522 -0.1864577
   0.55461076 -1.45649221]
 [ 0.65025037  0.79775616  0.02508248 -1.56677952 -1.0167757  -0.10593784
   0.06236182 -0.27459765]
 [-0.69599466  1.03763871  0.40960599  1.97186402  0.45891292  1.17874799
  -1.13045595 -0.83790218]]
V
 [[ 2.41934434 -2.94570438 -0.15651827  0.92741815  0.53501952  0.07621312
  -1.25600643  0.59494809]
 [-1.53663081 -0.70187285 -0.26943615 -0.96015249  0.85452404  0.023894
   0.07307