# Pruning and Fine-Tuning Benchmark for Google Colab (v0.0.28.0)

This is the Python script version of our notebook for Google Colab. 

- Version 0.0.28.0 (April 2025) - Enhance profile_full_model.py with improved organization and usability
- Version 0.0.27.0 (April 2025) - Use PruningFineTuningExperiment from utils.pruning
- Version 0.0.26 (April 2025) - refactor/modular-experiment 
   - Added Colab utilities for automatic environment optimization. 
   - Fix import paths for Colab compatibility.

Instructions:
1. Upload to a new Colab notebook using File > Upload notebook > Upload
2. The notebook will automatically configure the environment with:
   - GPU acceleration selection
   - Memory-optimized parameters based on available resources
   - Adaptive model configuration based on memory constraints

## Overview

1. **Baseline Evaluation**: Establish the initial model performance
2. **Pruning Phase**: Apply different pruning strategies and evaluate post-pruning performance
3. **Fine-Tuning Phase**: Fine-tune pruned models to recover or improve performance
4. **Analysis**: Compare performance across pruning levels and fine-tuning epochs

This experiment will run until interrupted, continuously improving the models and updating visualizations.

In [None]:
# Set up Colab environment with GPU acceleration
import os

# Install required packages and make sure HuggingFace datasets is properly installed
!pip install -q jax jaxlib flax transformers matplotlib numpy pandas seaborn tqdm optax psutil
!pip install -q 'datasets>=2.0.0' multiprocess

In [None]:
# Clone the repository (or branch if in one)
!git clone -b refactor/modular-experiment https://github.com/CambrianTech/sentinel-ai.git
# Don't cd into it yet

In [None]:
# Change to the repository directory
%cd sentinel-ai

# Import huggingface datasets directly (before any potential conflicts)
from datasets import load_dataset
import datasets
print(f"Using datasets from: {datasets.__file__}")

# Set up Colab environment with GPU and memory optimization
try:
    from utils.colab import setup_colab_environment, optimize_for_colab
    
    # Configure Colab environment with GPU preference
    env_info = setup_colab_environment(prefer_gpu=True)
    
    # Get optimized parameters based on model sizes we'll use (medium is good default)
    params = optimize_for_colab(model_size="medium", prefer_stability=True)
    
    # Extract parameters for use in experiment
    optimized_batch_size = params["batch_size"]
    sequence_length = params["sequence_length"]
    stability_level = params["stability_level"]
    use_fp16 = params["use_fp16"]
    
    print(f"\n✅ Using optimized parameters for Colab:")
    print(f"  - Batch size: {optimized_batch_size}")
    print(f"  - Sequence length: {sequence_length}")
    print(f"  - Stability level: {stability_level}")
    print(f"  - Mixed precision: {use_fp16}")
    
except ImportError:
    print("⚠️ Colab utilities not available, using default parameters")
    print("This may be the first run before utils/colab are present")
    
    # Try to auto-select GPU via Google Colab runtime API
    try:
        from google.colab import runtime
        runtime.change_runtime(runtime_type="GPU")
        print("✅ GPU acceleration enabled!")
    except:
        print("⚠️ Could not auto-select GPU. Please set it manually.")
    
    # Check for GPU availability
    try:
        !nvidia-smi
    except:
        print("❌ No GPU detected. Performance will be limited.")
    
    # Default parameters
    optimized_batch_size = 4
    sequence_length = 128
    stability_level = 1
    use_fp16 = False

In [None]:
# Import rest of the libraries
import os
import sys
import json
import time
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from pathlib import Path
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Import JAX/Flax
import jax
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState

# Import Hugging Face libraries
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM

# Add the current directory to path and import our modules
sys.path.append(".")
from utils.pruning import (
    Environment,
    ResultsManager,
    PruningModule, 
    get_strategy,
    FineTuner,
    ImprovedFineTuner,
    PruningFineTuningExperiment
)
from utils.pruning.stability import patch_fine_tuner, optimize_fine_tuner

