# 🔬 JAX Pruning Benchmark

This notebook implements a modular, robust pruning benchmark using JAX/Flax. It runs efficiently on:
- Google Colab (using GPUs/TPUs when available)
- M1/M2 Macs (avoiding BLAS crashes that occur with PyTorch)
- Any standard environment

## Features
- Auto-installation of dependencies
- Multiple pruning strategies
- Multiple models (automatically detected based on available memory)
- Progressive visualization as results are collected
- Robust error handling
- Compatible with overnight runs

In [None]:
# Install required packages
!pip install -q jax jaxlib flax transformers matplotlib numpy tqdm pandas seaborn

## Environment Setup

First, let's detect our environment and configure it appropriately.

In [None]:
import os
import sys
import json
import glob
import re
import time
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import platform
from pathlib import Path
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set plot style
plt.style.use('ggplot')
sns.set_theme(style="whitegrid")

# Environment detection
class Environment:
    """Detect and configure the runtime environment"""
    
    def __init__(self):
        # Check if we're running in Colab
        self.in_colab = 'google.colab' in sys.modules
        
        # Check if we're on a Mac
        self.is_mac = platform.system() == "Darwin"
        self.is_arm_mac = self.is_mac and platform.machine().startswith("arm")
        
        # Initialize JAX-related properties
        self.has_gpu = False
        self.has_tpu = False
        self.default_device = "cpu"
        self.memory_limit = 4  # Default memory limit in GB
        
        # Configure environment
        self._configure()
        
    def _configure(self):
        """Configure the environment based on detected hardware"""
        if self.is_arm_mac:
            # Mac-specific settings to avoid BLAS issues
            os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
            os.environ["OMP_NUM_THREADS"] = "1"
            os.environ["MKL_NUM_THREADS"] = "1"
            os.environ["OPENBLAS_NUM_THREADS"] = "1"
            self.memory_limit = 8  # M1/M2 Macs can handle more
        elif self.in_colab:
            # For Colab, try to detect and use TPU if available
            try:
                import jax
                import jax.tools.colab_tpu
                jax.tools.colab_tpu.setup_tpu()
                self.has_tpu = True
                self.default_device = "tpu"
                self.memory_limit = 24  # TPUs have more memory
                print("TPU configured for JAX")
            except Exception:
                # Check for GPU
                try:
                    import jax
                    jax.config.update('jax_platform_name', 'gpu')
                    if len(jax.devices('gpu')) > 0:
                        self.has_gpu = True
                        self.default_device = "gpu"
                        self.memory_limit = 12  # GPUs have decent memory
                        print("GPU configured for JAX")
                except Exception:
                    print("No TPU or GPU detected, using CPU")
    
    def get_suitable_models(self):
        """Return a list of models suitable for this environment"""
        all_models = {
            # Model name: approximate memory needed in GB
            "distilgpt2": 0.5,
            "gpt2": 1.5,
            "gpt2-medium": 3.0,
            "gpt2-large": 6.0,
            "gpt2-xl": 12.0,
            "facebook/opt-125m": 0.5,
            "facebook/opt-350m": 1.5,
            "facebook/opt-1.3b": 5.0,
            "EleutherAI/pythia-160m": 0.7,
            "EleutherAI/pythia-410m": 1.8,
            "EleutherAI/pythia-1b": 4.0
        }
        
        # Filter models based on memory limit
        suitable_models = {k: v for k, v in all_models.items() if v <= self.memory_limit}
        
        # Sort by size (smallest first)
        return sorted(suitable_models.keys(), key=lambda x: all_models[x])
    
    def print_info(self):
        """Print environment information"""
        print(f"Platform: {platform.platform()}")
        print(f"Python version: {platform.python_version()}")
        print(f"Running in Google Colab: {self.in_colab}")
        print(f"Running on Mac: {self.is_mac}, Apple Silicon: {self.is_arm_mac}")
        print(f"Default device: {self.default_device}")
        print(f"Memory limit (GB): {self.memory_limit}")
        print(f"\nModels available for this environment:")
        for model in self.get_suitable_models():
            print(f"  - {model}")

# Initialize environment
env = Environment()

# Import JAX and transformers
import jax
import jax.numpy as jnp
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM

