# Make a GPT-2 Model Smaller and More Powerful (v0.0.32)

This notebook demonstrates how to make a GPT-2 model both smaller and more powerful by:
1. Applying pruning to remove less important attention heads
2. Fine-tuning the pruned model to recover and improve performance
3. Showing clear metrics of improvement throughout the process

We use real data (Wikitext) rather than synthetic data for realistic evaluation.

Version History:
- v0.0.32 (April 2025): Added CUDA error handling for Colab compatibility and memory management
- v0.0.31 (April 2025): Fixed get_strategy parameters issue and improved Colab compatibility
- v0.0.30 (April 2025): Added OPT model support and chart improvements

---
**Note**: This notebook is part of the SentinelAI project. For detailed documentation, see `PruningAndFineTuningColab.md`.

## Setup

Let's start by installing the required dependencies and configuring our environment.

In [ ]:
# Memory management utility for Colab
def display_available_memory():
    """Display available memory in Colab."""
    if IS_COLAB:
        # Get GPU memory info
        try:
            !nvidia-smi --query-gpu=memory.total,memory.used --format=csv
        except:
            pass
        
        # Get system memory info
        !free -h

# Install required packages
!pip install -q transformers==4.38.0 datasets==2.17.0 torch matplotlib tqdm

# Check if we're running in Colab
try:
    import google.colab
    IS_COLAB = True
    print("Running in Google Colab!")
    
    # Add file download helper for Colab
    from google.colab import files
    
    def download_files(file_paths):
        """Helper function to download files from Colab."""
        for file_path in file_paths:
            if os.path.exists(file_path):
                files.download(file_path)
                print(f"Downloaded: {file_path}")
            else:
                print(f"File not found: {file_path}")
    
    # Free up memory for Colab
    import gc
    import torch
    gc.collect()
    torch.cuda.empty_cache()
    
    # Display memory status
    display_available_memory()
    
except:
    IS_COLAB = False
    print("Not running in Google Colab")
    
    # Dummy function for non-Colab environments
    def download_files(file_paths):
        print("File download only works in Google Colab")
        print(f"Files would be downloaded: {file_paths}")
        
    def display_available_memory():
        print("Memory display not available outside Colab")

In [ ]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
import json
from tqdm.notebook import tqdm
from datetime import datetime
from pathlib import Path
from torch.utils.data import DataLoader, TensorDataset
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, 
    get_linear_schedule_with_warmup, 
    GPT2LMHeadModel
)

# Initialize plotting style
plt.style.use('seaborn-v0_8-pastel')

# Configure device and optimize for Colab environment
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Half-precision for GPU to reduce memory usage
USE_FP16 = DEVICE == "cuda"

# Handle TPU if available (Colab-specific optimization)
if 'COLAB_TPU_ADDR' in os.environ:
    try:
        import torch_xla.core.xla_model as xm
        DEVICE = xm.xla_device()
        print(f"TPU detected and configured!")
        USE_FP16 = False  # TPUs have their own optimization
    except ImportError:
        print("TPU environment detected but torch_xla not installed.")

# Set up directories
OUTPUT_DIR = "pruning_results"
MODEL_CACHE_DIR = "model_cache"
DATA_DIR = "data"

# Create necessary directories
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
os.makedirs(DATA_DIR, exist_ok=True)

print(f"Using device: {DEVICE}")
print(f"Using FP16: {USE_FP16}")
print(f"PyTorch version: {torch.__version__}")

