# Pruning and Fine-Tuning Benchmark (v0.0.3)

This notebook demonstrates how pruning followed by fine-tuning can recover or even improve performance while reducing model size. The experiment runs continuously on Colab and updates visualizations in real-time.

## Version 0.0.3 (April 2025)
- Added ImprovedFineTuner with enhanced stability for large models
- Added automatic handling for OPT model compatibility
- Improved NaN detection and recovery
- Added memory optimization for large models

## Overview

1. **Baseline Evaluation**: Establish the initial model performance
2. **Pruning Phase**: Apply different pruning strategies and evaluate post-pruning performance
3. **Fine-Tuning Phase**: Fine-tune pruned models to recover or improve performance
4. **Analysis**: Compare performance across pruning levels and fine-tuning epochs

This experiment will run until interrupted, continuously improving the models and updating visualizations.

## Setup

First, let's install dependencies and clone the repository:

In [None]:
# Install required packages
!pip install -q jax jaxlib flax transformers datasets matplotlib numpy pandas seaborn tqdm optax

In [None]:
# Clone the repository (if not on a feature, change this back to !git clone https://github.com/CambrianTech/sentinel-ai.git and remove this message)
!git clone -b feature/colab-overnight https://github.com/CambrianTech/sentinel-ai.git
%cd sentinel-ai

In [ ]:
# Import libraries
import os
import sys
import json
import time
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from pathlib import Path
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Import JAX/Flax
import jax
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState

# Import Hugging Face libraries
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
from datasets import load_dataset

# Import our pruning library
from utils.pruning import (
    Environment,
    ResultsManager,
    PruningModule,
    PruningBenchmark,
    FineTuner,
    ImprovedFineTuner
)

# Set up plotting
plt.style.use('ggplot')
sns.set_theme(style="whitegrid")

## Environment Detection

Let's detect our environment capabilities:

In [None]:
# Initialize environment and detect capabilities
env = Environment()
env.print_info()