env.print_info()
print(f"\nJAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

## Data Management

Next, let's implement classes for managing pruning results and file storage.

In [None]:
class ResultsManager:
    """Manages pruning benchmark results"""
    
    def __init__(self, results_dir="pruning_results"):
        self.results_dir = Path(results_dir)
        self.results_dir.mkdir(exist_ok=True, parents=True)
        self.all_results = []
        self.results_df = None
    
    def save_result(self, result):
        """Save a single result to disk and update dataframe"""
        # Generate filename with timestamp
        timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
        strategy = result["strategy"]
        model = result["model"].replace("/", "_")
        pruning_level = result["pruning_level"]
        
        filename = f"{model}_{strategy}_{pruning_level}_{timestamp}.json"
        filepath = self.results_dir / filename
        
        # Save as JSON
        with open(filepath, "w") as f:
            json.dump(result, f, indent=2)
            
        # Add to results list and update dataframe
        self.all_results.append(result)
        self._update_dataframe()
        
        return filepath
    
    def load_results(self):
        """Load all results from disk"""
        self.all_results = []
        
        # Find all result files
        result_files = list(self.results_dir.glob("*.json"))
        
        if not result_files:
            print(f"No result files found in {self.results_dir}")
            return []
        
        # Load results
        for filepath in result_files:
            try:
                with open(filepath, "r") as f:
                    result = json.load(f)
                self.all_results.append(result)
            except Exception as e:
                print(f"Error loading {filepath}: {e}")
        
        # Update dataframe
        self._update_dataframe()
        
        return self.all_results
    
    def _update_dataframe(self):
        """Convert results to a pandas DataFrame for easier analysis"""
        if not self.all_results:
            self.results_df = pd.DataFrame()
            return
        
        # Extract the fields we care about for analysis
        data = []
        for result in self.all_results:
            data.append({
                "model": result.get("model", "unknown"),
                "strategy": result.get("strategy", "unknown"),
                "pruning_level": result.get("pruning_level", 0),
                "perplexity_before": result.get("perplexity_before", 0),
                "perplexity_after": result.get("perplexity_after", 0),
                "perplexity_change": result.get("perplexity_change", 0),
                "timestamp": result.get("timestamp", "")
            })
        
        self.results_df = pd.DataFrame(data)
        
    def print_summary(self):
        """Print a summary of all results"""
        if self.results_df is None or self.results_df.empty:
            print("No results available")
            return
        
        # Group by model and strategy
        groups = self.results_df.groupby(["model", "strategy"])
        
        print(f"Found {len(self.all_results)} result files:\n")
        
        for (model, strategy), group in groups:
            print(f"Model: {model}, Strategy: {strategy}")
            # Sort by pruning level
            sorted_group = group.sort_values("pruning_level")
            for _, row in sorted_group.iterrows():
                print(f"  Pruning level: {row['pruning_level']:.2f}, " + 
                      f"Perplexity change: {row['perplexity_change']:.4f}")
            print()
    
    def plot_results(self, figsize=(12, 8)):
        """Plot results as an interactive visualization"""
        if self.results_df is None or self.results_df.empty:
            print("No results to plot")
            return
        
        # Create figure
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize, sharey=False)
        
        # Colors for different strategies
        strategies = self.results_df["strategy"].unique()
        strategy_colors = dict(zip(strategies, sns.color_palette("colorblind", len(strategies))))
        
        # Plot 1: Perplexity change vs pruning level, grouped by model and strategy
        for model in sorted(self.results_df["model"].unique()):
            model_data = self.results_df[self.results_df["model"] == model]
            for strategy in strategies:
                strategy_data = model_data[model_data["strategy"] == strategy]
                if not strategy_data.empty:
                    # Sort by pruning level
                    strategy_data = strategy_data.sort_values("pruning_level")
                    ax1.plot(strategy_data["pruning_level"], strategy_data["perplexity_change"],
                            marker="o", linestyle="-", label=f"{model} - {strategy}",
                            color=strategy_colors[strategy])
        
        # Add horizontal line at y=0
        ax1.axhline(y=0, color="gray", linestyle="-", alpha=0.5)
        ax1.set_xlabel("Pruning Level")
        ax1.set_ylabel("Perplexity Change")
        ax1.set_title("Effect of Pruning on Model Perplexity")
        ax1.grid(True, linestyle="--", alpha=0.7)
        
        # Plot 2: Before vs After perplexity, grouped by model and strategy
        for model in sorted(self.results_df["model"].unique()):
            model_data = self.results_df[self.results_df["model"] == model]
            for strategy in strategies:
                strategy_data = model_data[model_data["strategy"] == strategy]
                if not strategy_data.empty:
                    # Point size proportional to pruning level
                    sizes = 100 * strategy_data["pruning_level"] + 20
                    ax2.scatter(strategy_data["perplexity_before"], strategy_data["perplexity_after"],
                               s=sizes, alpha=0.7, label=f"{model} - {strategy}",
                               color=strategy_colors[strategy])
        
        # Add diagonal line (y=x)
        lims = [0, max(self.results_df["perplexity_before"].max(), 
                      self.results_df["perplexity_after"].max()) * 1.1]
        ax2.plot(lims, lims, 'k--', alpha=0.5)
        ax2.set_xlabel("Perplexity Before Pruning")
        ax2.set_ylabel("Perplexity After Pruning")
        ax2.set_title("Perplexity Comparison")
        ax2.grid(True, linestyle="--", alpha=0.7)
        
        # Add legend
        handles, labels = [], []
        for ax in [ax1, ax2]:
            h, l = ax.get_legend_handles_labels()
            handles.extend(h)
            labels.extend(l)
        
        # Remove duplicates
        by_label = dict(zip(labels, handles))
        fig.legend(by_label.values(), by_label.keys(), loc='lower center', 
                   ncol=min(5, len(by_label)), bbox_to_anchor=(0.5, -0.05))
        
        plt.tight_layout()
        plt.subplots_adjust(bottom=0.15)
        plt.show()
        
        return fig

