# Self-Attention Mathematics: Step-by-Step Guide

This notebook provides a detailed explanation of the self-attention mechanism, breaking down the mathematics behind Query (Q), Key (K), and Value (V) matrices.

In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set style for better visualizations
sns.set_style('whitegrid')
torch.manual_seed(42)

<torch._C.Generator at 0x1c810cb64f0>

## 1. Understanding the Input

Let's start with a simple sentence and create token embeddings.

In [2]:
# Simple example: "The cat sat"
seq_len = 3
d_model = 4  # Small dimension for clarity

# Create simple token embeddings (normally from embedding layer)
X = torch.tensor([
    [1.0, 0.0, 1.0, 0.0],  # "The"
    [0.0, 1.0, 0.0, 1.0],  # "cat"
    [1.0, 1.0, 0.0, 0.0]   # "sat"
])

print("Input Embeddings (X):")
print(X)
print(f"\nShape: {X.shape} (seq_len={seq_len}, d_model={d_model})")

Input Embeddings (X):
tensor([[1., 0., 1., 0.],
        [0., 1., 0., 1.],
        [1., 1., 0., 0.]])

Shape: torch.Size([3, 4]) (seq_len=3, d_model=4)


## 2. Creating Query, Key, and Value Matrices

The core of attention is projecting the input into three different spaces:
- **Query (Q)**: What am I looking for?
- **Key (K)**: What do I contain?
- **Value (V)**: What information do I have to offer?

In [None]:
# Initialize weight matrices (in practice, these are learned)
W_q = torch.randn(d_model, d_model) * 0.1
W_k = torch.randn(d_model, d_model) * 0.1
W_v = torch.randn(d_model, d_model) * 0.1

print("Weight Matrices:")
print(f"W_q shape: {W_q.shape}")
print(f"W_k shape: {W_k.shape}")
print(f"W_v shape: {W_v.shape}")

In [None]:
# Compute Q, K, V by matrix multiplication
Q = X @ W_q  # Query: what each token is looking for
K = X @ W_k  # Key: what each token represents
V = X @ W_v  # Value: what information each token carries

print("Query Matrix (Q):")
print(Q)
print(f"\nKey Matrix (K):")
print(K)
print(f"\nValue Matrix (V):")
print(V)

## 3. Computing Attention Scores

Calculate how much each token should attend to every other token.

**Formula**: $\text{scores} = \frac{Q \cdot K^T}{\sqrt{d_k}}$

In [None]:
# Step 1: Compute raw attention scores (Q @ K^T)
d_k = d_model  # In our case, they're the same
raw_scores = Q @ K.transpose(-2, -1)

print("Raw Attention Scores (Q @ K^T):")
print(raw_scores)
print(f"\nShape: {raw_scores.shape}")
print("\nInterpretation: Element (i,j) = how much token i attends to token j")

In [None]:
# Step 2: Scale by sqrt(d_k) to prevent large values
import math

scaled_scores = raw_scores / math.sqrt(d_k)

print(f"Scaled Attention Scores (divided by sqrt({d_k}) = {math.sqrt(d_k):.2f}):")
print(scaled_scores)
print("\nScaling prevents gradient vanishing in softmax")

## 4. Applying Softmax to Get Attention Weights

Convert scores to probabilities that sum to 1 for each token.

In [None]:
# Apply softmax to each row
attention_weights = F.softmax(scaled_scores, dim=-1)

print("Attention Weights (after softmax):")
print(attention_weights)
print(f"\nRow sums (should be 1.0): {attention_weights.sum(dim=-1)}")

In [None]:
# Visualize attention weights
plt.figure(figsize=(8, 6))
sns.heatmap(attention_weights.numpy(), 
            annot=True, 
            fmt='.3f', 
            cmap='YlOrRd',
            xticklabels=['The', 'cat', 'sat'],
            yticklabels=['The', 'cat', 'sat'],
            cbar_kws={'label': 'Attention Weight'})
plt.title('Attention Weight Matrix\n(Row i shows how token i attends to all tokens)', fontsize=14)
plt.xlabel('Keys (attending TO)', fontsize=12)
plt.ylabel('Queries (attending FROM)', fontsize=12)
plt.tight_layout()
plt.show()

