# Chapter 3: Memory Optimization Techniques

## 🎯 Learning Objectives

By the end of this chapter, you will:
- **Master gradient checkpointing** theory and implementation
- **Understand memory-compute tradeoffs** in transformer training
- **Implement activation recomputation** strategies
- **Optimize memory allocation patterns** for maximum efficiency
- **Build dynamic memory management** systems for variable-length sequences

---

## 🧠 The Memory Crisis in LLM Training

### **Why Memory Optimization Matters**

Modern LLMs face a **memory wall** that fundamentally limits training:

#### **Memory Requirements Scale Exponentially**
- **Model Parameters**: 7B model = ~28GB (FP32) or ~14GB (FP16)
- **Optimizer States**: Adam requires 2x parameter memory = ~28GB additional
- **Gradients**: Same size as parameters = ~14GB additional  
- **Activations**: Grows with sequence length and batch size = 10-100GB+
- **Total**: 70-150GB for 7B model training!

#### **Hardware Constraints**
- **Consumer GPUs**: RTX 4090 = 24GB VRAM
- **Professional GPUs**: A100 = 80GB VRAM
- **High-end**: H100 = 80GB VRAM

**Result**: Even small LLMs don't fit on single GPUs without optimization!

---

## 📊 Memory Breakdown in Transformer Training

### **Memory Components (7B Parameter Model)**

```
┌─────────────────────────────────────────────────┐
│                Model Memory                     │
├─────────────────┬───────────────────────────────┤
│ Parameters      │ 14GB (FP16)                  │
│ Gradients       │ 14GB (FP16)                  │
│ Optimizer States│ 28GB (Adam: momentum + var)  │
│ Activations     │ 10-100GB (depends on batch)  │
├─────────────────┼───────────────────────────────┤
│ TOTAL          │ 66-156GB                     │
└─────────────────┴───────────────────────────────┘
```

### **Activation Memory Deep Dive**

**Activations are the memory killer** in transformer training:

#### **Per-Layer Activations (Typical)**
- **Input embeddings**: `[batch, seq_len, hidden_dim]`
- **Attention QKV**: `3 × [batch, seq_len, hidden_dim]`
- **Attention scores**: `[batch, num_heads, seq_len, seq_len]` ← **Quadratic!**
- **Attention output**: `[batch, seq_len, hidden_dim]`
- **FFN intermediate**: `[batch, seq_len, 4 × hidden_dim]` ← **4x expansion!**
- **Layer output**: `[batch, seq_len, hidden_dim]`

#### **Memory Scaling**
```python
# Attention memory scales quadratically with sequence length
attention_memory = batch_size * num_heads * seq_len² * bytes_per_element

# Example: GPT-3 scale
# batch=8, heads=96, seq_len=2048, FP16=2 bytes
# = 8 × 96 × 2048² × 2 = 6.4 GB per layer!
```

---

## 🔄 Gradient Checkpointing: Theory and Practice

### **The Core Insight**

**Trade computation for memory** by recomputing activations during backward pass:

#### **Standard Training (Memory Expensive)**
```
Forward Pass:  Store all activations
┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐
│ L1  │ │ L2  │ │ L3  │ │ L4  │
└──┬──┘ └──┬──┘ └──┬──┘ └──┬──┘
   │      │      │      │
   ▼      ▼      ▼      ▼
  💾     💾     💾     💾   ← All stored

Backward Pass: Use stored activations
   ▲      ▲      ▲      ▲
   │      │      │      │
   💾     💾     💾     💾   ← Retrieved from memory
```

#### **Gradient Checkpointing (Compute Intensive)**
```
Forward Pass:  Store only checkpoints
┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐
│ L1  │ │ L2  │ │ L3  │ │ L4  │
└──┬──┘ └──┬──┘ └──┬──┘ └──┬──┘
   │              │      
   ▼              ▼      
  💾              💾       ← Only some stored

Backward Pass: Recompute missing activations
   ▲      🔄      ▲      🔄
   │    (recomp)  │    (recomp)
   💾             💾          ← Compute on demand
```

