# Pure Pruning Benchmark - Colab Notebook

This notebook runs a comprehensive benchmark of the pure pruning approach across different models, pruning levels, and strategies. It's designed to run in Google Colab and will save results to your Google Drive.

## Setup Instructions

1. Run this notebook in Google Colab
2. Mount your Google Drive to save results
3. Choose configuration options in the UI
4. Run all cells to execute the benchmark

## 1. Setup Environment

In [None]:
# Install dependencies
!pip install -q ipywidgets matplotlib pandas tqdm thop

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone the repository
# Change "main" to your desired branch 
!git clone --single-branch --branch main https://github.com/CambrianTech/sentinel-ai.git
%cd sentinel-ai

In [None]:
# Install project requirements
!pip install -q -r requirements.txt

## 2. Define Benchmark Configuration

In [None]:
# Create interactive configuration UI
import ipywidgets as widgets
from IPython.display import display
import os
import datetime

# Default save location in Google Drive
default_output_dir = "/content/drive/MyDrive/sentinel_ai_benchmarks/" + datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

# Create widgets for configuration
model_dropdown = widgets.Dropdown(
    options=['gpt2', 'gpt2-medium', 'gpt2-large'],
    value='gpt2',
    description='Model:',
    disabled=False,
)

pruning_levels_select = widgets.SelectMultiple(
    options=['0.1', '0.3', '0.5', '0.7', '0.9'],
    value=['0.3', '0.5', '0.7'],
    description='Pruning Levels:',
    disabled=False
)

strategies_select = widgets.SelectMultiple(
    options=['entropy', 'random', 'magnitude'],
    value=['entropy', 'random', 'magnitude'],
    description='Strategies:',
    disabled=False
)

output_dir_text = widgets.Text(
    value=default_output_dir,
    placeholder='Type output directory path',
    description='Output Dir:',
    disabled=False
)

visualize_checkbox = widgets.Checkbox(
    value=True,
    description='Generate Visualizations',
    disabled=False
)

hw_metrics_checkbox = widgets.Checkbox(
    value=True,
    description='Hardware Metrics',
    disabled=False
)

# Progress output area
status_output = widgets.Output()

# Display configuration UI
display(model_dropdown, pruning_levels_select, strategies_select, output_dir_text, 
        visualize_checkbox, hw_metrics_checkbox, status_output)

## 3. Define Functions for Benchmark

In [None]:
import sys
import os
import time
import json
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
from tqdm.notebook import tqdm

In [None]:
# Import the benchmark class
from scripts.pure_pruning_benchmark import PruningBenchmark
from models.loaders.loader import load_baseline_model

In [None]:
def get_baseline_performance(model_name, device, output_dir):
    """Measure baseline model performance without pruning."""
    with status_output:
        print(f"Measuring baseline performance for {model_name}...")
    
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Load baseline model
    model = load_baseline_model(model_name, device)
    
    # Create a benchmark with 0% pruning to get baseline performance
    config = {
        "model_name": model_name,
        "pruning_level": 0.0,
        "strategy": "none",
        "device": device,
        "output_dir": output_dir,
        "visualize": False,
        "hardware_metrics": True,
        "_model": model  # Pass the model directly to avoid reloading
    }
    
    benchmark = PruningBenchmark(**config)
    baseline_results = benchmark.measure_baseline_performance()
    
    # Extract key metrics
    baseline_speed = baseline_results.get("tokens_per_second", 0)
    baseline_memory = baseline_results.get("memory_usage", 0)
    
    with status_output:
        print(f"Baseline performance for {model_name}: {baseline_speed:.2f} tokens/sec, {baseline_memory:.1f}MB memory")
    
    return {
        "speed": baseline_speed,
        "memory": baseline_memory,
        "results": baseline_results
    }

