# 🧠 Distilled | Attention Is All You Need
###### <span style="font-size:14pt">A Simplified Guide to Transformers and Attention Mechanisms</span> [paper link](https://arxiv.org/abs/1706.03762)

## 📌 Introduction

The Transformer architecture, introduced in the 2017 paper "Attention Is All You Need", has become the foundation of modern natural language processing. Unlike previous models that relied on recurrence (RNNs, LSTMs) or convolutions (CNNs), Transformers rely solely on **attention mechanisms** to model relationships in data.

This notebook provides a **distilled** explanation of the core attention concepts that power Transformers. My goal is to make these ideas accessible by combining intuitive explanations with practical code examples in PyTorch.

By the end of this notebook, you'll understand:

- What attention is and why it's powerful  
- The difference between **self-attention**, **multi-head attention**, **causal attention**, and **cross-attention**  
- How to implement each type of attention in code  
- How these pieces come together in the Transformer architecture  

Whether you're a student, researcher, or developer, this notebook is designed to be a clear, hands-on starting point for understanding how attention works in modern AI.

To properly format your content in a **Text cell** in Google Colab, including the equation, you can use Markdown combined with LaTeX syntax. Here's how you can do it:

---

## 🔍 What is Attention?

At its core, **attention** is a mechanism that allows a model to dynamically focus on different parts of its input when performing a task — such as translating a sentence or answering a question.

Imagine reading a sentence and trying to understand the meaning of a word like "it". To figure out what "it" refers to, your brain "attends" to earlier parts of the sentence. Attention mechanisms allow neural networks to do the same — to look at other words in the input and weigh their importance dynamically.

### 📐 The Scaled Dot-Product Attention

The most common form of attention used in Transformers is called **scaled dot-product attention**. It operates on three vectors:

- **Q (Query)** – What we’re looking for  
- **K (Key)** – What we compare it to  
- **V (Value)** – What we retrieve if it matches  

The formula is:

$$
\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right)V
$$

Where:

- \( Q \), \( K \), and \( V \) are matrices representing the query, key, and value vectors.  
- \( d_k \) is the dimension of the key vectors, used to scale the dot products.  
- The **softmax** ensures that the attention weights are positive and sum to 1.

### 🧠 Intuition

- The **dot product** \( $$QK^T$$ \) tells us how similar the query is to each key.  
- Dividing by \( $$\sqrt{d_k}$$ \) prevents extremely large values (which can harm learning).  
- The **softmax** turns these similarities into a probability distribution — essentially "how much attention" to pay to each position.  
- Finally, we apply these weights to the **values (V)** to get the output — a weighted combination of input information.

This core idea is used throughout the Transformer — and extended in various ways to support richer behavior like multi-head, causal, and cross attention.

---




## 🎯 Types of Attention

In this section, we'll explore the four key types of attention used in the Transformer architecture:

- **Self-Attention**
- **Multi-Head Attention**
- **Causal Attention**
- **Cross Attention**

Each attention type builds on the same underlying mechanism but serves different purposes within the model. We'll explain each one and provide a PyTorch implementation.

### 🧩 Self-Attention

**Self-Attention** allows a sequence element (e.g., a word) to attend to **other elements in the same sequence** to build context. For example, to understand the meaning of "bank", we may need to attend to "river" or "money" nearby.

In Transformers, each token in the input is compared with every other token (including itself), and a weighted representation is formed.

This is the foundation of both encoder and decoder blocks.


In [6]:
import torch
import torch.nn.functional as F

Q = torch.rand((3, 4))
K = torch.rand((3, 4))
V = torch.rand((3, 4))

dk = Q.size(-1)
attention_scores = Q @ K.T / dk**0.5
attention_weights = F.softmax(attention_scores, dim=-1)
output = attention_weights @ V

print("Attention Weights:\n", attention_weights)
print("\nSelf-Attention Output:\n", output)

Attention Weights:
 tensor([[0.3237, 0.3267, 0.3496],
        [0.3015, 0.3176, 0.3809],
        [0.2737, 0.3728, 0.3534]])

Self-Attention Output:
 tensor([[0.3456, 0.1623, 0.3909, 0.7038],
        [0.3319, 0.1635, 0.4014, 0.7107],
        [0.3484, 0.1527, 0.3724, 0.6964]])


### 🔁 Multi-Head Attention