### **Mathematical Framework**

#### **Memory-Compute Tradeoff**
- **Memory Reduction**: `O(√n)` instead of `O(n)` for n layers
- **Compute Overhead**: ~33% additional FLOPs
- **Net Benefit**: Enable much larger models/batches

#### **Optimal Checkpointing Strategy**
For `n` layers, optimal checkpointing uses `√n` checkpoints:
```
checkpoint_interval = sqrt(num_layers)
memory_saved = num_layers / sqrt(num_layers) = sqrt(num_layers)
```

Let's implement a comprehensive gradient checkpointing system:

In [None]:
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Callable, Optional, Dict, Any
from dataclasses import dataclass
import time
import gc
from contextlib import contextmanager
import warnings
warnings.filterwarnings('ignore')

@dataclass
class MemorySnapshot:
    """Capture memory usage at a specific point"""
    timestamp: float
    allocated_mb: float
    reserved_mb: float
    max_allocated_mb: float
    description: str = ""

class MemoryProfiler:
    """
    Professional memory profiling for gradient checkpointing analysis
    
    Educational Focus:
    This class demonstrates how to systematically measure memory usage
    and analyze the effectiveness of memory optimization techniques.
    """
    
    def __init__(self, enabled: bool = True):
        self.enabled = enabled and torch.cuda.is_available()
        self.snapshots: List[MemorySnapshot] = []
        
        if self.enabled:
            # Reset memory stats
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
    
    def snapshot(self, description: str = "") -> MemorySnapshot:
        """Take a memory snapshot"""
        
        if not self.enabled:
            return MemorySnapshot(time.time(), 0, 0, 0, description)
        
        snapshot = MemorySnapshot(
            timestamp=time.time(),
            allocated_mb=torch.cuda.memory_allocated() / (1024**2),
            reserved_mb=torch.cuda.memory_reserved() / (1024**2),
            max_allocated_mb=torch.cuda.max_memory_allocated() / (1024**2),
            description=description
        )
        
        self.snapshots.append(snapshot)
        return snapshot
    
    @contextmanager
    def profile_block(self, description: str):
        """Context manager for profiling a code block"""
        self.snapshot(f"{description} - start")
        try:
            yield
        finally:
            self.snapshot(f"{description} - end")
    
    def get_peak_memory_mb(self) -> float:
        """Get peak memory usage across all snapshots"""
        if not self.snapshots:
            return 0.0
        return max(s.max_allocated_mb for s in self.snapshots)
    
    def print_summary(self):
        """Print memory usage summary"""
        if not self.enabled:
            print("❌ Memory profiling not available (no CUDA GPU)")
            return
        
        print("\n📊 Memory Usage Summary:")
        print("-" * 60)
        
        for i, snapshot in enumerate(self.snapshots):
            print(f"{i+1:2d}. {snapshot.description:30s} "
                  f"Allocated: {snapshot.allocated_mb:6.1f} MB, "
                  f"Peak: {snapshot.max_allocated_mb:6.1f} MB")
        
        if self.snapshots:
            peak = self.get_peak_memory_mb()
            print(f"\n🏔️  Overall Peak Memory: {peak:.1f} MB")