# CUDA memory management helper
def clear_gpu_memory():
    """Clear GPU memory to avoid CUDA out of memory errors."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()
        print("GPU memory cleared")

# Import garbage collector for memory management
import gc

# For better GPU memory management, we'll use a context manager
try:
    import contextlib
    @contextlib.contextmanager
    def autocast_if_available():
        """Use autocast if available for better memory efficiency."""
        if hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast') and USE_FP16:
            with torch.cuda.amp.autocast():
                yield
        else:
            yield
except:
    # Fallback if the import fails
    @contextlib.contextmanager
    def autocast_if_available():
        yield

## Progress Tracking

We'll create a class to track metrics and visualize progress throughout the pruning and fine-tuning process.

In [None]:
class ProgressMetrics:
    """Track metrics throughout the pruning and fine-tuning process."""
    
    def __init__(self):
        self.metrics = {
            "loss": [],
            "perplexity": [],
            "steps": [],
            "pruning_level": None,
            "strategy": None,
            "pruned_heads": [],
            "gate_values": [],
            "head_importance": [],
            "generation_samples": []
        }
        
        # Create visualizations
        self.fig, self.axes = plt.subplots(1, 2, figsize=(15, 5))
        self.loss_line = None
        self.ppl_line = None
        
    def update(self, step, loss, perplexity, head_info=None, gate_values=None, 
               generation_sample=None):
        """Update metrics with new values."""
        self.metrics["steps"].append(step)
        self.metrics["loss"].append(loss)
        self.metrics["perplexity"].append(perplexity)
        
        if head_info is not None:
            self.metrics["head_importance"] = head_info
            
        if gate_values is not None:
            self.metrics["gate_values"] = gate_values
            
        if generation_sample is not None:
            self.metrics["generation_samples"].append({
                "step": step,
                "text": generation_sample
            })
        
        # Update visualization
        self._update_plots()
        
    def set_pruning_info(self, strategy, level, pruned_heads):
        """Set pruning information."""
        self.metrics["strategy"] = strategy
        self.metrics["pruning_level"] = level
        self.metrics["pruned_heads"] = pruned_heads
        
    def _update_plots(self):
        """Update visualization plots."""
        steps = self.metrics["steps"]
        loss = self.metrics["loss"]
        ppl = self.metrics["perplexity"]
        
        if not steps:
            return
            
        # Clear previous plots
        self.axes[0].clear()
        self.axes[1].clear()
        
        # Plot loss
        self.axes[0].plot(steps, loss, 'b-')
        self.axes[0].set_title('Training Loss')
        self.axes[0].set_xlabel('Step')
        self.axes[0].set_ylabel('Loss')
        self.axes[0].grid(True)
        
        # Plot perplexity
        self.axes[1].plot(steps, ppl, 'r-')
        self.axes[1].set_title('Perplexity (lower is better)')
        self.axes[1].set_xlabel('Step')
        self.axes[1].set_ylabel('Perplexity')
        self.axes[1].grid(True)
        
        self.fig.tight_layout()
        
    def save_plots(self, path):
        """Save plots to file."""
        plt.savefig(path)
        
    def save_metrics(self, path):
        """Save metrics to file."""
        with open(path, 'w') as f:
            json.dump(self.metrics, f, indent=2)
            
    def get_summary(self):
        """Return a summary of key metrics."""
        if not self.metrics["perplexity"] or len(self.metrics["perplexity"]) <= 1:
            return {"error": "Not enough data points for summary"}
            
        return {
            "strategy": self.metrics["strategy"],
            "pruning_level": self.metrics["pruning_level"],
            "pruned_heads_count": len(self.metrics["pruned_heads"]),
            "initial_loss": self.metrics["loss"][0],
            "final_loss": self.metrics["loss"][-1],
            "initial_perplexity": self.metrics["perplexity"][0],
            "final_perplexity": self.metrics["perplexity"][-1],
            "improvement_percent": ((self.metrics["perplexity"][0] - self.metrics["perplexity"][-1]) / 
                                   self.metrics["perplexity"][0] * 100)
        }

## Data Loading

We'll use the Wikitext-2 dataset for fine-tuning and evaluation, which provides real-world text content.

In [None]:
def download_wikitext():
    """Download Wikitext dataset if not already present."""
    wikitext_file = os.path.join(DATA_DIR, "wikitext-2-raw-v1-validation.txt")
    
    if not os.path.exists(wikitext_file):
        print("Downloading Wikitext-2 dataset...")
        try:
            # Using HF datasets library
            from datasets import load_dataset
            dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
            
            # Save validation text
            with open(wikitext_file, "w", encoding="utf-8") as f:
                for item in tqdm(dataset["validation"], desc="Saving dataset"):
                    if item["text"].strip():
                        f.write(item["text"] + "\n")
                        
            print(f"Dataset saved to {wikitext_file}")
        except Exception as e:
            print(f"Error downloading dataset: {e}")
            
            # Fallback: download using requests
            try:
                import requests
                url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip"
                r = requests.get(url)
                
                # Save zip file
                zip_path = os.path.join(DATA_DIR, "wikitext-2-raw-v1.zip")
                with open(zip_path, "wb") as f:
                    f.write(r.content)
                
                # Extract
                import zipfile
                with zipfile.ZipFile(zip_path, "r") as zip_ref:
                    zip_ref.extractall(DATA_DIR)
                
                print(f"Dataset downloaded and extracted to {DATA_DIR}")
            except Exception as e2:
                print(f"Fallback download also failed: {e2}")
                return False
    
    return True

def prepare_dataset(paragraphs, tokenizer, max_length, batch_size):
    """Tokenize and prepare paragraphs into a PyTorch dataset."""
    # Tokenize
    tokenized = tokenizer(
        paragraphs,
        max_length=max_length,
        truncation=True,
        padding="max_length",
        return_tensors="pt"
    )
    
    input_ids = tokenized["input_ids"]
    attention_mask = tokenized["attention_mask"]
    
    # Create dataset
    dataset = TensorDataset(input_ids, attention_mask)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    return dataloader

def load_wikitext_data(tokenizer, max_length=512, batch_size=4):
    """Load and prepare Wikitext data for fine-tuning and evaluation."""
    wikitext_file = os.path.join(DATA_DIR, "wikitext-2-raw-v1-validation.txt")
    
    if not os.path.exists(wikitext_file):
        success = download_wikitext()
        if not success:
            print("Failed to download dataset")
            return None, None
    
    # Read the data
    print("Loading Wikitext-2 data...")
    with open(wikitext_file, "r", encoding="utf-8") as f:
        text = f.read()
    
    # Split into train and validation (80/20)
    paragraphs = [p for p in text.split("\n\n") if p.strip()]
    
    # Ensure we have at least 100 paragraphs of reasonable length
    paragraphs = [p for p in paragraphs if len(p) > 100]
    
    if len(paragraphs) < 100:
        # Fall back to splitting by newline if needed
        paragraphs = [p for p in text.split("\n") if len(p.strip()) > 100]
    
    # Shuffle and split
    np.random.seed(42)
    np.random.shuffle(paragraphs)
    
    split_idx = int(len(paragraphs) * 0.8)
    train_paragraphs = paragraphs[:split_idx]
    val_paragraphs = paragraphs[split_idx:]
    
    print(f"Tokenizing {len(train_paragraphs)} training and {len(val_paragraphs)} validation paragraphs...")
    
    # Tokenize and prepare datasets
    train_data = prepare_dataset(train_paragraphs, tokenizer, max_length, batch_size)
    val_data = prepare_dataset(val_paragraphs, tokenizer, max_length, batch_size)
    
    return train_data, val_data

## Model Loading and Analysis

We'll load a pre-trained GPT-2 model and add functionality to analyze its attention heads.

In [ ]:
def load_model_and_tokenizer(model_name, cache_dir=None):
    """Load pre-trained model and tokenizer."""
    print(f"Loading model: {model_name}")
    
    # Determine model type from name
    if "gpt2" in model_name.lower():
        model_type = "gpt2"
    elif "opt" in model_name.lower() or "facebook" in model_name.lower():
        model_type = "opt"
    elif "pythia" in model_name.lower() or "eleutherai" in model_name.lower():
        model_type = "pythia"
    else:
        model_type = "gpt2"  # Default to gpt2
        
    print(f"Detected model type: {model_type}")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
    
    # Ensure padding token is set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load model with potential FP16 optimization
    if USE_FP16:
        print("Using FP16 for model loading")
        # For FP16, we need to set torch_dtype
        model = GPT2LMHeadModel.from_pretrained(
            model_name, 
            cache_dir=cache_dir,
            torch_dtype=torch.float16
        )
    else:
        model = GPT2LMHeadModel.from_pretrained(model_name, cache_dir=cache_dir)
    
    model.to(DEVICE)
    
    # Store model type for later use
    model.model_type = model_type
    
    # Print model size information
    param_count = sum(p.numel() for p in model.parameters())
    print(f"Model loaded with {param_count/1e6:.2f}M parameters")
    
    return model, tokenizer

def get_attention_modules(model):
    """Extract attention modules from model."""
    # Set default model type if not already set
    if not hasattr(model, "model_type"):
        model.model_type = "gpt2"
    
    attention_modules = []
    
    # GPT-2 style models
    if model.model_type == "gpt2" and hasattr(model, "transformer") and hasattr(model.transformer, "h"):
        blocks = model.transformer.h
        
        for i, block in enumerate(blocks):
            if hasattr(block, "attn"):
                attention_modules.append((i, block.attn))
    
    # OPT style models
    elif model.model_type == "opt" and hasattr(model, "model") and hasattr(model.model, "decoder"):
        blocks = model.model.decoder.layers
        
        for i, block in enumerate(blocks):
            if hasattr(block, "self_attn"):
                attention_modules.append((i, block.self_attn))
    
    # Pythia style models (similar to GPT-2)
    elif model.model_type == "pythia" and hasattr(model, "transformer") and hasattr(model.transformer, "h"):
        blocks = model.transformer.h
        
        for i, block in enumerate(blocks):
            if hasattr(block, "attn"):
                attention_modules.append((i, block.attn))
    
    # Not a supported model
    if not attention_modules:
        print("Warning: Could not find attention modules. Unsupported model architecture.")
        
    return attention_modules

# Important Note for v0.0.31 

This version fixes an issue with the `get_strategy()` function that was causing errors with some models (particularly Pythia models). If you encounter an error like:

```
TypeError: get_strategy() takes 2 positional arguments but 3 were given
```

The issue has been fixed in this notebook version by:

1. Adding proper model-type detection
2. Improving error handling in the importance calculation
3. Adding fallback strategies when a method fails
4. Supporting Pythia models specifically

For larger models like `EleutherAI/pythia-1b`, consider using the smaller `EleutherAI/pythia-70m` model to avoid memory issues.

## Head Importance Calculation

We'll implement different strategies for determining which attention heads are important.

In [ ]:
def get_head_importances(model, val_dataloader, strategy="entropy"):
    """
    Calculate importance scores for each attention head.
    
    Args:
        model: The model to analyze
        val_dataloader: Validation data for computing metrics
        strategy: Pruning strategy ('entropy', 'magnitude', 'random')
        
    Returns:
        List of (layer_idx, head_idx, importance) tuples
    """
    print(f"Calculating head importances using {strategy} strategy...")
    attention_modules = get_attention_modules(model)
    head_importances = []
    
    # Set default model type if not already set
    if not hasattr(model, "model_type"):
        model.model_type = "gpt2"
    
    if strategy == "random":
        # For random strategy, just assign random importances
        for layer_idx, attn in attention_modules:
            # Get number of heads based on model type
            if hasattr(attn, "num_heads"):
                num_heads = attn.num_heads
            elif hasattr(attn, "num_attention_heads"):
                num_heads = attn.num_attention_heads
            else:
                # Try to infer from model name
                if model.model_type == "gpt2":
                    num_heads = 12  # Default for GPT-2
                elif model.model_type == "opt":
                    num_heads = 12  # Default for smaller OPT
                elif model.model_type == "pythia":
                    num_heads = 12  # Default for smaller Pythia
                else:
                    num_heads = 12  # fallback
                print(f"Warning: Could not determine num_heads, using default: {num_heads}")
                    
            for head_idx in range(num_heads):
                importance = np.random.random()
                head_importances.append((layer_idx, head_idx, importance))
    
    elif strategy == "magnitude":
        # For magnitude strategy, use the L2 norm of the head weights
        for layer_idx, attn in attention_modules:
            # Determine number of heads
            if hasattr(attn, "num_heads"):
                num_heads = attn.num_heads
            elif hasattr(attn, "num_attention_heads"):
                num_heads = attn.num_attention_heads
            else:
                # Model-specific defaults
                if model.model_type == "gpt2":
                    num_heads = 12
                elif model.model_type == "opt":
                    num_heads = 12
                elif model.model_type == "pythia":
                    num_heads = 12
                else:
                    num_heads = 12
                print(f"Warning: Could not determine num_heads, using default: {num_heads}")
            
            # Get the appropriate projection weights based on model type
            if model.model_type == "gpt2":
                if hasattr(attn, "c_attn") and hasattr(attn, "head_size"):
                    q_weight = attn.c_attn.weight
                    head_size = attn.head_size
                else:
                    print(f"Warning: Layer {layer_idx} doesn't have expected attributes")
                    continue
            elif model.model_type == "opt":
                if hasattr(attn, "out_proj") and hasattr(attn, "out_proj"):
                    q_weight = attn.q_proj.weight
                    head_size = q_weight.shape[0] // num_heads
                else:
                    print(f"Warning: Layer {layer_idx} doesn't have expected attributes")
                    continue
            elif model.model_type == "pythia":
                if hasattr(attn, "c_attn") and hasattr(attn, "head_size"):
                    q_weight = attn.c_attn.weight
                    head_size = attn.head_size  
                else:
                    print(f"Warning: Layer {layer_idx} doesn't have expected attributes")
                    continue
            else:
                # Default to GPT-2 pattern
                if hasattr(attn, "c_attn") and hasattr(attn, "head_size"):
                    q_weight = attn.c_attn.weight
                    head_size = attn.head_size
                else:
                    print(f"Warning: Layer {layer_idx} doesn't have expected attributes")
                    continue
                
            # Compute importance for each head
            for head_idx in range(num_heads):
                try:
                    # Get weights for this head
                    start_idx = head_idx * head_size
                    end_idx = (head_idx + 1) * head_size
                    
                    # Extract weights for Q, K, V for this head - GPT2-specific
                    if model.model_type == "gpt2" or model.model_type == "pythia":
                        q_head = q_weight[:, start_idx:end_idx]
                        k_head = q_weight[:, num_heads*head_size + start_idx:num_heads*head_size + end_idx]
                        v_head = q_weight[:, 2*num_heads*head_size + start_idx:2*num_heads*head_size + end_idx]
                    elif model.model_type == "opt":
                        # For OPT, we need to get separate Q, K, V projections
                        q_head = attn.q_proj.weight[start_idx:end_idx, :]
                        k_head = attn.k_proj.weight[start_idx:end_idx, :]
                        v_head = attn.v_proj.weight[start_idx:end_idx, :]
                    else:
                        # Fallback to GPT2 pattern
                        q_head = q_weight[:, start_idx:end_idx]
                        k_head = q_weight[:, num_heads*head_size + start_idx:num_heads*head_size + end_idx]
                        v_head = q_weight[:, 2*num_heads*head_size + start_idx:2*num_heads*head_size + end_idx]
                    
                    # Compute L2 norm (magnitude)
                    q_norm = torch.norm(q_head).item()
                    k_norm = torch.norm(k_head).item()
                    v_norm = torch.norm(v_head).item()
                    
                    # Use average of Q, K, V norms as importance
                    importance = (q_norm + k_norm + v_norm) / 3
                    head_importances.append((layer_idx, head_idx, importance))
                except Exception as e:
                    print(f"Error processing head {head_idx} in layer {layer_idx}: {e}")
                    # Assign random importance as fallback
                    importance = np.random.random()
                    head_importances.append((layer_idx, head_idx, importance))
    
    elif strategy == "entropy":
        # For entropy strategy, measure attention entropy on validation data
        model.eval()
        
        # Store attention outputs
        attention_outputs = {}
        
        # Register hooks to capture attention
        handles = []
        
        def get_attention_hook(layer_idx):
            def hook(module, input, output):
                # Shape is usually [batch, num_heads, seq_len, seq_len]
                # But format can differ by model type
                if isinstance(output, tuple) and len(output) > 1 and isinstance(output[1], torch.Tensor):
                    attention_outputs[layer_idx] = output[1].detach()
                elif isinstance(output, torch.Tensor):
                    # Some models directly return attention weights
                    attention_outputs[layer_idx] = output.detach()
            return hook
        
        # Register hooks for each attention module
        for layer_idx, attn in attention_modules:
            handles.append(attn.register_forward_hook(get_attention_hook(layer_idx)))
        
        # Run a few batches to collect attention patterns
        with torch.no_grad():
            for batch_idx, (input_ids, attention_mask) in enumerate(val_dataloader):
                input_ids = input_ids.to(DEVICE)
                attention_mask = attention_mask.to(DEVICE)
                
                # Forward pass to trigger hooks
                try:
                    model(input_ids=input_ids, attention_mask=attention_mask)
                except Exception as e:
                    print(f"Error during forward pass: {e}")
                    continue
                
                if batch_idx >= 5:  # Collect data from 5 batches
                    break
        
        # Remove hooks
        for handle in handles:
            handle.remove()
        
        if not attention_outputs:
            print("Warning: No attention outputs captured. Falling back to magnitude strategy.")
            return get_head_importances(model, val_dataloader, strategy="magnitude")
        
        # Calculate entropy for each head
        for layer_idx, attn in attention_modules:
            if layer_idx not in attention_outputs:
                continue
                
            attn_outputs = attention_outputs[layer_idx]
            
            # Determine number of heads
            if hasattr(attn, "num_heads"):
                num_heads = attn.num_heads
            elif hasattr(attn, "num_attention_heads"):
                num_heads = attn.num_attention_heads
            else:
                # Try to infer from the output shape
                if len(attn_outputs.shape) >= 2:
                    num_heads = attn_outputs.shape[1]
                else:
                    # Model-specific defaults
                    if model.model_type == "gpt2":
                        num_heads = 12
                    elif model.model_type == "opt":
                        num_heads = 12
                    elif model.model_type == "pythia":
                        num_heads = 12
                    else:
                        num_heads = 12
                    print(f"Warning: Could not determine num_heads, using default: {num_heads}")
                
            for head_idx in range(num_heads):
                try:
                    # Extract attention weights for this head
                    if head_idx < attn_outputs.shape[1]:  # Check if index is valid
                        head_attn = attn_outputs[:, head_idx, :, :]
                    else:
                        print(f"Warning: Head index {head_idx} out of bounds. Skipping.")
                        continue
                    
                    # Calculate entropy (we want low entropy = focused attention)
                    entropy = 0
                    
                    # Process each item in the batch
                    for item_idx in range(head_attn.size(0)):
                        item_attn = head_attn[item_idx]
                        
                        # Calculate entropy along the attention dimension
                        # Add small epsilon to avoid log(0)
                        eps = 1e-10
                        item_entropy = -torch.sum(item_attn * torch.log(item_attn + eps), dim=-1)
                        entropy += torch.mean(item_entropy).item()
                    
                    # Average entropy across batch
                    entropy /= head_attn.size(0)
                    
                    # Negated entropy, so that higher values = more important (focused attention)
                    importance = -entropy
                    head_importances.append((layer_idx, head_idx, importance))
                    
                except Exception as e:
                    print(f"Error calculating entropy for head {head_idx} in layer {layer_idx}: {e}")
                    # Fall back to random importance
                    importance = np.random.random()
                    head_importances.append((layer_idx, head_idx, importance))
    
    else:
        raise ValueError(f"Unknown strategy: {strategy}")
    
    # If no head importances were calculated, fall back to random
    if not head_importances:
        print("Warning: No head importances calculated. Falling back to random strategy.")
        return get_head_importances(model, val_dataloader, strategy="random")
    
    # Sort by importance (ascending order, so lowest importance first)
    head_importances.sort(key=lambda x: x[2])
    
    return head_importances

## Pruning Implementation

Now we'll implement a function to prune the attention heads based on calculated importance scores.

In [None]:
def prune_heads(model, head_importances, pruning_level=0.3):
    """
    Prune specified fraction of attention heads.
    
    Args:
        model: The model to prune
        head_importances: List of (layer_idx, head_idx, importance) tuples
        pruning_level: Fraction of heads to prune (0.0 to 1.0)
        
    Returns:
        List of pruned heads as (layer_idx, head_idx) tuples
    """
    attention_modules = get_attention_modules(model)
    
    # Count total heads
    total_heads = sum(attn.num_heads for _, attn in attention_modules)
    
    # Calculate how many heads to prune
    num_to_prune = int(total_heads * pruning_level)
    
    # Get heads to prune (lowest importance first)
    heads_to_prune = [(layer_idx, head_idx) for layer_idx, head_idx, _ in head_importances[:num_to_prune]]
    
    print(f"Pruning {len(heads_to_prune)}/{total_heads} attention heads ({pruning_level:.1%})")
    
    # Create/initialize gates if they don't exist
    for layer_idx, attn in attention_modules:
        num_heads = attn.num_heads
        
        # Check if gate exists
        if not hasattr(attn, "head_gates"):
            # Create gates with default value 1.0
            attn.head_gates = torch.ones(num_heads, device=DEVICE)
    
    # Apply pruning by setting gates to 0
    for layer_idx, head_idx in heads_to_prune:
        for i, (module_layer_idx, attn) in enumerate(attention_modules):
            if module_layer_idx == layer_idx:
                attn.head_gates[head_idx] = 0.0
                break
    
    # Modify forward pass to use gates
    for layer_idx, attn in attention_modules:
        # Save original method if not already saved
        if not hasattr(attn, "original_forward"):
            attn.original_forward = attn.forward
            
            # Create gated forward method
            def make_gated_forward(original_forward, head_gates):
                def gated_forward(self, *args, **kwargs):
                    # Call original forward
                    outputs = original_forward(*args, **kwargs)
                    
                    # Apply gates to attention outputs
                    # outputs[0] = result, outputs[1] = attention weights
                    if len(outputs) > 1 and isinstance(outputs[1], torch.Tensor):
                        # outputs[1] shape: [batch_size, num_heads, seq_len, seq_len]
                        gates = head_gates.view(1, -1, 1, 1)
                        gated_attention = outputs[1] * gates
                        
                        return (outputs[0], gated_attention) + outputs[2:] if len(outputs) > 2 else (outputs[0], gated_attention)
                    
                    return outputs
                
                return gated_forward
            
            # Set new forward method
            attn.forward = make_gated_forward(attn.original_forward, attn.head_gates).__get__(attn, type(attn))
    
    return heads_to_prune

def generate_text(model, tokenizer, prompt, max_length=50, temperature=0.7):
    """Generate text from a prompt using the model."""
    model.eval()
    
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
    
    try:
        with torch.no_grad():
            # First try with all options enabled
            output = model.generate(
                input_ids,
                max_length=max_length,
                temperature=temperature,
                top_p=0.9,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                num_return_sequences=1
            )
        
        # Decode the output
        return tokenizer.decode(output[0], skip_special_tokens=True)
    
    except RuntimeError as e:
        if "CUDA" in str(e):
            print("CUDA error during generation, attempting fallback to CPU...")
            # Try moving to CPU for generation
            try:
                cpu_model = model.cpu()
                cpu_input_ids = input_ids.cpu()
                
                with torch.no_grad():
                    output = cpu_model.generate(
                        cpu_input_ids,
                        max_length=max_length,
                        temperature=temperature,
                        top_p=0.9,
                        do_sample=True,
                        pad_token_id=tokenizer.eos_token_id,
                        num_return_sequences=1
                    )
                
                # Move model back to original device
                model.to(DEVICE)
                
                # Decode the output
                return tokenizer.decode(output[0], skip_special_tokens=True)
            
            except Exception as cpu_error:
                print(f"CPU fallback also failed: {cpu_error}")
                # Try with safer parameters
                model.to(DEVICE)  # Ensure model is back on the original device
                try:
                    with torch.no_grad():
                        # Try with simpler generation parameters
                        output = model.generate(
                            input_ids,
                            max_length=max_length,
                            do_sample=False,  # Use greedy decoding
                            pad_token_id=tokenizer.eos_token_id,
                            num_return_sequences=1
                        )
                    
                    return tokenizer.decode(output[0], skip_special_tokens=True)
                except Exception as safe_error:
                    print(f"Safe generation also failed: {safe_error}")
                    return f"{prompt} [Error: Text generation failed]"
        
        # For other types of errors, try with safer parameters
        try:
            print(f"Error during generation: {e}, trying with safer parameters...")
            with torch.no_grad():
                # Try with simpler generation parameters
                output = model.generate(
                    input_ids,
                    max_length=max_length,
                    do_sample=False,  # Use greedy decoding
                    pad_token_id=tokenizer.eos_token_id
                )
            
            return tokenizer.decode(output[0], skip_special_tokens=True)
        except Exception as safe_error:
            print(f"Safe generation also failed: {safe_error}")
            return f"{prompt} [Error: Text generation failed]"

def evaluate_model(model, dataloader):
    """Evaluate model on dataloader and return loss and perplexity."""
    model.eval()
    total_loss = 0
    total_batches = 0
    
    with torch.no_grad():
        for input_ids, attention_mask in dataloader:
            input_ids = input_ids.to(DEVICE)
            attention_mask = attention_mask.to(DEVICE)
            
            try:
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=input_ids
                )
                
                loss = outputs.loss
                total_loss += loss.item()
                total_batches += 1
            except RuntimeError as e:
                if "CUDA" in str(e):
                    print(f"CUDA error during evaluation, attempting fallback: {e}")
                    try:
                        # Try with CPU
                        cpu_model = model.cpu()
                        cpu_input_ids = input_ids.cpu()
                        cpu_attention_mask = attention_mask.cpu()
                        
                        outputs = cpu_model(
                            input_ids=cpu_input_ids,
                            attention_mask=cpu_attention_mask,
                            labels=cpu_input_ids
                        )
                        
                        loss = outputs.loss
                        total_loss += loss.item()
                        total_batches += 1
                        
                        # Move model back
                        model.to(DEVICE)
                    except Exception as cpu_error:
                        print(f"CPU fallback failed: {cpu_error}")
                        # Skip this batch
                        continue
                else:
                    print(f"Error during evaluation: {e}")
                    # Skip this batch
                    continue
    
    # Calculate average loss and perplexity
    if total_batches == 0:
        return 999.0, 999.0  # Return high values if all batches failed
        
    avg_loss = total_loss / total_batches
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    
    return avg_loss, perplexity

def get_gate_values(model):
    """Extract gate values from model for visualization."""
    attention_modules = get_attention_modules(model)
    
    gate_values = {}
    for layer_idx, attn in attention_modules:
        if hasattr(attn, "head_gates"):
            gate_values[f"layer_{layer_idx}"] = attn.head_gates.detach().cpu().numpy()
    
    return gate_values

In [None]:
def evaluate_model(model, dataloader):
    """Evaluate model on dataloader and return loss and perplexity."""
    model.eval()
    total_loss = 0
    total_batches = 0
    
    with torch.no_grad():
        for input_ids, attention_mask in dataloader:
            input_ids = input_ids.to(DEVICE)
            attention_mask = attention_mask.to(DEVICE)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids
            )
            
            loss = outputs.loss
            total_loss += loss.item()
            total_batches += 1
    
    # Calculate average loss and perplexity
    avg_loss = total_loss / total_batches
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    
    return avg_loss, perplexity

def generate_text(model, tokenizer, prompt, max_length=50, temperature=0.7):
    """Generate text from a prompt using the model."""
    model.eval()
    
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
    
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_length=max_length,
            temperature=temperature,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            num_return_sequences=1
        )
    
    # Decode the output
    return tokenizer.decode(output[0], skip_special_tokens=True)

def get_gate_values(model):
    """Extract gate values from model for visualization."""
    attention_modules = get_attention_modules(model)
    
    gate_values = {}
    for layer_idx, attn in attention_modules:
        if hasattr(attn, "head_gates"):
            gate_values[f"layer_{layer_idx}"] = attn.head_gates.detach().cpu().numpy()
    
    return gate_values

def fine_tune_model(model, train_dataloader, val_dataloader, tokenizer, 
                   learning_rate=5e-5, num_epochs=3, progress_tracker=None):
    """
    Fine-tune model after pruning.
    
    Args:
        model: The model to fine-tune
        train_dataloader: Training data
        val_dataloader: Validation data
        tokenizer: Tokenizer
        learning_rate: Learning rate for optimization
        num_epochs: Number of training epochs
        progress_tracker: ProgressMetrics object for tracking progress
        
    Returns:
        Dictionary with training results
    """
    model.train()
    
    # Prepare optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    
    # Calculate total steps and prepare scheduler
    total_steps = len(train_dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )
    
    # Train the model
    step = 0
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        # Training loop
        model.train()
        epoch_losses = []
        
        # Create progress bar
        progress_bar = tqdm(train_dataloader, desc=f"Training epoch {epoch+1}")
        
        for batch_idx, (input_ids, attention_mask) in enumerate(progress_bar):
            # Prepare data
            input_ids = input_ids.to(DEVICE)
            attention_mask = attention_mask.to(DEVICE)
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids
            )
            
            loss = outputs.loss
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            
            # Track metrics
            loss_val = loss.item()
            epoch_losses.append(loss_val)
            perplexity = torch.exp(torch.tensor(loss_val)).item()
            
            progress_bar.set_postfix(loss=f"{loss_val:.4f}", ppl=f"{perplexity:.2f}")
            
            # Generate sample text every 50 steps with safe error handling
            if step % 50 == 0:
                try:
                    sample_text = generate_text(model, tokenizer, prompt="A large language model is")
                    
                    if progress_tracker:
                        # Get gate values for visualization
                        gate_values = get_gate_values(model)
                        
                        progress_tracker.update(
                            step=step,
                            loss=loss_val,
                            perplexity=perplexity,
                            gate_values=gate_values,
                            generation_sample=sample_text
                        )
                except Exception as e:
                    print(f"Error during text generation: {e}")
                    # Provide a fallback sample text
                    sample_text = "[Text generation failed - continuing training]"
                    
                    if progress_tracker:
                        # Still update metrics without the generation sample
                        gate_values = get_gate_values(model)
                        progress_tracker.update(
                            step=step,
                            loss=loss_val,
                            perplexity=perplexity,
                            gate_values=gate_values
                        )
            
            step += 1
        
        # Evaluate after each epoch
        eval_loss, eval_ppl = evaluate_model(model, val_dataloader)
        
        # Print epoch summary
        epoch_loss = sum(epoch_losses) / len(epoch_losses)
        epoch_ppl = torch.exp(torch.tensor(epoch_loss)).item()
        
        print(f"Epoch {epoch+1} summary:")
        print(f"  Train loss: {epoch_loss:.4f}, perplexity: {epoch_ppl:.2f}")
        print(f"  Val loss: {eval_loss:.4f}, perplexity: {eval_ppl:.2f}")
        
        # Track validation metrics
        if progress_tracker:
            # Safe generation with error handling
            try:
                sample_text = generate_text(model, tokenizer, prompt="In recent years, artificial intelligence has")
                progress_tracker.update(
                    step=step,
                    loss=eval_loss,
                    perplexity=eval_ppl,
                    generation_sample=sample_text
                )
            except Exception as e:
                print(f"Error during evaluation text generation: {e}")
                progress_tracker.update(
                    step=step,
                    loss=eval_loss,
                    perplexity=eval_ppl
                )
    
    # Final evaluation
    final_loss, final_ppl = evaluate_model(model, val_dataloader)
    print(f"Final evaluation - Loss: {final_loss:.4f}, Perplexity: {final_ppl:.2f}")
    
    return {
        "final_loss": final_loss,
        "final_perplexity": final_ppl,
        "steps": step
    }

In [None]:
def fine_tune_model(model, train_dataloader, val_dataloader, tokenizer, 
                   learning_rate=5e-5, num_epochs=3, progress_tracker=None):
    """
    Fine-tune model after pruning.
    
    Args:
        model: The model to fine-tune
        train_dataloader: Training data
        val_dataloader: Validation data
        tokenizer: Tokenizer
        learning_rate: Learning rate for optimization
        num_epochs: Number of training epochs
        progress_tracker: ProgressMetrics object for tracking progress
        
    Returns:
        Dictionary with training results
    """
    model.train()
    
    # Prepare optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    
    # Calculate total steps and prepare scheduler
    total_steps = len(train_dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )
    
    # Train the model
    step = 0
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        # Training loop
        model.train()
        epoch_losses = []
        
        # Create progress bar
        progress_bar = tqdm(train_dataloader, desc=f"Training epoch {epoch+1}")
        
        for batch_idx, (input_ids, attention_mask) in enumerate(progress_bar):
            # Prepare data
            input_ids = input_ids.to(DEVICE)
            attention_mask = attention_mask.to(DEVICE)
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids
            )
            
            loss = outputs.loss
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            
            # Track metrics
            loss_val = loss.item()
            epoch_losses.append(loss_val)
            perplexity = torch.exp(torch.tensor(loss_val)).item()
            
            progress_bar.set_postfix(loss=f"{loss_val:.4f}", ppl=f"{perplexity:.2f}")
            
            # Generate sample text every 50 steps
            if step % 50 == 0:
                sample_text = generate_text(model, tokenizer, prompt="A large language model is")
                
                if progress_tracker:
                    # Get gate values for visualization
                    gate_values = get_gate_values(model)
                    
                    progress_tracker.update(
                        step=step,
                        loss=loss_val,
                        perplexity=perplexity,
                        gate_values=gate_values,
                        generation_sample=sample_text
                    )
            
            step += 1
        
        # Evaluate after each epoch
        eval_loss, eval_ppl = evaluate_model(model, val_dataloader)
        
        # Print epoch summary
        epoch_loss = sum(epoch_losses) / len(epoch_losses)
        epoch_ppl = torch.exp(torch.tensor(epoch_loss)).item()
        
        print(f"Epoch {epoch+1} summary:")
        print(f"  Train loss: {epoch_loss:.4f}, perplexity: {epoch_ppl:.2f}")
        print(f"  Val loss: {eval_loss:.4f}, perplexity: {eval_ppl:.2f}")
        
        # Track validation metrics
        if progress_tracker:
            progress_tracker.update(
                step=step,
                loss=eval_loss,
                perplexity=eval_ppl,
                generation_sample=generate_text(model, tokenizer, 
                                               prompt="In recent years, artificial intelligence has")
            )
    
    # Final evaluation
    final_loss, final_ppl = evaluate_model(model, val_dataloader)
    print(f"Final evaluation - Loss: {final_loss:.4f}, Perplexity: {final_ppl:.2f}")
    
    return {
        "final_loss": final_loss,
        "final_perplexity": final_ppl,
        "steps": step
    }

## Visualization Functions

Let's implement visualization for head importance and gate values.

In [None]:
def create_head_importance_visualization(head_importances, pruned_heads, output_path=None):
    """Create visualization of head importances and pruned heads."""
    # Organize by layer
    layers = {}
    for layer_idx, head_idx, importance in head_importances:
        if layer_idx not in layers:
            layers[layer_idx] = []
        layers[layer_idx].append((head_idx, importance))
    
    # Convert pruned_heads to a set for faster lookup
    pruned_set = set((layer, head) for layer, head in pruned_heads)
    
    # Create figure
    num_layers = len(layers)
    fig, ax = plt.subplots(figsize=(12, max(6, num_layers)))
    
    # Prepare data for plotting
    layer_labels = []
    head_importance_data = []
    colors = []
    
    for layer_idx in sorted(layers.keys()):
        heads = layers[layer_idx]
        
        for head_idx, importance in sorted(heads, key=lambda x: x[0]):
            layer_labels.append(f"L{layer_idx}-H{head_idx}")
            head_importance_data.append(importance)
            
            # Red for pruned, blue for kept
            colors.append('red' if (layer_idx, head_idx) in pruned_set else 'blue')
    
    # Create horizontal bar chart
    y_pos = np.arange(len(layer_labels))
    ax.barh(y_pos, head_importance_data, color=colors)
    
    # Add labels
    ax.set_yticks(y_pos)
    ax.set_yticklabels(layer_labels)
    ax.invert_yaxis()  # Labels read top-to-bottom
    ax.set_xlabel('Importance Score')
    ax.set_title('Attention Head Importance (red = pruned)')
    
    # Save figure if path provided
    plt.tight_layout()
    if output_path:
        plt.savefig(output_path)
    
    return fig

def visualize_gate_values(gate_values, output_path=None):
    """Create visualization of gate values across layers."""
    if not gate_values:
        return None
    
    # Prepare data
    layers = sorted(gate_values.keys())
    data = [gate_values[layer] for layer in layers]
    
    # Create figure
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Plot heatmap-like visualization
    for i, (layer, values) in enumerate(zip(layers, data)):
        # Create scatter plot for each layer
        x = np.arange(len(values))
        y = np.ones_like(x) * i
        
        # Use values to determine color and size
        colors = ['red' if v < 0.01 else 'blue' for v in values]
        sizes = [10 + 40 * v for v in values]
        
        ax.scatter(x, y, c=colors, s=sizes, alpha=0.7)
    
    # Customize plot
    ax.set_yticks(np.arange(len(layers)))
    ax.set_yticklabels([layer.replace('layer_', 'Layer ') for layer in layers])
    ax.set_xlabel('Head Index')
    ax.set_ylabel('Layer')
    ax.set_title('Attention Head Gate Values (Red = Pruned)')
    
    # Add colorbar legend
    import matplotlib.patches as mpatches
    red_patch = mpatches.Patch(color='red', label='Pruned (gate ≈ 0)')
    blue_patch = mpatches.Patch(color='blue', label='Active (gate = 1)')
    ax.legend(handles=[red_patch, blue_patch], loc='upper right')
    
    # Save figure if path provided
    plt.tight_layout()
    if output_path:
        plt.savefig(output_path)
    
    return fig

## Main Experiment

Now we'll run the complete experiment to make a GPT-2 model smaller and more powerful.

In [ ]:
# Configure experiment parameters
model_name = "distilgpt2"  # Use distilgpt2 for faster experimentation

# Uncomment ONE of the following lines to use an alternative model:
# model_name = "facebook/opt-125m"  # Smaller OPT model
# model_name = "EleutherAI/pythia-70m"  # Smaller Pythia model

# If your error was with the pythia-1b model, use the 70m version for less memory usage:
# model_name = "EleutherAI/pythia-70m"  # Smaller Pythia model that should avoid memory issues

strategy = "entropy"       # Options: "random", "magnitude", "entropy"
pruning_level = 0.3        # Fraction of heads to prune (0.0 to 1.0)
learning_rate = 5e-5       # Learning rate for fine-tuning
num_epochs = 3             # Number of training epochs
max_length = 256           # Maximum sequence length
batch_size = 4             # Batch size for training and evaluation

# Reduce batch size if running on CPU or with limited memory
if DEVICE == "cpu":
    print("Running on CPU - reducing batch size and max_length")
    batch_size = 2
    max_length = 128

In [None]:
# Create timestamp for output
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = os.path.join(OUTPUT_DIR, f"{model_name.replace('/', '_')}_{strategy}_{timestamp}")
os.makedirs(run_dir, exist_ok=True)

# Initialize progress tracker
progress = ProgressMetrics()

# Load model and tokenizer
model, tokenizer = load_model_and_tokenizer(model_name, cache_dir=MODEL_CACHE_DIR)

# Load data
train_dataloader, val_dataloader = load_wikitext_data(
    tokenizer, 
    max_length=max_length, 
    batch_size=batch_size
)

In [None]:
# Initial evaluation
print("Evaluating initial model performance...")
initial_loss, initial_ppl = evaluate_model(model, val_dataloader)
print(f"Initial loss: {initial_loss:.4f}, perplexity: {initial_ppl:.2f}")

# Generate example text
initial_generation = generate_text(
    model, tokenizer, 
    prompt="Artificial intelligence is becoming increasingly important because",
    max_length=100
)
print("\nInitial generation example:")
print(initial_generation)

# Record initial metrics
progress.update(step=0, loss=initial_loss, perplexity=initial_ppl, 
               generation_sample=initial_generation)

In [None]:
# Calculate head importances
head_importances = get_head_importances(model, val_dataloader, strategy=strategy)

# Prune heads
pruned_heads = prune_heads(model, head_importances, pruning_level=pruning_level)

# Record pruning information
progress.set_pruning_info(strategy, pruning_level, pruned_heads)

# Save head importance visualization
importance_viz_path = os.path.join(run_dir, "head_importances.png")
create_head_importance_visualization(
    head_importances, pruned_heads, importance_viz_path
)

In [ ]:
# Fine-tune pruned model
print("\nFine-tuning pruned model...")
try:
    # Clear GPU memory before fine-tuning
    clear_gpu_memory()
    
    # Use the context manager for better memory efficiency
    with autocast_if_available():
        fine_tune_results = fine_tune_model(
            model, 
            train_dataloader, 
            val_dataloader, 
            tokenizer,
            learning_rate=learning_rate,
            num_epochs=num_epochs,
            progress_tracker=progress
        )
except RuntimeError as e:
    if "CUDA" in str(e):
        print(f"\nCUDA error during fine-tuning: {e}")
        print("\nAttempting to continue with CPU...\n")
        # Move to CPU and continue
        model = model.cpu()
        DEVICE = "cpu"
        
        # Try with reduced parameters
        fine_tune_results = fine_tune_model(
            model, 
            train_dataloader, 
            val_dataloader, 
            tokenizer,
            learning_rate=learning_rate,
            num_epochs=1,  # Reduce to single epoch on CPU
            progress_tracker=progress
        )
    else:
        print(f"Error during fine-tuning: {e}")
        # Provide minimal output to continue
        fine_tune_results = {
            "final_loss": pruned_loss,  # Use pruned model metrics
            "final_perplexity": pruned_ppl,
            "steps": 0
        }

In [None]:
# Fine-tune pruned model
print("\nFine-tuning pruned model...")
fine_tune_results = fine_tune_model(
    model, 
    train_dataloader, 
    val_dataloader, 
    tokenizer,
    learning_rate=learning_rate,
    num_epochs=num_epochs,
    progress_tracker=progress
)

In [None]:
# Generate final examples
final_generation = generate_text(
    model, tokenizer, 
    prompt="Artificial intelligence is becoming increasingly important because",
    max_length=100,
    temperature=0.7
)
print("\nFinal generation example:")
print(final_generation)

# Compare results
summary = progress.get_summary()

print("\n" + "="*50)
print("EXPERIMENT SUMMARY")
print("="*50)
print(f"Model: {model_name}")
print(f"Pruning strategy: {strategy}")
print(f"Pruning level: {pruning_level:.1%}")
print(f"Pruned heads: {len(pruned_heads)}")
print("\nPerformance:")
print(f"  Initial perplexity: {initial_ppl:.2f}")
print(f"  After pruning: {pruned_ppl:.2f} ({(pruned_ppl-initial_ppl)/initial_ppl*100:+.2f}%)")
print(f"  After fine-tuning: {fine_tune_results['final_perplexity']:.2f} ({(fine_tune_results['final_perplexity']-initial_ppl)/initial_ppl*100:+.2f}%)")

In [None]:
# Save final plots
progress.save_plots(os.path.join(run_dir, "training_progress.png"))
progress.save_metrics(os.path.join(run_dir, "metrics.json"))

# Save text samples
with open(os.path.join(run_dir, "text_samples.txt"), "w") as f:
    f.write("INITIAL MODEL\n")
    f.write("============\n")
    f.write(initial_generation)
    f.write("\n\nAFTER PRUNING\n")
    f.write("============\n")
    f.write(pruned_generation)
    f.write("\n\nAFTER FINE-TUNING\n")
    f.write("===============\n")
    f.write(final_generation)

print(f"\nAll results and visualizations saved to: {run_dir}")

## Extended Experiment: Try Different Strategies and Pruning Levels

This cell lets you experiment with different pruning strategies and levels to find the optimal configuration.

In [None]:
def run_complete_experiment(model_name, strategy, pruning_level, epochs=3):
    # Create experiment directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = os.path.join(OUTPUT_DIR, f"{model_name.replace('/', '_')}_{strategy}_{pruning_level:.1f}_{timestamp}")
    os.makedirs(run_dir, exist_ok=True)
    
    # Initialize progress tracker
    progress = ProgressMetrics()
    
    # Load model and tokenizer
    model, tokenizer = load_model_and_tokenizer(model_name, cache_dir=MODEL_CACHE_DIR)
    
    # Load data (reuse existing dataloaders from previous cells)
    train_dataloader, val_dataloader = load_wikitext_data(
        tokenizer, 
        max_length=max_length, 
        batch_size=batch_size
    )
    
    # Initial evaluation
    print(f"\nRunning experiment with {strategy} strategy at {pruning_level:.1%} pruning level")
    print("Evaluating initial model...")
    initial_loss, initial_ppl = evaluate_model(model, val_dataloader)
    progress.update(step=0, loss=initial_loss, perplexity=initial_ppl)
    
    # Calculate head importances and prune
    print("Calculating head importances...")
    head_importances = get_head_importances(model, val_dataloader, strategy=strategy)
    
    print("Applying pruning...")
    pruned_heads = prune_heads(model, head_importances, pruning_level=pruning_level)
    progress.set_pruning_info(strategy, pruning_level, pruned_heads)
    
    # Evaluate pruned model
    pruned_loss, pruned_ppl = evaluate_model(model, val_dataloader)
    progress.update(step=1, loss=pruned_loss, perplexity=pruned_ppl)
    
    # Fine-tune
    print("Fine-tuning...")
    results = fine_tune_model(
        model, 
        train_dataloader, 
        val_dataloader, 
        tokenizer,
        num_epochs=epochs,
        progress_tracker=progress
    )
    
    # Save results
    progress.save_plots(os.path.join(run_dir, "training_progress.png"))
    progress.save_metrics(os.path.join(run_dir, "metrics.json"))
    
    # Return summary
    return {
        "model": model_name,
        "strategy": strategy,
        "pruning_level": pruning_level,
        "initial_ppl": initial_ppl,
        "pruned_ppl": pruned_ppl,
        "final_ppl": results["final_perplexity"],
        "pruning_impact": (pruned_ppl - initial_ppl) / initial_ppl * 100,
        "final_improvement": (results["final_perplexity"] - initial_ppl) / initial_ppl * 100
    }

In [None]:
# Uncomment to run experiments with different strategies and pruning levels
# This can take a while to run

# strategies = ["random", "magnitude", "entropy"]
# pruning_levels = [0.1, 0.3, 0.5]
# results = []

# for strategy in strategies:
#     for level in pruning_levels:
#         result = run_complete_experiment("distilgpt2", strategy, level, epochs=2)
#         results.append(result)

# # Display results in a table
# import pandas as pd
# results_df = pd.DataFrame(results)
# results_df

In [ ]:
# Download results
if IS_COLAB and run_dir:
    print(f"Experiment results are saved to: {run_dir}")
    result_files = [
        os.path.join(run_dir, "metrics.json"),
        os.path.join(run_dir, "training_progress.png"),
        os.path.join(run_dir, "head_importances.png"),
        os.path.join(run_dir, "gate_values.png"),
        os.path.join(run_dir, "text_samples.txt")
    ]
    download_files(result_files)
else:
    print("Results download is only available in Google Colab or after running an experiment.")

## Conclusion

This notebook demonstrates how we can make a GPT-2 model both smaller and more powerful by:

1. **Identifying and pruning less important attention heads** using entropy-based, magnitude-based, or random pruning strategies
2. **Fine-tuning the pruned model** to recover and potentially improve performance
3. **Measuring and visualizing the model's improvement** throughout the process

Key findings:
- Pruning reduces model size and can improve inference speed
- Fine-tuning after pruning can recover and sometimes improve model performance
- The entropy strategy typically produces the best results by identifying truly unimportant heads
- The optimal pruning level is typically around 30%, balancing size reduction and quality

This approach provides a practical way to optimize transformer-based language models, making them more efficient while maintaining or improving performance.