## Self Attention
`Self Attention` is term used in Transformer based Architectures like `BERT`, `RoBERTa`, `Swin Transformer`, `ViT` and many more....!!

These Architecutures trust entirely on `Self Attention` mechanisms to draw global dependencies between `inputs` and `outputs`. 
The `self-attention` mechanism allows the inputs to interact with each other (“self”) and find out who they should pay more attention to (“attention”). The outputs are aggregates of these interactions and attention scores.

To Illustrate this Concept, we follow some steps
### 1. Prepare Inputs
We start with 3 inputs with dimension 4.
### 2. Initialize Weights
Every input must have 3 representations. These are `key`, `query` and `value`. In this, lets take these representations of dimension 3. So now the Weights must have a shape `[4, 3]`

To obtain these, every input is multiplied with a set of weights for `keys`, a set of weights for `query` and a set of weights of `value`. Lets take the weights to be

Weights for key:
```
[[0, 0, 1],
 [1, 1, 0],
 [0, 1, 0],
 [1, 1, 0]]
```

Weights for query:
```
[[1, 0, 1],
 [1, 0, 0],
 [0, 0, 1],
 [0, 1, 1]]
```

Weights for value
```
[[0, 2, 0],
 [0, 3, 0],
 [1, 0, 3],
 [1, 1, 0]]
```
These weights are initialised using a random distribution like `Gaussian`, `Xavier`, `Kaiming` distributions. 

### 3. Derive key, queue and value
`Key` representation for Input 1:
```
               [0, 0, 1]
[1, 0, 1, 0] x [1, 1, 0] = [0, 1, 1]
               [0, 1, 0]
               [1, 1, 0]
```

Use the same set of weights to get the `key` representation for Input 2:
```
               [0, 0, 1]
[0, 2, 0, 2] x [1, 1, 0] = [4, 4, 0]
               [0, 1, 0]
               [1, 1, 0]
```

Same for Input 3:
```
               [0, 0, 1]
[1, 1, 1, 1] x [1, 1, 0] = [2, 3, 1]
               [0, 1, 0]
               [1, 1, 0]
```

A Faster way is to `vectorize` the above operations
```
               [0, 0, 1]
[1, 0, 1, 0]   [1, 1, 0]   [0, 1, 1]
[0, 2, 0, 2] x [0, 1, 0] = [4, 4, 0]
[1, 1, 1, 1]   [1, 1, 0]   [2, 3, 1]
```

So Now, Lets do the same to obtain `value` representations for every input:
```
               [0, 2, 0]
[1, 0, 1, 0]   [0, 3, 0]   [1, 2, 3] 
[0, 2, 0, 2] x [1, 0, 3] = [2, 8, 0]
[1, 1, 1, 1]   [1, 1, 0]   [2, 6, 3]
```

and `query` representations
```
               [1, 0, 1]
[1, 0, 1, 0]   [1, 0, 0]   [1, 0, 2]
[0, 2, 0, 2] x [0, 0, 1] = [2, 2, 2]
[1, 1, 1, 1]   [0, 1, 1]   [2, 1, 3]
```

### 4. Calculate Attention Scores for Input 1

To obtain `attention` scores, we take the dot product (`@`) between Input 1's `query` with all the `keys`. Since there are 3 `key` representations we get 3 Attention Scores.
```
            [0, 4, 2]
[1, 0, 2] x [1, 4, 3] = [2, 4, 4]
            [1, 0, 1]
```
`Dot` Product (@) is one of the `score functions`. There are other functions like `scaled dot product` and `additive/concat`.

### 5. Calculate Softmax
Take the `softmax` across these Attention Scores.
```python
softmax([2, 4, 4]) = [0.0, 0.5, 0.5]
```
### 6. Multiply Scores with values
The softmaxed attention scores for each input is multiplied by its corresponding `value`. This results in 3 `alignment vectors` also known as <b>`weighted values`</b>

Now, take all the `weighted values` and sum them elementwise.
```
  [0.0, 0.0, 0.0]
+ [1.0, 4.0, 0.0]
+ [1.0, 3.0, 1.5]
-----------------
= [2.0, 7.0, 1.5]
```

The resulting vector is `Output 1`, which is based on the `query` representation from Input 1 interacting with all other keys.

### 8. Repeat for Input 2 and Input 3

NOTE: The dimension of `query` and `key` must always be the same because of the `dot` product score function. However, the dimension of value may be different from `query` and `key`. The resulting output will consequently follow the dimension of `value`.


### Code It...!!!
Step 1: Prepare Inputs

In [1]:
import torch

x = [
    [1, 0, 1, 0], # Input 1
    [0, 2, 0, 2], # Input 2
    [1, 1, 1, 1] # Input 3
]

x = torch.tensor(x, dtype=torch.float32)

Step 2: Initialize Weights

In [2]:
w_key = [
  [0, 0, 1],
  [1, 1, 0],
  [0, 1, 0],
  [1, 1, 0]
]
w_query = [
  [1, 0, 1],
  [1, 0, 0],
  [0, 0, 1],
  [0, 1, 1]
]
w_value = [
  [0, 2, 0],
  [0, 3, 0],
  [1, 0, 3],
  [1, 1, 0]
]

w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)

Step 3: Derive key, query and value

In [3]:
keys = x @ w_key
querys = x @ w_query
values = x @ w_value

print(keys)

tensor([[0., 1., 1.],
        [4., 4., 0.],
        [2., 3., 1.]])


In [4]:
print(querys)

tensor([[1., 0., 2.],
        [2., 2., 2.],
        [2., 1., 3.]])


In [5]:
print(values)

tensor([[1., 2., 3.],
        [2., 8., 0.],
        [2., 6., 3.]])


Step 4: Calculate Attention Scores

In [6]:
attn_scores = querys @ keys.T

print(attn_scores)

tensor([[ 2.,  4.,  4.],
        [ 4., 16., 12.],
        [ 4., 12., 10.]])


Step 5: Calculate Softmax

In [7]:
attn_scores_softmax = torch.nn.functional.softmax(attn_scores, dim=-1)

print(attn_scores_softmax)

tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],
        [6.0337e-06, 9.8201e-01, 1.7986e-02],
        [2.9539e-04, 8.8054e-01, 1.1917e-01]])


In [8]:
# Normalize / Approximate the Attention Scores
attn_scores_softmax = [
  [0.0, 0.5, 0.5],
  [0.0, 1.0, 0.0],
  [0.0, 0.9, 0.1]
]
attn_scores_softmax = torch.tensor(attn_scores_softmax)

In [9]:
attn_scores_softmax

tensor([[0.0000, 0.5000, 0.5000],
        [0.0000, 1.0000, 0.0000],
        [0.0000, 0.9000, 0.1000]])

Step 6: Multiply Attention Scores with Values

In [10]:
weighted_values = values[:, None] * attn_scores_softmax.T[:, :, None]

print(weighted_values)

tensor([[[0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]],

        [[1.0000, 4.0000, 0.0000],
         [2.0000, 8.0000, 0.0000],
         [1.8000, 7.2000, 0.0000]],

        [[1.0000, 3.0000, 1.5000],
         [0.0000, 0.0000, 0.0000],
         [0.2000, 0.6000, 0.3000]]])


Step 7: Sum Weighted Values

In [11]:
outputs = weighted_values.sum(dim=0)

print(outputs)

tensor([[2.0000, 7.0000, 1.5000],
        [2.0000, 8.0000, 0.0000],
        [2.0000, 7.8000, 0.3000]])