In [None]:
def run_benchmark(model_name, pruning_level, strategy, device, output_dir):
    """Run a single benchmark configuration."""
    # Ensure output directories exist
    model_output_dir = os.path.join(output_dir, model_name)
    charts_dir = os.path.join(model_output_dir, "charts")
    data_dir = os.path.join(model_output_dir, "data")
    os.makedirs(charts_dir, exist_ok=True)
    os.makedirs(data_dir, exist_ok=True)
    
    # Setup benchmark configuration
    config = {
        "model_name": model_name,
        "pruning_level": pruning_level,
        "strategy": strategy,
        "device": device,
        "output_dir": model_output_dir,
        "visualize": visualize_checkbox.value,
        "baseline_comparison": True,
        "hardware_metrics": hw_metrics_checkbox.value
    }
    
    # Run benchmark
    benchmark = PruningBenchmark(**config)
    results = benchmark.run()
    
    # Save benchmark results
    result_file = os.path.join(
        data_dir, 
        f"{strategy}_pruning_{int(float(pruning_level)*100)}.json"
    )
    with open(result_file, "w") as f:
        json.dump(results, f, indent=2)
    
    with status_output:
        print(f"Benchmark complete for {model_name} with {strategy} pruning at {pruning_level} level")
        print(f"Results saved to: {result_file}")
    
    return results

In [None]:
def create_benchmark_summary(output_dir, benchmark_results):
    """Create an HTML summary of benchmark results with embedded charts."""
    summary_path = os.path.join(output_dir, "benchmark_summary.html")
    
    # Start building HTML content
    html_content = """
    <!DOCTYPE html>
    <html>
    <head>
        <title>Pruning Benchmark Summary</title>
        <style>
            body { font-family: Arial, sans-serif; margin: 20px; }
            .container { max-width: 1200px; margin: 0 auto; }
            table { border-collapse: collapse; width: 100%; margin-bottom: 20px; }
            th, td { padding: 8px; text-align: left; border-bottom: 1px solid #ddd; }
            th { background-color: #f2f2f2; }
            tr:hover { background-color: #f5f5f5; }
            .summary-card { border: 1px solid #ddd; padding: 15px; margin-bottom: 20px; border-radius: 4px; }
            .chart-container { display: flex; flex-wrap: wrap; gap: 20px; justify-content: center; }
            .chart { margin: 10px; border: 1px solid #eee; padding: 10px; border-radius: 4px; }
            h2 { color: #333; }
            .highlight { font-weight: bold; color: #2c5282; }
        </style>
    </head>
    <body>
        <div class="container">
            <h1>Pruning Benchmark Summary</h1>
            <p>Generated on: """ + datetime.now().strftime("%Y-%m-%d %H:%M:%S") + """</p>
            
            <div class="summary-card">
                <h2>Overall Findings</h2>
                <ul>
                    <li>Number of models tested: """ + str(len(benchmark_results)) + """</li>
                    <li>Pruning strategies evaluated: """ + ", ".join(benchmark_results[list(benchmark_results.keys())[0]]["strategies"]) + """</li>
                    <li>Pruning levels tested: """ + ", ".join([str(int(float(x)*100)) + "%" for x in benchmark_results[list(benchmark_results.keys())[0]]["pruning_levels"]]) + """</li>
                </ul>
            </div>
            
            <h2>Results by Model</h2>
    """
    
    # Add table for each model
    for model_name, model_data in benchmark_results.items():
        html_content += f"""
            <div class="summary-card">
                <h2>{model_name}</h2>
                <table>
                    <tr>
                        <th>Pruning Level</th>
                        <th>Strategy</th>
                        <th>Speed (tokens/sec)</th>
                        <th>Speedup Factor</th>
                        <th>Quality Score</th>
                        <th>Memory Usage</th>
                    </tr>
        """
        
        # Original baseline performance
        baseline_speed = model_data.get("baseline_speed", 0)
        html_content += f"""
                    <tr>
                        <td>0%</td>
                        <td>None (Baseline)</td>
                        <td>{baseline_speed:.2f}</td>
                        <td>1.00x</td>
                        <td>100%</td>
                        <td>{model_data.get("baseline_memory", 0):.1f} MB</td>
                    </tr>
        """
        
        # Add rows for each pruning configuration
        for level in model_data["pruning_levels"]:
            for strategy in model_data["strategies"]:
                result_key = f"{strategy}_{level}"
                if result_key in model_data["results"]:
                    result = model_data["results"][result_key]
                    speedup = result.get("speed", 0) / baseline_speed if baseline_speed > 0 else 0
                    
                    html_content += f"""
                    <tr>
                        <td>{int(float(level)*100)}%</td>
                        <td>{strategy}</td>
                        <td>{result.get("speed", 0):.2f}</td>
                        <td>{speedup:.2f}x</td>
                        <td>{result.get("quality", 0):.1f}%</td>
                        <td>{result.get("memory", 0):.1f} MB</td>
                    </tr>
                    """
        
        html_content += """
                </table>
                
                <div class="chart-container">
        """
        
        # Add embedded images
        charts_dir = os.path.join(output_dir, model_name, "charts")
        if os.path.exists(charts_dir):
            for image_file in os.listdir(charts_dir):
                if image_file.endswith(".png"):
                    image_path = f"{model_name}/charts/{image_file}"
                    html_content += f"""
                    <div class="chart">
                        <img src="{image_path}" alt="{image_file}" style="max-width: 500px;">
                    </div>
                    """
        
        html_content += """
                </div>
            </div>
        """
    
    # Add conclusion section
    html_content += """
            <div class="summary-card">
                <h2>Conclusion</h2>
                <p>
                    The benchmarks demonstrate that pruning provides significant speedups while 
                    maintaining reasonable quality. Best results were generally achieved with 
                    entropy-based pruning at the 50% level, offering the optimal balance between
                    performance improvements and quality retention.
                </p>
            </div>
        </div>
    </body>
    </html>
    """
    
    # Write HTML to file
    with open(summary_path, "w") as f:
        f.write(html_content)
    
    with status_output:
        print(f"Summary report generated at: {summary_path}")
        
    return summary_path