class SimpleTransformerLayer(nn.Module):
    """
    Simplified transformer layer for memory optimization demonstrations
    
    Educational Purpose:
    This implementation focuses on memory patterns rather than
    performance, making it ideal for learning optimization techniques.
    """
    
    def __init__(self, hidden_dim: int, num_heads: int, ff_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        
        # Multi-head attention
        self.qkv_proj = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        
        # Feed-forward network
        self.ff1 = nn.Linear(hidden_dim, ff_dim, bias=False)
        self.ff2 = nn.Linear(ff_dim, hidden_dim, bias=False)
        
        # Layer normalization
        self.ln1 = nn.LayerNorm(hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim)
        
        self.activation = nn.GELU()
    
    def forward(self, x: torch.Tensor, use_checkpointing: bool = False) -> torch.Tensor:
        """Forward pass with optional gradient checkpointing"""
        
        if use_checkpointing:
            # Use gradient checkpointing for sub-components
            attn_out = checkpoint.checkpoint(self._attention_block, x)
            output = checkpoint.checkpoint(self._ffn_block, attn_out)
        else:
            # Standard forward pass
            attn_out = self._attention_block(x)
            output = self._ffn_block(attn_out)
        
        return output
    
    def _attention_block(self, x: torch.Tensor) -> torch.Tensor:
        """Multi-head attention block"""
        batch_size, seq_len, _ = x.shape
        
        # Layer norm + QKV projection
        normed = self.ln1(x)
        qkv = self.qkv_proj(normed)
        
        # Reshape for multi-head attention
        q, k, v = qkv.chunk(3, dim=-1)
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Attention computation (memory intensive!)
        attention_scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(self.head_dim)
        attention_probs = torch.softmax(attention_scores, dim=-1)
        attention_output = torch.matmul(attention_probs, v)
        
        # Reshape and project
        attention_output = attention_output.transpose(1, 2).contiguous()
        attention_output = attention_output.view(batch_size, seq_len, self.hidden_dim)
        output = self.out_proj(attention_output)
        
        # Residual connection
        return x + output
    
    def _ffn_block(self, x: torch.Tensor) -> torch.Tensor:
        """Feed-forward network block"""
        # Layer norm + FFN
        normed = self.ln2(x)
        ff_intermediate = self.activation(self.ff1(normed))  # 4x expansion!
        ff_output = self.ff2(ff_intermediate)
        
        # Residual connection
        return x + ff_output

class SimpleTransformer(nn.Module):
    """
    Simple transformer model for memory optimization experiments
    """
    
    def __init__(self, vocab_size: int, hidden_dim: int, num_layers: int, 
                 num_heads: int, max_seq_len: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Embedding layers
        self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
        self.position_embedding = nn.Embedding(max_seq_len, hidden_dim)
        
        # Transformer layers
        ff_dim = 4 * hidden_dim  # Standard 4x expansion
        self.layers = nn.ModuleList([
            SimpleTransformerLayer(hidden_dim, num_heads, ff_dim)
            for _ in range(num_layers)
        ])
        
        # Output projection
        self.ln_final = nn.LayerNorm(hidden_dim)
        self.output_proj = nn.Linear(hidden_dim, vocab_size, bias=False)
    
    def forward(self, input_ids: torch.Tensor, use_checkpointing: bool = False) -> torch.Tensor:
        """Forward pass with optional gradient checkpointing"""
        batch_size, seq_len = input_ids.shape
        device = input_ids.device
        
        # Embeddings
        positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
        x = self.token_embedding(input_ids) + self.position_embedding(positions)
        
        # Transformer layers
        for layer in self.layers:
            x = layer(x, use_checkpointing=use_checkpointing)
        
        # Final layer norm and output projection
        x = self.ln_final(x)
        logits = self.output_proj(x)
        
        return logits

print("✅ Memory Optimization Framework Initialized!")
print("🧠 Ready to explore gradient checkpointing and memory optimization")

## 🧪 Gradient Checkpointing Experiment

Let's conduct a comprehensive experiment comparing standard training vs gradient checkpointing:

In [None]:
def memory_comparison_experiment():
    """
    Comprehensive experiment comparing memory usage patterns
    between standard training and gradient checkpointing.
    
    Educational Focus:
    This experiment quantifies the memory-compute tradeoff
    and demonstrates practical optimization techniques.
    """
    
    print("🧪 Starting Memory Optimization Experiment")
    print("=" * 60)
    
    # Experimental configuration
    configs = [
        {
            "name": "Small Model",
            "vocab_size": 10000,
            "hidden_dim": 512,
            "num_layers": 6,
            "num_heads": 8,
            "seq_len": 256,
            "batch_size": 4
        },
        {
            "name": "Medium Model", 
            "vocab_size": 20000,
            "hidden_dim": 1024,
            "num_layers": 12,
            "num_heads": 16,
            "seq_len": 512,
            "batch_size": 2
        }
    ]
    
    results = {}
    
    for config in configs:
        print(f"\n🔬 Testing Configuration: {config['name']}")
        print(f"   Hidden Dim: {config['hidden_dim']}, Layers: {config['num_layers']}")
        print(f"   Sequence Length: {config['seq_len']}, Batch Size: {config['batch_size']}")
        
        config_results = {}
        
        # Test both standard and checkpointed training
        for use_checkpointing in [False, True]:
            mode = "Gradient Checkpointing" if use_checkpointing else "Standard Training"
            print(f"\n  📊 Testing: {mode}")
            
            try:
                # Initialize memory profiler
                profiler = MemoryProfiler()
                profiler.snapshot("Initial state")
                
                # Create model
                with profiler.profile_block("Model creation"):
                    model = SimpleTransformer(
                        vocab_size=config['vocab_size'],
                        hidden_dim=config['hidden_dim'],
                        num_layers=config['num_layers'],
                        num_heads=config['num_heads'],
                        max_seq_len=config['seq_len']
                    )
                    
                    if torch.cuda.is_available():
                        model = model.cuda()
                    
                    # Create optimizer
                    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
                
                # Create sample data
                device = next(model.parameters()).device
                input_ids = torch.randint(
                    0, config['vocab_size'], 
                    (config['batch_size'], config['seq_len']),
                    device=device
                )
                labels = torch.randint(
                    0, config['vocab_size'],
                    (config['batch_size'], config['seq_len']),
                    device=device
                )
                
                profiler.snapshot("Data prepared")
                
                # Measure forward pass
                with profiler.profile_block("Forward pass"):
                    start_time = time.time()
                    logits = model(input_ids, use_checkpointing=use_checkpointing)
                    forward_time = time.time() - start_time
                    
                    # Compute loss
                    loss = nn.functional.cross_entropy(
                        logits.view(-1, config['vocab_size']),
                        labels.view(-1)
                    )
                
                # Measure backward pass
                with profiler.profile_block("Backward pass"):
                    start_time = time.time()
                    loss.backward()
                    backward_time = time.time() - start_time
                
                profiler.snapshot("Training complete")
                
                # Store results
                peak_memory = profiler.get_peak_memory_mb()
                config_results[mode] = {
                    "peak_memory_mb": peak_memory,
                    "forward_time_ms": forward_time * 1000,
                    "backward_time_ms": backward_time * 1000,
                    "total_time_ms": (forward_time + backward_time) * 1000,
                    "loss_value": loss.item(),
                    "profiler": profiler
                }
                
                print(f"     ✅ Peak Memory: {peak_memory:.1f} MB")
                print(f"     ⏱️  Forward: {forward_time*1000:.1f} ms, Backward: {backward_time*1000:.1f} ms")
                print(f"     📉 Loss: {loss.item():.4f}")
                
                # Clean up
                del model, optimizer, logits, loss
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                gc.collect()
                
            except Exception as e:
                print(f"     ❌ Failed: {e}")
                config_results[mode] = {"error": str(e)}
                
                # Emergency cleanup
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                gc.collect()
        
        results[config['name']] = config_results
    
    return results

def analyze_memory_results(results: Dict[str, Dict[str, Any]]):
    """
    Analyze and visualize memory optimization results
    
    Educational Focus:
    This analysis demonstrates how to quantify optimization benefits
    and make data-driven decisions about memory techniques.
    """
    
    print("\n📊 Memory Optimization Analysis")
    print("=" * 60)
    
    # Prepare data for visualization
    config_names = []
    standard_memory = []
    checkpointed_memory = []
    memory_savings = []
    
    standard_time = []
    checkpointed_time = []
    time_overhead = []
    
    for config_name, config_results in results.items():
        if ("Standard Training" in config_results and 
            "Gradient Checkpointing" in config_results and
            "error" not in config_results["Standard Training"] and
            "error" not in config_results["Gradient Checkpointing"]):
            
            std_result = config_results["Standard Training"]
            chk_result = config_results["Gradient Checkpointing"]
            
            config_names.append(config_name)
            
            # Memory analysis
            std_mem = std_result["peak_memory_mb"]
            chk_mem = chk_result["peak_memory_mb"]
            
            standard_memory.append(std_mem)
            checkpointed_memory.append(chk_mem)
            
            savings = ((std_mem - chk_mem) / std_mem) * 100 if std_mem > 0 else 0
            memory_savings.append(savings)
            
            # Time analysis
            std_time = std_result["total_time_ms"]
            chk_time = chk_result["total_time_ms"]
            
            standard_time.append(std_time)
            checkpointed_time.append(chk_time)
            
            overhead = ((chk_time - std_time) / std_time) * 100 if std_time > 0 else 0
            time_overhead.append(overhead)
            
            print(f"\n🔍 {config_name}:")
            print(f"   Memory Reduction: {savings:.1f}% ({std_mem:.1f} → {chk_mem:.1f} MB)")
            print(f"   Time Overhead: {overhead:.1f}% ({std_time:.1f} → {chk_time:.1f} ms)")
    
    # Create visualization
    if config_names:
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
        fig.suptitle('🧠 Memory Optimization Analysis: Standard vs Gradient Checkpointing', 
                     fontsize=16, fontweight='bold')
        
        # 1. Memory Usage Comparison
        x = np.arange(len(config_names))
        width = 0.35
        
        bars1 = ax1.bar(x - width/2, standard_memory, width, label='Standard Training', 
                       color='#FF6B6B', alpha=0.8)
        bars2 = ax1.bar(x + width/2, checkpointed_memory, width, label='Gradient Checkpointing',
                       color='#4ECDC4', alpha=0.8)
        
        ax1.set_xlabel('Model Configuration')
        ax1.set_ylabel('Peak Memory (MB)')
        ax1.set_title('💾 Memory Usage Comparison')
        ax1.set_xticks(x)
        ax1.set_xticklabels(config_names)
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Add value labels
        for bars in [bars1, bars2]:
            for bar in bars:
                height = bar.get_height()
                ax1.text(bar.get_x() + bar.get_width()/2., height + max(standard_memory + checkpointed_memory)*0.01,
                        f'{height:.0f}', ha='center', va='bottom', fontsize=10)
        
        # 2. Memory Savings
        bars3 = ax2.bar(config_names, memory_savings, color='#95E1D3', alpha=0.8, edgecolor='black')
        ax2.set_ylabel('Memory Savings (%)')
        ax2.set_title('📈 Memory Reduction Achieved')
        ax2.grid(True, alpha=0.3)
        
        for bar, savings in zip(bars3, memory_savings):
            height = bar.get_height()
            ax2.text(bar.get_x() + bar.get_width()/2., height + max(memory_savings)*0.01,
                    f'{savings:.1f}%', ha='center', va='bottom', fontsize=12, fontweight='bold')
        
        # 3. Time Comparison
        bars4 = ax3.bar(x - width/2, standard_time, width, label='Standard Training',
                       color='#FF6B6B', alpha=0.8)
        bars5 = ax3.bar(x + width/2, checkpointed_time, width, label='Gradient Checkpointing',
                       color='#4ECDC4', alpha=0.8)
        
        ax3.set_xlabel('Model Configuration')
        ax3.set_ylabel('Total Time (ms)')
        ax3.set_title('⏱️ Training Time Comparison')
        ax3.set_xticks(x)
        ax3.set_xticklabels(config_names)
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        
        # 4. Time Overhead
        bars6 = ax4.bar(config_names, time_overhead, color='#F8C471', alpha=0.8, edgecolor='black')
        ax4.set_ylabel('Time Overhead (%)')
        ax4.set_title('⚡ Computational Overhead')
        ax4.grid(True, alpha=0.3)
        
        for bar, overhead in zip(bars6, time_overhead):
            height = bar.get_height()
            ax4.text(bar.get_x() + bar.get_width()/2., height + max(time_overhead)*0.01,
                    f'{overhead:.1f}%', ha='center', va='bottom', fontsize=12, fontweight='bold')
        
        plt.tight_layout()
        plt.show()
        
        # Summary statistics
        if memory_savings and time_overhead:
            avg_memory_savings = np.mean(memory_savings)
            avg_time_overhead = np.mean(time_overhead)
            
            print(f"\n🎯 Summary Statistics:")
            print(f"   Average Memory Savings: {avg_memory_savings:.1f}%")
            print(f"   Average Time Overhead: {avg_time_overhead:.1f}%")
            
            # Efficiency analysis
            efficiency_ratio = avg_memory_savings / avg_time_overhead if avg_time_overhead > 0 else float('inf')
            print(f"   Efficiency Ratio: {efficiency_ratio:.2f} (memory saved per % time cost)")
            
            if avg_memory_savings > 20 and avg_time_overhead < 50:
                print("   ✅ Gradient checkpointing is highly beneficial!")
            elif avg_memory_savings > 10:
                print("   🟡 Gradient checkpointing provides moderate benefits")
            else:
                print("   🔴 Gradient checkpointing benefits are limited for this configuration")
    
    else:
        print("❌ No successful results to analyze")

# Run the comprehensive memory experiment
print("🚀 Starting Comprehensive Memory Optimization Experiment")
print("This will compare standard training vs gradient checkpointing...")

memory_results = memory_comparison_experiment()
analyze_memory_results(memory_results)

## 🔄 Advanced Memory Optimization Techniques

### **Selective Activation Checkpointing**

Not all layers benefit equally from checkpointing. Let's implement intelligent checkpointing:

In [None]:
class AdaptiveCheckpointingStrategy:
    """
    Intelligent checkpointing that selects optimal layers for recomputation
    
    Educational Focus:
    This demonstrates advanced optimization where we analyze each layer's
    memory vs compute characteristics to make optimal decisions.
    """
    
    def __init__(self, model: nn.Module, memory_budget_mb: float = 8000):
        self.model = model
        self.memory_budget_mb = memory_budget_mb
        self.layer_profiles = {}
        self.optimal_strategy = None
    
    def profile_layers(self, sample_input: torch.Tensor) -> Dict[str, Dict[str, float]]:
        """
        Profile each layer to understand memory vs compute characteristics
        
        Returns:
            Dictionary mapping layer names to their resource profiles
        """
        
        print("🔍 Profiling individual layers for optimization strategy...")
        
        self.layer_profiles = {}
        
        # Hook to capture activations and measure memory
        activation_sizes = {}
        computation_times = {}
        
        def memory_hook(name):
            def hook_fn(module, input, output):
                if torch.cuda.is_available():
                    # Estimate activation memory size
                    if isinstance(output, torch.Tensor):
                        size_mb = output.numel() * output.element_size() / (1024**2)
                        activation_sizes[name] = size_mb
                    elif isinstance(output, (tuple, list)):
                        total_size = sum(o.numel() * o.element_size() for o in output if isinstance(o, torch.Tensor))
                        activation_sizes[name] = total_size / (1024**2)
            return hook_fn
        
        # Register hooks for all layers
        hooks = []
        layer_names = []
        
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Linear, nn.LayerNorm, SimpleTransformerLayer)):
                hook = module.register_forward_hook(memory_hook(name))
                hooks.append(hook)
                layer_names.append(name)
        
        try:
            # Profile forward pass
            self.model.eval()
            with torch.no_grad():
                start_time = time.time()
                output = self.model(sample_input)
                total_time = time.time() - start_time
            
            # Estimate computation time per layer (simplified)
            estimated_time_per_layer = total_time / len([n for n in layer_names if 'layers.' in n])
            
            # Build layer profiles
            for name in layer_names:
                if name in activation_sizes:
                    # Calculate memory-to-compute ratio (higher = better candidate for checkpointing)
                    memory_mb = activation_sizes[name]
                    compute_cost = estimated_time_per_layer  # Simplified estimate
                    
                    memory_to_compute_ratio = memory_mb / (compute_cost * 1000) if compute_cost > 0 else 0
                    
                    self.layer_profiles[name] = {
                        'activation_memory_mb': memory_mb,
                        'estimated_compute_ms': compute_cost * 1000,
                        'memory_compute_ratio': memory_to_compute_ratio,
                        'checkpoint_priority': memory_to_compute_ratio  # Higher = should checkpoint
                    }
        
        finally:
            # Clean up hooks
            for hook in hooks:
                hook.remove()
        
        return self.layer_profiles
    
    def compute_optimal_strategy(self) -> List[str]:
        """
        Compute optimal checkpointing strategy based on memory budget
        
        Returns:
            List of layer names that should use checkpointing
        """
        
        if not self.layer_profiles:
            print("⚠️ No layer profiles available. Run profile_layers() first.")
            return []
        
        # Sort layers by checkpoint priority (memory-to-compute ratio)
        sorted_layers = sorted(
            self.layer_profiles.items(),
            key=lambda x: x[1]['checkpoint_priority'],
            reverse=True
        )
        
        # Greedy selection based on memory budget
        total_memory_saved = 0
        selected_layers = []
        
        for layer_name, profile in sorted_layers:
            memory_saving = profile['activation_memory_mb']
            
            if total_memory_saved + memory_saving <= self.memory_budget_mb:
                selected_layers.append(layer_name)
                total_memory_saved += memory_saving
            else:
                break
        
        self.optimal_strategy = selected_layers
        
        print(f"\n🎯 Optimal Checkpointing Strategy:")
        print(f"   Layers to checkpoint: {len(selected_layers)}")
        print(f"   Estimated memory saved: {total_memory_saved:.1f} MB")
        
        # Show top candidates
        print(f"\n   Top checkpoint candidates:")
        for i, (layer_name, profile) in enumerate(sorted_layers[:5]):
            selected = "✅" if layer_name in selected_layers else "❌"
            print(f"   {i+1}. {selected} {layer_name}: {profile['activation_memory_mb']:.1f} MB, "
                  f"Ratio: {profile['memory_compute_ratio']:.2f}")
        
        return selected_layers
    
    def print_analysis(self):
        """Print detailed analysis of checkpointing opportunities"""
        
        if not self.layer_profiles:
            print("❌ No layer profiles available")
            return
        
        print("\n📊 Layer-by-Layer Analysis:")
        print("-" * 80)
        print(f"{'Layer Name':<30} {'Memory (MB)':<12} {'Compute (ms)':<12} {'Ratio':<8} {'Priority'}")
        print("-" * 80)
        
        # Sort by memory usage
        sorted_layers = sorted(
            self.layer_profiles.items(),
            key=lambda x: x[1]['activation_memory_mb'],
            reverse=True
        )
        
        for layer_name, profile in sorted_layers:
            priority = "High" if profile['checkpoint_priority'] > 1.0 else "Medium" if profile['checkpoint_priority'] > 0.5 else "Low"
            print(f"{layer_name:<30} {profile['activation_memory_mb']:>8.1f}     "
                  f"{profile['estimated_compute_ms']:>8.1f}     "
                  f"{profile['memory_compute_ratio']:>6.2f}   {priority}")