# Initialize results manager
results_manager = ResultsManager()

# Load existing results if any
results_manager.load_results()
results_manager.print_summary()

## Pruning Module

Now let's implement the core pruning functionality.

In [None]:
class PruningModule:
    """Core pruning implementation using JAX/Flax"""
    
    def __init__(self, model_name="distilgpt2"):
        self.model_name = model_name
        self.model = None
        self.tokenizer = None
        self.original_params = None
        self.model_type = self._get_model_type(model_name)
        self.num_layers = 0
        self.num_heads = 0
        
    def _get_model_type(self, model_name):
        """Determine model type from name"""
        if "gpt2" in model_name.lower():
            return "gpt2"
        elif "opt" in model_name.lower():
            return "opt"
        elif "pythia" in model_name.lower():
            return "pythia"
        else:
            # Default to GPT-2 structure
            return "gpt2"
    
    def load_model(self):
        """Load model and tokenizer"""
        print(f"Loading model {self.model_name}...")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            self.model = FlaxAutoModelForCausalLM.from_pretrained(self.model_name)
            self.original_params = self.model.params
            
            # Get model details based on model type
            if self.model_type == "gpt2":
                self.num_layers = len(self.original_params["transformer"]["h"])
                self.num_heads = 12  # Standard for most GPT-2 variants
                if "distil" in self.model_name.lower():
                    self.num_layers = 6  # DistilGPT2 has 6 layers
                elif "medium" in self.model_name.lower():
                    self.num_heads = 16  # GPT2-medium has 16 heads
                elif "large" in self.model_name.lower():
                    self.num_heads = 20  # GPT2-large has 20 heads
                elif "xl" in self.model_name.lower():
                    self.num_heads = 25  # GPT2-xl has 25 heads
            elif self.model_type == "opt":
                self.num_layers = len(self.original_params["model"]["decoder"]["layers"])
                # Extract num_heads from config
                self.num_heads = 12  # Default, will be refined below
                try:
                    if "125m" in self.model_name.lower():
                        self.num_heads = 12
                    elif "350m" in self.model_name.lower():
                        self.num_heads = 16
                    elif "1.3b" in self.model_name.lower():
                        self.num_heads = 32
                except Exception:
                    pass  # Stick with default
            elif self.model_type == "pythia":
                self.num_layers = len(self.original_params["transformer"]["h"])
                # Extract num_heads based on model size
                self.num_heads = 12  # Default
                try:
                    if "160m" in self.model_name.lower():
                        self.num_heads = 12
                    elif "410m" in self.model_name.lower():
                        self.num_heads = 16
                    elif "1b" in self.model_name.lower():
                        self.num_heads = 16
                except Exception:
                    pass  # Stick with default
            
            print(f"Model loaded successfully. Layers: {self.num_layers}, Heads per layer: {self.num_heads}")
            return True
        except Exception as e:
            print(f"Error loading model: {e}")
            return False
    
    def prune_head(self, params, layer_idx, head_idx):
        """Zero out weights for a specific attention head"""
        if self.model_type == "gpt2":
            # Access path to transformer layers
            transformer_path = "transformer"
            layer_path = "h"
            layer_key = str(layer_idx)
            attn_path = "attn"
            
            # Get attention block
            attn_block = params[transformer_path][layer_path][layer_key][attn_path]
            
            # Calculate head dimensions
            if "c_attn" in attn_block:
                hidden_size = attn_block["c_attn"]["kernel"].shape[1]
            else:
                # Fallback using output projection
                hidden_size = attn_block["c_proj"]["kernel"].shape[0]
                
            head_size = hidden_size // self.num_heads
            
            # Calculate indices for this head
            start_idx = head_idx * head_size
            end_idx = (head_idx + 1) * head_size
            
            # Zero out the output projection for this head
            output_proj = attn_block["c_proj"]["kernel"]
            zeros = jnp.zeros_like(output_proj[start_idx:end_idx, :])
            output_proj = output_proj.at[start_idx:end_idx, :].set(zeros)
            
            # Update parameters
            params[transformer_path][layer_path][layer_key][attn_path]["c_proj"]["kernel"] = output_proj
            
        elif self.model_type == "opt":
            # For OPT models
            model_path = "model"
            decoder_path = "decoder"
            layers_path = "layers"
            layer_key = str(layer_idx)
            attn_path = "self_attn"
            
            # Get attention block
            attn_block = params[model_path][decoder_path][layers_path][layer_key][attn_path]
            
            # Calculate head dimensions
            hidden_size = attn_block["out_proj"]["kernel"].shape[0]
            head_size = hidden_size // self.num_heads
            
            # Calculate indices for this head
            start_idx = head_idx * head_size
            end_idx = (head_idx + 1) * head_size
            
            # Zero out the output projection for this head
            output_proj = attn_block["out_proj"]["kernel"]
            zeros = jnp.zeros_like(output_proj[start_idx:end_idx, :])
            output_proj = output_proj.at[start_idx:end_idx, :].set(zeros)
            
            # Update parameters
            params[model_path][decoder_path][layers_path][layer_key][attn_path]["out_proj"]["kernel"] = output_proj
            
        elif self.model_type == "pythia":
            # For Pythia models (similar to GPT-2)
            transformer_path = "transformer"
            layer_path = "h"
            layer_key = str(layer_idx)
            attn_path = "attn"
            
            # Get attention block
            attn_block = params[transformer_path][layer_path][layer_key][attn_path]
            
            # Calculate head dimensions
            hidden_size = attn_block["proj"]["kernel"].shape[0]
            head_size = hidden_size // self.num_heads
            
            # Calculate indices for this head
            start_idx = head_idx * head_size
            end_idx = (head_idx + 1) * head_size
            
            # Zero out the output projection for this head
            output_proj = attn_block["proj"]["kernel"]
            zeros = jnp.zeros_like(output_proj[start_idx:end_idx, :])
            output_proj = output_proj.at[start_idx:end_idx, :].set(zeros)
            
            # Update parameters
            params[transformer_path][layer_path][layer_key][attn_path]["proj"]["kernel"] = output_proj
        
        return params
    
    def evaluate_perplexity(self, params, text):
        """Evaluate model perplexity on text"""
        # Tokenize input
        inputs = self.tokenizer(text, return_tensors="jax")
        
        # Get logits
        outputs = self.model(**inputs, params=params)
        logits = outputs.logits
        
        # Calculate loss
        input_ids = inputs["input_ids"]
        
        # Shift logits and labels for next token prediction
        shift_logits = logits[:, :-1]
        shift_labels = input_ids[:, 1:]
        
        # Calculate cross entropy loss
        loss = jnp.mean(
            -jnp.sum(
                jax.nn.log_softmax(shift_logits) * jax.nn.one_hot(shift_labels, shift_logits.shape[-1]),
                axis=-1
            )
        )
        
        # Return perplexity
        return jnp.exp(loss).item()
    
    def generate_text(self, params, prompt, max_length=50):
        """Generate text using the model"""
        # Tokenize input
        inputs = self.tokenizer(prompt, return_tensors="jax")
        
        # Generate text
        outputs = self.model.generate(
            **inputs,
            params=params,
            max_length=max_length,
            do_sample=True,
            top_k=40,
            top_p=0.95,
            temperature=0.8
        )
        
        # Decode output
        text = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
        return text

