## Self Attention - Pytorch

- A self-attention module takes in n inputs, and returns n outputs. 
- In layman’s terms, the self-attention mechanism allows the inputs to interact with each other and find out who they should pay more attention to. 
- The outputs are aggregates of these interactions and attention scores.

## Why we need self attention?

![alt text](https://miro.medium.com/max/1200/0*hJfCjjx0r0slacNm.png)

- Attention Mechanism are common in NLP tasks based on neural networks such as RNN/CNN.

- Self-Attention is capable of learning the distant dependencies within the phrase.

-  Self-Attention is unique as it ignores the distance between words and directly computes dependency relationships, making it capable of learning the internal structure of a sentence and more merely calculating in parallel.

- Below are the mathematical steps to calculate self attention.

![alt text](http://jalammar.github.io/images/t/transformer_self_attention_vectors.png)

In [0]:
import torch

## 1. Prepare inputs - All inputs are vectorised:

In [0]:


input = [
  [1, 0, 1, 0], # Input 1
  [0, 2, 0, 2], # Input 2
  [1, 1, 1, 1],  # Input 3
  [3, 1, 1, 1]  # Input 4
 ]
input = torch.tensor(input, dtype=torch.float32)
input

tensor([[1., 0., 1., 0.],
        [0., 2., 0., 2.],
        [1., 1., 1., 1.],
        [3., 1., 1., 1.]])

## 2. Initialize weights: 
- Every input must have three representations. 
- These representations are called key, query, and value.
- Lets say each of this representation to have a size of 5.
- Because every input has a dimension of 4, this means each set of the weights must have a shape of 4×5.

In [0]:
w_key = [
  [0, 0, 1, 1, 1],
  [1, 1, 0, 0, 1],
  [0, 1, 0, 1, 1],
  [1, 1, 0, 0, 0]
]
w_query = [
  [1, 0, 1, 0, 1],
  [1, 0, 0, 1, 1],
  [0, 0, 1, 0, 0],
  [0, 1, 1, 1, 0]
]
w_value = [
  [0, 2, 0, 3, 1],
  [0, 3, 0, 1, 3],
  [1, 0, 3, 1, 1],
  [1, 1, 0, 2, 2]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)

print("Weights for key: \n", w_key)
print("Weights for query: \n", w_query)
print("Weights for value: \n", w_value)

Weights for key: 
 tensor([[0., 0., 1., 1., 1.],
        [1., 1., 0., 0., 1.],
        [0., 1., 0., 1., 1.],
        [1., 1., 0., 0., 0.]])
Weights for query: 
 tensor([[1., 0., 1., 0., 1.],
        [1., 0., 0., 1., 1.],
        [0., 0., 1., 0., 0.],
        [0., 1., 1., 1., 0.]])
Weights for value: 
 tensor([[0., 2., 0., 3., 1.],
        [0., 3., 0., 1., 3.],
        [1., 0., 3., 1., 1.],
        [1., 1., 0., 2., 2.]])


## 3.Derive keys, querys and values

In [0]:
queries = torch.matmul(input, w_query)
queries

tensor([[1., 0., 2., 0., 1.],
        [2., 2., 2., 4., 2.],
        [2., 1., 3., 2., 2.],
        [4., 1., 5., 2., 4.]])

In [0]:
# The above could be written also as:
input @ w_query

tensor([[1., 0., 2., 0., 1.],
        [2., 2., 2., 4., 2.],
        [2., 1., 3., 2., 2.],
        [4., 1., 5., 2., 4.]])

In [0]:
values = torch.matmul(input, w_value)
values

tensor([[ 1.,  2.,  3.,  4.,  2.],
        [ 2.,  8.,  0.,  6., 10.],
        [ 2.,  6.,  3.,  7.,  7.],
        [ 2., 10.,  3., 13.,  9.]])

In [0]:
input @ w_value

tensor([[ 1.,  2.,  3.,  4.,  2.],
        [ 2.,  8.,  0.,  6., 10.],
        [ 2.,  6.,  3.,  7.,  7.],
        [ 2., 10.,  3., 13.,  9.]])

In [0]:
keys = torch.matmul(input, w_key)
keys

tensor([[0., 1., 1., 2., 2.],
        [4., 4., 0., 0., 2.],
        [2., 3., 1., 2., 3.],
        [2., 3., 3., 4., 5.]])

In [0]:
input @ w_key

tensor([[0., 1., 1., 2., 2.],
        [4., 4., 0., 0., 2.],
        [2., 3., 1., 2., 3.],
        [2., 3., 3., 4., 5.]])

## 4. Calculate Attention Scores

- To obtain attention scores, we start off with taking a dot product between Input 1’s query with all keys , including itself. Since there are 4 key representations , we obtain 4 attention scores.

In [0]:
attn_sc = torch.matmul(queries, keys.T)
attn_sc

# attention scores from Query 1
# attention scores from Query 2
# attention scores from Query 3
# attention scores from Query 4

tensor([[ 4.,  6.,  7., 13.],
        [16., 20., 26., 42.],
        [12., 16., 20., 34.],
        [18., 28., 32., 54.]])

## 5. Calculate Softmax


In [0]:
from torch.nn.functional import softmax

attn_scores_softmax = softmax(attn_sc, dim=-1)
print(attn_scores_softmax)

tensor([[1.2298e-04, 9.0869e-04, 2.4701e-03, 9.9650e-01],
        [5.1091e-12, 2.7895e-10, 1.1254e-07, 1.0000e+00],
        [2.7895e-10, 1.5230e-08, 8.3153e-07, 1.0000e+00],
        [2.3195e-16, 5.1091e-12, 2.7895e-10, 1.0000e+00]])


## 6. Multiply attention scores with values

- The softmaxed attention scores for each input is multiplied with its corresponding value. This results in 4 alignment vectors. In this tutorial, we’ll refer to them as weighted values.

In [0]:
weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]
print(weighted_values)

tensor([[[1.2298e-04, 2.4596e-04, 3.6893e-04, 4.9191e-04, 2.4596e-04],
         [5.1091e-12, 1.0218e-11, 1.5327e-11, 2.0436e-11, 1.0218e-11],
         [2.7895e-10, 5.5789e-10, 8.3684e-10, 1.1158e-09, 5.5789e-10],
         [2.3195e-16, 4.6390e-16, 6.9586e-16, 9.2781e-16, 4.6390e-16]],

        [[1.8174e-03, 7.2695e-03, 0.0000e+00, 5.4521e-03, 9.0869e-03],
         [5.5789e-10, 2.2316e-09, 0.0000e+00, 1.6737e-09, 2.7895e-09],
         [3.0460e-08, 1.2184e-07, 0.0000e+00, 9.1380e-08, 1.5230e-07],
         [1.0218e-11, 4.0873e-11, 0.0000e+00, 3.0655e-11, 5.1091e-11]],

        [[4.9401e-03, 1.4820e-02, 7.4102e-03, 1.7291e-02, 1.7291e-02],
         [2.2507e-07, 6.7521e-07, 3.3761e-07, 7.8775e-07, 7.8775e-07],
         [1.6631e-06, 4.9892e-06, 2.4946e-06, 5.8207e-06, 5.8207e-06],
         [5.5789e-10, 1.6737e-09, 8.3684e-10, 1.9526e-09, 1.9526e-09]],

        [[1.9930e+00, 9.9650e+00, 2.9895e+00, 1.2954e+01, 8.9685e+00],
         [2.0000e+00, 1.0000e+01, 3.0000e+00, 1.3000e+01, 9.0000e+00],


## 7. Get Sum weighted values.

- Get the attention scores for each input vector.

In [0]:
outputs = weighted_values.sum(dim=0)
print(outputs)


tensor([[ 1.9999,  9.9873,  2.9973, 12.9777,  8.9951],
        [ 2.0000, 10.0000,  3.0000, 13.0000,  9.0000],
        [ 2.0000, 10.0000,  3.0000, 13.0000,  9.0000],
        [ 2.0000, 10.0000,  3.0000, 13.0000,  9.0000]])


## Conclusion

- These attention values are then fed into feedforward neural networks that go into neural network architectures like transformers.
![alt text](http://jalammar.github.io/images/t/encoder_with_tensors.png)

## References

- https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a
- http://jalammar.github.io/illustrated-transformer/
- https://medium.com/@Alibaba_Cloud/self-attention-mechanisms-in-natural-language-processing-9f28315ff905



- https://bastings.github.io/annotated_encoder_decoder/
- https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
- https://towardsdatascience.com/attention-seq2seq-with-pytorch-learning-to-invert-a-sequence-34faf4133e53
- https://talbaumel.github.io/blog/attention/