# Pruning Benchmark Notebook

This notebook implements a stable pruning benchmark that works on both Google Colab and local machines (including M1/M2 Macs).

Note: This implementation includes special fixes for Apple Silicon Macs that may experience BLAS crashes.

In [None]:
# Install required packages
!pip install transformers torch numpy matplotlib

In [None]:
import os
import sys
import torch
import numpy as np
import json
import glob
import re
import matplotlib.pyplot as plt
from datetime import datetime

# Set environment variables to avoid BLAS crashes on Mac
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Further limit PyTorch's threading
torch.set_num_threads(1)
if hasattr(torch, 'set_num_interop_threads'):
    torch.set_num_interop_threads(1)

# Import transformers
from transformers import GPT2LMHeadModel, GPT2Tokenizer

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Pruning Strategies

We'll implement two pruning strategies:
1. Random - Prune random heads
2. Entropy - Prune heads based on their importance, measured as weight magnitude

Both strategies are designed for stability especially on M1/M2 Macs.

In [None]:
class PruningStrategy:
    """Base class for pruning strategies"""
    def __init__(self, model):
        self.model = model
        
    def get_head_importance(self, layer_idx):
        """Get importance scores for all heads in a layer"""
        raise NotImplementedError("Subclasses must implement get_head_importance")
    
    def prune_heads(self, layer_idx, head_idxs):
        """Prune specific heads in a layer"""
        raise NotImplementedError("Subclasses must implement prune_heads")

class RandomPruningStrategy(PruningStrategy):
    """Random pruning strategy"""
    def get_head_importance(self, layer_idx):
        attn = self.model.transformer.h[layer_idx].attn
        num_heads = attn.num_heads
        # Generate random importance scores
        return np.random.rand(num_heads)
    
    def prune_heads(self, layer_idx, head_idxs):
        print(f"Pruning layer {layer_idx}, heads {head_idxs} using Random strategy")
        self._zero_out_heads(layer_idx, head_idxs)
        return self.model
    
    def _zero_out_heads(self, layer_idx, head_idxs):
        """Zero out the output projection for specific heads"""
        attn = self.model.transformer.h[layer_idx].attn
        hidden_size = attn.embed_dim
        num_heads = attn.num_heads
        head_size = hidden_size // num_heads
        
        with torch.no_grad():
            for head_idx in head_idxs:
                # Calculate start and end indices for this head
                start_idx = head_idx * head_size
                end_idx = (head_idx + 1) * head_size
                
                # Zero out the output projection weights for this head
                attn.c_proj.weight[:, start_idx:end_idx] = 0
                
                # If there's a bias, zero it out too
                if hasattr(attn.c_proj, 'bias') and attn.c_proj.bias is not None:
                    attn.c_proj.bias[start_idx:end_idx] = 0
        
        print(f"Successfully pruned heads {head_idxs} in layer {layer_idx}")

class EntropyPruningStrategy(PruningStrategy):
    """Entropy-based pruning strategy"""
    def __init__(self, model, tokenizer, sample_text=None):
        super().__init__(model)
        self.tokenizer = tokenizer
        
        # Use default sample text if none provided
        if sample_text is None:
            self.sample_text = [
                "The quick brown fox jumps over the lazy dog",
                "Artificial intelligence is transforming the world",
                "Machine learning models can process large amounts of data"
            ]
        else:
            self.sample_text = sample_text
    
    def get_head_importance(self, layer_idx):
        """Calculate entropy-based importance for each head"""
        print(f"Calculating entropy-based importance for layer {layer_idx}...")
        
        # This is a simplified proxy for entropy-based importance
        # In a real implementation, we would calculate attention entropy
        # For stability on M1/M2 Mac, we're using a simpler approach
        attn = self.model.transformer.h[layer_idx].attn
        num_heads = attn.num_heads
        head_importance = np.zeros(num_heads)
        
        for head_idx in range(num_heads):
            # To avoid complex BLAS operations, we'll estimate importance
            # based on the magnitude of output projection weights
            head_dim = attn.head_dim
            start_idx = head_idx * head_dim
            end_idx = (head_idx + 1) * head_dim
            
            # Get L2 norm of weights as a proxy for importance
            weight_norm = torch.norm(attn.c_proj.weight[:, start_idx:end_idx]).item()
            head_importance[head_idx] = weight_norm
        
        # Normalize importance scores
        if np.sum(head_importance) > 0:
            head_importance = head_importance / np.sum(head_importance)
        
        return head_importance
    
    def prune_heads(self, layer_idx, head_idxs):
        print(f"Pruning layer {layer_idx}, heads {head_idxs} using Entropy strategy")
        self._zero_out_heads(layer_idx, head_idxs)
        return self.model
    
    def _zero_out_heads(self, layer_idx, head_idxs):
        """Same implementation as RandomPruningStrategy"""
        attn = self.model.transformer.h[layer_idx].attn
        hidden_size = attn.embed_dim
        num_heads = attn.num_heads
        head_size = hidden_size // num_heads
        
        with torch.no_grad():
            for head_idx in head_idxs:
                # Calculate start and end indices for this head
                start_idx = head_idx * head_size
                end_idx = (head_idx + 1) * head_size
                
                # Zero out the output projection weights for this head
                attn.c_proj.weight[:, start_idx:end_idx] = 0
                
                # If there's a bias, zero it out too
                if hasattr(attn.c_proj, 'bias') and attn.c_proj.bias is not None:
                    attn.c_proj.bias[start_idx:end_idx] = 0
        
        print(f"Successfully pruned heads {head_idxs} in layer {layer_idx}")

