# Pruning Strategies Comparison with Modular Experiment Framework (v0.1.0)

This notebook demonstrates how to run experiments comparing different pruning strategies using the modular experiment framework. It's designed for use in Colab, but works in any environment.

## Overview

In this notebook we'll:
1. Set up the environment and clone the repository
2. Compare different pruning strategies (random, magnitude, entropy) with the same model
3. Visualize and analyze the results
4. Compare recovery rates across strategies

The notebook uses the `PruningExperiment` and `PruningFineTuningExperiment` classes from the modular experiment framework.

## Setup

First, we'll install the required dependencies and clone the repository:

In [None]:
# Install required packages
!pip install -q jax jaxlib flax transformers matplotlib numpy pandas seaborn tqdm optax
!pip install -q 'datasets>=2.0.0' scikit-learn hmmlearn

In [ ]:
# Clone the repository - using the refactor/modular-experiment branch
!git clone -b refactor/modular-experiment https://github.com/CambrianTech/sentinel-ai.git
# Don't cd into it yet

In [None]:
# Import huggingface datasets directly before changing directory
# This ensures we're using the system package
from datasets import load_dataset
import datasets
print(f"Using datasets from: {datasets.__file__}")

# Now safely change to the repository directory
%cd sentinel-ai

# Import standard libraries
import os
import sys
import json
import time
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from pathlib import Path
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Import JAX/Flax
import jax
import jax.numpy as jnp
import optax

# Import Hugging Face libraries
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM

# Add the current directory to path and import our modules
sys.path.append(".")
from utils.pruning import (
    Environment,
    PruningModule, 
    get_strategy,
    PruningExperiment
)

# Set up plotting
plt.style.use('ggplot')
sns.set_theme(style="whitegrid")

## Environment Detection

Next, we'll detect our environment capabilities to optimize experiments for the available hardware:

In [None]:
# Initialize environment and detect capabilities
env = Environment()
env.print_info()