**Multi-Head Attention** allows the model to learn **multiple attention distributions** in parallel. Each "head" focuses on different parts or relationships in the input.

The results from all heads are concatenated and projected to the output space.

This is a core component of all Transformer layers.


In [7]:
import torch
import torch.nn.functional as F

# Input: 3 tokens, 8-dimensional embeddings
x = torch.rand((3, 8))

# Parameters
num_heads = 2
d_model = x.size(-1)
d_k = d_model // num_heads

# Randomly initialized projection weights for Q, K, V for each head
W_q = torch.rand((num_heads, d_model, d_k))
W_k = torch.rand((num_heads, d_model, d_k))
W_v = torch.rand((num_heads, d_model, d_k))
W_o = torch.rand((num_heads * d_k, d_model))  # final output projection

# Step 1: Project input into Q, K, V for each head
Q = torch.einsum('nd, hdf -> hnf', x, W_q)  # shape: (heads, tokens, d_k)
K = torch.einsum('nd, hdf -> hnf', x, W_k)
V = torch.einsum('nd, hdf -> hnf', x, W_v)

# Step 2: Compute scaled dot-product attention per head
scores = torch.einsum('hnf, hmf -> hnm', Q, K) / (d_k ** 0.5)  # (heads, tokens, tokens)
weights = F.softmax(scores, dim=-1)
attn_output = torch.einsum('hnm, hmf -> hnf', weights, V)  # (heads, tokens, d_k)

# Step 3: Concatenate heads
attn_output = attn_output.transpose(0, 1).reshape(3, -1)  # (tokens, heads * d_k)

# Step 4: Final linear projection
final_output = attn_output @ W_o  # (tokens, d_model)

print("Multi-Head Attention Output:\n", final_output)

Multi-Head Attention Output:
 tensor([[10.5482,  9.1351,  7.7318,  9.3621,  9.1614, 10.2630,  5.0940,  7.4390],
        [10.6146,  9.2001,  7.7724,  9.4230,  9.2220, 10.3311,  5.1205,  7.4808],
        [10.4710,  9.0539,  7.6865,  9.2882,  9.0836, 10.1739,  5.0604,  7.3916]])


### 🚫 Causal (Masked) Attention

**Causal Attention** prevents tokens from attending to **future positions** in the sequence. It's used in language models like GPT to ensure that prediction at position `t` only depends on positions ≤ `t`.

This is achieved using a **lower-triangular mask**.

In [12]:
import torch
import torch.nn.functional as F

# Input: batch_size=1, 3 tokens, 8-dimensional embeddings
x = torch.rand((1, 3, 8))  # (batch_size, seq_len, d_model)

# Parameters
num_heads = 2
d_model = x.size(-1)
d_k = d_model // num_heads

# Randomly initialized projection weights for Q, K, V for each head
W_q = torch.rand((num_heads, d_model, d_k))
W_k = torch.rand((num_heads, d_model, d_k))
W_v = torch.rand((num_heads, d_model, d_k))
W_o = torch.rand((num_heads * d_k, d_model))  # final output projection

# Step 1: Project input into Q, K, V for each head
Q = torch.einsum('bnd, hdf -> bhnf', x, W_q)  # shape: (batch_size, heads, seq_len, d_k)
K = torch.einsum('bnd, hdf -> bhnf', x, W_k)
V = torch.einsum('bnd, hdf -> bhnf', x, W_v)

# Step 2: Create causal mask (lower triangular matrix) for each sequence in the batch
seq_len = x.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(1)  # (1, 1, seq_len, seq_len)

# Step 3: Compute scaled dot-product attention per head with causal mask
scores = torch.einsum('bhnf, bhmf -> bhnm', Q, K) / (d_k ** 0.5)  # (batch_size, heads, seq_len, seq_len)
scores = scores.masked_fill(mask == 0, float('-inf'))  # Apply causal mask
weights = F.softmax(scores, dim=-1)
attn_output = torch.einsum('bhnm, bhmf -> bhnf', weights, V)  # (batch_size, heads, seq_len, d_k)

# Step 4: Concatenate heads (from heads dimension)
attn_output = attn_output.transpose(1, 2).contiguous().view(1, seq_len, -1)  # (batch_size, seq_len, heads * d_k)

# Step 5: Final linear projection
final_output = attn_output @ W_o  # (batch_size, seq_len, d_model)

