### Summary

This notebook aims to provide notes on attention mechanism in LLMs and implementations of popular attentions.

#### What is attention?

Attention is a computing mechanism to **summarize semantic similarities** by performing dot-product on query-key pairs and return a weighted matrix.

#### Why we need attention?

In tasks like natural language processing and sequence modeling, the input is often a *variable-length sequence*, such as text, audio, or video frames. Traditional RNN methods (like LSTM, GRU) face *vanishing gradients* or *exploding gradients* when processing long sequences. Also, due to the intrinsic limitations of the recurrent structure, the model struggles to flexibly establish direct and controllable dependencies between different positions, resulting loss of detail for long sequence.

The introduction of the Attention mechanism is mainly driven by the idea that the model, when processing the current time step (or current word), should be able to *adaptively* "attend" to more important parts of the input sequence and ignore less relevant ones. It does so by explicitly computing the relevance (semantic similarity) between positions, using which to assign weights and extract contextual information.

In applications like machine translation, text summarization, and reading comprehension, the need for information from different positions in the input sequence varies at each time step. The Attention mechanism allows the model to dynamically assign “attention weights,” so that a word being generated can emphasize semantically relevant parts of the input. Compared to pure RNN/CNN models, this mechanism performs significantly better in capturing long-range dependencies.

#### Computation steps for Scaled dot-product attention

This is the version used in [transformer paper](https://www.arxiv.org/abs/1706.03762).

1. Get `Q,K,V`
    - `Q` stands for query, it could be the hidden representation of input sequence at certain position, depending on the task and the module.
    - `K` stands for key, usually vectors generated from *reference sequence* (same=self-attention, different=cross-attention).
    - `V` stands for value, corresponding to `K`, the value vector of each position, representing information sending to query.

2. Compute attention score

Scaled dot-product from transformer $$\text{score}(Q,K_i)=\frac{Q\cdot K_i}{\sqrt{d_k}}$$

The scaling factor $\frac{1}{\sqrt{d_k}}$ is used for preventing too large dot-product, mitigating potential gradient issue.

3. Compute attention weight

Perform a `softmax` on the scores, representing the weights (*extent of attention*) assigned to different position at reference sequence (K). $\alpha_i= \text{softmax}(score(Q,K))$

4. Sum up

$$\text{Attention}(Q,K,V) = \sum_i \alpha_i \cdot V_i$$

This output vector represents the **contextual representation** of a sequence. Simplifing further we can get the paper's representation:

$$\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

#### Self-Attention vs. Cross-Attention

For self-attention:
- Secnario: Used to model the dependencies between different positions **within the same sequence**
- Example: In a Transformer Encoder/Decoder, applying Q,K,V on the same sequence is self-attention.
- Advantage: allows the network to directly model interactions between **any two positions** in the sequence, without relying on step-by-step propagation like in recurrent structures.

For cross-attention:
- Secnario: In tasks like translation, decoder need to look up or align with encoder's output sequence. Therefore the decoder's current hidden state is used as `Q`, and the encoder's output is used as `K` and `V`.
- Reason: allows the decoder to **look across sequences** to find the most relevant position in the source sentence representation when generating the next word.

#### Main stream attention methods
1. **Scaled Dot-Product Attention** 
    - After applying linear transformations to Q, K, V, compute attention using "dot product + softmax + weighted sum".
2. **Multi-Head Attention (MHA)**
    - Split Q, K, V into `h` subspaces (heads), compute attention separately in each, then concatenate and project again.
    - Advantage: Enables the model to learn different types of attention patterns across different subspaces, enhancing model expressiveness. (like kernels in CNN)
3. Flash Attention
    - Mainly algorithmic implementation optimization (by calculating in blocks), same formula as scaled dot-product attention.
4. Sparse Attention
    - restricts calculations to a subset of key positions (e.g., local windows, selected global tokens)
    - $$\text{Attention}_\text{sparse} (Q,K,V) = \text{softmax}\left(\frac{QK^T_\text{masked}}{\sqrt{d_k}}\right)V$$


#### Weight sharing in Transformer

In some variants or specific implementations, to reduce the number of parameters, **weight sharing** is applied across **different layers of the Encoder or Decoder**.

This is feasible because:

- All layers in a Transformer have the **same structure** (self-attention + feedforward network), forming a stack.
- Experiments show that such sharing **does not significantly degrade model performance**, while offering the benefit of **greatly reducing parameter count**.

However, note that the original Transformer paper (Vaswani et al.) **does not mandate sharing across all layers**. In practice, this can be flexibly chosen based on project needs.


### Implementations

In [None]:
# Multi-Head Attention
import torch
from torch import nn

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads,dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
        
        # Initialize projection matrices for Q, K, V
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None, dropout=None):
        batch_size = query.size(0)
        
        # Project inputs to Q, K, V
        q = self.q_proj(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, float('-inf')) # Fill in a large negative value to ensure after softmax it's 0
        attn_weights = torch.softmax(scores, dim=-1)
        
        if dropout is not None:
            attn_weights = self.dropout(attn_weights)

        attn_output = torch.matmul(attn_weights, v)
        
        # Concatenate heads and project output
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
        return self.out_proj(attn_output)