## Pruning Strategies

Let's implement different strategies for selecting heads to prune.

In [None]:
class PruningStrategy:
    """Base class for pruning strategies"""
    
    def __init__(self, pruning_module):
        self.pruning_module = pruning_module
    
    def get_head_importance(self, params):
        """Calculate importance for all heads"""
        raise NotImplementedError("Subclasses must implement get_head_importance")
    
    def prune_heads(self, params, head_indices):
        """Prune specified heads"""
        pruned_params = jax.tree_util.tree_map(lambda x: x, params)  # Deep copy
        
        for layer_idx, head_idx in head_indices:
            pruned_params = self.pruning_module.prune_head(pruned_params, layer_idx, head_idx)
            
        return pruned_params

class RandomStrategy(PruningStrategy):
    """Random pruning strategy"""
    
    def get_head_importance(self, params):
        """Assign random importance to heads"""
        all_head_importance = []
        
        for layer_idx in range(self.pruning_module.num_layers):
            for head_idx in range(self.pruning_module.num_heads):
                # Random importance score
                score = random.random()
                all_head_importance.append((layer_idx, head_idx, score))
        
        return all_head_importance

class MagnitudeStrategy(PruningStrategy):
    """Magnitude-based pruning strategy"""
    
    def get_head_importance(self, params):
        """Calculate importance based on weight magnitudes"""
        all_head_importance = []
        model_type = self.pruning_module.model_type
        
        for layer_idx in range(self.pruning_module.num_layers):
            for head_idx in range(self.pruning_module.num_heads):
                # Get head weights based on model type
                if model_type == "gpt2":
                    # Access attention output projection
                    transformer_path = "transformer"
                    layer_path = "h"
                    layer_key = str(layer_idx)
                    attn_path = "attn"
                    
                    attn_block = params[transformer_path][layer_path][layer_key][attn_path]
                    output_proj = attn_block["c_proj"]["kernel"]
                    
                    # Calculate head dimensions
                    head_size = output_proj.shape[0] // self.pruning_module.num_heads
                    
                    # Get weights for this head
                    start_idx = head_idx * head_size
                    end_idx = (head_idx + 1) * head_size
                    head_weights = output_proj[start_idx:end_idx, :]
                    
                elif model_type == "opt":
                    # For OPT models
                    model_path = "model"
                    decoder_path = "decoder"
                    layers_path = "layers"
                    layer_key = str(layer_idx)
                    attn_path = "self_attn"
                    
                    attn_block = params[model_path][decoder_path][layers_path][layer_key][attn_path]
                    output_proj = attn_block["out_proj"]["kernel"]
                    
                    # Calculate head dimensions
                    head_size = output_proj.shape[0] // self.pruning_module.num_heads
                    
                    # Get weights for this head
                    start_idx = head_idx * head_size
                    end_idx = (head_idx + 1) * head_size
                    head_weights = output_proj[start_idx:end_idx, :]
                    
                elif model_type == "pythia":
                    # For Pythia models
                    transformer_path = "transformer"
                    layer_path = "h"
                    layer_key = str(layer_idx)
                    attn_path = "attn"
                    
                    attn_block = params[transformer_path][layer_path][layer_key][attn_path]
                    output_proj = attn_block["proj"]["kernel"]
                    
                    # Calculate head dimensions
                    head_size = output_proj.shape[0] // self.pruning_module.num_heads
                    
                    # Get weights for this head
                    start_idx = head_idx * head_size
                    end_idx = (head_idx + 1) * head_size
                    head_weights = output_proj[start_idx:end_idx, :]
                
                # Calculate importance as L2 norm of weights
                importance = jnp.linalg.norm(head_weights).item()
                all_head_importance.append((layer_idx, head_idx, importance))
        
        return all_head_importance