# Demonstrate adaptive checkpointing
def demonstrate_adaptive_checkpointing():
    """
    Demonstrate adaptive checkpointing strategy selection
    """
    
    print("🧠 Adaptive Checkpointing Strategy Demonstration")
    print("=" * 60)
    
    try:
        # Create a model for analysis
        model = SimpleTransformer(
            vocab_size=10000,
            hidden_dim=768,
            num_layers=8,
            num_heads=12,
            max_seq_len=512
        )
        
        if torch.cuda.is_available():
            model = model.cuda()
        
        # Create sample input
        device = next(model.parameters()).device
        sample_input = torch.randint(0, 10000, (2, 256), device=device)
        
        # Initialize adaptive strategy
        strategy = AdaptiveCheckpointingStrategy(model, memory_budget_mb=500)
        
        # Profile layers
        layer_profiles = strategy.profile_layers(sample_input)
        
        # Print analysis
        strategy.print_analysis()
        
        # Compute optimal strategy
        optimal_layers = strategy.compute_optimal_strategy()
        
        print(f"\n🎯 Adaptive checkpointing analysis complete!")
        print(f"   Analyzed {len(layer_profiles)} layers")
        print(f"   Recommended {len(optimal_layers)} layers for checkpointing")
        
        return strategy
        
    except Exception as e:
        print(f"❌ Adaptive checkpointing demonstration failed: {e}")
        return None

