# Notes

**Question**

* Why do we need attention mechanism?
* Why do we need sqrt(d_k) in denominator (self attention formula)?
    * Keep the numbers (KQ^{T}) on the right scale
* Why do we need Masking?
    * Helps the model do not cheat (by looking words in the future)
    * we don't need in the enconders, just in the decoders

**Huh**

* [Truncated Backpropagation](https://arxiv.org/abs/1705.08209) 
* BRNN

**Aha**

* RNNs are slow, since we need to feed one token at time
* The encoder transform input vectors in high quality vectors (by context)
* Vectors: 
    * Q => What I am looking for
    * K => What I can offrer
    * V => What I actually offer
* Softmax transforms a Vector intro a Probability Distribution
* Multi-Head Attention => Stack Attention mechanism



# Code

In [22]:
import numpy as np
import math

In [23]:
L, d_k, d_v = 4, 8, 8

q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)


* (Aha) L => length of input sequence (e.g My name is..)


* (Huh) what's d_k and d_v

In [24]:
np.matmul(q, k.T)

array([[ 7.28849593, -7.23187435,  3.43074998, -0.92817642],
       [-5.87109603,  5.42403214, -3.432637  , -2.45887263],
       [-3.17989029,  1.62472654, -5.05057619,  4.4481915 ],
       [ 0.12658553, -3.83293587, -0.74856257, -6.26228145]])

* (huh) what matmul does?

In [25]:
q.var(), k.var(), np.matmul(q, k.T).var(), np.matmul(q, k.T).var() / math.sqrt(d_k)

(np.float64(1.7088368202136353),
 np.float64(0.9332648496490006),
 np.float64(18.435487410675933),
 np.float64(6.517929081284088))

## Masking

In [28]:
mask = np.tril(np.ones((L, L)))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

* (Huh) what's a triagular matrix?

In [32]:
mask[mask == 0] = -np.inf
mask[mask == 1] = 0

In [33]:
scaled = np.matmul(q, k.T).var() / math.sqrt(d_k)
scaled + mask

array([[6.51792908,       -inf,       -inf,       -inf],
       [6.51792908, 6.51792908,       -inf,       -inf],
       [6.51792908, 6.51792908, 6.51792908,       -inf],
       [6.51792908, 6.51792908, 6.51792908, 6.51792908]])

* (Aha) -inf means "no context for you buddy"

In [39]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

x = np.random.randn(10, 2)

print("softmax x")
print(f"{softmax(x)}")

print("\n")
print("x")
print(f"{x}")

softmax x
[[0.31975761 0.68024239]
 [0.22497968 0.77502032]
 [0.07906715 0.92093285]
 [0.35108277 0.64891723]
 [0.51068233 0.48931767]
 [0.22200682 0.77799318]
 [0.80085823 0.19914177]
 [0.48678087 0.51321913]
 [0.17752741 0.82247259]
 [0.30979596 0.69020404]]


x
[[ 0.4259755   1.18086143]
 [-1.47609278 -0.23921365]
 [-1.06771174  1.38737794]
 [-0.86724152 -0.25295836]
 [ 0.04656291  0.00382708]
 [-0.65384737  0.60016231]
 [-0.16581249 -1.55747948]
 [ 0.11699794  0.16988679]
 [ 0.17420282  1.70739294]
 [-0.82380603 -0.02273263]]


* (aha) In the softmax version, each rows add to one (because it's a probability)

In [41]:
attention = softmax(scaled + mask)
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.5       , 0.5       , 0.        , 0.        ],
       [0.33333333, 0.33333333, 0.33333333, 0.        ],
       [0.25      , 0.25      , 0.25      , 0.25      ]])

In [42]:
new_v = np.matmul(attention, v)
new_v

array([[-2.16871155, -0.78041098,  0.25572825,  0.73675932, -0.3480004 ,
        -0.56011948, -0.53965195, -0.67495813],
       [-0.62182327, -0.16760881,  0.06496911, -0.13564645, -0.79712902,
         0.50635758,  0.0850759 , -0.65417167],
       [-0.30656334,  0.23549939, -0.22881219,  0.09043942, -0.5740131 ,
         0.48148352, -0.27546704, -0.31740677],
       [-0.14594315,  0.16658348, -0.02133608, -0.26571102,  0.00498718,
         0.371516  ,  0.24024852, -0.57709043]])

In [43]:
v

array([[-2.16871155, -0.78041098,  0.25572825,  0.73675932, -0.3480004 ,
        -0.56011948, -0.53965195, -0.67495813],
       [ 0.92506501,  0.44519336, -0.12579003, -1.00805221, -1.24625765,
         1.57283464,  0.70980376, -0.63338521],
       [ 0.32395651,  1.04171578, -0.81637479,  0.54261115, -0.12778125,
         0.43173541, -0.99655292,  0.35612303],
       [ 0.33591742, -0.04016425,  0.60109226, -1.33416232,  1.74198804,
         0.04161343,  1.78739519, -1.35614143]])