# Why we need Attention mechanism

## The problem with modeling long sequences

Well, before the attention mechanism, architectures like RNN (encoder + decoder) were used in a sequence to sequence tasks like machine translation, however they had some noticeable problems.
The encoder part of the architecture takes the each part of the input, then processes it into a hidden state (memory cell), the decoder then takes the hidden state to produce the output (the hidden state here plays the role of the embeddings).

The issue here is that the decoder cannot access earlier hidden state during the decoding phase so it solely relies on the **current** hidden state.

And this was the reason the **Attetion mechanism** was introduced.

## Capturing data dependencies

One shortcoming for the RNN is that it must remember entire encoded text in a single hidden state before passing it to the decoder
*Bahdanau attention* was introduces that made the decoder selectively access different parts of the input sequence at each step 

The transformer architecture came in afterwards with a self-attention mechanism inspired by the *Bahdanau attention mechanism*
 - Self-attention is a mechanism that allows each position in the input sequence to attend to all positions in the same sequence when computing the representation of a sequence.


## Attending to different parts of the input with self-attention

in self-attention, the term “self” means that each element of a sequence (e.g., a word) pays attention to other elements within the same sequence to understand context and relationships. It learns how different parts of the input relate to each other, unlike traditional attention, which models relationships between two different sequences (e.g., input and output in seq2seq models).

### Simple self-attention mechanism without trainable weights

In self-attention, our goal is to calcualte a context vector z<sup>(i)</sup> for each element x<sup>(i)</sup> in the input sequence. A context vector can be interpreted as an enriched embedding vector. The importance or contribution of each input element for computing z<sup>2</sup> for example is determined by the attention weights α<sub>21</sub> to α<sub>2T</sub>. When computing z<sup>2</sup>, the attention weights are calculated with respect to input element x<sup>2</sup> and all other inputs.

<img src="self-attention.png" width="600">

This enhanced context vector, z<sup>2</sup>, is an embedding that contains information about x<sup>(2)</sup> and all other input elements x<sup>(1)</sup>to x<sup>(T)</sup>.

In self-attention, context vectors play a crucial role. Their purpose is to create enriched representations of each element in an input sequence (like a sentence) by incorporating
information from all other elements in the sequence

In [110]:
import torch

inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],  # Your (x^1)
        [0.55, 0.87, 0.66],  # journey (x^2)
        [0.57, 0.85, 0.64],  # starts (x^3)
        [0.22, 0.58, 0.33],  # with (x^4)
        [0.77, 0.25, 0.10],  # one (x^5)
        [0.05, 0.80, 0.55],
    ]  # step (x^6)
)

<img src="attention-weight.png" width="650">

We calculate the intermediate attention scores between the query token and each input token. We determine these scores by computing the dot product of the query, x<sup>2</sup> , with every other input token

In [111]:
query = inputs[1]

attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(query, x_i)

attn_scores_2

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

After computing the attention scores ω<sub>21</sub> to ω<sub>2T</sub> with respect to the input query x<sup>(2)</sup> , the next step is to obtain the attention weights α<sub>21</sub> to α<sub>2T</sub> by normalizing the attention scores
The goal is to obtain attention weights that sum up to 1

In [112]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()

print(f"Attention weights: {attn_weights_2_tmp}")
print(f"Sum: {attn_weights_2_tmp.sum()}")

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: 1.0000001192092896


In practice, it's more common and advisable to use the softmax function for normalization. This approach is better at managing extreme values and offers more favorable gradient properties during training, the softmax function ensures that the attention weights are always positive.
This makes the output interpretable as probabilities or relative importance

In [113]:
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)

print(f"Attention weights: {attn_weights_2_naive}")
print(f"Sum: {attn_weights_2_naive.sum()}")

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: 1.0


This softmax we implemented may encounter numerical instability problems, such as overflow or underflow when dealing with large or small inputs, in practice is advisable to use pytoch implementation

In [114]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

print(f"Attention weights: {attn_weights_2}")
print(f"Sum: {attn_weights_2.sum()}")

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: 1.0


In [115]:
query = inputs[1] # 2nd input token is the query
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i
print(context_vec_2)


tensor([0.4419, 0.6515, 0.5683])


### Computing attention weights for all input tokens