print("Causal Multi-Head Attention Output:\n", final_output)

Causal Multi-Head Attention Output:
 tensor([[[ 6.6643,  9.4161, 10.0614,  8.9188, 11.8910,  8.9163, 12.2942,
          12.1606],
         [ 6.4994,  9.0797,  9.5428,  8.6220, 11.6023,  8.6823, 11.8995,
          11.7685],
         [ 6.2610,  8.9364,  9.2423,  8.4325, 11.2500,  8.5664, 11.6209,
          11.5014]]])


### 🔄 Cross Attention

**Cross Attention** is used in **encoder-decoder** models, where the decoder input (query) attends to the encoder output (key and value).

It's the key attention type in applications like machine translation or image captioning.

In [13]:
import torch
import torch.nn.functional as F

# Input: batch_size=1, 3 tokens in encoder and decoder, 8-dimensional embeddings
encoder_input = torch.rand((1, 3, 8))  # (batch_size, seq_len, d_model)
decoder_input = torch.rand((1, 3, 8))  # (batch_size, seq_len, d_model)

# Parameters
num_heads = 2
d_model = encoder_input.size(-1)
d_k = d_model // num_heads

# Randomly initialized projection weights for Q, K, V for each head
W_q = torch.rand((num_heads, d_model, d_k))
W_k = torch.rand((num_heads, d_model, d_k))
W_v = torch.rand((num_heads, d_model, d_k))
W_o = torch.rand((num_heads * d_k, d_model))  # final output projection

# Step 1: Project encoder and decoder input into Q, K, V for each head
Q = torch.einsum('bnd, hdf -> bhnf', decoder_input, W_q)  # shape: (batch_size, heads, seq_len, d_k)
K = torch.einsum('bnd, hdf -> bhnf', encoder_input, W_k)
V = torch.einsum('bnd, hdf -> bhnf', encoder_input, W_v)

# Step 2: Create causal mask (lower triangular matrix) for each sequence in the decoder
seq_len = decoder_input.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(1)  # (1, 1, seq_len, seq_len)

# Step 3: Compute scaled dot-product attention per head with causal mask
scores = torch.einsum('bhnf, bhmf -> bhnm', Q, K) / (d_k ** 0.5)  # (batch_size, heads, seq_len, seq_len)
scores = scores.masked_fill(mask == 0, float('-inf'))  # Apply causal mask
weights = F.softmax(scores, dim=-1)
attn_output = torch.einsum('bhnm, bhmf -> bhnf', weights, V)  # (batch_size, heads, seq_len, d_k)

# Step 4: Concatenate heads (from heads dimension)
attn_output = attn_output.transpose(1, 2).contiguous().view(1, seq_len, -1)  # (batch_size, seq_len, heads * d_k)

# Step 5: Final linear projection
final_output = attn_output @ W_o  # (batch_size, seq_len, d_model)

print("Cross Attention Output:\n", final_output)

Cross Attention Output:
 tensor([[[ 9.2096,  9.4300,  5.6639,  4.4452,  5.7366,  7.6541,  6.8250,
           6.9600],
         [ 9.8741, 10.0145,  6.1487,  5.0298,  6.2863,  7.9965,  7.7780,
           7.3250],
         [10.0505, 10.1721,  6.2643,  5.1377,  6.4207,  8.1500,  7.9149,
           7.4401]]])


### Conclusion

In this section, we’ve explored key attention mechanisms that power **Transformer** models: **Self-Attention**, **Multi-Head Attention**, **Causal Attention**, and **Cross-Attention**. Here's a simplified summary:

- **Self-Attention:** Each token in a sequence attends to every other token in the same sequence to gather contextual information. This is essential for understanding relationships in the data.
  
- **Multi-Head Attention:** Multiple attention heads allow the model to focus on different aspects of the sequence in parallel, providing richer information.

- **Causal Attention (Masked Attention):** In autoregressive models like GPT, tokens can only attend to themselves and previous tokens to prevent information leakage from future tokens during training.

- **Cross-Attention:** In encoder-decoder models (like machine translation), the decoder attends to the encoder’s output, focusing on relevant parts of the input sequence.

Together, these mechanisms enable Transformers to handle complex tasks like **machine translation**, **text generation**, and **image captioning**, making them incredibly powerful for sequence-based tasks.
