# JAX Pruning Benchmark

This notebook implements a stable pruning benchmark using JAX/Flax instead of PyTorch. This is particularly useful for M1/M2 Macs that may experience BLAS crashes with PyTorch.

In [None]:
# Install required packages
!pip install -q jax jaxlib flax transformers matplotlib numpy

In [None]:
import os
import sys
import json
import glob
import re
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

# Set environment variables for JAX
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

# Import JAX and transformers
import jax
import jax.numpy as jnp
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

## Pruning Implementation

Next, we'll implement our head pruning functions using JAX/Flax. This approach is more stable on M1/M2 Macs compared to PyTorch-based implementations.

In [None]:
def prune_head_in_params(params, layer_idx, head_idx, model_type="gpt2"):
    """Zero out weights for a specific attention head in Flax params"""
    if model_type == "gpt2":
        # Access path to transformer layers
        transformer_path = "transformer"
        # In Flax, layer indices are stored as strings
        layer_path = f"h"
        layer_key = str(layer_idx)
        attn_path = "attn"
        
        # Get attention block
        attn_block = params[transformer_path][layer_path][layer_key][attn_path]
        
        # Get head dimension and number of heads
        num_heads = 12  # Standard for GPT-2
        if "distil" in model_type:
            num_heads = 12  # DistilGPT-2 also has 12 heads
        
        hidden_size = attn_block["c_attn"]["kernel"].shape[1]
        head_size = hidden_size // num_heads
        
        # Calculate start and end indices for this head in query, key, value
        q_start = head_idx * head_size
        q_end = (head_idx + 1) * head_size
        
        # Zero out the output projection for this head
        output_proj = attn_block["c_proj"]["kernel"]
        # In Flax, c_proj.kernel has shape [hidden_size, hidden_size]
        # We need to zero out the rows corresponding to this head
        zeros = jnp.zeros_like(output_proj[q_start:q_end, :])
        output_proj = output_proj.at[q_start:q_end, :].set(zeros)
        
        # Update the parameters
        params[transformer_path][layer_path][layer_key][attn_path]["c_proj"]["kernel"] = output_proj
        
        print(f"Successfully pruned layer {layer_idx}, head {head_idx}")
    
    return params

def evaluate_perplexity(model, params, tokenizer, text):
    """Evaluate model perplexity on text"""
    # Tokenize input
    inputs = tokenizer(text, return_tensors="jax")
    
    # Get logits
    outputs = model(**inputs, params=params)
    logits = outputs.logits
    
    # Calculate loss
    input_ids = inputs["input_ids"]
    
    # Shift logits and labels for next token prediction
    shift_logits = logits[:, :-1]
    shift_labels = input_ids[:, 1:]
    
    # Calculate cross entropy loss
    loss = jnp.mean(
        -jnp.sum(
            jax.nn.log_softmax(shift_logits) * jax.nn.one_hot(shift_labels, shift_logits.shape[-1]),
            axis=-1
        )
    )
    
    # Return perplexity
    return jnp.exp(loss).item()

def generate_text(model, params, tokenizer, prompt, max_length=50):
    """Generate text using the model"""
    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="jax")
    
    # Generate text
    outputs = model.generate(
        **inputs,
        params=params,
        max_length=max_length,
        do_sample=True,
        top_k=40,
        top_p=0.95
    )
    
    # Decode output
    text = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
    return text

## Importance Calculation Strategies

Now we'll implement different strategies for calculating head importance.

In [None]:
def calculate_random_importance(params, num_layers, num_heads):
    """Calculate random importance scores for each head"""
    all_head_importance = []
    
    for layer_idx in range(num_layers):
        # Generate random scores
        importance = np.random.rand(num_heads)
        
        # Normalize scores
        importance = importance / np.sum(importance)
        
        # Add to all heads
        for head_idx, score in enumerate(importance):
            all_head_importance.append((layer_idx, head_idx, score))
    
    return all_head_importance

def calculate_entropy_importance(params, num_layers, num_heads):
    """Calculate entropy-based importance scores for each head"""
    all_head_importance = []
    
    for layer_idx in range(num_layers):
        # Use simplified proxy: norm of output projection weights
        importance = np.zeros(num_heads)
        layer_params = params["transformer"]["h"][str(layer_idx)]["attn"]
        output_proj = layer_params["c_proj"]["kernel"]
        
        head_size = output_proj.shape[0] // num_heads
        for head_idx in range(num_heads):
            start_idx = head_idx * head_size
            end_idx = (head_idx + 1) * head_size
            importance[head_idx] = jnp.linalg.norm(output_proj[start_idx:end_idx, :]).item()
        
        # Normalize importance scores
        if np.sum(importance) > 0:
            importance = importance / np.sum(importance)
        
        # Add to all heads
        for head_idx, score in enumerate(importance):
            all_head_importance.append((layer_idx, head_idx, score))
    
    return all_head_importance

## Save and Load Results

Let's implement functions to save and load results.