class AttentionEntropyStrategy(PruningStrategy):
    """Entropy-based pruning strategy using attention patterns"""
    
    def __init__(self, pruning_module, sample_text=None):
        super().__init__(pruning_module)
        
        # Sample text for evaluating attention entropy
        if sample_text is None:
            self.sample_text = [
                "The quick brown fox jumps over the lazy dog",
                "Artificial intelligence is transforming the world",
                "Machine learning models can process large amounts of data",
                "The future of technology depends on sustainable practices",
                "Researchers are working on new methods to improve efficiency"
            ]
        else:
            self.sample_text = sample_text if isinstance(sample_text, list) else [sample_text]
    
    def get_head_importance(self, params):
        """Calculate importance based on fallback to magnitude"""
        # For simplicity and compatibility across model types,
        # we'll just use magnitude-based pruning as a proxy
        magnitude_strategy = MagnitudeStrategy(self.pruning_module)
        return magnitude_strategy.get_head_importance(params)

# Factory function to get strategy by name
def get_strategy(name, pruning_module, sample_text=None):
    """Get pruning strategy by name"""
    if name.lower() == "random":
        return RandomStrategy(pruning_module)
    elif name.lower() == "magnitude":
        return MagnitudeStrategy(pruning_module)
    elif name.lower() == "entropy":
        return AttentionEntropyStrategy(pruning_module, sample_text)
    else:
        raise ValueError(f"Unknown strategy: {name}")