# Run adaptive checkpointing demonstration
adaptive_strategy = demonstrate_adaptive_checkpointing()

## 🎯 Key Takeaways from Memory Optimization

### **Memory is the Critical Constraint**
- **Activations dominate** memory usage in transformer training
- **Quadratic attention** creates severe memory pressure
- **Batch size** is often limited by memory, not compute

### **Gradient Checkpointing Trade-offs**
- **Memory reduction**: 20-60% typical savings
- **Compute overhead**: 20-40% additional time
- **Net benefit**: Enables larger models/batches
- **Sweet spot**: Models with 6+ layers

### **Optimization Strategy**
- **Profile first**: Understand your memory bottlenecks
- **Selective checkpointing**: Not all layers need it
- **Memory-compute analysis**: Optimize based on actual ratios
- **Adaptive strategies**: Adjust based on hardware constraints

### **Production Implications**
- **Training efficiency**: More data per GPU hour
- **Cost reduction**: Fewer GPUs needed for same model
- **Scalability**: Enables larger context lengths
- **Flexibility**: Better resource utilization

---

## 💡 Advanced Memory Techniques

### **Dynamic Memory Allocation**
```python
# Sequence length adaptive batching
def adaptive_batch_size(seq_lengths, memory_budget):
    # Adjust batch size based on sequence length distribution
    max_seq_len = max(seq_lengths)
    memory_per_sample = estimate_memory(max_seq_len)
    return min(memory_budget // memory_per_sample, len(seq_lengths))
```

### **Activation Compression**
```python
# Store activations in lower precision during checkpointing
def compressed_checkpoint(func, *args):
    with torch.cuda.amp.autocast():
        return checkpoint.checkpoint(func, *args)
```

### **Memory-Mapped Activations**
```python
# Offload activations to CPU memory or disk
def cpu_offload_checkpoint(func, *args):
    # Move activations to CPU during storage
    pass
```

---

## 🔬 Exercises

### **Exercise 1: Custom Checkpointing**
Implement a custom checkpointing strategy that selectively checkpoints attention vs feed-forward layers.

### **Exercise 2: Memory Budget Optimizer**
Create a system that automatically determines optimal batch sizes given hardware memory constraints.

### **Exercise 3: Sequence Length Analysis**
Analyze how memory usage scales with sequence length and implement dynamic batching.

---

**Next: Chapter 4 - DeepSpeed ZeRO Deep Dive** ⚡

*In the next chapter, we'll explore DeepSpeed ZeRO's revolutionary approach to parameter partitioning and how it enables training models that don't fit on single GPUs.*