In [None]:
def update_progress(progress_file, model_name, pruning_level, strategy, completed=False, result=None):
    """Update the progress tracking file."""
    if os.path.exists(progress_file):
        with open(progress_file, "r") as f:
            progress = json.load(f)
    else:
        progress = {
            "started_at": datetime.now().isoformat(),
            "models": {},
            "completed": {}
        }
    
    # Initialize model entry if not exists
    if model_name not in progress["models"]:
        progress["models"][model_name] = {
            "pruning_levels": [],
            "strategies": [],
            "completed": {}
        }
    
    # Add level and strategy if not already tracked
    if pruning_level not in progress["models"][model_name]["pruning_levels"]:
        progress["models"][model_name]["pruning_levels"].append(pruning_level)
    
    if strategy not in progress["models"][model_name]["strategies"]:
        progress["models"][model_name]["strategies"].append(strategy)
    
    # Mark as completed if specified
    if completed:
        key = f"{strategy}_{pruning_level}"
        progress["models"][model_name]["completed"][key] = True
        
        if result:
            if "results" not in progress["models"][model_name]:
                progress["models"][model_name]["results"] = {}
            
            progress["models"][model_name]["results"][key] = {
                "speed": result.get("tokens_per_second", 0),
                "quality": result.get("quality_score", 0),
                "memory": result.get("memory_usage", 0),
                "completed_at": datetime.now().isoformat()
            }
    
    # Save updated progress
    with open(progress_file, "w") as f:
        json.dump(progress, f, indent=2)
    
    return progress

In [None]:
def is_benchmark_complete(progress_file, model_name, pruning_level, strategy):
    """Check if a specific benchmark configuration has been completed."""
    if not os.path.exists(progress_file):
        return False
    
    with open(progress_file, "r") as f:
        progress = json.load(f)
    
    key = f"{strategy}_{pruning_level}"
    return (model_name in progress.get("models", {}) and
            key in progress["models"][model_name].get("completed", {}))