## 5. Computing the Final Attention Output

Weighted sum of values using attention weights.

**Formula**: $\text{output} = \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$

In [None]:
# Multiply attention weights with values
attention_output = attention_weights @ V

print("Attention Output:")
print(attention_output)
print(f"\nShape: {attention_output.shape}")
print("\nEach row is a weighted combination of all value vectors")

## 6. Complete Self-Attention Function

Let's put it all together in a single function.

In [None]:
def self_attention(X, W_q, W_k, W_v, mask=None):
    """
    Complete self-attention mechanism
    
    Args:
        X: Input embeddings (seq_len, d_model)
        W_q, W_k, W_v: Weight matrices
        mask: Optional attention mask
    
    Returns:
        output: Attention output
        attention_weights: Attention distribution
    """
    # Step 1: Project to Q, K, V
    Q = X @ W_q
    K = X @ W_k
    V = X @ W_v
    
    # Step 2: Compute scaled attention scores
    d_k = Q.size(-1)
    scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
    
    # Step 3: Apply mask (if provided)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Step 4: Apply softmax
    attention_weights = F.softmax(scores, dim=-1)
    
    # Step 5: Compute weighted sum of values
    output = attention_weights @ V
    
    return output, attention_weights

# Test the function
output, weights = self_attention(X, W_q, W_k, W_v)
print("Self-Attention Output:")
print(output)
print(f"\nMatches previous computation: {torch.allclose(output, attention_output)}")

## 7. Understanding Masking (Optional)

Masks prevent attention to certain positions (e.g., future tokens in autoregressive models).

In [None]:
# Create a causal mask (lower triangular)
causal_mask = torch.tril(torch.ones(seq_len, seq_len))

print("Causal Mask (prevents attending to future tokens):")
print(causal_mask)

# Apply masked attention
masked_output, masked_weights = self_attention(X, W_q, W_k, W_v, mask=causal_mask)

print("\nMasked Attention Weights:")
print(masked_weights)

In [None]:
# Visualize masked attention
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Original attention
sns.heatmap(attention_weights.numpy(), 
            annot=True, fmt='.3f', cmap='YlOrRd',
            xticklabels=['The', 'cat', 'sat'],
            yticklabels=['The', 'cat', 'sat'],
            ax=axes[0])
axes[0].set_title('Standard Self-Attention', fontsize=12)

# Masked attention
sns.heatmap(masked_weights.numpy(), 
            annot=True, fmt='.3f', cmap='YlOrRd',
            xticklabels=['The', 'cat', 'sat'],
            yticklabels=['The', 'cat', 'sat'],
            ax=axes[1])
axes[1].set_title('Causal (Masked) Self-Attention', fontsize=12)

plt.tight_layout()
plt.show()

## 8. Key Intuitions

### What did we learn?

1. **Query (Q)**: Represents what each token is "searching for"
2. **Key (K)**: Represents what each token "offers" or "contains"
3. **Value (V)**: The actual information to be aggregated

### The Attention Mechanism:
- **Scores**: Dot product between queries and keys measures relevance
- **Scaling**: Dividing by âˆšd_k stabilizes gradients
- **Softmax**: Converts scores to a probability distribution
- **Output**: Weighted combination of values based on attention

### Why is this powerful?
- **Dynamic**: Each token can attend to any other token
- **Contextual**: Representations depend on the entire sequence
- **Parallelizable**: All computations can be done simultaneously
- **Learnable**: Q, K, V projections are learned from data

## 9. Multi-Head Attention Preview

In practice, we use multiple attention heads to capture different types of relationships.

In [None]:
# Simulate 2 attention heads
num_heads = 2
d_k = d_model // num_heads

print(f"Original dimension: {d_model}")
print(f"Number of heads: {num_heads}")
print(f"Dimension per head: {d_k}")
print("\nEach head learns different attention patterns!")
print("- Head 1 might focus on syntax (e.g., subject-verb agreement)")
print("- Head 2 might focus on semantics (e.g., word meanings)")

## Summary

**Self-Attention Formula**:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where:
- $Q = XW_q$ (Queries)
- $K = XW_k$ (Keys)
- $V = XW_v$ (Values)

This mechanism allows each position in a sequence to attend to all positions, creating rich contextual representations!