In [None]:
# Initialize environment and detect capabilities
env = Environment()
env.print_info()

# Check JAX capabilities
print(f"\nJAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

## Run Experiment with Modular Framework

Now we'll use the modular experiment framework to run our pruning and fine-tuning experiments.

Note: We use the `PruningFineTuningExperiment` class imported from `utils.pruning` to ensure 
consistent behavior between this notebook and local testing.

In [None]:
# Initialize experiment with memory optimizations
experiment = PruningFineTuningExperiment(
    results_dir="pruning_finetuning_results",
    use_improved_fine_tuner=True,      # Use the improved fine-tuner with stability enhancements
    detect_environment=True,           # Automatically detect Colab environment
    optimize_memory=True,              # Optimize for T4 GPU memory constraints 
    batch_size=optimized_batch_size,   # Use the optimized batch size
    sequence_length=sequence_length,   # Use the optimized sequence length (matches utils implementation)
    stability_level=stability_level    # Use optimized stability level
)

# Memory optimization information
print(f"\nExperiment configured with optimized parameters:")
print(f"- Batch size: {optimized_batch_size}")
print(f"- Sequence length: {sequence_length}")
print(f"- Stability level: {stability_level}")
print(f"- Mixed precision: {use_fp16}")

# Memory optimizations include:
# - Reduced batch sizes for larger models
# - Shorter sequence lengths for memory efficiency
# - Adaptive sample counts based on model size
# - Conservative synthetic data generation
# These optimizations help prevent OOM errors with larger models

In [None]:
# Configuration
STRATEGIES = ["random", "magnitude", "entropy"]
PRUNING_LEVELS = [0.1, 0.3, 0.5]
PROMPT = "Artificial intelligence will transform society by"
FINE_TUNING_EPOCHS = 2  # Small number for quick iterations
MAX_RUNTIME = 6 * 3600  # 6 hours

# Start the experiment
results = experiment.run_experiment(
    strategies=STRATEGIES,
    pruning_levels=PRUNING_LEVELS,
    prompt=PROMPT,
    fine_tuning_epochs=FINE_TUNING_EPOCHS,
    max_runtime=MAX_RUNTIME
)

## Longer Overnight Run

For an extended overnight run, uncomment and run this cell:

In [None]:
# Overnight Configuration - More conservative settings
OVERNIGHT_STRATEGIES = ["random", "magnitude", "entropy"]
OVERNIGHT_PRUNING_LEVELS = [0.1, 0.2, 0.3, 0.4, 0.5]  # Skip higher levels to avoid memory issues
OVERNIGHT_PROMPT = "Artificial intelligence will revolutionize industries by"
OVERNIGHT_FINE_TUNING_EPOCHS = 3  # Reduced to avoid memory issues with larger models
OVERNIGHT_MAX_RUNTIME = 20 * 3600  # 20 hours

# Initialize experiment for overnight run
overnight_experiment = PruningFineTuningExperiment(
    results_dir="overnight_results",
    use_improved_fine_tuner=True,
    detect_environment=True,
    optimize_memory=True,
    batch_size=1,                      # Smaller batch for longer sequences
    sequence_length=128,               # Longer sequences for better quality
    stability_level=3                  # Maximum stability for overnight runs
)

# Uncomment to run overnight experiment
# overnight_results = overnight_experiment.run_experiment(
#     strategies=OVERNIGHT_STRATEGIES,
#     pruning_levels=OVERNIGHT_PRUNING_LEVELS,
#     prompt=OVERNIGHT_PROMPT,
#     fine_tuning_epochs=OVERNIGHT_FINE_TUNING_EPOCHS,
#     max_runtime=OVERNIGHT_MAX_RUNTIME
# )

## Real-time Experiment Monitoring

The cell below can be executed independently while experiments are running to visualize the current state of experiments.

In [None]:
# This cell can be run at any time to visualize current experiment progress
import os
import glob
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

def visualize_ongoing_experiments(results_dir="pruning_finetuning_results", figsize=(10, 8)):
    """
    Create a real-time visualization of ongoing experiments
    This can be run independently while experiments are in progress
    """
    # Set better default styling for plots to fix layout issues
    plt.rcParams.update({
        'figure.figsize': (10, 8),
        'figure.titlesize': 14,
        'axes.titlesize': 11,
        'axes.labelsize': 10,
        'xtick.labelsize': 9,
        'ytick.labelsize': 9,
        'legend.fontsize': 7,
        'legend.title_fontsize': 8,
        'font.family': 'sans-serif'
    })
    
    # Check if results directory exists
    if not os.path.exists(results_dir):
        print(f"Results directory '{results_dir}' not found")
        return
    
    # List all result files
    result_files = glob.glob(os.path.join(results_dir, "*.json"))
    
    if not result_files:
        print(f"No result files found in '{results_dir}'")
        return
    
    # Load all result files
    results = []
    for file_path in result_files:
        try:
            with open(file_path, 'r') as f:
                result = json.load(f)
                results.append(result)
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
    
    if not results:
        print("No valid result files found")
        return
    
    print(f"Found {len(results)} experiment results")
    
    # Extract data for visualization
    data = []
    
    for result in results:
        # Extract experiment info
        model = result.get("model", "unknown")
        # Shorten model name for better display
        if '/' in model:
            model = model.split('/')[-1]
        strategy = result.get("strategy", "unknown")
        pruning_level = result.get("pruning_level", 0)
        timestamp = result.get("timestamp", "unknown")
        
        # Extract perplexity data from different stages
        stages = result.get("stages", {})
        
        baseline_perplexity = stages.get("baseline", {}).get("perplexity", None)
        pruned_perplexity = stages.get("pruned", {}).get("perplexity", None)
        finetuned_perplexity = stages.get("fine_tuned", {}).get("perplexity", None)
        
        recovery_percentage = stages.get("fine_tuned", {}).get("recovery_percentage", None)
        improvement_percentage = stages.get("fine_tuned", {}).get("improvement_percentage", None)
        
        # Add to dataframe
        if baseline_perplexity is not None:
            data.append({
                "model": model,
                "strategy": strategy,
                "pruning_level": pruning_level,
                "stage": "baseline",
                "perplexity": baseline_perplexity,
                "timestamp": timestamp
            })
        
        if pruned_perplexity is not None:
            data.append({
                "model": model,
                "strategy": strategy,
                "pruning_level": pruning_level,
                "stage": "pruned",
                "perplexity": pruned_perplexity,
                "timestamp": timestamp
            })
        
        if finetuned_perplexity is not None:
            data.append({
                "model": model,
                "strategy": strategy,
                "pruning_level": pruning_level,
                "stage": "fine_tuned",
                "perplexity": finetuned_perplexity,
                "recovery_percentage": recovery_percentage,
                "improvement_percentage": improvement_percentage,
                "timestamp": timestamp
            })
    
    # Convert to dataframe
    df = pd.DataFrame(data)
    
    if df.empty:
        print("No valid data extracted from results")
        return
    
    # Create figure with a more compact layout
    fig = plt.figure(figsize=figsize)
    
    # 1. Perplexity across stages by model and strategy
    plt.subplot(2, 2, 1)
    
    # Get unique values for grouping
    models = df["model"].unique()
    strategies = df["strategy"].unique()
    
    # Filter to main stages
    stages_df = df[df["stage"].isin(["baseline", "pruned", "fine_tuned"])]
    
    # Plot lines connecting stages for each experiment
    for model in models:
        model_df = stages_df[stages_df["model"] == model]
        
        for strategy in strategies:
            strategy_df = model_df[model_df["strategy"] == strategy]
            
            for pruning_level in strategy_df["pruning_level"].unique():
                experiment_df = strategy_df[strategy_df["pruning_level"] == pruning_level]
                
                # Sort by stage
                stage_order = {"baseline": 0, "pruned": 1, "fine_tuned": 2}
                experiment_df = experiment_df.sort_values(by="stage", key=lambda x: x.map(stage_order))
                
                # Only plot if we have at least 2 stages
                if len(experiment_df) >= 2:
                    label = f"{model[:6]}-{strategy[:3]}-{pruning_level:.1f}"
                    plt.plot(experiment_df["stage"], experiment_df["perplexity"], "o-", label=label)
    
    plt.title("Perplexity Across Stages")
    plt.xlabel("Stage")
    plt.ylabel("Perplexity")
    plt.xticks(rotation=45)
    plt.legend(fontsize=7, loc='upper right', ncol=2)
    plt.grid(True, alpha=0.3)
    
    # 2. Pruning level vs perplexity by strategy
    plt.subplot(2, 2, 2)
    
    # Filter to specific stages
    baseline_df = df[df["stage"] == "baseline"]
    pruned_df = df[df["stage"] == "pruned"]
    finetuned_df = df[df["stage"] == "fine_tuned"]
    
    for strategy in strategies:
        # Get strategy data for each stage
        baseline_strategy = baseline_df[baseline_df["strategy"] == strategy]
        pruned_strategy = pruned_df[pruned_df["strategy"] == strategy]
        finetuned_strategy = finetuned_df[finetuned_df["strategy"] == strategy]
        
        # Plot lines for each stage if data exists
        if not baseline_strategy.empty:
            plt.plot(baseline_strategy["pruning_level"], baseline_strategy["perplexity"], 
                    "o--", label=f"Base-{strategy[:3]}", alpha=0.7)
        
        if not pruned_strategy.empty:
            plt.plot(pruned_strategy["pruning_level"], pruned_strategy["perplexity"], 
                    "s--", label=f"Pruned-{strategy[:3]}", alpha=0.7)
        
        if not finetuned_strategy.empty:
            plt.plot(finetuned_strategy["pruning_level"], finetuned_strategy["perplexity"], 
                    "^-", label=f"Tuned-{strategy[:3]}", alpha=0.7)
    
    plt.title("Perplexity vs Pruning Level")
    plt.xlabel("Pruning Level")
    plt.ylabel("Perplexity")
    plt.legend(fontsize=7, loc='best')
    plt.grid(True, alpha=0.3)
    
    # 3. Recovery/improvement percentages
    plt.subplot(2, 2, 3)
    
    # Create dataframe with recovery metrics
    recovery_df = finetuned_df.copy()
    
    if not recovery_df.empty:
        # Create unified recovery column (negative means improvement)
        recovery_df["recovery"] = recovery_df["recovery_percentage"]
        # If recovery is NaN but improvement exists, use negative of improvement
        mask = recovery_df["recovery"].isna() & recovery_df["improvement_percentage"].notna()
        recovery_df.loc[mask, "recovery"] = -recovery_df.loc[mask, "improvement_percentage"]
        
        # Plot by strategy
        for strategy in strategies:
            strategy_recovery = recovery_df[recovery_df["strategy"] == strategy]
            if not strategy_recovery.empty:
                # Sort by pruning level
                strategy_recovery = strategy_recovery.sort_values("pruning_level")
                plt.plot(strategy_recovery["pruning_level"], strategy_recovery["recovery"], 
                        "o-", label=strategy)
        
        plt.axhline(y=0, color="k", linestyle="--", alpha=0.3)
        plt.axhline(y=100, color="g", linestyle="--", alpha=0.3)
        plt.text(0.01, 100, "Full Recovery", color="green", ha="left", va="bottom", fontsize=8)
        plt.text(0.01, -5, "Improvement", color="blue", ha="left", va="top", fontsize=8)
        
        plt.title("Recovery/Improvement %")
        plt.xlabel("Pruning Level")
        plt.ylabel("% (negative = improvement)")
        plt.legend(fontsize=7, loc='best')
        plt.grid(True, alpha=0.3)
    else:
        plt.text(0.5, 0.5, "No recovery data available yet", 
                ha="center", va="center", fontsize=12)
    
    # 4. Status overview
    plt.subplot(2, 2, 4)
    
    # Count experiments by status
    total_exps = len(set([(r["model"], r["strategy"], r["pruning_level"]) for r in results]))
    completed_exps = len(finetuned_df)
    pruned_only = len(set(pruned_df["timestamp"])) - completed_exps
    baseline_only = len(set(baseline_df["timestamp"])) - pruned_only - completed_exps
    
    # Create status labels and counts
    status_labels = ["Completed", "Pruned", "Baseline", "Planned"]
    status_counts = [
        completed_exps,
        pruned_only,
        baseline_only,
        total_exps - completed_exps - pruned_only - baseline_only
    ]
    
    # Create status bar chart
    colors = ["green", "orange", "blue", "gray"]
    plt.bar(status_labels, status_counts, color=colors)
    
    for i, count in enumerate(status_counts):
        plt.text(i, count + 0.1, str(count), ha="center")
    
    plt.title(f"Experiment Status (Total: {total_exps})")
    plt.xlabel("Status")
    plt.ylabel("Count")
    plt.xticks(rotation=45)
    
    # Add timestamp
    plt.figtext(0.5, 0.01, f"Last updated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", 
               ha="center", fontsize=8)
    
    # Apply tight layout to reduce white space
    plt.tight_layout(pad=1.5)
    plt.subplots_adjust(bottom=0.15)
    plt.show()
    
    return df

# Run the visualization - you can run this cell repeatedly to refresh
df = visualize_ongoing_experiments()

# Display success count by strategy if we have data
if df is not None and not df.empty and "fine_tuned" in df["stage"].values:
    finetuned = df[df["stage"] == "fine_tuned"]
    
    # Calculate improvement status
    finetuned["status"] = "No Change"
    finetuned.loc[finetuned["perplexity"] < finetuned["perplexity"], "status"] = "Improved"
    finetuned.loc[finetuned["perplexity"] > finetuned["perplexity"], "status"] = "Degraded"
    
    # Count by strategy and status
    status_by_strategy = pd.crosstab(finetuned["strategy"], finetuned["status"])
    display(status_by_strategy)

## Head Importance Visualization

The cell below can be used to visualize which heads are most important in your model:

In [None]:
# Visualize attention head importance for different pruning strategies
# This can help identify which heads are most critical for model performance

# Initialize the model
model_name = "gpt2"  # Change to one of the models you're using
pruning_module = PruningModule(model_name)
if not pruning_module.load_model():
    print(f"Failed to load model {model_name}")
else:
    # Get original parameters
    original_params = pruning_module.original_params
    
    # Set up a sample prompt
    prompt = "Artificial intelligence will transform society by"
    
    # Calculate importance for different strategies
    strategies = {}
    
    try:
        # Random strategy (baseline)
        random_strategy = get_strategy("random", pruning_module, prompt)
        random_importance = random_strategy.get_head_importance(original_params)
        strategies["random"] = random_importance
        
        # Magnitude strategy
        magnitude_strategy = get_strategy("magnitude", pruning_module, prompt)
        magnitude_importance = magnitude_strategy.get_head_importance(original_params)
        strategies["magnitude"] = magnitude_importance
        
        # Entropy strategy
        entropy_strategy = get_strategy("entropy", pruning_module, prompt)
        entropy_importance = entropy_strategy.get_head_importance(original_params)
        strategies["entropy"] = entropy_importance
        
        # Set better plot styling
        plt.rcParams.update({
            'figure.figsize': (12, 1.5 * pruning_module.num_layers),
            'figure.titlesize': 14,
            'axes.titlesize': 12,
            'axes.labelsize': 10,
            'xtick.labelsize': 9,
            'ytick.labelsize': 9,
            'legend.fontsize': 8,
            'font.family': 'sans-serif'
        })
        
        # Now visualize the head importance scores
        fig, axes = plt.subplots(pruning_module.num_layers, 3, figsize=(12, 1.5 * pruning_module.num_layers))
        
        # Create title
        fig.suptitle(f"Attention Head Importance by Strategy for {model_name}", fontsize=16)
        
        # Set column titles
        for i, strategy_name in enumerate(["random", "magnitude", "entropy"]):
            axes[0, i].set_title(f"{strategy_name.capitalize()} Strategy")
        
        # Create a heatmap for each strategy showing head importance
        for layer in range(pruning_module.num_layers):
            for i, strategy_name in enumerate(["random", "magnitude", "entropy"]):
                # Extract importance scores for this layer
                layer_scores = [score for l, h, score in strategies[strategy_name] if l == layer]
                
                # Create array for visualization
                scores_array = np.array(layer_scores).reshape(1, -1)
                
                # Create heatmap
                cax = axes[layer, i].imshow(scores_array, cmap="viridis", aspect="auto")
                
                # Add labels
                axes[layer, i].set_yticks([0])
                axes[layer, i].set_yticklabels([f"Layer {layer}"])
                axes[layer, i].set_xticks(range(pruning_module.num_heads))
                axes[layer, i].set_xticklabels([f"H{h}" for h in range(pruning_module.num_heads)], 
                                              rotation=90 if pruning_module.num_heads > 8 else 0)
                
                # Add importance values as text
                for h in range(pruning_module.num_heads):
                    score = scores_array[0, h]
                    if np.isnan(score):
                        text_color = "black"
                    else:
                        text_color = "white" if score > 0.5 else "black"
                    axes[layer, i].text(h, 0, f"{score:.2f}", ha="center", va="center", 
                                       color=text_color, fontsize=8)
        
        # Add a colorbar
        fig.colorbar(cax, ax=axes.ravel().tolist(), shrink=0.6)
        
        plt.tight_layout(rect=[0, 0, 1, 0.96])  # Make room for the title
        plt.show()
        
        # Now show the top 10 most and least important heads according to entropy
        # (usually considered the most accurate measure)
        sorted_entropy = sorted(entropy_importance, key=lambda x: x[2])
        
        print("Top 10 Least Important Heads (candidates for pruning):")
        for i, (layer, head, score) in enumerate(sorted_entropy[:10]):
            print(f"{i+1}. Layer {layer}, Head {head}: {score:.4f}")
            
        print("\nTop 10 Most Important Heads (preserved even with aggressive pruning):")
        for i, (layer, head, score) in enumerate(sorted_entropy[-10:]):
            print(f"{i+1}. Layer {layer}, Head {head}: {score:.4f}")
    
    except Exception as e:
        print(f"Error calculating head importance: {e}")
        import traceback
        traceback.print_exc()

## Comprehensive Analysis

After collecting results, run a comprehensive analysis:

In [None]:
# Plot results with improved sizing
experiment.plot_results(figsize=(10, 8))

In [None]:
# Create a summary table
if not experiment.results_df.empty:
    # Get data for different stages
    baseline_df = experiment.results_df[experiment.results_df["stage"] == "baseline"][["model", "strategy", "pruning_level", "perplexity"]]
    baseline_df = baseline_df.rename(columns={"perplexity": "baseline_perplexity"})
    
    pruned_df = experiment.results_df[experiment.results_df["stage"] == "pruned"][["model", "strategy", "pruning_level", "perplexity"]]
    pruned_df = pruned_df.rename(columns={"perplexity": "pruned_perplexity"})
    
    finetuned_df = experiment.results_df[experiment.results_df["stage"] == "fine_tuned"][["model", "strategy", "pruning_level", "perplexity"]]
    finetuned_df = finetuned_df.rename(columns={"perplexity": "finetuned_perplexity"})
    
    # Merge dataframes
    summary = pd.merge(baseline_df, pruned_df, on=["model", "strategy", "pruning_level"])
    summary = pd.merge(summary, finetuned_df, on=["model", "strategy", "pruning_level"])
    
    # Calculate changes
    summary["pruning_effect"] = summary["pruned_perplexity"] - summary["baseline_perplexity"]
    summary["finetuning_effect"] = summary["finetuned_perplexity"] - summary["pruned_perplexity"]
    summary["net_change"] = summary["finetuned_perplexity"] - summary["baseline_perplexity"]
    
    # Display summary
    summary.head()