In [None]:
def save_results(strategy_name, pruning_level, results):
    """Save results to a JSON file"""
    # Create results directory if it doesn't exist
    results_dir = "pruning_results"
    os.makedirs(results_dir, exist_ok=True)
    
    # Generate filename with timestamp
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
    filename = f"{strategy_name}_{pruning_level}_{timestamp}.json"
    filepath = os.path.join(results_dir, filename)
    
    # Save results
    with open(filepath, "w") as f:
        json.dump(results, f, indent=2)
    
    print(f"Results saved to {filepath}")
    return filepath

def view_results(results_dir="pruning_results"):
    """Load and display results from previous benchmark runs"""
    # Create results directory if it doesn't exist
    os.makedirs(results_dir, exist_ok=True)
    
    # Find all result files
    result_files = glob.glob(os.path.join(results_dir, "*.json"))
    
    if not result_files:
        print(f"No result files found in {results_dir}")
        return []
    
    # Load results
    all_results = []
    
    for filepath in result_files:
        try:
            # Extract strategy and pruning level from filename
            filename = os.path.basename(filepath)
            
            # Try to match both formats:
            # 1. New format: "strategy_level_timestamp.json"
            # 2. Old format: "strategy_pruning_level_results.json"
            match1 = re.match(r"(\w+)_(\d+\.\d+)_\d+\.json", filename)
            match2 = re.match(r"(\w+)_pruning_(\d+\.\d+)_results\.json", filename)
            
            if match1:
                # New format
                strategy = match1.group(1)
                pruning_level = float(match1.group(2))
            elif match2:
                # Old format
                strategy = match2.group(1)
                pruning_level = float(match2.group(2))
            else:
                print(f"Warning: Couldn't parse filename {filename}, skipping...")
                continue
                
            # Load the results file
            with open(filepath, "r") as f:
                results = json.load(f)
            
            # Add strategy and pruning level if not in results
            if "strategy" not in results:
                results["strategy"] = strategy
            if "pruning_level" not in results:
                results["pruning_level"] = pruning_level
                
            all_results.append(results)
        except Exception as e:
            print(f"Error loading {filepath}: {e}")
    
    # Sort results by strategy and pruning level
    all_results.sort(key=lambda x: (x["strategy"], x["pruning_level"]))
    
    # Group results by strategy
    strategies = set(r["strategy"] for r in all_results)
    grouped_results = {strategy: [] for strategy in strategies}
    
    for result in all_results:
        grouped_results[result["strategy"]].append(result)
    
    # Display summary
    print(f"Found {len(all_results)} result files:\n")
    
    for strategy, results in grouped_results.items():
        print(f"Strategy: {strategy}")
        for result in results:
            perplexity_change = result.get("perplexity_change", "N/A")
            if isinstance(perplexity_change, (int, float)):
                perplexity_change = f"{perplexity_change:.4f}"
            
            print(f"  Pruning level: {result['pruning_level']}, " + 
                  f"Perplexity change: {perplexity_change}")
        print()
    
    return all_results

## Run Pruning Benchmark

Now let's implement the main function to run the pruning benchmark.

