# Week 2 Day 10: Minimal Transformer Implementation - Part 1

## Overview
This notebook implements the core components of a minimal transformer from scratch, focusing on:
- Minimal attention mechanism
- Feed-forward networks
- Layer normalization and residual connections
- Complete transformer block

## Learning Objectives
By the end of this notebook, you will:
1. Understand how to implement attention from scratch
2. Build feed-forward networks for transformers
3. Combine components into a complete transformer block
4. Visualize attention patterns and understand their meaning

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math
from typing import Optional

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Minimal Attention Implementation

Let's start by implementing the core attention mechanism from scratch.

In [None]:
class MinimalAttention(nn.Module):
    """Minimal implementation of scaled dot-product attention."""
    
    def __init__(self, d_model: int, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
        batch_size, seq_len, d_model = x.shape
        
        # Generate Q, K, V
        Q = self.w_q(x)  # (batch_size, seq_len, d_model)
        K = self.w_k(x)
        V = self.w_v(x)
        
        # Reshape for multi-head attention
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Scaled dot-product attention
        attention_output, attention_weights = self.scaled_dot_product_attention(
            Q, K, V, mask
        )
        
        # Concatenate heads
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, d_model
        )
        
        # Final linear projection
        output = self.w_o(attention_output)
        
        return output, attention_weights
    
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Apply softmax
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention to values
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights

## 2. Feed-Forward Network

The feed-forward network applies position-wise transformations.

In [None]:
class FeedForward(nn.Module):
    """Position-wise feed-forward network."""
    
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Apply first linear transformation with GELU activation
        x = self.linear1(x)
        x = F.gelu(x)
        x = self.dropout(x)
        
        # Apply second linear transformation
        x = self.linear2(x)
        return x

## 3. Complete Transformer Block

Now let's combine attention and feed-forward into a complete transformer block.

In [None]:
class TransformerBlock(nn.Module):
    """Complete transformer block with attention and feed-forward."""
    
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.attention = MinimalAttention(d_model, num_heads, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
        # Self-attention with residual connection and layer norm (pre-norm)
        norm_x = self.norm1(x)
        attn_output, attn_weights = self.attention(norm_x, mask)
        x = x + self.dropout(attn_output)
        
        # Feed-forward with residual connection and layer norm (pre-norm)
        norm_x = self.norm2(x)
        ff_output = self.feed_forward(norm_x)
        x = x + self.dropout(ff_output)
        
        return x, attn_weights

## 4. Testing the Components

Let's test our transformer components with sample data.

In [None]:
# Test parameters
batch_size = 2
seq_len = 10
d_model = 512
num_heads = 8
d_ff = 2048

# Create sample input
x = torch.randn(batch_size, seq_len, d_model)
print(f"Input shape: {x.shape}")

# Test attention
attention = MinimalAttention(d_model, num_heads)
attn_output, attn_weights = attention(x)
print(f"Attention output shape: {attn_output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")

# Test feed-forward
ff = FeedForward(d_model, d_ff)
ff_output = ff(x)
print(f"Feed-forward output shape: {ff_output.shape}")

# Test complete transformer block
transformer_block = TransformerBlock(d_model, num_heads, d_ff)
block_output, block_attn_weights = transformer_block(x)
print(f"Transformer block output shape: {block_output.shape}")

## 5. Attention Visualization

Let's visualize the attention patterns to understand what the model is learning.

In [None]:
def visualize_attention(attention_weights, head_idx=0, batch_idx=0):
    """Visualize attention weights for a specific head and batch."""
    # Extract attention weights for specific head and batch
    attn = attention_weights[batch_idx, head_idx].detach().numpy()
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(attn, annot=True, fmt='.3f', cmap='Blues', 
                xticklabels=range(attn.shape[1]),
                yticklabels=range(attn.shape[0]))
    plt.title(f'Attention Weights - Head {head_idx}, Batch {batch_idx}')
    plt.xlabel('Key Position')
    plt.ylabel('Query Position')
    plt.tight_layout()
    plt.show()

# Visualize attention patterns
visualize_attention(block_attn_weights, head_idx=0)
visualize_attention(block_attn_weights, head_idx=1)

## 6. Gradient Flow Analysis

Let's analyze how gradients flow through our transformer block.

In [None]:
def analyze_gradient_flow(model, x):
    """Analyze gradient flow through the model."""
    # Forward pass
    output, _ = model(x)
    
    # Create a simple loss (sum of all outputs)
    loss = output.sum()
    
    # Backward pass
    loss.backward()
    
    # Collect gradients
    gradients = {}
    for name, param in model.named_parameters():
        if param.grad is not None:
            gradients[name] = param.grad.norm().item()
    
    return gradients

# Analyze gradient flow
x_grad = torch.randn(1, 5, d_model, requires_grad=True)
transformer_grad = TransformerBlock(d_model, num_heads, d_ff)

gradients = analyze_gradient_flow(transformer_grad, x_grad)

# Plot gradient norms
plt.figure(figsize=(12, 6))
names = list(gradients.keys())
values = list(gradients.values())

plt.bar(range(len(names)), values)
plt.xticks(range(len(names)), names, rotation=45, ha='right')
plt.ylabel('Gradient Norm')
plt.title('Gradient Flow Through Transformer Block')
plt.tight_layout()
plt.show()

print("Gradient norms:")
for name, norm in gradients.items():
    print(f"{name}: {norm:.6f}")

## 7. Parameter Initialization Analysis

Let's examine the importance of proper parameter initialization.

In [None]:
def initialize_weights(model, init_type='xavier'):
    """Initialize model weights with different strategies."""
    for name, param in model.named_parameters():
        if 'weight' in name and param.dim() > 1:
            if init_type == 'xavier':
                nn.init.xavier_uniform_(param)
            elif init_type == 'kaiming':
                nn.init.kaiming_uniform_(param)
            elif init_type == 'normal':
                nn.init.normal_(param, 0, 0.02)
        elif 'bias' in name:
            nn.init.zeros_(param)

# Test different initialization strategies
init_strategies = ['xavier', 'kaiming', 'normal']
results = {}

for init_type in init_strategies:
    model = TransformerBlock(d_model, num_heads, d_ff)
    initialize_weights(model, init_type)
    
    # Forward pass
    x_test = torch.randn(1, 10, d_model)
    output, _ = model(x_test)
    
    # Calculate statistics
    results[init_type] = {
        'output_mean': output.mean().item(),
        'output_std': output.std().item(),
        'output_range': (output.min().item(), output.max().item())
    }

# Display results
print("Initialization Strategy Comparison:")
print("-" * 50)
for init_type, stats in results.items():
    print(f"{init_type.upper()}:")
    print(f"  Mean: {stats['output_mean']:.6f}")
    print(f"  Std:  {stats['output_std']:.6f}")
    print(f"  Range: ({stats['output_range'][0]:.3f}, {stats['output_range'][1]:.3f})")
    print()

## 8. Summary and Key Insights

In this notebook, we've implemented and analyzed the core components of a transformer:

### Key Components:
1. **Minimal Attention**: Scaled dot-product attention with multi-head support
2. **Feed-Forward Network**: Position-wise transformations with GELU activation
3. **Transformer Block**: Complete block with residual connections and layer normalization

### Key Insights:
1. **Attention Patterns**: Attention weights show how tokens relate to each other
2. **Gradient Flow**: Residual connections help maintain healthy gradient flow
3. **Initialization**: Proper weight initialization is crucial for training stability

### Next Steps:
In Part 2, we'll:
- Build a complete minimal transformer model
- Implement training and evaluation loops
- Test on simple language modeling tasks
- Analyze performance and scaling properties