# Pruning Benchmark Notebook (Colab-Optimized)

This notebook is designed to run pruning benchmarks on Google Colab overnight, leveraging JAX/Flax for stability and performance.

## Features
- Automatically detects and utilizes TPU/GPU when available
- Uses JAX/Flax for stable operation (works on M1/M2 Macs as well)
- Progressive visualization during benchmark runs
- Comprehensive analysis after completion
- Supports multiple models and pruning strategies

In [None]:
# Install dependencies
!pip install -q jax jaxlib flax transformers matplotlib numpy tqdm pandas seaborn

In [None]:
# Clone the repository
!git clone https://github.com/CambrianTech/sentinel-ai.git
%cd sentinel-ai

In [None]:
# Import the pruning library
from utils.pruning import (
    Environment,
    ResultsManager,
    PruningBenchmark
)

# Set up plotting
import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use('ggplot')
sns.set_theme(style="whitegrid")

In [None]:
# Initialize environment and detect capabilities
env = Environment()
env.print_info()

In [None]:
# Initialize results manager
results_manager = ResultsManager("pruning_results")
results_manager.load_results()
results_manager.print_summary()

In [None]:
# Initialize benchmark runner
benchmark = PruningBenchmark(results_manager)

## Run a Single Benchmark

Let's run a single benchmark to test our setup.

In [None]:
# Run a single benchmark
model_name = env.get_suitable_models()[0]  # Use the smallest model
result = benchmark.run_single_benchmark(
    model_name=model_name,
    strategy_name="random",
    pruning_level=0.1,
    prompt="Artificial intelligence will transform"
)

In [None]:
# Visualize the results so far
results_manager.plot_results(figsize=(12, 8))

## Progressive Pruning Test

Test how much we can prune before the model breaks down completely.

In [None]:
# Configuration
MODELS = env.get_suitable_models()[:2]  # Use the first 2 models
STRATEGIES = ["random", "magnitude"]
PRUNING_LEVELS = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
PROMPT = "Artificial intelligence will revolutionize"
MAX_RUNTIME = 3600  # 1 hour

print(f"Running progressive pruning test with:")
print(f"  Models: {MODELS}")
print(f"  Strategies: {STRATEGIES}")
print(f"  Pruning levels: {PRUNING_LEVELS}")
print(f"  Prompt: '{PROMPT}'")
print(f"  Maximum runtime: {MAX_RUNTIME/3600:.1f} hours")

In [None]:
# Run the benchmarks
results = benchmark.run_multiple_benchmarks(
    models=MODELS,
    strategies=STRATEGIES,
    pruning_levels=PRUNING_LEVELS,
    prompt=PROMPT,
    max_runtime=MAX_RUNTIME
)

## Multi-Model Comparison

Compare how different model architectures respond to pruning.

In [None]:
# Get all available models
ALL_MODELS = env.get_suitable_models()
print(f"Available models: {ALL_MODELS}")

# Configuration
STRATEGY = "magnitude"  # Most stable strategy
COMPARISON_LEVELS = [0.1, 0.3, 0.5, 0.7]
COMPARISON_PROMPT = "Artificial intelligence will transform society by"
COMPARISON_RUNTIME = 7200  # 2 hours

print(f"\nRunning multi-model comparison with:")
print(f"  Models: {ALL_MODELS}")
print(f"  Strategy: {STRATEGY}")
print(f"  Pruning levels: {COMPARISON_LEVELS}")
print(f"  Prompt: '{COMPARISON_PROMPT}'")
print(f"  Maximum runtime: {COMPARISON_RUNTIME/3600:.1f} hours")

In [None]:
# Run the multi-model comparison
comparison_results = benchmark.run_multiple_benchmarks(
    models=ALL_MODELS,
    strategies=[STRATEGY],
    pruning_levels=COMPARISON_LEVELS,
    prompt=COMPARISON_PROMPT,
    max_runtime=COMPARISON_RUNTIME
)

## Overnight Benchmark

Run a comprehensive overnight benchmark testing all combinations.

In [None]:
# Configuration for overnight run
OVERNIGHT_MODELS = env.get_suitable_models()  # Use all available models
OVERNIGHT_STRATEGIES = ["random", "magnitude", "entropy"]
OVERNIGHT_LEVELS = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
OVERNIGHT_PROMPT = "Artificial intelligence will revolutionize industries by"
OVERNIGHT_RUNTIME = 8 * 3600  # 8 hours

# Calculate number of benchmarks
TOTAL_BENCHMARKS = len(OVERNIGHT_MODELS) * len(OVERNIGHT_STRATEGIES) * len(OVERNIGHT_LEVELS)

print(f"Overnight benchmark configuration:")
print(f"  Models: {OVERNIGHT_MODELS}")
print(f"  Strategies: {OVERNIGHT_STRATEGIES}")
print(f"  Pruning levels: {OVERNIGHT_LEVELS}")
print(f"  Prompt: '{OVERNIGHT_PROMPT}'")
print(f"  Maximum runtime: {OVERNIGHT_RUNTIME/3600:.1f} hours")
print(f"  Total benchmarks: {TOTAL_BENCHMARKS}")

In [None]:
# Uncomment to run overnight benchmarks
# overnight_results = benchmark.run_multiple_benchmarks(
#     models=OVERNIGHT_MODELS,
#     strategies=OVERNIGHT_STRATEGIES,
#     pruning_levels=OVERNIGHT_LEVELS,
#     prompt=OVERNIGHT_PROMPT,
#     max_runtime=OVERNIGHT_RUNTIME
# )

## Comprehensive Analysis

Analyze all benchmark results collected.

In [None]:
# Load all results
results_manager.load_results()
results_manager.print_summary()

In [None]:
# Basic visualization
fig = results_manager.plot_results(figsize=(14, 8))

In [None]:
# Advanced analysis
if hasattr(results_manager, 'plot_advanced_analysis'):
    results_manager.plot_advanced_analysis(figsize=(14, 10))

## Custom Model Comparison Visualization

In [None]:
# Model Comparison Plot
if results_manager.results_df is not None and not results_manager.results_df.empty:
    plt.figure(figsize=(14, 8))
    
    # Get unique models and strategies
    models = results_manager.results_df["model"].unique()
    strategies = results_manager.results_df["strategy"].unique()
    
    # For each model and strategy combination
    for model in models:
        for strategy in strategies:
            # Filter data
            data = results_manager.results_df[
                (results_manager.results_df["model"] == model) &
                (results_manager.results_df["strategy"] == strategy)
            ]
            
            if not data.empty:
                # Sort by pruning level
                data = data.sort_values("pruning_level")
                
                # Plot
                plt.plot(
                    data["pruning_level"],
                    data["perplexity_change"],
                    marker="o",
                    label=f"{model} - {strategy}"
                )
    
    plt.xlabel("Pruning Level")
    plt.ylabel("Perplexity Change")
    plt.title("Effect of Pruning on Different Models")
    plt.legend()
    plt.grid(True, linestyle="--", alpha=0.7)
    plt.tight_layout()
    plt.show()