# Check JAX capabilities
print(f"\nJAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

# Note: We now import FineTuner and ImprovedFineTuner from utils.pruning
# The notebook below shows the usage, but the implementation is now in the module

# FineTuner is used for:
# - Small models (distilgpt2, gpt2, pythia-70m)
# - Simple fine-tuning tasks
# - When memory is not a concern

# ImprovedFineTuner is used for:
# - Large models (opt-1.3b, bloom-560m+)
# - Models prone to NaN loss (OPT family)
# - When stability is critical
# - Fine-tuning with limited resources

# For example:
# fine_tuner = FineTuner(pruning_module, dataset_name="wikitext", batch_size=8)
# OR
# fine_tuner = ImprovedFineTuner(pruning_module, dataset_name="wikitext", batch_size=4)
#
# tuned_params, metrics = fine_tuner.fine_tune(
#     pruned_params, 
#     num_epochs=fine_tuning_epochs,
#     learning_rate=5e-5,
#     evaluate_interval=5
# )

In [None]:
class FineTuner:
    """Fine-tunes a pruned model to recover performance"""
    
    def __init__(self, pruning_module, dataset_name="openwebtext", batch_size=4):
        self.pruning_module = pruning_module
        self.dataset_name = dataset_name
        self.batch_size = batch_size
        self.max_seq_length = 128  # Modest sequence length for faster training
        self.train_state = None
        self.metrics_history = []
        
        # Detect number of devices
        self.devices = jax.devices()
        self.n_devices = len(self.devices)
        if self.n_devices > 1:
            print(f"Using {self.n_devices} devices for training")
            self.batch_size = max(self.batch_size, self.n_devices)
            # Make batch size divisible by device count
            self.batch_size = (self.batch_size // self.n_devices) * self.n_devices
            print(f"Adjusted batch size to {self.batch_size} for multi-device training")
    
    def _prepare_dataset(self):
        """Load and prepare the dataset for fine-tuning"""
        try:
            # Try to load a small portion of the dataset for faster loading
            dataset = load_dataset(self.dataset_name, split="train[:5000]")
            
            # Process dataset
            tokenizer = self.pruning_module.tokenizer
            
            def tokenize_function(examples):
                # Tokenize the texts
                tokenized = tokenizer(examples["text"])
                return tokenized
            
            tokenized_dataset = dataset.map(
                tokenize_function,
                batched=True,
                num_proc=1,
                remove_columns=["text"]
            )
            
            # Create data loader
            def create_batch(samples):
                # Prepare batch of appropriate shape
                batch = {k: np.array(v) for k, v in samples.items()}
                
                # Create 'labels' for the causal language modeling task
                batch["labels"] = batch["input_ids"].copy()
                
                # Get sequence lengths
                seq_lengths = (batch["input_ids"] != tokenizer.pad_token_id).sum(axis=1)
                
                # Loop through samples and pad/truncate as needed
                for i, length in enumerate(seq_lengths):
                    # Ensure we have at least 2 tokens (can't shift with just 1)
                    if length < 2:
                        # Add padding to have at least 2 tokens
                        padding = np.array([tokenizer.pad_token_id] * (2 - length))
                        batch["input_ids"][i] = np.concatenate([batch["input_ids"][i][:length], padding])
                        batch["attention_mask"][i] = np.concatenate([batch["attention_mask"][i][:length], 
                                                                    np.ones_like(padding)])
                        batch["labels"][i] = np.concatenate([batch["labels"][i][:length], padding])
                        seq_lengths[i] = 2
                    
                    # Truncate to max sequence length if needed
                    if length > self.max_seq_length:
                        batch["input_ids"][i] = batch["input_ids"][i][:self.max_seq_length]
                        batch["attention_mask"][i] = batch["attention_mask"][i][:self.max_seq_length]
                        batch["labels"][i] = batch["labels"][i][:self.max_seq_length]
                        seq_lengths[i] = self.max_seq_length
                
                return batch
            
            # Create data loader
            dataloader = tokenized_dataset.batch(self.batch_size)
            dataloader = dataloader.map(create_batch, batched=True)
            
            return dataloader
        
        except Exception as e:
            print(f"Error preparing dataset: {e}")
            print("Falling back to synthetic data for training")
            return self._prepare_synthetic_dataset()
    
    def _prepare_synthetic_dataset(self):
        """Create synthetic data for training when dataset loading fails"""
        tokenizer = self.pruning_module.tokenizer
        
        # Generate random token IDs (avoid special tokens)
        vocab_size = tokenizer.vocab_size
        special_tokens = set([tokenizer.pad_token_id, tokenizer.eos_token_id, 
                             tokenizer.bos_token_id, tokenizer.unk_token_id])
        
        # Create 100 samples of random token sequences
        samples = []
        for _ in range(100):
            # Generate random length between 10 and max_seq_length
            length = np.random.randint(10, self.max_seq_length)
            
            # Generate random token IDs
            token_ids = np.random.randint(0, vocab_size, size=length)
            
            # Replace special tokens with normal tokens
            for i, token_id in enumerate(token_ids):
                if token_id in special_tokens:
                    token_ids[i] = (token_id + 1) % vocab_size
            
            # Create sample
            sample = {
                "input_ids": token_ids,
                "attention_mask": np.ones_like(token_ids),
                "labels": token_ids.copy()
            }
            samples.append(sample)
        
        # Create batches
        batches = []
        for i in range(0, len(samples), self.batch_size):
            batch_samples = samples[i:i+self.batch_size]
            
            # Pad to the same length within batch
            max_len = max(len(s["input_ids"]) for s in batch_samples)
            
            batch = {
                "input_ids": [],
                "attention_mask": [],
                "labels": []
            }
            
            for sample in batch_samples:
                pad_len = max_len - len(sample["input_ids"])
                batch["input_ids"].append(np.pad(sample["input_ids"], (0, pad_len), 
                                                constant_values=tokenizer.pad_token_id))
                batch["attention_mask"].append(np.pad(sample["attention_mask"], (0, pad_len), 
                                                    constant_values=0))
                batch["labels"].append(np.pad(sample["labels"], (0, pad_len), 
                                            constant_values=tokenizer.pad_token_id))
            
            # Convert to arrays
            batch = {k: np.array(v) for k, v in batch.items()}
            batches.append(batch)
        
        return batches
    
    def _create_train_state(self, params, learning_rate=5e-5):
        """Create a training state for the fine-tuning process"""
        # Create optimizer
        optimizer = optax.adam(learning_rate)
        
        # Create train state
        model = self.pruning_module.model
        return TrainState.create(
            apply_fn=model.__call__,
            params=params,
            tx=optimizer
        )
    
    def _loss_fn(self, params, batch):
        """Loss function for the language modeling task"""
        model = self.pruning_module.model
        
        # Get logits from model
        outputs = model(**batch, params=params, train=True)
        logits = outputs.logits
        
        # Get labels and create masks
        labels = batch["labels"]
        
        # Create loss mask (don't compute loss for padding tokens)
        loss_mask = (labels != self.pruning_module.tokenizer.pad_token_id)
        
        # Shift logits and labels for next token prediction
        shift_logits = logits[:, :-1]
        shift_labels = labels[:, 1:]
        shift_mask = loss_mask[:, 1:]
        
        # Calculate cross entropy loss
        loss = optax.softmax_cross_entropy_with_integer_labels(
            shift_logits, shift_labels
        )
        
        # Apply mask and calculate mean
        loss = (loss * shift_mask).sum() / shift_mask.sum()
        
        return loss
    
    def _train_step(self, state, batch):
        """Single training step"""
        grad_fn = jax.value_and_grad(self._loss_fn)
        loss, grads = grad_fn(state.params, batch)
        new_state = state.apply_gradients(grads=grads)
        return new_state, loss
    
    def fine_tune(self, pruned_params, num_epochs=1, learning_rate=5e-5, evaluate_interval=5):
        """Fine-tune the pruned model"""
        print(f"\nFine-tuning model with {self.dataset_name} dataset for {num_epochs} epochs...")
        
        # Prepare dataset
        dataset = self._prepare_dataset()
        
        # Create training state
        self.train_state = self._create_train_state(pruned_params, learning_rate)
        self.metrics_history = []
        
        # Training loop
        total_steps = 0
        perplexity_history = []
        
        for epoch in range(num_epochs):
            # Shuffled dataset for each epoch (if it's a list of batches)
            if isinstance(dataset, list):
                np.random.shuffle(dataset)
                epoch_dataset = dataset
            else:
                # If it's a datasets.Dataset, shuffle
                epoch_dataset = dataset.shuffle()
            
            # Create progress bar
            epoch_desc = f"Epoch {epoch+1}/{num_epochs}"
            batch_count = len(epoch_dataset) if hasattr(epoch_dataset, "__len__") else "?"
            progress_bar = tqdm(enumerate(epoch_dataset), desc=epoch_desc, 
                               total=batch_count if batch_count != "?" else None)
            
            epoch_losses = []
            
            for step, batch in progress_bar:
                # Train step
                self.train_state, loss = self._train_step(self.train_state, batch)
                total_steps += 1
                epoch_losses.append(loss.item())
                
                # Update progress bar
                progress_bar.set_description(f"{epoch_desc} - Loss: {loss.item():.4f}")
                
                # Evaluate periodically
                if total_steps % evaluate_interval == 0:
                    # Generate dummy text to check progress
                    prompt = "Artificial intelligence will transform"
                    try:
                        generated = self.pruning_module.generate_text(
                            self.train_state.params, prompt, max_length=30
                        )
                        perplexity = self.pruning_module.evaluate_perplexity(
                            self.train_state.params, prompt
                        )
                        perplexity_history.append((total_steps, perplexity))
                        print(f"\nStep {total_steps} - Perplexity: {perplexity:.4f}")
                        print(f"Generated: {generated}")
                    except Exception as e:
                        print(f"Error evaluating model: {e}")
            
            # End of epoch metrics
            epoch_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0
            print(f"\nEpoch {epoch+1} completed. Average loss: {epoch_loss:.4f}")
            
            self.metrics_history.append({
                "epoch": epoch + 1,
                "loss": epoch_loss,
                "perplexity_history": perplexity_history
            })
        
        print("\nFine-tuning completed!")
        return self.train_state.params, self.metrics_history
    
    def plot_training_progress(self, figsize=(12, 6)):
        """Plot training progress"""
        if not self.metrics_history:
            print("No training metrics available yet")
            return
        
        # Extract epoch losses
        epochs = [m["epoch"] for m in self.metrics_history]
        losses = [m["loss"] for m in self.metrics_history]
        
        # Extract perplexity history
        steps = []
        perplexities = []
        for m in self.metrics_history:
            for step, perplexity in m.get("perplexity_history", []):
                steps.append(step)
                perplexities.append(perplexity)
        
        # Create figure
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
        
        # Plot losses
        ax1.plot(epochs, losses, "o-", color="blue")
        ax1.set_xlabel("Epoch")
        ax1.set_ylabel("Loss")
        ax1.set_title("Training Loss")
        ax1.grid(True, linestyle="--", alpha=0.7)
        
        # Plot perplexities
        if steps and perplexities:
            ax2.plot(steps, perplexities, "o-", color="green")
            ax2.set_xlabel("Step")
            ax2.set_ylabel("Perplexity")
            ax2.set_title("Perplexity During Training")
            ax2.grid(True, linestyle="--", alpha=0.7)
        else:
            ax2.text(0.5, 0.5, "No perplexity data available",
                    ha="center", va="center", fontsize=12)
        
        plt.tight_layout()
        plt.show()
        
        return fig

## Experiment Manager

Let's create an experiment manager to run the full experiment:

In [ ]:
class PruningFineTuningExperiment:
    """Manages the pruning + fine-tuning experiment"""
    
    def __init__(self, results_dir="pruning_finetuning_results"):
        self.results_dir = Path(results_dir)
        self.results_dir.mkdir(exist_ok=True, parents=True)
        self.results = []
        self.current_experiment = {}
        
        # Initialize environment
        self.env = Environment()
        
        # Get suitable models for this environment
        self.available_models = self.env.get_suitable_models()
        print(f"Models available: {', '.join(self.available_models)}")
        
        # Setup Results Manager
        self.results_manager = ResultsManager(str(self.results_dir))
        self.results_df = pd.DataFrame()
    
    def run_experiment(self, strategies, pruning_levels, prompt, fine_tuning_epochs=1, max_runtime=3600):
        """Run the full experiment"""
        if not self.available_models:
            print("No suitable models found for this environment")
            return
        
        # Start time for runtime tracking
        start_time = time.time()
        
        # Generate all experiment combinations
        experiments = []
        for model in self.available_models:
            for strategy in strategies:
                for level in pruning_levels:
                    experiments.append({
                        "model": model,
                        "strategy": strategy,
                        "pruning_level": level,
                        "prompt": prompt,
                        "fine_tuning_epochs": fine_tuning_epochs
                    })
        
        # Shuffle to get more diverse results early
        random.shuffle(experiments)
        
        # Create progress bar
        pbar = tqdm(total=len(experiments), desc="Running experiments")
        
        # Run experiments
        for i, exp in enumerate(experiments):
            # Check if we've exceeded the runtime limit
            current_runtime = time.time() - start_time
            if max_runtime is not None and current_runtime > max_runtime:
                print(f"\nReached maximum runtime of {max_runtime/3600:.1f} hours")
                break
                
            # Update progress bar
            pbar.set_description(f"Testing {exp['model']}, {exp['strategy']}, {exp['pruning_level']:.2f}")
            
            # Run experiment
            try:
                result = self.run_single_experiment(**exp)
                if result is not None:
                    self.results.append(result)
                
                # Update progress bar
                pbar.update(1)
                
                # Plot intermediate results every few experiments
                if (i + 1) % 1 == 0 or i == len(experiments) - 1:
                    self.plot_results()
            except Exception as e:
                print(f"Error in experiment {exp['model']}, {exp['strategy']}, {exp['pruning_level']:.2f}: {e}")
                import traceback
                traceback.print_exc()
                # Still update progress bar
                pbar.update(1)
        
        # Close progress bar
        pbar.close()
        
        # Final results
        print(f"\nCompleted {len(self.results)} experiments out of {len(experiments)} attempted")
        runtime = time.time() - start_time
        print(f"Total runtime: {runtime/3600:.2f} hours ({runtime/60:.2f} minutes)")
        
        # Plot final results
        self.plot_results()
        
        return self.results
    
    def run_single_experiment(self, model, strategy, pruning_level, prompt, fine_tuning_epochs=1):
        """Run a single experiment with pruning and fine-tuning"""
        print(f"\n{'='*80}")
        print(f"Experiment: {model}, {strategy} strategy, {pruning_level:.2f} pruning level")
        print(f"{'='*80}")
        
        # Initialize pruning module
        pruning_module = PruningModule(model)
        if not pruning_module.load_model():
            print(f"Failed to load model {model}")
            return None
        
        # Setup experiment record
        self.current_experiment = {
            "model": model,
            "strategy": strategy,
            "pruning_level": pruning_level,
            "prompt": prompt,
            "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "stages": {}
        }
        
        # 1. Evaluate baseline model
        print("\n>> Stage 1: Evaluating baseline model")
        original_params = pruning_module.original_params
        
        # Evaluate perplexity and generation
        perplexity_baseline = pruning_module.evaluate_perplexity(original_params, prompt)
        print(f"Baseline perplexity: {perplexity_baseline:.4f}")
        
        generated_baseline = pruning_module.generate_text(original_params, prompt)
        print(f"Baseline generated: {generated_baseline}")
        
        # Record baseline results
        self.current_experiment["stages"]["baseline"] = {
            "perplexity": float(perplexity_baseline),
            "generated_text": generated_baseline
        }
        
        # 2. Apply pruning
        print("\n>> Stage 2: Applying pruning")
        from utils.pruning.strategies import get_strategy
        pruning_strat = get_strategy(strategy, pruning_module, prompt)
        
        # Calculate importance scores
        print("Calculating head importance...")
        all_head_importance = pruning_strat.get_head_importance(original_params)
        
        # Sort by importance (ascending)
        all_head_importance.sort(key=lambda x: x[2])
        
        # Determine number of heads to prune
        total_heads = pruning_module.num_layers * pruning_module.num_heads
        heads_to_prune = int(total_heads * pruning_level)
        print(f"Pruning {heads_to_prune} out of {total_heads} heads")
        
        # Get head indices to prune (least important first)
        head_indices = [(l, h) for l, h, _ in all_head_importance[:heads_to_prune]]
        
        # Prune heads
        print("Pruning heads...")
        pruned_params = pruning_strat.prune_heads(original_params, head_indices)
        
        # Evaluate after pruning
        perplexity_pruned = pruning_module.evaluate_perplexity(pruned_params, prompt)
        print(f"Pruned perplexity: {perplexity_pruned:.4f}")
        
        generated_pruned = pruning_module.generate_text(pruned_params, prompt)
        print(f"Pruned generated: {generated_pruned}")
        
        # Record pruning results
        self.current_experiment["stages"]["pruned"] = {
            "perplexity": float(perplexity_pruned),
            "perplexity_change": float(perplexity_pruned - perplexity_baseline),
            "generated_text": generated_pruned,
            "pruned_heads": heads_to_prune,
            "total_heads": total_heads,
            "head_indices": head_indices
        }
        
        # 3. Fine-tune the pruned model
        print("\n>> Stage 3: Fine-tuning the pruned model")
        
        # Create fine-tuner with dataset config
        dataset_name = "wikitext"
        dataset_config = "wikitext-2-v1"
        
        # Determine batch size based on environment
        if self.env.in_colab and self.env.has_tpu:
            # TPUs can handle larger batch sizes
            batch_size = 16
        elif self.env.in_colab and self.env.has_gpu:
            batch_size = 8
        else:
            batch_size = 4
        
        # Check if model name indicates this might be a large model
        model_name = model.lower()
        use_improved_tuner = any(x in model_name for x in ['opt', 'large', '1.3b', 'bloom'])
        
        # Select fine-tuner based on model size/type
        if use_improved_tuner:
            print(f"Using ImprovedFineTuner for model {model} to enhance stability")
            fine_tuner = ImprovedFineTuner(
                pruning_module, 
                dataset_name=dataset_name,
                dataset_config=dataset_config,
                batch_size=batch_size
            )
            # Use lower learning rate for large models
            learning_rate = 1e-5
        else:
            print(f"Using standard FineTuner for model {model}")
            fine_tuner = FineTuner(
                pruning_module, 
                dataset_name=dataset_name,
                batch_size=batch_size
            )
            learning_rate = 5e-5
        
        # Fine-tune model
        try:
            tuned_params, metrics = fine_tuner.fine_tune(
                pruned_params, 
                num_epochs=fine_tuning_epochs,
                learning_rate=learning_rate,
                evaluate_interval=5
            )
        except Exception as e:
            print(f"Error during fine-tuning: {e}")
            # If standard tuner fails, fall back to improved tuner
            if not use_improved_tuner:
                print("Falling back to ImprovedFineTuner after error")
                fine_tuner = ImprovedFineTuner(
                    pruning_module, 
                    dataset_name=dataset_name,
                    dataset_config=dataset_config,
                    batch_size=max(1, batch_size // 2)  # Reduce batch size
                )
                tuned_params, metrics = fine_tuner.fine_tune(
                    pruned_params,
                    num_epochs=fine_tuning_epochs,
                    learning_rate=1e-5,  # Lower learning rate for stability
                    evaluate_interval=5
                )
        
        # Plot training progress
        fine_tuner.plot_training_progress()
        
        # Evaluate fine-tuned model
        perplexity_tuned = pruning_module.evaluate_perplexity(tuned_params, prompt)
        print(f"Fine-tuned perplexity: {perplexity_tuned:.4f}")
        
        generated_tuned = pruning_module.generate_text(tuned_params, prompt)
        print(f"Fine-tuned generated: {generated_tuned}")
        
        # Record fine-tuning results
        self.current_experiment["stages"]["fine_tuned"] = {
            "perplexity": float(perplexity_tuned),
            "perplexity_change_from_baseline": float(perplexity_tuned - perplexity_baseline),
            "perplexity_change_from_pruned": float(perplexity_tuned - perplexity_pruned),
            "generated_text": generated_tuned,
            "training_epochs": fine_tuning_epochs,
            "training_metrics": metrics
        }
        
        # Compute recovery percentage
        if perplexity_pruned > perplexity_baseline:
            # Calculate how much of the perplexity increase was recovered
            perplexity_increase = perplexity_pruned - perplexity_baseline
            perplexity_recovery = perplexity_pruned - perplexity_tuned
            recovery_percentage = (perplexity_recovery / perplexity_increase) * 100 if perplexity_increase > 0 else 0
            
            self.current_experiment["stages"]["fine_tuned"]["recovery_percentage"] = float(recovery_percentage)
            print(f"Recovery percentage: {recovery_percentage:.2f}%")
        else:
            # Pruning improved perplexity, so we measure improvement from baseline
            improvement_percentage = ((perplexity_baseline - perplexity_tuned) / perplexity_baseline) * 100
            
            self.current_experiment["stages"]["fine_tuned"]["improvement_percentage"] = float(improvement_percentage)
            print(f"Improvement percentage: {improvement_percentage:.2f}%")
        
        # 4. Save results
        print("\n>> Stage 4: Saving results")
        
        # Save to disk
        timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
        result_filename = f"{model.replace('/', '_')}_{strategy}_{pruning_level:.2f}_{timestamp}.json"
        result_path = self.results_dir / result_filename
        
        import json
        with open(result_path, "w") as f:
            json.dump(self.current_experiment, f, indent=2)
            
        print(f"Results saved to {result_path}")
        
        # Update DataFrame for plotting
        self._update_dataframe()
        
        return self.current_experiment
    
    # [rest of the class remains the same]

## Run the Experiment

Now we can run the full experiment:

In [None]:
# Initialize experiment
experiment = PruningFineTuningExperiment("pruning_finetuning_results")

In [None]:
# Configuration
STRATEGIES = ["random", "magnitude", "entropy"]
PRUNING_LEVELS = [0.1, 0.3, 0.5]
PROMPT = "Artificial intelligence will transform society by"
FINE_TUNING_EPOCHS = 2  # Small number for quick iterations
MAX_RUNTIME = 6 * 3600  # 6 hours

# Start the experiment
results = experiment.run_experiment(
    strategies=STRATEGIES,
    pruning_levels=PRUNING_LEVELS,
    prompt=PROMPT,
    fine_tuning_epochs=FINE_TUNING_EPOCHS,
    max_runtime=MAX_RUNTIME
)

## Longer Overnight Run

For an extended overnight run, uncomment and run this cell:

In [None]:
# Overnight Configuration
OVERNIGHT_STRATEGIES = ["random", "magnitude", "entropy"]
OVERNIGHT_PRUNING_LEVELS = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
OVERNIGHT_PROMPT = "Artificial intelligence will revolutionize industries by"
OVERNIGHT_FINE_TUNING_EPOCHS = 5  # More epochs for better recovery
OVERNIGHT_MAX_RUNTIME = 24 * 3600  # 24 hours

# Initialize experiment for overnight run
overnight_experiment = PruningFineTuningExperiment("overnight_results")

# Run overnight experiment (uncomment to run)
# overnight_results = overnight_experiment.run_experiment(
#     strategies=OVERNIGHT_STRATEGIES,
#     pruning_levels=OVERNIGHT_PRUNING_LEVELS,
#     prompt=OVERNIGHT_PROMPT,
#     fine_tuning_epochs=OVERNIGHT_FINE_TUNING_EPOCHS,
#     max_runtime=OVERNIGHT_MAX_RUNTIME
# )

## Comprehensive Analysis

After collecting results, run a comprehensive analysis:

In [None]:
# Plot results
experiment.plot_results(figsize=(16, 12))

In [None]:
# Additional analysis: Compare models
if not experiment.results_df.empty:
    finetuned_results = experiment.results_df[experiment.results_df["stage"] == "post_finetuning"]
    
    plt.figure(figsize=(10, 8))
    for model in finetuned_results["model"].unique():
        model_data = finetuned_results[finetuned_results["model"] == model]
        
        for strategy in model_data["strategy"].unique():
            strategy_data = model_data[model_data["strategy"] == strategy]
            strategy_data = strategy_data.sort_values("pruning_level")
            
            plt.plot(strategy_data["pruning_level"], strategy_data["perplexity"],
                    marker="o", label=f"{model} - {strategy}")
    
    plt.title("Final Perplexity After Fine-tuning")
    plt.xlabel("Pruning Level")
    plt.ylabel("Perplexity")
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
# Create a summary table
if not experiment.results_df.empty:
    # Get data for different stages
    initial = experiment.results_df[experiment.results_df["stage"] == "initial"][["model", "strategy", "pruning_level", "perplexity"]]
    initial = initial.rename(columns={"perplexity": "initial_perplexity"})
    
    pruned = experiment.results_df[experiment.results_df["stage"] == "post_pruning"][["model", "strategy", "pruning_level", "perplexity"]]
    pruned = pruned.rename(columns={"perplexity": "pruned_perplexity"})
    
    finetuned = experiment.results_df[experiment.results_df["stage"] == "post_finetuning"][["model", "strategy", "pruning_level", "perplexity"]]
    finetuned = finetuned.rename(columns={"perplexity": "finetuned_perplexity"})
    
    # Merge dataframes
    summary = pd.merge(initial, pruned, on=["model", "strategy", "pruning_level"])
    summary = pd.merge(summary, finetuned, on=["model", "strategy", "pruning_level"])
    
    # Calculate changes
    summary["pruning_effect"] = summary["pruned_perplexity"] - summary["initial_perplexity"]
    summary["finetuning_effect"] = summary["finetuned_perplexity"] - summary["pruned_perplexity"]
    summary["net_change"] = summary["finetuned_perplexity"] - summary["initial_perplexity"]
    
    # Display summary
    summary