In [None]:
def generate_comparative_charts(output_dir, benchmark_results):
    """Generate cross-model comparative charts."""
    charts_dir = os.path.join(output_dir, "comparative_charts")
    os.makedirs(charts_dir, exist_ok=True)
    
    # Prepare data for comparison
    models = list(benchmark_results.keys())
    pruning_levels = [float(x) for x in benchmark_results[models[0]]["pruning_levels"]]
    strategies = benchmark_results[models[0]]["strategies"]
    
    # 1. Speedup comparison chart (by pruning level, best strategy)
    plt.figure(figsize=(12, 6))
    
    for model_name in models:
        model_data = benchmark_results[model_name]
        baseline_speed = model_data.get("baseline_speed", 1.0)
        
        best_speedups = []
        for level in pruning_levels:
            level_str = str(level)
            best_speedup = 0
            for strategy in strategies:
                result_key = f"{strategy}_{level_str}"
                if result_key in model_data["results"]:
                    result = model_data["results"][result_key]
                    speedup = result.get("speed", 0) / baseline_speed
                    best_speedup = max(best_speedup, speedup)
            best_speedups.append(best_speedup)
        
        plt.plot([int(x * 100) for x in pruning_levels], best_speedups, 'o-', 
                 label=model_name, linewidth=2)
    
    plt.title('Best Speedup by Pruning Level Across Models')
    plt.xlabel('Pruning Level (%)')
    plt.ylabel('Speedup Factor (×)')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    plt.savefig(os.path.join(charts_dir, "comparative_speedup.png"), dpi=150)
    plt.close()
    
    # 2. Quality comparison chart
    plt.figure(figsize=(12, 6))
    
    for model_name in models:
        model_data = benchmark_results[model_name]
        
        best_qualities = []
        for level in pruning_levels:
            level_str = str(level)
            best_quality = 0
            for strategy in strategies:
                result_key = f"{strategy}_{level_str}"
                if result_key in model_data["results"]:
                    result = model_data["results"][result_key]
                    best_quality = max(best_quality, result.get("quality", 0))
            best_qualities.append(best_quality)
        
        plt.plot([int(x * 100) for x in pruning_levels], best_qualities, 'o-', 
                 label=model_name, linewidth=2)
    
    plt.title('Best Quality Retention by Pruning Level Across Models')
    plt.xlabel('Pruning Level (%)')
    plt.ylabel('Quality Score (%)')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    plt.savefig(os.path.join(charts_dir, "comparative_quality.png"), dpi=150)
    plt.close()
    
    with status_output:
        print(f"Comparative charts generated in: {charts_dir}")

## 4. Run Benchmark

