# Write your first Attention mechanism

Before transformers, recurrent neural networks (RNNs) were considered the cutting edge in Natural Language Processing (NLP). An RNN is a type of neural network where outputs from previous steps are fed as inputs to the current step. This characteristic enables an RNN to retain information from previous steps, making them well-suited for sequential data like text. In the context of NLP, an RNN takes an input, such as a word or character, processes it through its network, and generates a vector known as the hidden state. If you are unfamiliar with RNNs, don't worry, you don't need to know the detailed workings of RNNs to follow this discussion. 

One area where RNNs played an important role was in the development of machine translation systems, where the model translates text from one language to another. However, the word sequence in one language might be different from another one due to the grammatical structures in the source and target language. To address this issue we can use an encoder-decoder architecture. The encoder's role is to convert input sequence information into a numerical representation, typically referred to as the final hidden state. The encoder updates its hidden state at each step, trying to capture the entire meaning of the input sentence in the final hidden state. The decoder then takes this final hidden state to start generating the translated sentence, one word at a time. \
However, a significant challenge of this architecture lies in the fact that the final hidden state of the encoder creates an information bottleneck. it has to represent the meaning of the whole input sequence because this is all the decoder has access to when generating the output. This is especially challenging for long sequences, where information at the start of the sequence might be lost in the process of compressing everything to a single, fixed representation.

To address this challenge, an "attention mechanism" is introduced, permitting the decoder to selectively access different hidden states of the encoder. But, why selective? Using all the states at the same time would create a huge input for the decoder, the attention mechanism lets the decoder assign a different amount of weight, or "attention" to each of the encoder states at every decoding timestep. \
Researchers, as detailed in the paper "Attention is all you need," have demonstrated that RNN architectures are not required for NLP applications such as machine translation and proposed a transformer architecture with a “self-attention mechanism”.  

In [None]:
%pip install --upgrade pip
%pip install --disable-pip-version-check \
%pip install torch==2.0.1

The main idea behind the self-attention mechanism is that instead of using fixed embeddings for each token, we can use the whole sequence to compute a weighted average of each embedding. Given a sequence of token embeddings $ x_{1}, ..., x_{n} $ self-attention produces a sequence of new embeddings $ x_{1}^{'}, ..., x_{n}^{'} $ where each $ x_{i}^{'} $ is a linear combination of all the $ x_{j}^{'},  j=1...n $: 

$(1) \; x_{i} = \sum \limits _{j=1} ^{n} w_{ji}x_{j}$


There are several ways to implement a self-attention layer. The original implementation  introduced in the paper “” is called "Scaled Dot-Product Attention" 

$ (2) \; Attention(Q,K,V) =  softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V$

In [11]:
import torch
from torch import nn

Putting these steps together, we will have the following function:

Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. 

$(3) \; MultiHead(Q,K,V) =  Concat(head_{1}, ..., head_{h})W^{O} $ \
\
where $ head_{i} = Attention(QW_{i}^{Q},KW_{i}^{K},VW{i}^{V})$ and, \
$ W_{i}^{Q} \in \mathbb{R}^{d_{model}\times d_{k}}, W_{i}^{K} \in \mathbb{R}^{d_{model} \times d_{k}}, W_{i}^{V} \in \mathbb{R}^{d_{model} \times d_{v}} and W^{O} \in \mathbb{R}^{d_{hv} \times d_{model}} $ are weight matrices.

These three weight matrices are used to project the embedded input tokens, x(i), into query, key, and value vectors.

These matrices transform input data into queries, keys, and values, which are crucial components of the attention mechanism. As the model is exposed to more data during training, it adjusts these trainable weights

# What are Query, Key and Value ? 

In attention mechanisms, we use terms like "key," "query," and "value" which come from information retrieval and databases. They help us store, search, and get information efficiently.

Think of a "query" like a search term you put into a database. It's what the model is currently focusing on or trying to understand, like a word in a sentence. The query helps the model figure out how much attention to give to other parts of the input.

A "key" is like an index in a database used for searching. Each item in the input sequence, such as each word in a sentence, has a key. These keys are matched with the query to find relevant information.

The "value" in this context is similar to the value in a key-value pair in a database. It represents the actual content or representation of the input items. Once the model determines which keys (and thus which parts of the input) are most relevant to the query (the current focus item), it retrieves the corresponding values. 

With this introduction, let's code our very first attention mechanism. Imagine we have an embedding model that generates embeddings in a 5 dimentional embedding space. Assume that our embedding model has generated the following embedding vectors for our input sentence "Write your first Attention mechanism".  

Please note that embedding values in this example are totally random and dosen't express any information. 

In [14]:
import torch
# Write your very first Attention mechanism
inputs = torch.tensor(
  [[0.172, 0.295, 0.618, 0.459, 0.818], # Write 
   [0.265, 0.563, 0.718, 0.323, 0.126], # your  
   [0.071, 0.235, 0.594, 0.954, 0.418], # very   
   [0.206, 0.333, 0.044, 0.862, 0.152], # first    
   [0.300, 0.505, 0.727, 0.495, 0.898], # Attention     
   [0.095, 0.809, 0.596, 0.110, 0.447]] # mechanism   
)

