# Day 13: Attention Is All You Need - The Transformer

Interactive exploration of the architecture that changed everything.

**Paper:** [Attention Is All You Need](https://arxiv.org/abs/1706.03762) (Vaswani et al., 2017)

**In this notebook:**
1. Self-Attention Fundamentals
2. Scaled Dot-Product Attention
3. Multi-Head Attention
4. Positional Encoding
5. Complete Transformer
6. Visualizations

In [None]:
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)
print("Ready to explore Transformers!")

## Part 1: Why Attention?

### The Problem with RNNs

RNNs process sequences step-by-step:
- Sequential = slow training (can't parallelize)
- Long paths = vanishing gradients
- Fixed hidden state = memory bottleneck

### The Solution: Attention

Allow every position to directly attend to every other position!

In [None]:
# Visualize the difference

# RNN: sequential information flow
print("RNN Information Flow:")
print("word_1 -> word_2 -> word_3 -> ... -> word_n")
print("(Each word only sees previous context via hidden state)")

print("\nTransformer Information Flow:")
print("All words connected directly via attention!")

# Simple adjacency matrix comparison
seq_len = 5
rnn_flow = np.tril(np.ones((seq_len, seq_len)))
transformer_flow = np.ones((seq_len, seq_len))

fig, axes = plt.subplots(1, 2, figsize=(10, 4))

axes[0].imshow(rnn_flow, cmap='Blues')
axes[0].set_title('RNN: Sequential Flow', fontsize=12)
axes[0].set_xlabel('Position')
axes[0].set_ylabel('Position')

axes[1].imshow(transformer_flow, cmap='Blues')
axes[1].set_title('Attention: All-to-All', fontsize=12)
axes[1].set_xlabel('Position')
axes[1].set_ylabel('Position')

plt.tight_layout()
plt.show()

## Part 2: Scaled Dot-Product Attention

The core computation:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

**Key insight:** Attention is a learnable, differentiable lookup!
- **Q (Query):** What am I looking for?
- **K (Key):** What do I contain?
- **V (Value):** What do I return if matched?

In [None]:
def softmax(x, axis=-1):
    """Numerically stable softmax."""
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute attention.
    
    Args:
        Q: Queries (batch, seq_q, d_k)
        K: Keys (batch, seq_k, d_k)
        V: Values (batch, seq_k, d_v)
        mask: Optional mask
    """
    d_k = K.shape[-1]
    
    # Step 1: Compute scores
    scores = np.matmul(Q, K.transpose(0, 2, 1))
    
    # Step 2: Scale
    scores = scores / np.sqrt(d_k)
    
    # Step 3: Mask (optional)
    if mask is not None:
        scores = np.where(mask, -1e9, scores)
    
    # Step 4: Softmax
    weights = softmax(scores, axis=-1)
    
    # Step 5: Weighted sum
    output = np.matmul(weights, V)
    
    return output, weights

In [None]:
# Demo: Self-attention on a simple sequence

# Imagine 3 tokens with 4-dimensional representations
seq = np.array([[
    [1, 0, 1, 0],  # Token 0
    [0, 1, 0, 1],  # Token 1
    [1, 1, 0, 0],  # Token 2
]])  # Shape: (1, 3, 4)

# In self-attention, Q = K = V = input (projected)
Q = K = V = seq

output, weights = scaled_dot_product_attention(Q, K, V)

print("Input sequences:")
print(seq[0])

print("\nAttention weights (who attends to whom):")
print(weights[0].round(3))

print("\nOutput (weighted sum of values):")
print(output[0].round(3))

In [None]:
# Visualize attention weights

plt.figure(figsize=(6, 5))
plt.imshow(weights[0], cmap='Blues', vmin=0, vmax=1)
plt.colorbar(label='Attention Weight')

# Add values
for i in range(3):
    for j in range(3):
        plt.text(j, i, f'{weights[0, i, j]:.2f}', 
                ha='center', va='center', fontsize=12)

plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Self-Attention Weights')
plt.xticks([0, 1, 2], ['Token 0', 'Token 1', 'Token 2'])
plt.yticks([0, 1, 2], ['Token 0', 'Token 1', 'Token 2'])
plt.tight_layout()
plt.show()

## Part 3: Why Scale by sqrt(d_k)?

Without scaling, dot products grow with dimension, pushing softmax into saturation!

In [None]:
# Demonstrate the scaling problem

d_k_values = [8, 64, 512]

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

for d_k in d_k_values:
    # Random unit-variance vectors
    q = np.random.randn(1000, d_k)
    k = np.random.randn(1000, d_k)
    
    # Dot products
    dots = (q * k).sum(axis=1)
    dots_scaled = dots / np.sqrt(d_k)
    
    axes[0].hist(dots, bins=50, alpha=0.5, label=f'd_k={d_k}', density=True)
    axes[1].hist(dots_scaled, bins=50, alpha=0.5, label=f'd_k={d_k}', density=True)

axes[0].set_title('Without Scaling: Variance grows with d_k', fontsize=11)
axes[0].set_xlabel('Dot Product Value')
axes[0].legend()

axes[1].set_title('With Scaling: Variance is stable', fontsize=11)
axes[1].set_xlabel('Scaled Dot Product Value')
axes[1].legend()

plt.tight_layout()
plt.show()

print("Without scaling, large d_k -> large dot products -> extreme softmax!")

## Part 4: Multi-Head Attention

One attention head learns one type of relationship. But language has MANY types!

Solution: Run multiple attention heads in parallel, each focusing on different patterns.

In [None]:
class MultiHeadAttention:
    def __init__(self, d_model, n_heads):
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Projection matrices
        scale = np.sqrt(2.0 / d_model)
        self.W_Q = np.random.randn(d_model, d_model) * scale
        self.W_K = np.random.randn(d_model, d_model) * scale
        self.W_V = np.random.randn(d_model, d_model) * scale
        self.W_O = np.random.randn(d_model, d_model) * scale
    
    def forward(self, query, key, value, mask=None):
        batch = query.shape[0]
        seq_q, seq_k = query.shape[1], key.shape[1]
        
        # Project
        Q = query @ self.W_Q
        K = key @ self.W_K
        V = value @ self.W_V
        
        # Reshape to (batch, n_heads, seq, d_k)
        Q = Q.reshape(batch, seq_q, self.n_heads, self.d_k).transpose(0, 2, 1, 3)
        K = K.reshape(batch, seq_k, self.n_heads, self.d_k).transpose(0, 2, 1, 3)
        V = V.reshape(batch, seq_k, self.n_heads, self.d_k).transpose(0, 2, 1, 3)
        
        # Attention per head
        scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(self.d_k)
        if mask is not None:
            scores = np.where(mask, -1e9, scores)
        self.weights = softmax(scores, axis=-1)
        attn = self.weights @ V
        
        # Concatenate heads
        attn = attn.transpose(0, 2, 1, 3).reshape(batch, seq_q, self.d_model)
        
        return attn @ self.W_O

# Test
mha = MultiHeadAttention(d_model=64, n_heads=8)
x = np.random.randn(1, 5, 64)  # 5 tokens, 64 dims
out = mha.forward(x, x, x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
print(f"Attention weights shape: {mha.weights.shape} (batch, heads, seq, seq)")

In [None]:
# Visualize different heads

fig, axes = plt.subplots(2, 4, figsize=(14, 6))

for i, ax in enumerate(axes.flat):
    ax.imshow(mha.weights[0, i], cmap='Blues', vmin=0, vmax=1)
    ax.set_title(f'Head {i}')
    ax.set_xlabel('Key')
    ax.set_ylabel('Query')

plt.suptitle('Multi-Head Attention: Each Head Learns Different Patterns', fontsize=14)
plt.tight_layout()
plt.show()

## Part 5: Positional Encoding

Without recurrence, the model has no sense of position!

"cat sat mat" = "mat cat sat" (same to pure attention)

Solution: Add sinusoidal position information.

In [None]:
def create_positional_encoding(max_len, d_model):
    """Create sinusoidal positional encoding."""
    pe = np.zeros((max_len, d_model))
    position = np.arange(max_len)[:, np.newaxis]
    div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
    
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    
    return pe

# Create and visualize
pe = create_positional_encoding(100, 64)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Heatmap
im = axes[0].imshow(pe.T, cmap='RdBu', aspect='auto', vmin=-1, vmax=1)
axes[0].set_xlabel('Position')
axes[0].set_ylabel('Dimension')
axes[0].set_title('Positional Encoding Heatmap')
plt.colorbar(im, ax=axes[0])

# Curves
for dim in range(0, 8, 2):
    axes[1].plot(pe[:50, dim], label=f'dim {dim}')
axes[1].set_xlabel('Position')
axes[1].set_ylabel('Encoding Value')
axes[1].set_title('Sinusoidal Curves at Different Dimensions')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Low dimensions = slow waves (global position)")
print("High dimensions = fast waves (local details)")

In [None]:
# Position similarity: nearby positions should be similar

ref_positions = [0, 10, 25, 50]

plt.figure(figsize=(10, 5))

for ref in ref_positions:
    similarities = [np.dot(pe[ref], pe[i]) for i in range(100)]
    plt.plot(similarities, label=f'Position {ref}')

plt.xlabel('Position')
plt.ylabel('Similarity (dot product)')
plt.title('Position Similarity: Nearby Positions Are More Similar')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## Part 6: Putting It Together - Complete Transformer

Let's see the full architecture in action!

In [None]:
# Import our implementation
import sys
sys.path.insert(0, '.')

from implementation import Transformer, create_causal_mask

# Create a small Transformer
transformer = Transformer(
    src_vocab_size=50,
    tgt_vocab_size=50,
    d_model=64,
    n_heads=4,
    n_encoder_layers=2,
    n_decoder_layers=2,
    d_ff=256,
    dropout_p=0.0  # No dropout for demo
)
transformer.eval()

print("Transformer created!")
print(f"  d_model: 64")
print(f"  n_heads: 4")
print(f"  n_layers: 2 encoder, 2 decoder")

In [None]:
# Forward pass demo

# Source: "tokens" 1-5
src = np.array([[1, 2, 3, 4, 5]])

# Target: start token + first 3 tokens
tgt = np.array([[0, 1, 2, 3]])

# Create causal mask
tgt_mask = create_causal_mask(tgt.shape[1])

# Forward pass
logits = transformer.forward(src, tgt, tgt_mask=tgt_mask)

print(f"Source tokens: {src[0]}")
print(f"Target tokens: {tgt[0]}")
print(f"Output logits shape: {logits.shape}")

# Predictions
predictions = logits.argmax(axis=-1)
print(f"Predicted next tokens: {predictions[0]}")

In [None]:
# Greedy decoding

def greedy_decode(model, src, max_len=10, start_token=0):
    """Generate output sequence greedily."""
    model.eval()
    
    # Start with start token
    tgt = np.array([[start_token]])
    
    for _ in range(max_len - 1):
        tgt_mask = create_causal_mask(tgt.shape[1])
        logits = model.forward(src, tgt, tgt_mask=tgt_mask)
        
        # Get next token
        next_token = logits[:, -1, :].argmax(axis=-1, keepdims=True)
        tgt = np.concatenate([tgt, next_token], axis=1)
    
    return tgt

# Generate
src = np.array([[1, 2, 3, 4, 5]])
output = greedy_decode(transformer, src, max_len=5)

print(f"Input: {src[0]}")
print(f"Generated: {output[0]}")
print("\n(Random weights, so output is random - but the architecture works!)")

## Part 7: The Causal Mask

The decoder can't look at future tokens during training!

In [None]:
# Visualize causal mask

seq_len = 6
mask = create_causal_mask(seq_len)

plt.figure(figsize=(6, 5))
plt.imshow(~mask[0, 0], cmap='Greens')  # Show where attention IS allowed

for i in range(seq_len):
    for j in range(seq_len):
        text = 'OK' if not mask[0, 0, i, j] else 'X'
        color = 'white' if not mask[0, 0, i, j] else 'red'
        plt.text(j, i, text, ha='center', va='center', fontsize=10, color=color)

plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Causal Mask: Decoder Can Only See Past')
plt.xticks(range(seq_len))
plt.yticks(range(seq_len))
plt.tight_layout()
plt.show()

print("Each position can only attend to itself and previous positions!")

## Summary

### Key Components

1. **Scaled Dot-Product Attention**: The core mechanism - Q, K, V with sqrt(d_k) scaling

2. **Multi-Head Attention**: Parallel attention heads for different patterns

3. **Positional Encoding**: Sinusoidal position information

4. **Encoder Block**: Self-attention + FFN + residuals + layer norm

5. **Decoder Block**: Masked self-attention + cross-attention + FFN

### Why Transformers Won

- Parallelizable (fast training on GPUs)
- Direct long-range connections
- Flexible attention patterns
- Scale well with data and compute

This architecture powers: BERT, GPT, T5, ViT, DALL-E, Stable Diffusion, and more!

In [None]:
print("Day 13 Complete!")
print("\nYou now understand:")
print("  - How attention replaces recurrence")
print("  - Why scaling by sqrt(d_k) is crucial")
print("  - How multi-head attention captures multiple patterns")
print("  - Why positional encoding is needed")
print("  - The complete Transformer architecture")
print("\nNext steps:")
print("  - Try the exercises in exercises/")
print("  - Train on a real task with train_minimal.py")
print("  - Explore pre-trained models (BERT, GPT, etc.)")