# Neural Self-Attention

<a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/"><img alt="Creative Commons License" align="left" src="https://i.creativecommons.org/l/by-nc-sa/4.0/80x15.png" /></a>&nbsp;| Dennis G. Wilson<br> https://supaerodatascience.github.io/deep-learning/


Based on [medium article](https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a) by [Raimi Karim](https://towardsdatascience.com/@remykarem)

Colab version by [Manuel Romero](https://twitter.com/mrm8488)

Additional illustrations by [Jay Alammar](https://jalammar.github.io/illustrated-transformer/)

## Background

Attention modules were first introduced in neural networks as a way to index the latent space of a recurrent network, although similar ideas were used in Neural Turing Machines to index an external memory. The idea in both is the same: a mask which specifies which part of a sequence to pay **attention** to. 

+ Alex Graves, Greg Wayne, and Ivo Danihelka. "Neural turing machines." CoRR, abs/1410.5401, 2014. [pdf](https://arxiv.org/pdf/1410.5401.pdf)

+ Bahdanau, Dzmitry, Kyung Hyun Cho, and Yoshua Bengio. "Neural machine translation by jointly learning to align and translate." 3rd International Conference on Learning Representations, ICLR 2015. [pdf](https://arxiv.org/pdf/1409.0473.pdf)


![Self-attention example](https://jalammar.github.io/images/t/transformer_self-attention_visualization.png)

While originally used as part of recurrent or convolutional networks, Transformer networks rely entirely on attention modules to analyze sequences. Specifically, they use self-attention, where an entire sequence is presented for analysis. A self-attention module takes in n inputs, and returns n outputs. What happens in this module? Put simply, the self-attention mechanism allows the inputs to interact with each other (“self”) and find out who they should pay more attention to (“attention”). The outputs are aggregates of these interactions and attention scores.

+ Vaswani, Ashish, et al. "Attention is all you need." Advances in Neural Information Processing Systems 30, 2017. [pdf](https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)

A self-attention head has three weight matrices (represented as fully-connected NN layers) which compute a **query**, **key**, and **value** per input. The key and values can be though of as a dictionary the network has learned - for each key, it will save associations with the list of possible vaues. In a language model, this is often associated words, either based on order or context. The query is similar to the key, but it is for one input at a time, while the key is calculated over all inputs. The query will be used to calculate a score based on each key to find keys which have similar values. The score then weights the value vectors to output relevant values corresponding to each input.

![Query, key and value](https://jalammar.github.io/images/t/transformer_self_attention_vectors.png)

The following animation illustrates the whole process of the self-attention head. A sequence of input tokens are used to compute keys and values, which are then compared with each query, calculated per input. For each query, the scores of all keys are calculated and then used to weight the sum of values, resulting in an output value vector per input.

![texto alternativo](https://miro.medium.com/max/1973/1*_92bnsMJy8Bl539G4v93yg.gif)

Following, we are going to explain and implement:
1. Prepare inputs
2. Initialise weights
3. Derive key, query and value
4. Calculate attention scores for Input 1
5. Calculate softmax
6. Multiply scores with values
7. Sum weighted values to get Output 1
8. Repeat steps 4–7 for Input 2 & Input 3

In [None]:
import torch

### Step 1: Prepare inputs

In text-based models, an encoding matrix is used to turn tokens into a specific encoding, often floating-point based with high dimensionality (768 for example.) For this tutorial, for the sake of simplicity, we use 3 inputs, each with dimension 4, with an integer encoding.

![texto alternativo](https://miro.medium.com/max/1973/1*hmvdDXrxhJsGhOQClQdkBA.png)


In [None]:
x = [
  [1, 0, 1, 0], # Input 1
  [0, 2, 0, 2], # Input 2
  [1, 1, 1, 1]  # Input 3
 ]
x = torch.tensor(x, dtype=torch.float32)
x

### Step 2: Initialise weights

Every input must have three representations (see diagram below). These representations are called **key** (orange), **query** (red), and **value** (purple). For this example, let’s take that we want these representations to have a dimension of 3. Because every input has a dimension of 4, this means each set of the weights must have a shape of 4×3.

![texto del enlace](https://miro.medium.com/max/1975/1*VPvXYMGjv0kRuoYqgFvCag.gif)

In [None]:
w_key = [
  [0, 0, 1],
  [1, 1, 0],
  [0, 1, 0],
  [1, 1, 0]
]
w_query = [
  [1, 0, 1],
  [1, 0, 0],
  [0, 0, 1],
  [0, 1, 1]
]
w_value = [
  [0, 2, 0],
  [0, 3, 0],
  [1, 0, 3],
  [1, 1, 0]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)

print("Weights for key: \n", w_key)
print("Weights for query: \n", w_query)
print("Weights for value: \n", w_value)

Note: *In a neural network setting, these weights are usually small numbers, initialised randomly using an appropriate random distribution like Gaussian, Xavier and Kaiming distributions.*

### Step 3: Derive key, query and value

Now that we have the three sets of weights, let’s actually obtain the **key**, **query** and **value** representations for every input.

Obtaining the keys:
```
               [0, 0, 1]
[1, 0, 1, 0]   [1, 1, 0]   [0, 1, 1]
[0, 2, 0, 2] x [0, 1, 0] = [4, 4, 0]
[1, 1, 1, 1]   [1, 1, 0]   [2, 3, 1]
```
![texto alternativo](https://miro.medium.com/max/1975/1*dr6NIaTfTxEWzxB2rc0JWg.gif)

Obtaining the values:
```
               [0, 2, 0]
[1, 0, 1, 0]   [0, 3, 0]   [1, 2, 3] 
[0, 2, 0, 2] x [1, 0, 3] = [2, 8, 0]
[1, 1, 1, 1]   [1, 1, 0]   [2, 6, 3]
```
![texto alternativo](https://miro.medium.com/max/1975/1*5kqW7yEwvcC0tjDOW3Ia-A.gif)


Obtaining the querys:
```
               [1, 0, 1]
[1, 0, 1, 0]   [1, 0, 0]   [1, 0, 2]
[0, 2, 0, 2] x [0, 0, 1] = [2, 2, 2]
[1, 1, 1, 1]   [0, 1, 1]   [2, 1, 3]
```
![texto alternativo](https://miro.medium.com/max/1975/1*wO_UqfkWkv3WmGQVHvrMJw.gif)

Notes: *Notes
In practice, a bias vector may be added to the product of matrix multiplication.*

In [None]:
keys = torch.matmul(x, w_key)
querys = torch.matmul(x, w_query)
values = torch.matmul(x, w_value)

print("Keys: \n", keys)
# tensor([[0., 1., 1.],
#         [4., 4., 0.],
#         [2., 3., 1.]])

print("Querys: \n", querys)
# tensor([[1., 0., 2.],
#         [2., 2., 2.],
#         [2., 1., 3.]])
print("Values: \n", values)
# tensor([[1., 2., 3.],
#         [2., 8., 0.],
#         [2., 6., 3.]])

### Step 4: Calculate attention scores
![texto alternativo](https://miro.medium.com/max/1973/1*u27nhUppoWYIGkRDmYFN2A.gif)

To obtain **attention scores**, we start off with taking a dot product between Input 1’s **query** (red) with **all keys** (orange), including itself. Since there are 3 key representations (because we have 3 inputs), we obtain 3 attention scores (blue).

```
            [0, 4, 2]
[1, 0, 2] x [1, 4, 3] = [2, 4, 4]
            [1, 0, 1]
```
Notice that we only use the query from Input 1. Later we’ll work on repeating this same step for the other querys.

Note: *The above operation is known as dot product attention, one of the several score functions. Other score functions include scaled dot product and additive/concat.*            

In [None]:
attn_scores = torch.matmul(querys, keys.T)
print(attn_scores)

# tensor([[ 2.,  4.,  4.],  # attention scores from Query 1
#         [ 4., 16., 12.],  # attention scores from Query 2
#         [ 4., 12., 10.]]) # attention scores from Query 3

### Step 5: Calculate softmax
![texto alternativo](https://miro.medium.com/max/1973/1*jf__2D8RNCzefwS0TP1Kyg.gif)

Take the **softmax** across these **attention scores** (blue).
```
softmax([2, 4, 4]) = [0.0, 0.5, 0.5]
```

In [None]:
from torch.nn.functional import softmax

attn_scores_softmax = softmax(attn_scores, dim=-1)
print(attn_scores_softmax)
# tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],
#         [6.0337e-06, 9.8201e-01, 1.7986e-02],
#         [2.9539e-04, 8.8054e-01, 1.1917e-01]])

# For readability, approximate the above as follows
attn_scores_softmax = [
  [0.0, 0.5, 0.5],
  [0.0, 1.0, 0.0],
  [0.0, 0.9, 0.1]
]
attn_scores_softmax = torch.tensor(attn_scores_softmax)
print(attn_scores_softmax)

### Step 6: Multiply scores with values
![texto alternativo](https://miro.medium.com/max/1973/1*9cTaJGgXPbiJ4AOCc6QHyA.gif)

The softmaxed attention scores for each input (blue) is multiplied with its corresponding **value** (purple). This results in 3 alignment vectors (yellow). In this tutorial, we’ll refer to them as **weighted values**.
```
1: 0.0 * [1, 2, 3] = [0.0, 0.0, 0.0]
2: 0.5 * [2, 8, 0] = [1.0, 4.0, 0.0]
3: 0.5 * [2, 6, 3] = [1.0, 3.0, 1.5]
``` 

In [None]:
weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]
print(weighted_values)

### Step 7: Sum weighted values
![texto alternativo](https://miro.medium.com/max/1973/1*1je5TwhVAwwnIeDFvww3ew.gif)

Take all the **weighted values** (yellow) and sum them element-wise:

```
  [0.0, 0.0, 0.0]
+ [1.0, 4.0, 0.0]
+ [1.0, 3.0, 1.5]
-----------------
= [2.0, 7.0, 1.5]
```

The resulting vector ```[2.0, 7.0, 1.5]``` (dark green) **is Output 1**, which is based on the **query representation from Input 1** interacting with all other keys, including itself.


### Step 8: Repeat for Input 2 & Input 3
![texto alternativo](https://miro.medium.com/max/1973/1*G8thyDVqeD8WHim_QzjvFg.gif)

Note: *The dimension of **query** and **key** must always be the same because of the dot product score function. However, the dimension of **value** may be different from **query** and **key**. The resulting output will consequently follow the dimension of **value**.*

In [None]:
outputs = weighted_values.sum(dim=0)
print(outputs)

# tensor([[2.0000, 7.0000, 1.5000],  # Output 1
#         [2.0000, 8.0000, 0.0000],  # Output 2
#         [2.0000, 7.8000, 0.3000]]) # Output 3

The following image summarizes the different steps. One step used in practice is to normalize the score, before the softmax, by the square root of the dimension of the key vectors. This is done only to stabilize gradients, as it leads to less drastic differences in the softmax.

![Self-attention steps](https://jalammar.github.io/images/t/self-attention-output.png)

## Exercise 1

Implement a division by the square root of the key vector dimensionality. How does this change the softmax values?

## Exercise 2

Make the first and the third input vectors identical. Then, implement a simple positional encoding scheme to add to the inputs. Verify that the outputs are different when the first and third input vectors are the same.

## Exercise 3

Using identical vectors for the first and third inputs, change any of the weight vectors to try to get the first and third outputs to be as close as possible.