In [16]:
query = inputs[4]
print(query)

tensor([0.3000, 0.5050, 0.7270, 0.4950, 0.8980])


First we should generate attention weight matrix, which simply is the dot product of each embedding vector with other embedding vecotrs. 

That is $ Attention Scores \in \mathbb{R}^{d_{t}\times d_{t}} $ \
Where $ d_{t} $ is the number of input tokens (i.e., words), here $ d_{t} = 6 $   

In [18]:
token_nums = inputs.shape[0]
attn_scores = torch.zeros(token_nums, token_nums)

print(attn_scores)

tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])


In [21]:
attn_scores = torch.matmul(inputs, inputs.T) 
print(attn_scores)

tensor([[1.3783, 0.9067, 1.2284, 0.6809, 1.6116, 1.0395],
        [0.9067, 1.0229, 0.9384, 0.5712, 1.1588, 1.0004],
        [1.2284, 0.9384, 1.4979, 1.0049, 1.4194, 0.8427],
        [0.6809, 0.5712, 1.0049, 0.9214, 0.8251, 0.4780],
        [1.6116, 1.1588, 1.4194, 0.8251, 1.9250, 1.3262],
        [1.0395, 1.0004, 0.8427, 0.4780, 1.3262, 1.2306]])


In [22]:
attn_weights = torch.softmax(attn_scores, dim=0)
print(attn_weights)

tensor([[0.2017, 0.1599, 0.1740, 0.1533, 0.1986, 0.1696],
        [0.1259, 0.1796, 0.1302, 0.1374, 0.1263, 0.1631],
        [0.1736, 0.1650, 0.2278, 0.2120, 0.1638, 0.1393],
        [0.1004, 0.1143, 0.1391, 0.1950, 0.0904, 0.0967],
        [0.2547, 0.2057, 0.2106, 0.1771, 0.2716, 0.2259],
        [0.1437, 0.1756, 0.1183, 0.1252, 0.1493, 0.2053]])


In [24]:
context_vectors = torch.matmul(attn_weights, inputs)  
print(context_vectors) # 6x5

tensor([[0.1967, 0.4789, 0.5950, 0.5593, 0.5353],
        [0.1602, 0.4103, 0.4791, 0.4388, 0.3872],
        [0.1958, 0.4637, 0.5726, 0.6295, 0.4997],
        [0.1339, 0.3155, 0.3587, 0.4392, 0.3088],
        [0.2527, 0.6194, 0.7701, 0.6962, 0.6941],
        [0.1697, 0.4522, 0.5215, 0.4399, 0.4340]])


As you can see, context vectors are the same size as our inputs. In other way, we simply modified the embeddings to reflect the attention to other tokens as well.

# Self Attention

The paper "Attention Is All You Need" introduces <em> Scaled Dot-Product Attention </em>. For instance, when scaling up the embedding dimension, which is typically greater than thousand for GPT-like LLMs, large dot products can result in very small gradients during backpropagation due to the softmax function applied to them. As dot products increase, the softmax function behaves more like a step function, resulting in gradients nearing zero. These small gradients can drastically slow down learning. We suspect that for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients. 

<center><figure><img src="imgs/scaled-dot-product.png" alt="drawing" width="300"/><figcaption>Fig. 1: Scaled Dot-Product Attention.</figcaption></figure></center>    

In [6]:
from math import sqrt
def scaled_dot_product_attention(Q, K, V):
    print("key ", K.shape, "value ",K.size)
    dim_k = K.size(-1)
    attn_scores = torch.matmul(Q, K.T)
    attn_weights = torch.softmax(attn_scores / sqrt(dim_k),dim=-1)
    return torch.matmul(attn_weights,V)

In [110]:
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.W_q = nn.Linear(embed_dim, head_dim, bias=False)
        self.W_k = nn.Linear(embed_dim, head_dim, bias=False)
        self.W_v = nn.Linear(embed_dim, head_dim, bias=False)
        
    def forward(self, x):
        keys = self.W_k(x)
        queries = self.W_q(x)
        values = self.W_v(x)
        
        # attn_scores = queries @ keys.T
        # attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5,dim=1)
        # context_vec = attn_weights @ values 
        # return context_vec
        
        attention_outputs = scaled_dot_product_attention_simple(queries, keys, values)
        return attention_outputs
    
    

In [111]:
embed_dim = 3
head_dim = 2

self_attention = SelfAttention(embed_dim,head_dim)
print(query)
self_attention(inputs)

tensor([0.5500, 0.8700, 0.6600])


tensor([[ 0.2366, -0.3615],
        [ 0.2361, -0.3613],
        [ 0.2361, -0.3613],
        [ 0.2359, -0.3606],
        [ 0.2353, -0.3596],
        [ 0.2363, -0.3613]], grad_fn=<MmBackward0>)

In [None]:
class MultiHeadAttention_v2(nn.Module):
    def __init__(self, config):
        super().__init__(embed_dim, num_heads, head_dim)
        self.heads = nn.ModuleList([SelfAttention(embed_dim, head_dim) for _ in range(num_heads)])
        self.output_linear = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, hidden_state):
        x = torch.cat([head(hidden_state) for head in self.heads], dim=-1)
        x = self.output_linear(x)
        return x   