# Check JAX capabilities
print(f"\nJAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

# Get suitable models based on hardware
suitable_models = env.get_suitable_models()
print(f"\nSuitable models for this environment: {', '.join(suitable_models)}")

## Strategy Comparison Experiment

Now, we'll set up an experiment to compare different pruning strategies on the same model. We'll use a small model (distilgpt2) and run multiple experiments with different strategies.

In [None]:
# Create a directory for experiment results
results_dir = "strategy_comparison_results"
os.makedirs(results_dir, exist_ok=True)

# Select a model that's suitable for our environment
# We'll prefer smaller models for faster experiments
preferred_models = ["distilgpt2", "gpt2", "facebook/opt-125m", "EleutherAI/pythia-160m"]
model_to_use = None

for model in preferred_models:
    if model in suitable_models:
        model_to_use = model
        break
        
if not model_to_use and suitable_models:
    # If none of our preferred models are available, use the first suitable one
    model_to_use = suitable_models[0]
    
if not model_to_use:
    # Fallback to a small model if no detection was possible
    model_to_use = "distilgpt2"
    
print(f"Selected model for experiments: {model_to_use}")

In [ ]:
# Initialize the experiment using our modular experiment framework
# This provides a consistent, reusable approach to running pruning experiments
experiment = PruningExperiment(
    results_dir=results_dir,
    use_improved_fine_tuner=True,  # Use improved fine-tuner with NaN prevention
    detect_environment=True,       # Automatically detect hardware capabilities 
    optimize_memory=True           # Optimize parameters based on model and hardware
)

# Configure the experiment
strategies = ["random", "magnitude", "entropy"]
pruning_level = 0.3  # 30% pruning
prompt = "The future of artificial intelligence is"
fine_tuning_epochs = 1  # Just one epoch for demonstration

## Run Experiments

Now we'll run experiments for each strategy:

In [None]:
# Run experiments for each strategy
results = []

for strategy in strategies:
    print(f"\n\nRunning experiment with {strategy} strategy...")
    
    result = experiment.run_single_experiment(
        model=model_to_use,
        strategy=strategy,
        pruning_level=pruning_level,
        prompt=prompt,
        fine_tuning_epochs=fine_tuning_epochs,
        save_results=True
    )
    
    results.append(result)
    
    # Update the visualization after each experiment
    experiment.plot_results()

## Comparative Analysis

Now let's analyze the results to compare the different strategies:

In [None]:
# Create a DataFrame for comparison
comparison_data = []

for result in results:
    strategy = result["strategy"]
    baseline_perplexity = result["stages"]["baseline"]["perplexity"]
    pruned_perplexity = result["stages"]["pruned"]["perplexity"]
    pruning_effect = pruned_perplexity - baseline_perplexity
    
    if "fine_tuned" in result["stages"]:
        fine_tuned_perplexity = result["stages"]["fine_tuned"]["perplexity"]
        fine_tuning_effect = fine_tuned_perplexity - pruned_perplexity
        net_effect = fine_tuned_perplexity - baseline_perplexity
        
        # Get recovery or improvement metrics
        recovery_percentage = result["stages"]["fine_tuned"].get("recovery_percentage", None)
        improvement_percentage = result["stages"]["fine_tuned"].get("improvement_percentage", None)
        
        comparison_data.append({
            "Strategy": strategy,
            "Baseline Perplexity": baseline_perplexity,
            "Pruned Perplexity": pruned_perplexity,
            "Fine-tuned Perplexity": fine_tuned_perplexity,
            "Pruning Effect": pruning_effect,
            "Fine-tuning Effect": fine_tuning_effect,
            "Net Effect": net_effect,
            "Recovery %": recovery_percentage,
            "Improvement %": improvement_percentage
        })
    else:
        comparison_data.append({
            "Strategy": strategy,
            "Baseline Perplexity": baseline_perplexity,
            "Pruned Perplexity": pruned_perplexity,
            "Pruning Effect": pruning_effect
        })

comparison_df = pd.DataFrame(comparison_data)
comparison_df

## Visualizing Strategy Comparison

Let's create a custom visualization to compare the strategies more directly:

In [None]:
# Create a bar chart comparing strategies
plt.figure(figsize=(12, 6))

# Plot baseline, pruned, and fine-tuned perplexity for each strategy
x = np.arange(len(strategies))
width = 0.25

if "Fine-tuned Perplexity" in comparison_df.columns:
    # If we have fine-tuning results
    plt.bar(x - width, comparison_df["Baseline Perplexity"], width, label="Baseline")
    plt.bar(x, comparison_df["Pruned Perplexity"], width, label="Pruned")
    plt.bar(x + width, comparison_df["Fine-tuned Perplexity"], width, label="Fine-tuned")
else:
    # If we only have pruning results
    plt.bar(x - width/2, comparison_df["Baseline Perplexity"], width, label="Baseline")
    plt.bar(x + width/2, comparison_df["Pruned Perplexity"], width, label="Pruned")

plt.xlabel("Pruning Strategy")
plt.ylabel("Perplexity")
plt.title(f"Comparison of Pruning Strategies ({model_to_use}, {pruning_level*100:.0f}% pruning)")
plt.xticks(x, strategies)
plt.legend()
plt.grid(True, alpha=0.3)

# Add value labels
for i, strategy in enumerate(strategies):
    row = comparison_df[comparison_df["Strategy"] == strategy].iloc[0]
    
    # Baseline
    plt.text(i - width, row["Baseline Perplexity"] + 5, 
             f"{row['Baseline Perplexity']:.1f}", 
             ha="center", va="bottom", fontsize=9)
    
    # Pruned
    plt.text(i, row["Pruned Perplexity"] + 5, 
             f"{row['Pruned Perplexity']:.1f}", 
             ha="center", va="bottom", fontsize=9)
    
    # Fine-tuned (if available)
    if "Fine-tuned Perplexity" in comparison_df.columns:
        plt.text(i + width, row["Fine-tuned Perplexity"] + 5, 
                 f"{row['Fine-tuned Perplexity']:.1f}", 
                 ha="center", va="bottom", fontsize=9)

plt.tight_layout()
plt.show()

## Recovery/Improvement Comparison

If we have fine-tuning results, let's compare the recovery or improvement percentages:

In [None]:
if "Recovery %" in comparison_df.columns or "Improvement %" in comparison_df.columns:
    plt.figure(figsize=(10, 5))
    
    # Create a combined metric for display (recovery is positive, improvement is negative)
    recovery_values = []
    for i, strategy in enumerate(strategies):
        row = comparison_df[comparison_df["Strategy"] == strategy].iloc[0]
        
        if pd.notna(row["Recovery %"]):
            # This is a recovery scenario (pruning hurt, fine-tuning helped)
            recovery_values.append(row["Recovery %"])
        elif pd.notna(row["Improvement %"]):
            # This is an improvement scenario (pruning helped, fine-tuning helped more)
            recovery_values.append(-row["Improvement %"])
        else:
            recovery_values.append(0)
    
    # Create bars with different colors based on whether it's recovery or improvement
    colors = ["red" if val >= 0 else "green" for val in recovery_values]
    plt.bar(strategies, recovery_values, color=colors)
    
    # Add horizontal line at 0
    plt.axhline(y=0, color="black", linestyle="--", alpha=0.5)
    
    # Add labels
    plt.text(strategies[0], 50, "Recovery %", color="red", ha="center", va="center", fontsize=10)
    plt.text(strategies[0], -50, "Improvement %", color="green", ha="center", va="center", fontsize=10)
    
    # Add value labels to bars
    for i, val in enumerate(recovery_values):
        if val >= 0:
            plt.text(i, val + 5, f"{val:.1f}%", ha="center", va="bottom", color="red", fontsize=9)
        else:
            plt.text(i, val - 5, f"{-val:.1f}%", ha="center", va="top", color="green", fontsize=9)
    
    plt.xlabel("Pruning Strategy")
    plt.ylabel("Percentage")
    plt.title(f"Recovery or Improvement by Strategy ({model_to_use}, {pruning_level*100:.0f}% pruning)")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

## Generated Text Comparison

Let's compare the quality of generated text from the baseline, pruned, and fine-tuned models:

In [None]:
# Extract generated text examples
for result in results:
    strategy = result["strategy"]
    print(f"\n\n{'='*80}\nStrategy: {strategy}\n{'='*80}")
    
    print("\nPrompt:")
    print(result["prompt"])
    
    print("\nBaseline generated:")
    print(result["stages"]["baseline"]["generated_text"])
    
    print("\nPruned generated:")
    print(result["stages"]["pruned"]["generated_text"])
    
    if "fine_tuned" in result["stages"]:
        print("\nFine-tuned generated:")
        print(result["stages"]["fine_tuned"]["generated_text"])

## Conclusion

Based on the experiments, we can draw the following conclusions about the different pruning strategies:

1. **Entropy-based Pruning**: [Fill in with observations from your experiments]
2. **Magnitude-based Pruning**: [Fill in with observations from your experiments]
3. **Random Pruning**: [Fill in with observations from your experiments]

The most effective strategy appears to be [fill in based on results], which achieved [summarize key results].

## Save Results

Finally, let's save our comparison results to a CSV file for further analysis:

In [None]:
# Save comparison results
comparison_csv_path = os.path.join(results_dir, "strategy_comparison.csv")
comparison_df.to_csv(comparison_csv_path, index=False)
print(f"Comparison results saved to {comparison_csv_path}")

# If we're in Colab, download the results
try:
    from google.colab import files
    files.download(comparison_csv_path)
    print("\nDownload initiated. Check your browser downloads.")
except:
    print("\nNot running in Google Colab. Results saved locally.")