## Benchmark Runner

Now let's implement the main benchmark runner.

In [None]:
class PruningBenchmark:
    """Main benchmark runner"""
    
    def __init__(self, results_manager):
        self.results_manager = results_manager
    
    def run_single_benchmark(self, model_name, strategy_name, pruning_level, prompt):
        """Run a single pruning benchmark"""
        print(f"\nRunning benchmark for {model_name} with {strategy_name} strategy at {pruning_level:.2f} pruning level")
        
        # Initialize pruning module
        pruning_module = PruningModule(model_name)
        if not pruning_module.load_model():
            print(f"Failed to load model {model_name}")
            return None
        
        # Get strategy
        strategy = get_strategy(strategy_name, pruning_module, prompt)
        
        # Get original parameters and create a copy for pruning
        original_params = pruning_module.original_params
        params = jax.tree_util.tree_map(lambda x: x, original_params)  # Deep copy
        
        # Evaluate model before pruning
        print("Evaluating model before pruning...")
        perplexity_before = pruning_module.evaluate_perplexity(params, prompt)
        print(f"Perplexity before pruning: {perplexity_before:.4f}")
        
        generated_before = pruning_module.generate_text(params, prompt)
        print(f"Generated (before pruning): {generated_before}")
        
        # Calculate importance scores
        print("\nCalculating head importance...")
        all_head_importance = strategy.get_head_importance(params)
        
        # Sort by importance (ascending)
        all_head_importance.sort(key=lambda x: x[2])
        
        # Determine number of heads to prune
        total_heads = pruning_module.num_layers * pruning_module.num_heads
        heads_to_prune = int(total_heads * pruning_level)
        print(f"Pruning {heads_to_prune} out of {total_heads} heads")
        
        # Get head indices to prune (least important first)
        head_indices = [(l, h) for l, h, _ in all_head_importance[:heads_to_prune]]
        
        # Prune heads
        print("\nPruning heads...")
        pruned_params = strategy.prune_heads(params, head_indices)
        
        # Evaluate model after pruning
        print("\nEvaluating model after pruning...")
        perplexity_after = pruning_module.evaluate_perplexity(pruned_params, prompt)
        print(f"Perplexity after pruning: {perplexity_after:.4f}")
        print(f"Perplexity change: {perplexity_after - perplexity_before:.4f}")
        
        generated_after = pruning_module.generate_text(pruned_params, prompt)
        print(f"Generated (after pruning): {generated_after}")
        
        # Prepare result
        result = {
            "model": model_name,
            "strategy": strategy_name,
            "pruning_level": pruning_level,
            "pruned_heads": heads_to_prune,
            "total_heads": total_heads,
            "prompt": prompt,
            "perplexity_before": float(perplexity_before),
            "perplexity_after": float(perplexity_after),
            "perplexity_change": float(perplexity_after - perplexity_before),
            "generated_before": generated_before,
            "generated_after": generated_after,
            "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        }
        
        # Save result
        self.results_manager.save_result(result)
        
        print("\nBenchmark completed successfully!")
        return result
    
    def run_multiple_benchmarks(self, models=None, strategies=None, pruning_levels=None, prompt=None, max_runtime=None):
        """Run multiple benchmarks with different parameters"""
        # Default values
        if models is None:
            models = env.get_suitable_models()
        if strategies is None:
            strategies = ["random", "magnitude"]
        if pruning_levels is None:
            pruning_levels = [0.1, 0.2, 0.3, 0.4, 0.5]
        if prompt is None:
            prompt = "Artificial intelligence will transform"
            
        # Start time for runtime tracking
        start_time = time.time()
        
        # Generate all benchmark combinations
        benchmarks = []
        for model in models:
            for strategy in strategies:
                for level in pruning_levels:
                    benchmarks.append((model, strategy, level, prompt))
        
        # Shuffle to get more diverse results early
        random.shuffle(benchmarks)
        
        # Create progress bar
        pbar = tqdm(total=len(benchmarks), desc="Running benchmarks")
        
        # Run benchmarks
        results = []
        for i, (model, strategy, level, bench_prompt) in enumerate(benchmarks):
            # Check if we've exceeded the runtime limit
            if max_runtime is not None and time.time() - start_time > max_runtime:
                print(f"\nReached maximum runtime of {max_runtime/3600:.1f} hours")
                break
                
            # Update progress bar
            pbar.set_description(f"Running {model}, {strategy}, {level:.2f}")
            
            # Run benchmark
            try:
                result = self.run_single_benchmark(model, strategy, level, bench_prompt)
                if result is not None:
                    results.append(result)
                
                # Update progress bar
                pbar.update(1)
                
                # Plot intermediate results every few benchmarks
                if (i + 1) % 3 == 0 or i == len(benchmarks) - 1:
                    self.results_manager.plot_results()
                    plt.close()
            except Exception as e:
                print(f"Error in benchmark {model}, {strategy}, {level:.2f}: {e}")
                # Still update progress bar
                pbar.update(1)
        
        # Close progress bar
        pbar.close()
        
        # Final results
        print(f"\nCompleted {len(results)} benchmarks out of {len(benchmarks)} attempted")
        runtime = time.time() - start_time
        print(f"Total runtime: {runtime/3600:.2f} hours ({runtime/60:.2f} minutes)")
        
        # Plot final results
        self.results_manager.plot_results()
        
        return results

# Initialize benchmark runner
benchmark = PruningBenchmark(results_manager)

## Run a Single Benchmark

Let's run a single benchmark to test our implementation.

In [None]:
# Run a single benchmark
model_name = env.get_suitable_models()[0]  # Use the smallest model
result = benchmark.run_single_benchmark(
    model_name=model_name,
    strategy_name="random",
    pruning_level=0.1,
    prompt="Artificial intelligence will transform"
)

## Run Multiple Benchmarks

Now let's run multiple benchmarks to collect comprehensive results.

In [None]:
# Get available models
available_models = env.get_suitable_models()
print(f"Available models: {available_models}")

# Select a subset of models to use
models_to_test = available_models[:2]  # Use the first 2 models
print(f"Using models: {models_to_test}")

# Select strategies
strategies_to_test = ["random", "magnitude"]

# Select pruning levels
pruning_levels_to_test = [0.1, 0.3, 0.5]

# Set the prompt
prompt = "Artificial intelligence will transform society by"

# Set maximum runtime (in seconds) - 1 hour for Colab, 20 minutes for local
max_runtime = 3600 if env.in_colab else 1200  # 1 hour for Colab, 20 min for local

# Run the benchmarks
results = benchmark.run_multiple_benchmarks(
    models=models_to_test,
    strategies=strategies_to_test,
    pruning_levels=pruning_levels_to_test,
    prompt=prompt,
    max_runtime=max_runtime
)

## View and Analyze Results

Let's visualize and analyze our benchmark results.

In [None]:
# Load all results
results_manager.load_results()

# Print summary
results_manager.print_summary()

# Plot results
fig = results_manager.plot_results(figsize=(14, 8))

## Additional Analysis

Let's create some advanced visualizations to better understand our results.

In [None]:
# Advanced analysis if we have results
if results_manager.results_df is not None and not results_manager.results_df.empty:
    # Set figure size for all plots
    plt.figure(figsize=(14, 8))
    
    # 1. Box plot of perplexity change by strategy
    plt.subplot(2, 2, 1)
    sns.boxplot(x="strategy", y="perplexity_change", data=results_manager.results_df)
    plt.title("Perplexity Change by Strategy")
    plt.grid(True, linestyle="--", alpha=0.7)
    
    # 2. Box plot of perplexity change by model
    plt.subplot(2, 2, 2)
    sns.boxplot(x="model", y="perplexity_change", data=results_manager.results_df)
    plt.title("Perplexity Change by Model")
    plt.xticks(rotation=45, ha="right")
    plt.grid(True, linestyle="--", alpha=0.7)
    
    # 3. Heatmap of average perplexity change (strategy vs pruning level)
    plt.subplot(2, 2, 3)
    pivot_df = results_manager.results_df.pivot_table(
        index="strategy", 
        columns="pruning_level", 
        values="perplexity_change", 
        aggfunc="mean"
    )
    sns.heatmap(pivot_df, annot=True, cmap="RdYlGn_r", center=0)
    plt.title("Average Perplexity Change by Strategy and Pruning Level")
    
    # 4. Relationship between perplexity before and change
    plt.subplot(2, 2, 4)
    sns.scatterplot(
        x="perplexity_before", 
        y="perplexity_change", 
        hue="strategy", 
        size="pruning_level",
        sizes=(50, 200),
        data=results_manager.results_df
    )
    plt.title("Perplexity Change vs Initial Perplexity")
    plt.grid(True, linestyle="--", alpha=0.7)
    
    plt.tight_layout()
    plt.show()

## Run Overnight Benchmarks

For running comprehensive benchmarks overnight in Colab.

In [None]:
# Configuration for overnight benchmarks
OVERNIGHT_MODELS = env.get_suitable_models()  # Use all available models
OVERNIGHT_STRATEGIES = ["random", "magnitude", "entropy"]
OVERNIGHT_PRUNING_LEVELS = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
OVERNIGHT_PROMPTS = [
    "Artificial intelligence will transform society by",
    "The future of technology depends on new innovations that",
    "Scientists are developing advanced systems that can"
]

# Max runtime for overnight: 8 hours
OVERNIGHT_MAX_RUNTIME = 8 * 3600  # 8 hours in seconds

print("Overnight benchmark configuration:")
print(f"Models: {OVERNIGHT_MODELS}")
print(f"Strategies: {OVERNIGHT_STRATEGIES}")
print(f"Pruning levels: {OVERNIGHT_PRUNING_LEVELS}")
print(f"Number of prompts: {len(OVERNIGHT_PROMPTS)}")
print(f"Maximum runtime: {OVERNIGHT_MAX_RUNTIME/3600:.1f} hours")

# Total benchmark combinations
total_benchmarks = len(OVERNIGHT_MODELS) * len(OVERNIGHT_STRATEGIES) * len(OVERNIGHT_PRUNING_LEVELS)
print(f"Total benchmark combinations: {total_benchmarks}")

# Uncomment to run overnight benchmarks
# overnight_results = benchmark.run_multiple_benchmarks(
#     models=OVERNIGHT_MODELS,
#     strategies=OVERNIGHT_STRATEGIES,
#     pruning_levels=OVERNIGHT_PRUNING_LEVELS,
#     prompt=OVERNIGHT_PROMPTS[0],  # Use the first prompt
#     max_runtime=OVERNIGHT_MAX_RUNTIME
# )