# Understanding Self-Attention: The Core Mechanism Behind Transformers

In the world of deep learning, particularly in Natural Language Processing (NLP) and computer vision, **self-attention** has become one of the most important mechanisms that drive state-of-the-art models. Self-attention is at the heart of the **Transformer architecture**, which has revolutionized the way machines understand and generate text. In this blog post, we will dive deep into the concept of self-attention, its workings, and why it's such a pivotal innovation in modern AI.


### 1. What is Self-Attention?
At its core, **self-attention** is a mechanism that allows a model to weigh the importance of different words in a sentence (or elements in a sequence) with respect to each other. Unlike traditional models that process words sequentially (e.g., RNNs and LSTMs), self-attention enables each word to attend to every other word in the sequence simultaneously, learning complex relationships between them.

In simpler terms, self-attention helps a model decide which parts of the input sequence are most relevant when producing the output for a particular word. This makes it highly effective for handling long-range dependencies in data, such as long sentences in NLP or distant pixels in images.


### 2. Why Do We Need Self-Attention?
Before self-attention, sequence data was traditionally processed by models like Recurrent Neural Networks (RNNs) or Long Short-Term Memory networks (LSTMs). While these models could process sequences, they had a major limitation: **sequential processing**. RNNs and LSTMs process data one step at a time, making them inherently slow and prone to problems with long-range dependencies (vanishing gradients).

**Self-attention** overcomes these limitations by processing the entire sequence in parallel, allowing the model to capture dependencies between words (or other elements) regardless of their distance in the sequence. This parallelization leads to significant speed-ups during training and inference, while also providing the model with a better understanding of the context between words.


### 3. How Does Self-Attention Work?
Self-attention works by computing a set of **attention scores** for each element in the sequence, indicating how much focus each word should give to every other word. This process involves three key components:


#### Query, Key, and Value Vectors

Each word in the input sequence is transformed into three vectors:

+ **Query (Q):** Represents the word you are currently focusing on.
+ **Key (K):** Represents the potential words that could be attended to.
+ **Value (V):** Represents the actual information that will be passed on once the attention scores are calculated.

These vectors are obtained by multiplying the input embeddings (representations of words) with learned weight matrices. The **Query** determines what information to look for, the **Key** tells the model where the relevant information is located, and the **Value** holds the information that gets passed through after attention scores are applied.


+ $ \mathbf{Q = Embedding * W_Q} $
+ $ \mathbf{K = Embedding * W_K} $
+ $ \mathbf{V = Embedding * W_V} $

The weights $ \mathbf{W_Q} $, $ \mathbf{W_K} $, and $ \mathbf{W_V} $ are learned during training.



#### Attention Scores

The next step is to compute the **attention scores**, which quantify the relevance between the current word (query) and all other words (keys). This is done by calculating the dot product between the query vector and key vector for every pair of words in the sequence.

The attention score $ \alpha_{ij} $ between the query $ Q_i $ and key $ K_j $ is given by $$ \alpha_{ij} = \frac{Q_i \cdot K_j}{\sqrt{d_k}} $$

Where:

- $ Q_i $ and $ K_j $ are the query and key vectors for words $ i $ and $ j $, respectively.
- $ d_k $ is the dimension of the key vectors (used for scaling to avoid large values).


#### Softmax and Normalization

Once the attention scores are computed, they are passed through a softmax function to normalize the scores. This makes the scores sum to 1 and turns them into probabilities, which are easier to interpret as a weighted attention distribution:
$$ \text{Attention Weight}_{ij} = \text{Softmax}(\alpha_{ij}) = \frac{e^{\alpha_{ij}}}{\sum_k e^{\alpha_{ik}}} $$


Now, the attention weights determine how much each word should focus on others in the sequence. Higher attention weights mean that the word is more important for computing the current word's output.


#### Weighted Sum of Values

Finally, the attention weights are used to compute a weighted sum of the value vectors for each word. This produces an output representation that is a mixture of information from all the words in the sequence, weighted by their attention scores.

$$
Output_i = \sum_j \text{Attention Weight}_{ij} \times V_j
$$


This output is a new representation of word $i$ enriched with context from other words in the sequence.

#### Why Scale the Dot Product? $ (\frac{Q_i \cdot K_j}{\mathbf{\sqrt{d_k}}})$

When the **dimension of the query and key vectors** is large, the dot product can become **large in magnitude**. This is due to the fact that as the dimensionality $d_k$ increases, the values in the query and key vectors are more likely to have higher values, which makes the dot product larger.

Now, when we apply the **softmax function** to the attention scores, we want to avoid extremely large or small values, because softmax has the tendency to push very large values to 1 and very small values to 0. This would lead to a situation where the model pays attention to only a single word (the one with the largest attention score), which is not ideal, especially when we're dealing with complex relationships in a sequence.

To control the scale of the attention scores and maintain a balanced distribution, we divide the dot product by the square root of the dimension of the key vector $d_k$. This ensures that the dot products stay within a reasonable range, making the softmax function behave more smoothly. This scaling factor (${\sqrt{d_k}}$) effectively normalizes the dot product and prevents the scores from becoming too large as the dimensionality of the vectors increases. Thus we maintain numerical stability and avoid the model focusing too much on a single word or ignoring important words because of disproportionately large values.