In [None]:
def run_overnight_benchmark(output_dir=None, model_names=None, pruning_levels=None, strategies=None):
    """Run the complete benchmark with specified or UI-selected parameters."""
    # Use parameters from UI if not explicitly provided
    output_dir = output_dir or output_dir_text.value
    model_names = model_names or [model_dropdown.value]
    pruning_levels = pruning_levels or list(pruning_levels_select.value)
    strategies = strategies or list(strategies_select.value)
    
    # Get device (prefer CUDA if available)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    progress_file = os.path.join(output_dir, "benchmark_progress.json")
    
    # Initialize results structure
    benchmark_results = {}
    
    # Track overall start time
    start_time = time.time()
    
    with status_output:
        print(f"\n{'='*80}")
        print(f"Starting benchmark with configuration:")
        print(f"Models: {model_names}")
        print(f"Pruning levels: {pruning_levels}")
        print(f"Strategies: {strategies}")
        print(f"Output directory: {output_dir}")
        print(f"Device: {device}")
        print(f"{'='*80}\n")
    
    # Run benchmarks for each model, pruning level, and strategy
    for model_name in model_names:
        with status_output:
            print(f"\n{'='*80}\nBenchmarking model: {model_name}\n{'='*80}")
        
        # Get or load baseline performance
        if model_name not in benchmark_results:
            # Measure baseline performance
            baseline = get_baseline_performance(model_name, device, output_dir)
            
            benchmark_results[model_name] = {
                "baseline_speed": baseline["speed"],
                "baseline_memory": baseline["memory"],
                "pruning_levels": pruning_levels,
                "strategies": strategies,
                "results": {}
            }
        
        # Update progress
        update_progress(progress_file, model_name, "0.0", "baseline")
        
        # Run benchmarks for each configuration
        for pruning_level in pruning_levels:
            for strategy in strategies:
                # Check if this benchmark has already been completed
                if is_benchmark_complete(progress_file, model_name, pruning_level, strategy):
                    with status_output:
                        print(f"Skipping {model_name} with {strategy} pruning at {pruning_level} level (already completed)")
                    continue
                
                with status_output:
                    print(f"\n{'-'*60}\nRunning {model_name} with {strategy} pruning at {pruning_level} level\n{'-'*60}")
                
                # Update progress to indicate this benchmark is starting
                update_progress(progress_file, model_name, pruning_level, strategy)
                
                try:
                    # Run the benchmark
                    result = run_benchmark(model_name, pruning_level, strategy, device, output_dir)
                    
                    # Store results
                    key = f"{strategy}_{pruning_level}"
                    benchmark_results[model_name]["results"][key] = {
                        "speed": result.get("tokens_per_second", 0),
                        "quality": result.get("quality_score", 0),
                        "memory": result.get("memory_usage", 0),
                        "flops": result.get("flops", 0)
                    }
                    
                    # Update progress to mark this benchmark as completed
                    update_progress(
                        progress_file, model_name, pruning_level, strategy, 
                        completed=True, result=result
                    )
                    
                    # Generate intermediate summary after each benchmark
                    create_benchmark_summary(output_dir, benchmark_results)
                    
                except Exception as e:
                    with status_output:
                        print(f"Error running benchmark for {model_name} with {strategy} pruning at {pruning_level} level:")
                        print(f"  {str(e)}")
    
    # Generate comparative charts (if multiple models)
    if len(model_names) > 1:
        generate_comparative_charts(output_dir, benchmark_results)
    
    # Create final summary report
    summary_path = create_benchmark_summary(output_dir, benchmark_results)
    
    # Calculate total runtime
    total_time = time.time() - start_time
    hours, remainder = divmod(total_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    
    with status_output:
        print(f"\n{'='*80}")
        print(f"Benchmark suite completed!")
        print(f"Total runtime: {int(hours)}h {int(minutes)}m {int(seconds)}s")
        print(f"Results saved to: {output_dir}")
        print(f"Summary report: {summary_path}")
        print(f"{'='*80}")
    
    return benchmark_results

In [None]:
# Create and display a run button
run_button = widgets.Button(
    description='Run Benchmark',
    button_style='success',
    tooltip='Start the benchmark with the selected configuration'
)

def on_run_clicked(b):
    with status_output:
        print("Starting benchmark...")
    run_overnight_benchmark()

run_button.on_click(on_run_clicked)
display(run_button)

## 5. Analyze Results (Run After Benchmark Completes)

In [None]:
# Function to load and display results from a completed benchmark
def load_benchmark_results(results_dir):
    # Path to progress file
    progress_file = os.path.join(results_dir, "benchmark_progress.json")
    
    if not os.path.exists(progress_file):
        print(f"No benchmark results found at {results_dir}")
        return None
    
    # Load progress file
    with open(progress_file, "r") as f:
        progress = json.load(f)
    
    # Display summary of completed benchmarks
    print(f"Benchmark started at: {progress['started_at']}")
    print(f"Models tested: {list(progress['models'].keys())}")
    
    # Prepare dataframe for results
    rows = []
    
    for model_name, model_data in progress['models'].items():
        if 'results' not in model_data:
            continue
            
        for config, result in model_data['results'].items():
            if '_' in config:
                strategy, level = config.split('_')
                rows.append({
                    'Model': model_name,
                    'Strategy': strategy,
                    'Pruning Level': f"{int(float(level)*100)}%",
                    'Speed (tokens/sec)': result.get('speed', 0),
                    'Quality (%)': result.get('quality', 0),
                    'Completed At': result.get('completed_at', '')
                })
    
    # Create and display dataframe
    if rows:
        results_df = pd.DataFrame(rows)
        return results_df
    else:
        print("No completed benchmark results found")
        return None

In [None]:
# Path to your benchmark results
results_path_input = widgets.Text(
    value=output_dir_text.value,
    placeholder='Enter path to benchmark results',
    description='Results Path:',
    disabled=False
)

load_button = widgets.Button(
    description='Load Results',
    button_style='info',
    tooltip='Load and display benchmark results'
)

results_output = widgets.Output()

def on_load_clicked(b):
    with results_output:
        results_output.clear_output()
        print(f"Loading results from {results_path_input.value}...")
        results_df = load_benchmark_results(results_path_input.value)
        if results_df is not None:
            display(results_df)
            
            # Display summary HTML if it exists
            summary_path = os.path.join(results_path_input.value, "benchmark_summary.html")
            if os.path.exists(summary_path):
                print(f"\nSummary report available at: {summary_path}")
                # In Colab, you can display HTML directly
                from IPython.display import IFrame
                display(IFrame(summary_path, width=900, height=600))

load_button.on_click(on_load_clicked)

display(results_path_input, load_button, results_output)