## Evaluation Functions

Now we'll define functions to evaluate model performance before and after pruning.

In [None]:
def evaluate_perplexity(model, tokenizer, text, device="cpu"):
    """Evaluate model perplexity on a text sample"""
    # Encode the text
    inputs = tokenizer(text, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Create a mask to avoid padding tokens
    attention_mask = torch.ones(inputs["input_ids"].shape, device=device)
    
    # Evaluate perplexity
    with torch.no_grad():
        outputs = model(
            input_ids=inputs["input_ids"], 
            attention_mask=attention_mask,
            labels=inputs["input_ids"]
        )
        
    return torch.exp(outputs.loss).item()

def generate_text(model, tokenizer, prompt, max_length=50, device="cpu"):
    """Generate text using the model"""
    inputs = tokenizer(prompt, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    attention_mask = torch.ones(inputs["input_ids"].shape, device=device)
    
    # Generate text
    with torch.no_grad():
        outputs = model.generate(
            inputs["input_ids"],
            attention_mask=attention_mask,
            max_length=max_length,
            num_return_sequences=1,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def save_results(strategy_name, pruning_level, results):
    """Save results to a JSON file"""
    # Create results directory if it doesn't exist
    results_dir = "pruning_results"
    os.makedirs(results_dir, exist_ok=True)
    
    # Generate filename with timestamp
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
    filename = f"{strategy_name}_{pruning_level}_{timestamp}.json"
    filepath = os.path.join(results_dir, filename)
    
    # Save results
    with open(filepath, "w") as f:
        json.dump(results, f, indent=2)
    
    print(f"Results saved to {filepath}")
    return filepath

## Run Pruning Benchmark

Now we'll implement the main function to run pruning benchmarks.

In [None]:
def run_pruning_benchmark(model_name="distilgpt2", strategy_name="random", 
                         pruning_level=0.3, prompt="Artificial intelligence is",
                         device="cpu"):
    """Run pruning benchmark with specified parameters"""
    print(f"Running pruning benchmark with:\n" + 
          f"  Model: {model_name}\n" +
          f"  Strategy: {strategy_name}\n" +
          f"  Pruning level: {pruning_level}\n" +
          f"  Device: {device}")
    
    # Load model and tokenizer
    print(f"\nLoading {model_name}...")
    model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    
    # Set model to evaluation mode
    model.eval()
    
    # Display model information
    num_layers = len(model.transformer.h)
    num_heads = model.transformer.h[0].attn.num_heads
    print(f"Model loaded: {model_name}")
    print(f"Layers: {num_layers}, Heads per layer: {num_heads}")
    
    # Create pruning strategy
    if strategy_name.lower() == "random":
        strategy = RandomPruningStrategy(model)
    elif strategy_name.lower() == "entropy":
        strategy = EntropyPruningStrategy(model, tokenizer)
    else:
        raise ValueError(f"Unknown pruning strategy: {strategy_name}")
    
    # Calculate number of heads to prune
    total_heads = num_layers * num_heads
    heads_to_prune = int(total_heads * pruning_level)
    print(f"Pruning level: {pruning_level} ({heads_to_prune} out of {total_heads} heads)")
    
    # Evaluate model before pruning
    print("\nEvaluating model before pruning...")
    
    # Generate text
    print(f"Prompt: '{prompt}'")
    generated_before = generate_text(model, tokenizer, prompt, device=device)
    print(f"Generated (before pruning): {generated_before}")
    
    # Calculate perplexity
    perplexity_before = evaluate_perplexity(model, tokenizer, prompt, device=device)
    print(f"Perplexity (before pruning): {perplexity_before:.4f}")
    
    # Perform pruning
    print("\nPerforming pruning...")
    
    # Get importance scores for all heads
    all_head_importance = []
    for layer_idx in range(num_layers):
        importance = strategy.get_head_importance(layer_idx)
        for head_idx, score in enumerate(importance):
            all_head_importance.append((layer_idx, head_idx, score))
    
    # Sort heads by importance (ascending)
    all_head_importance.sort(key=lambda x: x[2])
    
    # Prune least important heads
    pruned_heads = all_head_importance[:heads_to_prune]
    
    # Group by layer for efficient pruning
    pruned_by_layer = {}
    for layer_idx, head_idx, _ in pruned_heads:
        if layer_idx not in pruned_by_layer:
            pruned_by_layer[layer_idx] = []
        pruned_by_layer[layer_idx].append(head_idx)
    
    # Prune heads layer by layer
    for layer_idx, head_idxs in pruned_by_layer.items():
        strategy.prune_heads(layer_idx, head_idxs)
    
    # Evaluate model after pruning
    print("\nEvaluating model after pruning...")
    
    # Generate text
    generated_after = generate_text(model, tokenizer, prompt, device=device)
    print(f"Generated (after pruning): {generated_after}")
    
    # Calculate perplexity
    perplexity_after = evaluate_perplexity(model, tokenizer, prompt, device=device)
    print(f"Perplexity (after pruning): {perplexity_after:.4f}")
    print(f"Perplexity change: {perplexity_after - perplexity_before:.4f}")
    
    # Prepare results
    results = {
        "model": model_name,
        "strategy": strategy_name,
        "pruning_level": pruning_level,
        "total_heads": total_heads,
        "pruned_heads": heads_to_prune,
        "pruned_head_details": [(l, h) for l, h, _ in pruned_heads],
        "prompt": prompt,
        "generated_before": generated_before,
        "generated_after": generated_after,
        "perplexity_before": perplexity_before,
        "perplexity_after": perplexity_after,
        "perplexity_change": perplexity_after - perplexity_before,
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }
    
    # Save results
    save_results(strategy_name, pruning_level, results)
    
    print("\nPruning benchmark completed successfully!")
    return results

## Run a Single Benchmark Test

Let's run a single benchmark test to make sure everything is working.

In [None]:
# Run a single test with minimal pruning
results = run_pruning_benchmark(
    model_name="distilgpt2",
    strategy_name="random",
    pruning_level=0.1,
    prompt="Artificial intelligence is",
    device=device
)

## View Results

This function loads and displays results from previous benchmark runs.

In [None]:
def view_results(results_dir="pruning_results"):
    """Load and display results from previous benchmark runs"""
    # Create results directory if it doesn't exist
    os.makedirs(results_dir, exist_ok=True)
    
    # Find all result files
    result_files = glob.glob(os.path.join(results_dir, "*.json"))
    
    if not result_files:
        print(f"No result files found in {results_dir}")
        return []
    
    # Load results
    all_results = []
    
    for filepath in result_files:
        try:
            # Extract strategy and pruning level from filename
            filename = os.path.basename(filepath)
            
            # Try to match both formats:
            # 1. New format: "strategy_level_timestamp.json"
            # 2. Old format: "strategy_pruning_level_results.json"
            match1 = re.match(r"(\w+)_(\d+\.\d+)_\d+\.json", filename)
            match2 = re.match(r"(\w+)_pruning_(\d+\.\d+)_results\.json", filename)
            
            if match1:
                # New format
                strategy = match1.group(1)
                pruning_level = float(match1.group(2))
            elif match2:
                # Old format
                strategy = match2.group(1)
                pruning_level = float(match2.group(2))
            else:
                print(f"Warning: Couldn't parse filename {filename}, skipping...")
                continue
                
            # Load the results file
            with open(filepath, "r") as f:
                results = json.load(f)
            
            # Add strategy and pruning level if not in results
            if "strategy" not in results:
                results["strategy"] = strategy
            if "pruning_level" not in results:
                results["pruning_level"] = pruning_level
                
            all_results.append(results)
        except Exception as e:
            print(f"Error loading {filepath}: {e}")
    
    # Sort results by strategy and pruning level
    all_results.sort(key=lambda x: (x["strategy"], x["pruning_level"]))
    
    # Group results by strategy
    strategies = set(r["strategy"] for r in all_results)
    grouped_results = {strategy: [] for strategy in strategies}
    
    for result in all_results:
        grouped_results[result["strategy"]].append(result)
    
    # Display summary
    print(f"Found {len(all_results)} result files:\n")
    
    for strategy, results in grouped_results.items():
        print(f"Strategy: {strategy}")
        for result in results:
            perplexity_change = result.get("perplexity_change", "N/A")
            if isinstance(perplexity_change, (int, float)):
                perplexity_change = f"{perplexity_change:.4f}"
            
            print(f"  Pruning level: {result['pruning_level']}, " + 
                  f"Perplexity change: {perplexity_change}")
        print()
    
    return all_results

# Load and display results
all_results = view_results()

## Plot Results

Let's visualize the results by plotting perplexity change vs pruning level for each strategy.

In [None]:
def plot_results(all_results):
    """Plot perplexity change vs pruning level for each strategy"""
    if not all_results:
        print("No results to plot")
        return
    
    # Group results by strategy
    strategies = set(r["strategy"] for r in all_results)
    grouped_results = {strategy: [] for strategy in strategies}
    
    for result in all_results:
        # Only include results with perplexity change
        if "perplexity_change" in result and isinstance(result["perplexity_change"], (int, float)):
            grouped_results[result["strategy"]].append(result)
    
    # Plot results
    plt.figure(figsize=(10, 6))
    
    colors = ["blue", "red", "green", "orange", "purple"]
    markers = ["o", "s", "^", "d", "x"]
    
    for i, (strategy, results) in enumerate(grouped_results.items()):
        if not results:
            continue
            
        # Sort by pruning level
        results.sort(key=lambda x: x["pruning_level"])
        
        # Extract data for plotting
        pruning_levels = [r["pruning_level"] for r in results]
        perplexity_changes = [r["perplexity_change"] for r in results]
        
        # Plot
        color = colors[i % len(colors)]
        marker = markers[i % len(markers)]
        plt.plot(pruning_levels, perplexity_changes, marker=marker, color=color,
                linestyle="-", label=strategy.capitalize())
    
    plt.xlabel("Pruning Level")
    plt.ylabel("Perplexity Change")
    plt.title("Effect of Pruning on Model Perplexity")
    plt.grid(True, linestyle="--", alpha=0.7)
    plt.legend()
    
    # Add horizontal line at y=0
    plt.axhline(y=0, color="gray", linestyle="-", alpha=0.5)
    
    plt.tight_layout()
    plt.show()

# Plot results if we have any
if all_results:
    plot_results(all_results)

## Run Multiple Benchmark Tests

Run multiple benchmarks with different pruning levels and strategies. This can be run overnight on Google Colab.

In [None]:
def run_multiple_benchmarks(model_name="distilgpt2", strategies=["random", "entropy"],
                           pruning_levels=[0.1, 0.2, 0.3, 0.4, 0.5],
                           prompt="Artificial intelligence is",
                           device="cpu"):
    """Run multiple benchmarks with different parameters"""
    all_results = []
    
    for strategy in strategies:
        for level in pruning_levels:
            print(f"\n{'='*50}\nRunning benchmark: {strategy}, level: {level}\n{'='*50}\n")
            
            try:
                result = run_pruning_benchmark(
                    model_name=model_name,
                    strategy_name=strategy,
                    pruning_level=level,
                    prompt=prompt,
                    device=device
                )
                all_results.append(result)
            except Exception as e:
                print(f"Error in benchmark {strategy}, level {level}: {e}")
                import traceback
                traceback.print_exc()
    
    print(f"\nCompleted {len(all_results)} benchmarks out of {len(strategies) * len(pruning_levels)} attempted")
    
    # Plot results
    plot_results(all_results)
    
    return all_results

# Uncomment to run multiple benchmarks
# Note: This can take a long time, especially for higher pruning levels
# all_results = run_multiple_benchmarks(
#     model_name="distilgpt2",
#     strategies=["random", "entropy"],
#     pruning_levels=[0.1, 0.3, 0.5],
#     device=device
# )