### 4. Multi-Head Attention
A major enhancement of the self-attention mechanism is **multi-head attention**. Instead of computing a single attention score for each pair of words, multi-head attention computes several attention scores in parallel, using different learned projections of the query, key, and value vectors.

Each attention head captures a different aspect of the relationships between words. By concatenating the outputs from all attention heads and passing them through a final linear layer, the model can capture richer and more diverse relationships in the data.

This parallelization allows the model to attend to different parts of the sequence from different perspectives, leading to better performance.

### 5. Benefits of Self-Attention
Self-attention offers several key benefits:

+ **Parallelization:** Unlike RNNs, which process sequences step-by-step, self-attention can process entire sequences in parallel, speeding up both training and inference.
+ **Long-Range Dependencies:** Self-attention can capture long-range dependencies in data, unlike RNNs or LSTMs, which struggle with information that is far apart in the sequence.
+ **Scalability:** The self-attention mechanism scales well with the length of the sequence, as each word attends to all other words in constant time, leading to efficient computations.
+ **Flexibility:** Self-attention works equally well in NLP, computer vision, and other fields where relationships between elements need to be modeled, making it highly versatile.

### 6. Applications of Self-Attention
Self-attention is the backbone of several groundbreaking models that have achieved state-of-the-art performance across a variety of tasks:

+ **Machine Translation:** The Transformer, with self-attention, is used for high-quality machine translation, such as in Google Translate.
+ **Text Generation:** Models like GPT use self-attention to generate coherent, contextually relevant text.
+ **Text Classification:** BERT-based models have been fine-tuned for tasks like sentiment analysis, question answering, and document classification.
+ **Vision:** Vision Transformers (ViT) apply self-attention to image patches, enabling the model to learn global dependencies between pixels.
+ **Speech Recognition:** Self-attention models are also used in speech-to-text tasks, improving transcription accuracy.

In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, d_model, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size  # Dimensionality of input embeddings
        self.d_model = d_model  # Dimensionality of the model (hidden layers)
        self.heads = heads
        self.head_dim = d_model // heads  # Head dimension

        # Ensure the d_model is divisible by the number of heads
        assert self.head_dim * heads == d_model, "d_model must be divisible by number of heads"

        # Linear layers for queries, keys, and values (projecting from embed_size to d_model)
        self.values = nn.Linear(embed_size, d_model)  # Embedding to model size (W_V)
        self.keys = nn.Linear(embed_size, d_model)    # Embedding to model size (W_K)
        self.queries = nn.Linear(embed_size, d_model) # Embedding to model size (W_Q)

        # Output linear layer to combine multi-head output
        self.fc_out = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size = x.shape[0]  # Batch size
        seq_len = x.shape[1]     # Sequence length

        # Step 1: Pass the input word embeddings (x) through linear layers
        queries = self.queries(x)  # (batch_size, seq_len, d_model)
        keys = self.keys(x)        # (batch_size, seq_len, d_model)
        values = self.values(x)    # (batch_size, seq_len, d_model)

        # Step 2: Split the embeddings into multiple heads
        queries = queries.reshape(batch_size, seq_len, self.heads, self.head_dim)
        keys = keys.reshape(batch_size, seq_len, self.heads, self.head_dim)
        values = values.reshape(batch_size, seq_len, self.heads, self.head_dim)

        # Transpose to get the shape (batch_size, heads, seq_len, head_dim)
        queries = queries.permute(0, 2, 1, 3)
        keys = keys.permute(0, 2, 1, 3)
        values = values.permute(0, 2, 1, 3)

        # Step 3: Calculate attention scores (Q * K^T)
        energy = torch.matmul(queries, keys.permute(0, 1, 3, 2))  # Shape: (batch_size, heads, seq_len, seq_len)

        # Step 4: Apply mask (if any)
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))  # Apply mask

        # Step 5: Normalize the energy scores to get attention weights
        attention = torch.softmax(energy / (self.head_dim ** (1 / 2)), dim=-1)  # Shape: (batch_size, heads, seq_len, seq_len)

        # Step 6: Compute the weighted sum of values
        out = torch.matmul(attention, values)  # (batch_size, heads, seq_len, head_dim)

        # Step 7: Reshape the output back to (batch_size, seq_len, d_model)
        out = out.permute(0, 2, 1, 3).contiguous().reshape(batch_size, seq_len, self.heads * self.head_dim)

        # Step 8: Apply the final linear layer
        out = self.fc_out(out)

        return out



In [3]:
# Test the SelfAttention module
if __name__ == "__main__":
    embed_size = 128  # Dimensionality of input embeddings (e.g., word embeddings)
    d_model = 256     # Dimensionality of the model (hidden layers)
    heads = 8         # Number of attention heads
    seq_length = 10   # Length of the sequence (e.g., sentence length)
    batch_size = 4    # Number of samples in the batch

    # Random input tensor representing a batch of word embeddings
    x = torch.rand((batch_size, seq_length, embed_size))  # (batch_size, seq_len, embed_size)

    # Instantiate the SelfAttention module
    self_attention = SelfAttention(embed_size, d_model, heads)

    # Perform forward pass
    out = self_attention(x)

    print("Output shape:", out.shape)  # Expected shape: (batch_size, seq_len, d_model)


Output shape: torch.Size([4, 10, 256])