In [116]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [117]:
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


because python loops are generally slow, we need to use matrix multiplication to achieve the same results

In [118]:
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [119]:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [120]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [121]:
all_context_vector = attn_weights @ inputs
print(all_context_vector)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## Implementing self-attention with trainable weights

We will implement the self-attention mechanism step by step by introducing the three trainable weight matrices W<sub>q</sub>, W<sub>k</sub> and W<sub>v</sub>

<img src="qkv.png" width="650">

In [122]:
x_2 = inputs[1]           # Second input element
d_in = inputs.shape[1]    # Input embedding size
d_out = 2               # Output embedding size

# NOTE: in gpt-like models the input/output embedding sizes are usually the same, but for better illustration we chose different values

In [123]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

print("Query matrix weights: ", W_query)
print("Key matrix weights: ", W_query)
print("Value matrix weights: ", W_value)

Query matrix weights:  Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])
Key matrix weights:  Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])
Value matrix weights:  Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]])


In [124]:
x_2

tensor([0.5500, 0.8700, 0.6600])

In [125]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

print(query_2)

tensor([0.4306, 1.4551])


In [126]:
inputs, W_key

(tensor([[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]),
 Parameter containing:
 tensor([[0.1366, 0.1025],
         [0.1841, 0.7264],
         [0.3153, 0.6871]]))

In [127]:
keys = inputs @ W_key
values = inputs @ W_value

print("Keys shape: ", keys.shape)
print("Values shape: ", values.shape)

Keys shape:  torch.Size([6, 2])
Values shape:  torch.Size([6, 2])


The second step is now to compute the attention scores, as shown

<img src="attn_score_qkv.png" width="650">

In [128]:
attn_score_22 = keys[1].dot(query_2)
print(attn_score_22)

tensor(1.8524)


In [129]:
query_2

tensor([0.4306, 1.4551])

In [130]:
keys.T

tensor([[0.3669, 0.4433, 0.4361, 0.2408, 0.1827, 0.3275],
        [0.7646, 1.1419, 1.1156, 0.6706, 0.3292, 0.9642]])

In [131]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


Now we need to normalize the attention scores to obtain the attention weights

<img src="attn_score_norm_qkv.png" width="650">

we compute the attention weights by scaling the attention scores and using the softmax function we used earlier. The difference to earlier is that we now scale the attention scores by dividing them by the square root of the embedding dimension of the keys

In [132]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


The reason for the normalization by the embedding dimension size is to improve the training performance by avoiding small gradients. For instance, when scaling up the embedding dimension, which is typically greater than thousand for GPT-like LLMs, large dot products can result in very small gradients during backpropagation due to the softmax function applied to them. As dot products increase, the softmax function
behaves more like a step function, resulting in gradients nearing zero. These small gradients can drastically slow down learning or cause training to stagnate.

The scaling by the square root of the embedding dimension is the reason why this self-attention mechanism is also called scaled-dot product attention.


Now, the final step is to compute the context vectors

<img src="context_vec_qkv.png" width="650">

In [133]:
values

tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]])

In [134]:
attn_weights_2

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])

In [135]:
context_vec_2 = attn_weights_2 @ values
context_vec_2

tensor([0.3061, 0.8210])

### Implementing a compact self-attention Python class

In [136]:
import torch.nn as nn

class SelfAttentionV1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        queries = x @ self.W_query
        keys    = x @ self.W_key
        values  = x @ self.W_value

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values

        return context_vec

In [137]:
torch.manual_seed(123)

sa_v1 = SelfAttentionV1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


Self-attention mechanism summary

<img src="att_weight_kqv_summary.png" width="650" > 

We can improve the SelfAttentionV1 By using PyTorch's `nn.Linear` layers, which effectively perform matrix multiplication when the bias units are disabled. Additionally, a significant advantage of using nn.Linear instead of `nn.Parameter(torch.rand(...))` is that nn.Linear has an **optimized weight initialization scheme, contributing to more stable and effective model training**

In [138]:
import torch.nn as nn

class SelfAttentionV2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        queries = self.W_query(x)
        keys    = self.W_key(x)
        values  = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values

        return context_vec

In [139]:
torch.manual_seed(789)
sa_v2 = SelfAttentionV2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)
