## Self attention with Numpy

Self-attention is the main building block behind the transformer
architecture. A self-attention block takes in a sequence of tokens
and calculates how much "attention" each token should give
every other token.

In [3]:
import numpy as np

In [2]:
class SelfAttention:
    def __init__(self, 
                Wq, #The query matrix
                Wk, #The key matrix
                Wv #The value matrix
                ):
        self.Wq = Wq
        self.Wk = Wk
        self.Wv = Wv
        
    def forward(self, tokens):
        '''
        TODO
        queries = np.matmul(self.Wq, tokens.T)
        keys = np.matmul(self.Wk, tokens.T)
        values = np.matmul(self.Wv, tokens.T)
        '''

### A toy example
Say we have the words 'It's', 'snowing', and 'outside', given
by the embeddings

In [26]:
its = np.array([0.3, 0.8, -0.3])
snowing = np.array([-0.9, -0.1, 0])
outside = np.array([1, -0.2, 0.3])

combined = np.stack((its, snowing, outside), axis=0)
print(combined)

[[ 0.3  0.8 -0.3]
 [-0.9 -0.1  0. ]
 [ 1.  -0.2  0.3]]


(these are not actual embeddings). Let us pass this list of tokens
through a self-attention block. A self-attention block is
parameterized by three learnable matrices: 
- a *query* matrix
- a *key* matrix
- a *value* matrix

They must be shaped such that taking the matrix product of an
embedding vector with any of the above matrices is valid. 
Since our embeddings vectors are 1x3, we can define the matrices to be, for example:

In [15]:
Wq = np.array([[0.27, 0.14],
               [0.24, 0.09],
               [0.56, 0.84]])

Wk = np.array([[0.1 , 0.69],
               [0.38, 0.63],
               [0.91, 0.01]])

Wv = np.array([[0.81, 0.8],
               [0.67, 0.98],
               [0.04, 0.3]])

We first calculate the queries and keys:

In [37]:
queries = np.matmul(combined, Wq)
keys = np.matmul(combined, Wk)
print("queries: \n", queries, '\n')
print("keys: \n", keys)

queries: 
 [[ 0.105 -0.138]
 [-0.267 -0.135]
 [ 0.39   0.374]] 

keys: 
 [[ 0.061  0.708]
 [-0.128 -0.684]
 [ 0.297  0.567]]


In [31]:
np.matmul(queries, keys.T)

array([[-0.091299,  0.080952, -0.047061],
       [-0.111867,  0.126516, -0.155844],
       [ 0.288582, -0.305736,  0.327888]])

In [None]:
## TODO

In [32]:
queries

array([[ 0.105, -0.138],
       [-0.267, -0.135],
       [ 0.39 ,  0.374]])

In [33]:
keys

array([[ 0.061,  0.708],
       [-0.128, -0.684],
       [ 0.297,  0.567]])

Source(s):
- https://jalammar.github.io/illustrated-transformer/