# Credits
https://github.com/ajhalthor/Transformer-Neural-Network/blob/main/Self_Attention_for_Transformer_Neural_Networks.ipynb
https://www.youtube.com/watch?v=rPFkX5fJdRY&t=667s

## Self-attetion numerical ex

In [14]:
import numpy as np
import math

In [19]:
# let's define our inputs...for this example, our sentence will have 4 words and your key and value vector will have size 8
L, dk, dv = 4, 8, 8
np.random.seed(42)

q = np.random.randn(L,dk)
k = np.random.randn(L,dk)
v = np.random.randn(L,dv)

In [20]:
print(f'Q\n {q}')
print(f'K\n {k}')
print(f'V\n {v}')

Q
 [[ 0.49671415 -0.1382643   0.64768854  1.52302986 -0.23415337 -0.23413696
   1.57921282  0.76743473]
 [-0.46947439  0.54256004 -0.46341769 -0.46572975  0.24196227 -1.91328024
  -1.72491783 -0.56228753]
 [-1.01283112  0.31424733 -0.90802408 -1.4123037   1.46564877 -0.2257763
   0.0675282  -1.42474819]
 [-0.54438272  0.11092259 -1.15099358  0.37569802 -0.60063869 -0.29169375
  -0.60170661  1.85227818]]
K
 [[-0.01349722 -1.05771093  0.82254491 -1.22084365  0.2088636  -1.95967012
  -1.32818605  0.19686124]
 [ 0.73846658  0.17136828 -0.11564828 -0.3011037  -1.47852199 -0.71984421
  -0.46063877  1.05712223]
 [ 0.34361829 -1.76304016  0.32408397 -0.38508228 -0.676922    0.61167629
   1.03099952  0.93128012]
 [-0.83921752 -0.30921238  0.33126343  0.97554513 -0.47917424 -0.18565898
  -1.10633497 -1.19620662]]
V
 [[ 0.81252582  1.35624003 -0.07201012  1.0035329   0.36163603 -0.64511975
   0.36139561  1.53803657]
 [-0.03582604  1.56464366 -2.6197451   0.8219025   0.08704707 -0.29900735
   0.09

### Self attetion formula
1. First, we compute the matrix multiplication between Q and K. $Q.K^T$
2. Then we scale this matrix by dK to reduce the variance. $\frac{Q.K^T}{\sqrt{d_k}}$
    1. In the decoder, we need to mask the values
3. Then we take the softmax of these weight values to get probabilities. $Softmax(\frac{Q.K^T}{\sqrt{d_k}})$
4. Finally, we multiply this probability matrix by the V matrix. $Softmax(\frac{Q.K^T}{\sqrt{d_k}}).V$

In [21]:
#first step
np.matmul(q, k.T)

array([[-2.72357421,  0.40818741,  2.39601116, -1.18323729],
       [ 5.60012069,  1.1597874 , -4.7248515 ,  2.43859568],
       [ 1.03699903, -3.70553788, -3.0399309 ,  0.04345596],
       [ 0.09460324,  2.97027193,  0.43247995, -0.80026704]])

In [23]:
#second step
#we need to divide by dk to reduce its variance
scaled  = np.matmul(q, k.T) / np.sqrt(dk)

# let's compare the variance before and after the scaling
print(f'before scaling {np.matmul(q, k.T).var()}')
print(f'after scaling {scaled.var()}')

before scaling 6.8375652409831655
after scaling 0.8546956551228956


In [27]:
#third step
#we compute the softmax of the scaled matrix so all the columns sum to one
def softmax(x):
    return (np.exp(x).T/np.sum(np.exp(x), axis=1)).T

attetion = softmax(scaled)
print(f'Attetion vector\n{attetion}')

Attetion vector
[[0.08431243 0.25513027 0.51521078 0.14534652]
 [0.64059204 0.1332861  0.01664257 0.2094793 ]
 [0.47006414 0.08789379 0.11121405 0.33082801]
 [0.17794451 0.49185018 0.20052305 0.12968226]]


In [30]:
#Final step
#we multiply the attention vector for the values so we get the new values
new_v = np.matmul(attetion, v)
print(f'New values\n {new_v}')

New values
 [[-0.1308104   0.77212573  0.10108921  0.16807328 -0.46588684 -0.43681263
   0.46851458 -0.42075407]
 [ 0.40109276  1.19080398 -0.35037302  0.94668908  0.08274232 -0.53010106
   0.17683369  0.41923385]
 [ 0.17910025  0.98456145 -0.06763014  0.80678092 -0.14453166 -0.49373081
   0.15002954  0.10067088]
 [ 0.01421368  1.14907671 -0.99239485  0.60451701 -0.14600018 -0.40496816
   0.24215067 -0.82777073]]


### For the decoder we need to apply a mask functions before compute the probabilite vector

In [36]:
#for the decoder the steps are the following
#compute mask
mask = np.tril(np.ones((L,L)))
mask[mask == 0] = -np.inf
mask[mask == 1] = 0
masked_scaled = scaled + mask
masked_attetion = softmax(masked_scaled)
masked_new_v = np.matmul(masked_attetion, v)
print(f'Masked new values\n {masked_new_v}')

Masked new values
 [[ 0.81252582  1.35624003 -0.07201012  1.0035329   0.36163603 -0.64511975
   0.36139561  1.53803657]
 [ 0.66641301  1.39213367 -0.51081002  0.97225045  0.31434319 -0.58550834
   0.31495603  0.93081668]
 [ 0.52954961  1.21756173 -0.14905901  0.7267579   0.1310981  -0.57583252
   0.41805381  0.87397953]
 [ 0.01421368  1.14907671 -0.99239485  0.60451701 -0.14600018 -0.40496816
   0.24215067 -0.82777073]]