In [None]:
def run_pruning_benchmark(model_name="distilgpt2", strategy_name="random", 
                         pruning_level=0.3, prompt="Artificial intelligence is"):
    """Run pruning benchmark with specified parameters"""
    print(f"Running JAX/Flax pruning benchmark with:\n" + 
          f"  Model: {model_name}\n" +
          f"  Strategy: {strategy_name}\n" +
          f"  Pruning level: {pruning_level}\n" +
          f"  Prompt: {prompt}\n")
    
    # Load model and tokenizer
    print(f"Loading model and tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = FlaxAutoModelForCausalLM.from_pretrained(model_name)
    
    # Get model information
    if "gpt2" in model_name:
        model_type = "gpt2"
        num_layers = len(model.params["transformer"]["h"])
        num_heads = 12  # Standard for GPT-2
    else:
        raise ValueError(f"Unsupported model type: {model_name}")
    
    print(f"Model has {num_layers} layers with {num_heads} heads per layer")
    
    # Make a copy of the parameters (so we can keep original for comparison)
    original_params = model.params
    params = jax.tree_util.tree_map(lambda x: x, original_params)  # Deep copy
    
    # Evaluate model before pruning
    print("\nEvaluating model before pruning...")
    perplexity_before = evaluate_perplexity(model, params, tokenizer, prompt)
    print(f"Perplexity before pruning: {perplexity_before:.4f}")
    
    generated_before = generate_text(model, params, tokenizer, prompt)
    print(f"Generated (before pruning): {generated_before}")
    
    # Calculate head importance based on strategy
    print("\nCalculating head importance...")
    if strategy_name.lower() == "random":
        all_head_importance = calculate_random_importance(params, num_layers, num_heads)
    elif strategy_name.lower() == "entropy":
        all_head_importance = calculate_entropy_importance(params, num_layers, num_heads)
    else:
        raise ValueError(f"Unknown pruning strategy: {strategy_name}")
    
    # Sort by importance (ascending)
    all_head_importance.sort(key=lambda x: x[2])
    
    # Calculate number of heads to prune
    total_heads = num_layers * num_heads
    heads_to_prune = int(total_heads * pruning_level)
    print(f"Pruning {heads_to_prune} out of {total_heads} heads")
    
    # Get heads to prune (least important first)
    pruned_heads = all_head_importance[:heads_to_prune]
    
    # Prune the heads
    print("\nPruning heads...")
    for layer_idx, head_idx, _ in pruned_heads:
        params = prune_head_in_params(params, layer_idx, head_idx, model_type)
    
    # Evaluate model after pruning
    print("\nEvaluating model after pruning...")
    perplexity_after = evaluate_perplexity(model, params, tokenizer, prompt)
    print(f"Perplexity after pruning: {perplexity_after:.4f}")
    print(f"Perplexity change: {perplexity_after - perplexity_before:.4f}")
    
    generated_after = generate_text(model, params, tokenizer, prompt)
    print(f"Generated (after pruning): {generated_after}")
    
    # Save results
    results = {
        "model": model_name,
        "strategy": strategy_name,
        "pruning_level": pruning_level,
        "pruned_heads": heads_to_prune,
        "total_heads": total_heads,
        "prompt": prompt,
        "perplexity_before": float(perplexity_before),  # Convert from JAX array
        "perplexity_after": float(perplexity_after),
        "perplexity_change": float(perplexity_after - perplexity_before),
        "generated_before": generated_before,
        "generated_after": generated_after,
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }
    
    # Save results
    save_results(strategy_name, pruning_level, results)
    
    print("\nPruning benchmark completed successfully!")
    return results

## Plot Results

Let's implement a function to visualize the results.

In [None]:
def plot_results(all_results):
    """Plot perplexity change vs pruning level for each strategy"""
    if not all_results:
        print("No results to plot")
        return
    
    # Group results by strategy
    strategies = set(r["strategy"] for r in all_results)
    grouped_results = {strategy: [] for strategy in strategies}
    
    for result in all_results:
        # Only include results with perplexity change
        if "perplexity_change" in result and isinstance(result["perplexity_change"], (int, float)):
            grouped_results[result["strategy"]].append(result)
    
    # Plot results
    plt.figure(figsize=(10, 6))
    
    colors = ["blue", "red", "green", "orange", "purple"]
    markers = ["o", "s", "^", "d", "x"]
    
    for i, (strategy, results) in enumerate(grouped_results.items()):
        if not results:
            continue
            
        # Sort by pruning level
        results.sort(key=lambda x: x["pruning_level"])
        
        # Extract data for plotting
        pruning_levels = [r["pruning_level"] for r in results]
        perplexity_changes = [r["perplexity_change"] for r in results]
        
        # Plot
        color = colors[i % len(colors)]
        marker = markers[i % len(markers)]
        plt.plot(pruning_levels, perplexity_changes, marker=marker, color=color,
                linestyle="-", label=strategy.capitalize())
    
    plt.xlabel("Pruning Level")
    plt.ylabel("Perplexity Change")
    plt.title("Effect of Pruning on Model Perplexity")
    plt.grid(True, linestyle="--", alpha=0.7)
    plt.legend()
    
    # Add horizontal line at y=0
    plt.axhline(y=0, color="gray", linestyle="-", alpha=0.5)
    
    plt.tight_layout()
    plt.show()

## Run Single Benchmark

Let's run a single benchmark test with a low pruning level.

In [None]:
# Run a single test with minimal pruning
results = run_pruning_benchmark(
    model_name="distilgpt2",
    strategy_name="random",
    pruning_level=0.1,
    prompt="Artificial intelligence is"
)

## View and Plot Existing Results

Load and visualize existing results.

In [None]:
# Load and display results
all_results = view_results()

In [None]:
# Plot results if we have any
if all_results:
    plot_results(all_results)

## Run Multiple Benchmarks

Run multiple benchmarks with different strategies and pruning levels. This can be run overnight on Colab.

In [None]:
def run_multiple_benchmarks(model_name="distilgpt2", 
                           strategies=["random", "entropy"],
                           pruning_levels=[0.1, 0.2, 0.3, 0.4, 0.5],
                           prompt="Artificial intelligence is"):
    """Run multiple benchmarks with different parameters"""
    all_results = []
    
    for strategy in strategies:
        for level in pruning_levels:
            print(f"\n{'='*50}\nRunning benchmark: {strategy}, level: {level}\n{'='*50}\n")
            
            try:
                result = run_pruning_benchmark(
                    model_name=model_name,
                    strategy_name=strategy,
                    pruning_level=level,
                    prompt=prompt
                )
                all_results.append(result)
            except Exception as e:
                print(f"Error in benchmark {strategy}, level {level}: {e}")
                import traceback
                traceback.print_exc()
    
    print(f"\nCompleted {len(all_results)} benchmarks out of {len(strategies) * len(pruning_levels)} attempted")
    
    # Plot results
    plot_results(all_results)
    
    return all_results

# Uncomment to run multiple benchmarks
# Note: This can take a long time, especially for higher pruning levels
# all_results = run_multiple_benchmarks(
#     model_name="distilgpt2",
#     strategies=["random", "entropy"],
#     pruning_levels=[0.1, 0.3, 0.5]
# )