# Pruning and Fine-Tuning Benchmark for Google Colab (v1.0.0)

This is the Python script version of our notebook for Google Colab.

Instructions:
1. Upload to a new Colab notebook using File > Upload notebook > Upload
2. Runtime > Change runtime type > Select GPU or TPU hardware accelerator
3. Run cells to execute pruning and fine-tuning experiments

## Overview

1. **Baseline Evaluation**: Establish the initial model performance
2. **Pruning Phase**: Apply different pruning strategies and evaluate post-pruning performance
3. **Fine-Tuning Phase**: Fine-tune pruned models to recover or improve performance
4. **Analysis**: Compare performance across pruning levels and fine-tuning epochs

This experiment will run until interrupted, continuously improving the models and updating visualizations.

## Setup

First, let's install dependencies and clone the repository:

In [ ]:
# Install required packages and make sure HuggingFace datasets is properly installed
!pip install -q jax jaxlib flax transformers matplotlib numpy pandas seaborn tqdm optax
!pip install -q 'datasets>=2.0.0' multiprocess

In [ ]:
# Clone the repository (version 1.0.0) but make sure it's not in the Python path yet
!git clone -b main https://github.com/CambrianTech/sentinel-ai.git
# Don't cd into it yet

In [ ]:
# Import huggingface datasets directly before changing directory
# We want to make sure we're using the system package
from datasets import load_dataset
import datasets
print(f"Using datasets from: {datasets.__file__}")

# Now safely change to the repository directory
%cd sentinel-ai

# Import rest of the libraries
import os
import sys
import json
import time
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from pathlib import Path
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Import JAX/Flax
import jax
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState

# Import Hugging Face libraries
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM

# Add the current directory to path and import our modules
sys.path.append(".")
from utils.pruning import (
    Environment,
    ResultsManager,
    PruningModule, 
    get_strategy
)

# Set up plotting
plt.style.use('ggplot')
sns.set_theme(style="whitegrid")

In [None]:
# Import libraries
import os
import json
import time
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from pathlib import Path
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Import JAX/Flax
import jax
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState

# Import Hugging Face libraries
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM

# Import our pruning library
from utils.pruning import (
    Environment,
    ResultsManager,
    PruningModule, 
    get_strategy
)

# Set up plotting
plt.style.use('ggplot')
sns.set_theme(style="whitegrid")

## Environment Detection

Let's detect our environment capabilities:

In [None]:
# Initialize environment and detect capabilities
env = Environment()
env.print_info()

# Check JAX capabilities
print(f"\nJAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

## Fine-Tuning Implementation

First, let's implement the fine-tuning functionality:

In [ ]:
class FineTuner:
    """Fine-tunes a pruned model to recover performance"""
    
    def __init__(self, pruning_module, dataset_name="openwebtext", dataset_config=None, batch_size=4):
        self.pruning_module = pruning_module
        self.dataset_name = dataset_name
        self.dataset_config = dataset_config
        self.batch_size = batch_size
        self.max_seq_length = 128  # Modest sequence length for faster training
        self.train_state = None
        self.metrics_history = []
        
        # Detect number of devices
        self.devices = jax.devices()
        self.n_devices = len(self.devices)
        if self.n_devices > 1:
            print(f"Using {self.n_devices} devices for training")
            self.batch_size = max(self.batch_size, self.n_devices)
            # Make batch size divisible by device count
            self.batch_size = (self.batch_size // self.n_devices) * self.n_devices
            print(f"Adjusted batch size to {self.batch_size} for multi-device training")
    
    def _prepare_dataset(self):
        """Load and prepare the dataset for fine-tuning"""
        try:
            # Try to load a small portion of the dataset for faster loading
            if self.dataset_config:
                print(f"Loading dataset {self.dataset_name} with config {self.dataset_config}")
                dataset = load_dataset(self.dataset_name, self.dataset_config, split="train[:5000]")
            else:
                print(f"Loading dataset {self.dataset_name}")
                dataset = load_dataset(self.dataset_name, split="train[:5000]")
                
            print(f"Dataset loaded: {len(dataset)} examples")
            
            # Process dataset
            tokenizer = self.pruning_module.tokenizer
            
            # Ensure tokenizer has pad_token
            if tokenizer.pad_token is None:
                if tokenizer.eos_token is not None:
                    tokenizer.pad_token = tokenizer.eos_token
                else:
                    tokenizer.pad_token = tokenizer.eos_token = "[PAD]"
                print(f"Set pad_token to {tokenizer.pad_token}")
            
            def tokenize_function(examples):
                # Tokenize the texts
                if "text" in examples:
                    texts = examples["text"]
                else:
                    # Try to find text field (wikitext has different format)
                    keys = examples.keys()
                    text_key = next((k for k in keys if "text" in k.lower()), None)
                    if text_key:
                        texts = examples[text_key]
                    else:
                        # If no text field found, concatenate all string fields
                        texts = []
                        for i in range(len(examples[next(iter(keys))])):
                            example_text = " ".join(str(examples[k][i]) for k in keys 
                                                if isinstance(examples[k][i], str))
                            texts.append(example_text)
                
                tokenized = tokenizer(
                    texts, 
                    padding='max_length',
                    truncation=True,
                    max_length=self.max_seq_length,
                    return_tensors="np"
                )
                return tokenized
            
            # Remove columns that aren't strings
            columns_to_remove = []
            for col in dataset.column_names:
                if isinstance(dataset[0][col], (int, float, bool)) or dataset[0][col] is None:
                    continue
                columns_to_remove.append(col)
            
            tokenized_dataset = dataset.map(
                tokenize_function,
                batched=True,
                num_proc=1,
                remove_columns=columns_to_remove
            )
            
            # Create data loader function to process batches
            def process_batch(batch):
                # Prepare batch with consistent shape
                input_ids = np.array(batch["input_ids"])
                attention_mask = np.array(batch["attention_mask"])
                
                # Create labels for causal language modeling (shifted version of input_ids)
                labels = input_ids.copy()
                
                return {
                    "input_ids": input_ids,
                    "attention_mask": attention_mask,
                    "labels": labels
                }
            
            # Batch the dataset
            batched_dataset = []
            for i in range(0, len(tokenized_dataset), self.batch_size):
                end_idx = min(i + self.batch_size, len(tokenized_dataset))
                batch = tokenized_dataset[i:end_idx]
                processed_batch = process_batch(batch)
                batched_dataset.append(processed_batch)
            
            return batched_dataset
            
        except Exception as e:
            print(f"Error preparing dataset: {e}")
            print("Falling back to synthetic data for training")
            import traceback
            traceback.print_exc()
            return self._prepare_synthetic_dataset()
    
    def _prepare_synthetic_dataset(self):
        """Create synthetic data for training when dataset loading fails"""
        tokenizer = self.pruning_module.tokenizer
        
        # Ensure tokenizer has pad_token
        if tokenizer.pad_token is None:
            if tokenizer.eos_token is not None:
                tokenizer.pad_token = tokenizer.eos_token
            else:
                tokenizer.pad_token = tokenizer.eos_token = "[PAD]"
            print(f"Set pad_token to {tokenizer.pad_token}")
        
        # Generate random token IDs (avoid special tokens)
        vocab_size = tokenizer.vocab_size
        
        # Get special token IDs (safely)
        special_tokens = set()
        for token_name in ['pad_token_id', 'eos_token_id', 'bos_token_id', 'unk_token_id']:
            token_id = getattr(tokenizer, token_name, None)
            if token_id is not None:
                special_tokens.add(token_id)
        
        print(f"Creating synthetic dataset with vocab_size={vocab_size}, special_tokens={special_tokens}")
        
        # Create 100 samples of random token sequences
        samples = []
        for _ in range(100):
            # Generate random length between 10 and max_seq_length
            length = np.random.randint(10, self.max_seq_length)
            
            # Generate random token IDs
            token_ids = np.random.randint(0, vocab_size, size=length)
            
            # Replace special tokens with normal tokens
            for i, token_id in enumerate(token_ids):
                if token_id in special_tokens:
                    token_ids[i] = (token_id + 1) % vocab_size
                    # Make sure we're not just cycling through special tokens
                    while token_ids[i] in special_tokens:
                        token_ids[i] = (token_ids[i] + 1) % vocab_size
            
            # Create sample
            sample = {
                "input_ids": token_ids,
                "attention_mask": np.ones_like(token_ids),
                "labels": token_ids.copy()
            }
            samples.append(sample)
        
        # Create batches
        batches = []
        for i in range(0, len(samples), self.batch_size):
            batch_samples = samples[i:i+self.batch_size]
            
            # Pad to the same length within batch
            max_len = max(len(s["input_ids"]) for s in batch_samples)
            max_len = min(max_len, self.max_seq_length)  # Ensure we don't exceed max_seq_length
            
            batch = {
                "input_ids": [],
                "attention_mask": [],
                "labels": []
            }
            
            for sample in batch_samples:
                # Truncate if needed
                if len(sample["input_ids"]) > max_len:
                    input_ids = sample["input_ids"][:max_len]
                    attention_mask = sample["attention_mask"][:max_len]
                    labels = sample["labels"][:max_len]
                else:
                    input_ids = sample["input_ids"]
                    attention_mask = sample["attention_mask"]
                    labels = sample["labels"]
                
                # Pad if needed
                pad_len = max_len - len(input_ids)
                if pad_len > 0:
                    input_ids = np.pad(input_ids, (0, pad_len), 
                                      constant_values=tokenizer.pad_token_id)
                    attention_mask = np.pad(attention_mask, (0, pad_len), 
                                          constant_values=0)
                    labels = np.pad(labels, (0, pad_len), 
                                   constant_values=tokenizer.pad_token_id)
                
                batch["input_ids"].append(input_ids)
                batch["attention_mask"].append(attention_mask)
                batch["labels"].append(labels)
            
            # Convert to arrays
            batch = {
                "input_ids": np.array(batch["input_ids"]),
                "attention_mask": np.array(batch["attention_mask"]),
                "labels": np.array(batch["labels"])
            }
            batches.append(batch)
        
        print(f"Created {len(batches)} synthetic batches")
        return batches
    
    def _create_train_state(self, params, learning_rate=5e-5):
        """Create a training state for the fine-tuning process"""
        # Create optimizer
        optimizer = optax.adam(learning_rate)
        
        # Create train state
        model = self.pruning_module.model
        return TrainState.create(
            apply_fn=model.__call__,
            params=params,
            tx=optimizer
        )
    
    def _loss_fn(self, params, batch):
        """Loss function for the language modeling task"""
        model = self.pruning_module.model
        
        # Extract labels from batch but don't pass them to the model
        labels = batch.pop("labels", None)
        
        # Handle different model architectures
        try:
            # Get logits from model - don't pass 'train' param for OPT models
            # Check if model name contains 'opt' to detect OPT models
            is_opt_model = 'opt' in self.pruning_module.model_name.lower()
            
            if is_opt_model:
                # OPT models don't accept 'train' parameter
                outputs = model(**batch, params=params)
            else:
                # Other models like GPT-2 might need the 'train' parameter
                outputs = model(**batch, params=params, train=True)
                
            logits = outputs.logits
            
            # Add labels back to batch for next iteration
            batch["labels"] = labels
            
            # Create loss mask (don't compute loss for padding tokens)
            loss_mask = (labels != self.pruning_module.tokenizer.pad_token_id)
            
            # Shift logits and labels for next token prediction
            shift_logits = logits[:, :-1]
            shift_labels = labels[:, 1:]
            shift_mask = loss_mask[:, 1:]
            
            # Calculate cross entropy loss
            loss = optax.softmax_cross_entropy_with_integer_labels(
                shift_logits, shift_labels
            )
            
            # Apply mask and calculate mean
            loss = (loss * shift_mask).sum() / shift_mask.sum()
            
            return loss
        except Exception as e:
            print(f"Model inference error: {e}")
            # Add labels back to batch
            batch["labels"] = labels
            raise
    
    def _train_step(self, state, batch):
        """Single training step"""
        grad_fn = jax.value_and_grad(self._loss_fn)
        loss, grads = grad_fn(state.params, batch)
        new_state = state.apply_gradients(grads=grads)
        return new_state, loss
    
    def fine_tune(self, pruned_params, num_epochs=1, learning_rate=5e-5, evaluate_interval=5):
        """Fine-tune the pruned model"""
        print(f"\nFine-tuning model with {self.dataset_name} dataset for {num_epochs} epochs...")
        
        # Prepare dataset
        dataset = self._prepare_dataset()
        
        # Create training state
        self.train_state = self._create_train_state(pruned_params, learning_rate)
        self.metrics_history = []
        
        # Training loop
        total_steps = 0
        perplexity_history = []
        
        for epoch in range(num_epochs):
            # Shuffled dataset for each epoch (if it's a list of batches)
            if isinstance(dataset, list):
                np.random.shuffle(dataset)
                epoch_dataset = dataset
            else:
                # If it's a datasets.Dataset, shuffle
                epoch_dataset = dataset.shuffle()
            
            # Create progress bar
            epoch_desc = f"Epoch {epoch+1}/{num_epochs}"
            batch_count = len(epoch_dataset) if hasattr(epoch_dataset, "__len__") else "?"
            progress_bar = tqdm(enumerate(epoch_dataset), desc=epoch_desc, 
                               total=batch_count if batch_count != "?" else None)
            
            epoch_losses = []
            
            for step, batch in progress_bar:
                # Train step
                try:
                    self.train_state, loss = self._train_step(self.train_state, batch)
                    total_steps += 1
                    epoch_losses.append(loss.item())
                    
                    # Update progress bar
                    progress_bar.set_description(f"{epoch_desc} - Loss: {loss.item():.4f}")
                    
                    # Evaluate periodically
                    if total_steps % evaluate_interval == 0:
                        # Generate dummy text to check progress
                        prompt = "Artificial intelligence will transform"
                        try:
                            generated = self.pruning_module.generate_text(
                                self.train_state.params, prompt, max_length=30
                            )
                            perplexity = self.pruning_module.evaluate_perplexity(
                                self.train_state.params, prompt
                            )
                            perplexity_history.append((total_steps, perplexity))
                            print(f"\nStep {total_steps} - Perplexity: {perplexity:.4f}")
                            print(f"Generated: {generated}")
                        except Exception as e:
                            print(f"Error evaluating model: {e}")
                except Exception as e:
                    print(f"Error in training step: {e}")
                    # Continue to next batch
                    continue
            
            # End of epoch metrics
            epoch_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0
            print(f"\nEpoch {epoch+1} completed. Average loss: {epoch_loss:.4f}")
            
            self.metrics_history.append({
                "epoch": epoch + 1,
                "loss": epoch_loss,
                "perplexity_history": perplexity_history
            })
        
        print("\nFine-tuning completed!")
        return self.train_state.params, self.metrics_history
    
    def plot_training_progress(self, figsize=(10, 5)):
        """Plot training progress"""
        if not self.metrics_history:
            print("No training metrics available yet")
            return
        
        # Set better plot styling
        plt.rcParams.update({
            'figure.figsize': figsize,
            'figure.titlesize': 14,
            'axes.titlesize': 12,
            'axes.labelsize': 11,
            'xtick.labelsize': 10,
            'ytick.labelsize': 10,
            'legend.fontsize': 9,
            'font.family': 'sans-serif'
        })
        
        # Extract epoch losses
        epochs = [m["epoch"] for m in self.metrics_history]
        losses = [m["loss"] for m in self.metrics_history]
        
        # Extract perplexity history
        steps = []
        perplexities = []
        for m in self.metrics_history:
            for step, perplexity in m.get("perplexity_history", []):
                steps.append(step)
                perplexities.append(perplexity)
        
        # Create figure
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
        
        # Plot losses
        ax1.plot(epochs, losses, "o-", color="blue")
        ax1.set_xlabel("Epoch")
        ax1.set_ylabel("Loss")
        ax1.set_title("Training Loss")
        ax1.grid(True, linestyle="--", alpha=0.7)
        
        # Plot perplexities
        if steps and perplexities:
            ax2.plot(steps, perplexities, "o-", color="green")
            ax2.set_xlabel("Step")
            ax2.set_ylabel("Perplexity")
            ax2.set_title("Perplexity During Training")
            ax2.grid(True, linestyle="--", alpha=0.7)
        else:
            ax2.text(0.5, 0.5, "No perplexity data available",
                    ha="center", va="center", fontsize=12)
        
        plt.tight_layout()
        plt.show()
        
        return fig

## Experiment Manager

Let's create an experiment manager to run the full experiment:

In [ ]:
class PruningFineTuningExperiment:
    """Manages the pruning + fine-tuning experiment"""
    
    def __init__(self, results_dir="pruning_finetuning_results"):
        self.results_dir = Path(results_dir)
        self.results_dir.mkdir(exist_ok=True, parents=True)
        self.results = []
        self.current_experiment = {}
        
        # Initialize environment
        self.env = Environment()
        
        # Get suitable models for this environment
        self.available_models = self.env.get_suitable_models()
        print(f"Models available: {', '.join(self.available_models)}")
        
        # Setup Results Manager
        self.results_manager = ResultsManager(str(self.results_dir))
        self.results_df = pd.DataFrame()
        
        # Model size limits based on environment
        self.model_size_limits = {
            "gpt2": 1000,  # Always allow GPT-2 (124M params)
            "gpt2-medium": 1000,  # Always allow GPT-2 Medium (355M params)
            "opt-350m": 1000,  # Always allow OPT-350M
            "opt-125m": 1000,  # Always allow OPT-125M
            "facebook/opt-125m": 1000,  # Always allow OPT-125M
            "facebook/opt-350m": 1000,  # Always allow OPT-350M
        }
        
        # Add larger models only if we have enough resources
        if self.env.has_gpu or self.env.has_tpu:
            # If we have a GPU/TPU, allow medium-sized models
            self.model_size_limits.update({
                "gpt2-large": 1000,  # Allow GPT-2 Large (774M params) with GPU/TPU
                "facebook/opt-1.3b": 0.3,  # Allow OPT-1.3B with lower pruning levels only
            })
            
            if self.env.has_high_ram:
                # If we have high RAM, allow larger models
                self.model_size_limits.update({
                    "facebook/opt-2.7b": 0.2,  # Allow OPT-2.7B with lower pruning levels only
                })
    
    def run_experiment(self, strategies, pruning_levels, prompt, fine_tuning_epochs=1, max_runtime=3600):
        """Run the full experiment"""
        if not self.available_models:
            print("No suitable models found for this environment")
            return
        
        # Start time for runtime tracking
        start_time = time.time()
        
        # Generate all experiment combinations
        experiments = []
        for model in self.available_models:
            for strategy in strategies:
                for level in pruning_levels:
                    # Skip model/pruning combinations that would exceed memory limits
                    model_key = model.split('/')[-1] if '/' in model else model
                    model_size_limit = self.model_size_limits.get(model, self.model_size_limits.get(model_key, 0.0))
                    
                    if level > model_size_limit:
                        print(f"Skipping {model} with pruning level {level:.2f} - exceeds memory limits")
                        continue
                        
                    experiments.append({
                        "model": model,
                        "strategy": strategy,
                        "pruning_level": level,
                        "prompt": prompt,
                        "fine_tuning_epochs": fine_tuning_epochs
                    })
        
        # Shuffle to get more diverse results early
        random.shuffle(experiments)
        
        # Create progress bar
        pbar = tqdm(total=len(experiments), desc="Running experiments")
        
        # Run experiments
        for i, exp in enumerate(experiments):
            # Check if we've exceeded the runtime limit
            current_runtime = time.time() - start_time
            if max_runtime is not None and current_runtime > max_runtime:
                print(f"\nReached maximum runtime of {max_runtime/3600:.1f} hours")
                break
                
            # Update progress bar
            pbar.set_description(f"Testing {exp['model']}, {exp['strategy']}, {exp['pruning_level']:.2f}")
            
            # Run experiment
            try:
                result = self.run_single_experiment(**exp)
                if result is not None:
                    self.results.append(result)
                
                # Update progress bar
                pbar.update(1)
                
                # Plot intermediate results every few experiments
                if (i + 1) % 1 == 0 or i == len(experiments) - 1:
                    self.plot_results()
            except Exception as e:
                print(f"Error in experiment {exp['model']}, {exp['strategy']}, {exp['pruning_level']:.2f}: {e}")
                import traceback
                traceback.print_exc()
                # Still update progress bar
                pbar.update(1)
        
        # Close progress bar
        pbar.close()
        
        # Final results
        print(f"\nCompleted {len(self.results)} experiments out of {len(experiments)} attempted")
        runtime = time.time() - start_time
        print(f"Total runtime: {runtime/3600:.2f} hours ({runtime/60:.2f} minutes)")
        
        # Plot final results
        self.plot_results()
        
        return self.results
    
    def run_single_experiment(self, model, strategy, pruning_level, prompt, fine_tuning_epochs=1):
        """Run a single experiment with pruning and fine-tuning"""
        print(f"\n{'='*80}")
        print(f"Experiment: {model}, {strategy} strategy, {pruning_level:.2f} pruning level")
        print(f"{'='*80}")
        
        # Initialize pruning module
        pruning_module = PruningModule(model)
        if not pruning_module.load_model():
            print(f"Failed to load model {model}")
            return None
        
        # Store model name in the module for architecture detection
        pruning_module.model_name = model
        
        # Safety check - if model is too large, skip
        if "opt-1.3b" in model.lower() and pruning_level < 0.3:
            print(f"WARNING: {model} with pruning level {pruning_level:.2f} may cause memory issues")
        
        # Setup experiment record
        self.current_experiment = {
            "model": model,
            "strategy": strategy,
            "pruning_level": pruning_level,
            "prompt": prompt,
            "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "stages": {}
        }
        
        # 1. Evaluate baseline model
        print("\n>> Stage 1: Evaluating baseline model")
        original_params = pruning_module.original_params
        
        # Evaluate perplexity and generation
        perplexity_baseline = pruning_module.evaluate_perplexity(original_params, prompt)
        print(f"Baseline perplexity: {perplexity_baseline:.4f}")
        
        generated_baseline = pruning_module.generate_text(original_params, prompt)
        print(f"Baseline generated: {generated_baseline}")
        
        # Record baseline results
        self.current_experiment["stages"]["baseline"] = {
            "perplexity": float(perplexity_baseline),
            "generated_text": generated_baseline
        }
        
        # 2. Apply pruning
        print("\n>> Stage 2: Applying pruning")
        pruning_strat = get_strategy(strategy, pruning_module, prompt)
        
        # Calculate importance scores
        print("Calculating head importance...")
        all_head_importance = pruning_strat.get_head_importance(original_params)
        
        # Sort by importance (ascending)
        all_head_importance.sort(key=lambda x: x[2])
        
        # Determine number of heads to prune
        total_heads = pruning_module.num_layers * pruning_module.num_heads
        heads_to_prune = int(total_heads * pruning_level)
        print(f"Pruning {heads_to_prune} out of {total_heads} heads")
        
        # Get head indices to prune (least important first)
        head_indices = [(l, h) for l, h, _ in all_head_importance[:heads_to_prune]]
        
        # Prune heads
        print("Pruning heads...")
        pruned_params = pruning_strat.prune_heads(original_params, head_indices)
        
        # Evaluate after pruning
        perplexity_pruned = pruning_module.evaluate_perplexity(pruned_params, prompt)
        print(f"Pruned perplexity: {perplexity_pruned:.4f}")
        
        generated_pruned = pruning_module.generate_text(pruned_params, prompt)
        print(f"Pruned generated: {generated_pruned}")
        
        # Record pruning results
        self.current_experiment["stages"]["pruned"] = {
            "perplexity": float(perplexity_pruned),
            "perplexity_change": float(perplexity_pruned - perplexity_baseline),
            "generated_text": generated_pruned,
            "pruned_heads": heads_to_prune,
            "total_heads": total_heads,
            "head_indices": head_indices
        }
        
        # 3. Fine-tune the pruned model
        print("\n>> Stage 3: Fine-tuning the pruned model")
        
        # Create fine-tuner - use specific wikitext config and OpenWebText as fallback
        dataset_name = "wikitext-2-v1"  # Specify the config name
        dataset_config = "wikitext"
        
        # Adjust batch size based on model size and available hardware
        if self.env.in_colab and self.env.has_tpu:
            # TPUs can handle larger batch sizes
            batch_size = 16
        elif self.env.in_colab and self.env.has_gpu:
            batch_size = 8
            # Reduce batch size for larger models
            if "1.3b" in model.lower() or "large" in model.lower():
                batch_size = 4
        else:
            batch_size = 4
            
        fine_tuner = FineTuner(
            pruning_module, 
            dataset_name=dataset_config, 
            dataset_config=dataset_name, 
            batch_size=batch_size
        )
        
        # Fine-tune model
        try:
            tuned_params, metrics = fine_tuner.fine_tune(
                pruned_params, 
                num_epochs=fine_tuning_epochs,
                learning_rate=5e-5,
                evaluate_interval=5
            )
            
            # Plot training progress
            fine_tuner.plot_training_progress()
            
            # Evaluate fine-tuned model
            perplexity_tuned = pruning_module.evaluate_perplexity(tuned_params, prompt)
            print(f"Fine-tuned perplexity: {perplexity_tuned:.4f}")
            
            generated_tuned = pruning_module.generate_text(tuned_params, prompt)
            print(f"Fine-tuned generated: {generated_tuned}")
            
            # Record fine-tuning results
            self.current_experiment["stages"]["fine_tuned"] = {
                "perplexity": float(perplexity_tuned),
                "perplexity_change_from_baseline": float(perplexity_tuned - perplexity_baseline),
                "perplexity_change_from_pruned": float(perplexity_tuned - perplexity_pruned),
                "generated_text": generated_tuned,
                "training_epochs": fine_tuning_epochs,
                "training_metrics": metrics
            }
            
            # Compute recovery percentage
            if perplexity_pruned > perplexity_baseline:
                # Calculate how much of the perplexity increase was recovered
                perplexity_increase = perplexity_pruned - perplexity_baseline
                perplexity_recovery = perplexity_pruned - perplexity_tuned
                recovery_percentage = (perplexity_recovery / perplexity_increase) * 100 if perplexity_increase > 0 else 0
                
                self.current_experiment["stages"]["fine_tuned"]["recovery_percentage"] = float(recovery_percentage)
                print(f"Recovery percentage: {recovery_percentage:.2f}%")
            else:
                # Pruning improved perplexity, so we measure improvement from baseline
                improvement_percentage = ((perplexity_baseline - perplexity_tuned) / perplexity_baseline) * 100
                
                self.current_experiment["stages"]["fine_tuned"]["improvement_percentage"] = float(improvement_percentage)
                print(f"Improvement percentage: {improvement_percentage:.2f}%")
        
        except Exception as e:
            print(f"Error during fine-tuning: {e}")
            # Continue with partial results
            import traceback
            traceback.print_exc()
        
        # 4. Save results
        print("\n>> Stage 4: Saving results")
        
        # Save to disk
        timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
        result_filename = f"{model.replace('/', '_')}_{strategy}_{pruning_level:.2f}_{timestamp}.json"
        result_path = self.results_dir / result_filename
        
        import json
        with open(result_path, "w") as f:
            json.dump(self.current_experiment, f, indent=2)
            
        print(f"Results saved to {result_path}")
        
        # Update DataFrame for plotting
        self._update_dataframe()
        
        return self.current_experiment
    
    def _update_dataframe(self):
        """Update DataFrame for visualization"""
        # Extract data for DataFrame
        data = []
        
        for result in self.results:
            # Extract model and strategy info
            model = result["model"]
            strategy = result["strategy"]
            pruning_level = result["pruning_level"]
            
            # Add baseline stage
            if "baseline" in result["stages"]:
                baseline = result["stages"]["baseline"]
                data.append({
                    "model": model,
                    "strategy": strategy,
                    "pruning_level": pruning_level,
                    "stage": "baseline",
                    "perplexity": baseline["perplexity"]
                })
            
            # Add pruned stage
            if "pruned" in result["stages"]:
                pruned = result["stages"]["pruned"]
                data.append({
                    "model": model,
                    "strategy": strategy,
                    "pruning_level": pruning_level,
                    "stage": "pruned",
                    "perplexity": pruned["perplexity"],
                    "perplexity_change": pruned.get("perplexity_change", 0)
                })
                
            # Add fine-tuned stage
            if "fine_tuned" in result["stages"]:
                fine_tuned = result["stages"]["fine_tuned"]
                data.append({
                    "model": model,
                    "strategy": strategy,
                    "pruning_level": pruning_level,
                    "stage": "fine_tuned",
                    "perplexity": fine_tuned["perplexity"],
                    "perplexity_change_from_baseline": fine_tuned.get("perplexity_change_from_baseline", 0),
                    "perplexity_change_from_pruned": fine_tuned.get("perplexity_change_from_pruned", 0),
                    "recovery_percentage": fine_tuned.get("recovery_percentage", None),
                    "improvement_percentage": fine_tuned.get("improvement_percentage", None)
                })
        
        self.results_df = pd.DataFrame(data)
    
    def plot_results(self, figsize=(10, 8)):
        """Plot comprehensive experiment results"""
        if not self.results:
            print("No results available yet")
            return
            
        # Update DataFrame
        self._update_dataframe()
            
        if self.results_df.empty:
            print("No data available for plotting")
            return
        
        # Set plot styling to fix layout issues
        plt.rcParams.update({
            'figure.figsize': figsize,
            'figure.titlesize': 14,
            'axes.titlesize': 11,
            'axes.labelsize': 10,
            'xtick.labelsize': 9,
            'ytick.labelsize': 9,
            'legend.fontsize': 7,
            'legend.title_fontsize': 8,
            'font.family': 'sans-serif'
        })
        
        # Create figure
        fig = plt.figure(figsize=figsize)
        
        # 1. Perplexity across stages by model and strategy
        plt.subplot(2, 2, 1)
        
        # Get unique models and strategies
        models = self.results_df["model"].unique()
        strategies = self.results_df["strategy"].unique()
        
        # For display, shorten model names
        model_display = {m: m.split('/')[-1] if '/' in m else m for m in models}
        
        # Filter to main stages
        stages_df = self.results_df[self.results_df["stage"].isin(["baseline", "pruned", "fine_tuned"])]
        
        # Plot lines connecting stages for each experiment
        for model in models:
            model_df = stages_df[stages_df["model"] == model]
            
            for strategy in strategies:
                strategy_df = model_df[model_df["strategy"] == strategy]
                
                for pruning_level in strategy_df["pruning_level"].unique():
                    experiment_df = strategy_df[strategy_df["pruning_level"] == pruning_level]
                    
                    # Sort by stage to ensure correct order
                    stage_order = {"baseline": 0, "pruned": 1, "fine_tuned": 2}
                    experiment_df = experiment_df.sort_values(by="stage", key=lambda x: x.map(stage_order))
                    
                    # Plot if we have all stages
                    if len(experiment_df) >= 2:
                        label = f"{model_display[model][:6]}, {strategy[:3]}, {pruning_level:.1f}"
                        plt.plot(experiment_df["stage"], experiment_df["perplexity"], "o-", label=label)
        
        plt.title("Perplexity Across Stages")
        plt.xlabel("Stage")
        plt.ylabel("Perplexity")
        plt.xticks(rotation=45)
        plt.legend(fontsize=7, loc='best', ncol=2)
        plt.grid(True, alpha=0.3)
        
        # 2. Recovery percentage vs pruning level
        plt.subplot(2, 2, 2)
        
        # Get data with recovery information
        recovery_df = self.results_df[self.results_df["stage"] == "fine_tuned"].copy()
        
        if not recovery_df.empty:
            # Create recovery column (combining both metrics)
            recovery_df["recovery"] = recovery_df["recovery_percentage"]
            # If improvement percentage exists and recovery is NaN, use negative of improvement
            mask = recovery_df["recovery"].isna() & recovery_df["improvement_percentage"].notna()
            recovery_df.loc[mask, "recovery"] = -recovery_df.loc[mask, "improvement_percentage"]
            
            # Plot by strategy
            for strategy in strategies:
                strategy_df = recovery_df[recovery_df["strategy"] == strategy]
                if not strategy_df.empty:
                    for model in models:
                        model_strategy_df = strategy_df[strategy_df["model"] == model]
                        if not model_strategy_df.empty:
                            # Sort by pruning level
                            model_strategy_df = model_strategy_df.sort_values("pruning_level")
                            plt.plot(model_strategy_df["pruning_level"], model_strategy_df["recovery"], 
                                    "o-", label=f"{model_display[model][:6]}, {strategy[:3]}")
            
            plt.axhline(y=0, color="k", linestyle="--", alpha=0.3)
            plt.axhline(y=100, color="g", linestyle="--", alpha=0.3)
            plt.text(0.01, 100, "Full Recovery", color="green", ha="left", va="bottom", fontsize=8)
            plt.text(0.01, -5, "Improvement", color="blue", ha="left", va="top", fontsize=8)
            
            plt.title("Recovery Percentage by Pruning Level")
            plt.xlabel("Pruning Level")
            plt.ylabel("Recovery % (negative means improvement)")
            plt.legend(fontsize=7, loc='best')
            plt.grid(True, alpha=0.3)
        else:
            plt.text(0.5, 0.5, "No recovery data available yet", 
                    ha="center", va="center", fontsize=12)
        
        # 3. Perplexity change: pruning vs fine-tuning effect
        plt.subplot(2, 2, 3)
        
        if "perplexity_change" in self.results_df.columns and "perplexity_change_from_pruned" in self.results_df.columns:
            # Get pruning change
            pruned_df = self.results_df[self.results_df["stage"] == "pruned"].copy()
            pruned_df = pruned_df[["model", "strategy", "pruning_level", "perplexity_change"]]
            
            # Get fine-tuning change
            finetuned_df = self.results_df[self.results_df["stage"] == "fine_tuned"].copy()
            finetuned_df = finetuned_df[["model", "strategy", "pruning_level", "perplexity_change_from_pruned"]]
            
            # Merge
            effects_df = pd.merge(
                pruned_df, finetuned_df,
                on=["model", "strategy", "pruning_level"],
                suffixes=("_pruning", "_finetuning")
            )
            
            if not effects_df.empty:
                # Plot scatter with size based on pruning level
                for strategy in strategies:
                    strategy_df = effects_df[effects_df["strategy"] == strategy]
                    if not strategy_df.empty:
                        for model in models:
                            model_df = strategy_df[strategy_df["model"] == model]
                            if not model_df.empty:
                                plt.scatter(
                                    model_df["perplexity_change"], 
                                    model_df["perplexity_change_from_pruned"],
                                    s=model_df["pruning_level"] * 300,  # Size based on pruning level
                                    label=f"{model_display[model][:6]}, {strategy[:3]}",
                                    alpha=0.7
                                )
                
                plt.axhline(y=0, color="k", linestyle="--", alpha=0.3)
                plt.axvline(x=0, color="k", linestyle="--", alpha=0.3)
                
                # Add quadrant labels (smaller font)
                plt.text(-5, -5, "Both improved", fontsize=8, ha="center", va="center",
                        bbox=dict(facecolor="lightgreen", alpha=0.5))
                plt.text(5, -5, "Pruning hurt,\nFine-tuning fixed", fontsize=8, ha="center", va="center",
                        bbox=dict(facecolor="lightblue", alpha=0.5))
                plt.text(-5, 5, "Pruning helped,\nFine-tuning hurt", fontsize=8, ha="center", va="center",
                        bbox=dict(facecolor="lightyellow", alpha=0.5))
                plt.text(5, 5, "Both hurt", fontsize=8, ha="center", va="center",
                        bbox=dict(facecolor="lightcoral", alpha=0.5))
                
                plt.title("Effect of Pruning vs. Fine-tuning")
                plt.xlabel("Perplexity Change from Pruning")
                plt.ylabel("Perplexity Change from Fine-tuning")
                plt.legend(fontsize=7, loc='best')
                plt.grid(True, alpha=0.3)
            else:
                plt.text(0.5, 0.5, "No effect data available yet", 
                        ha="center", va="center", fontsize=12)
        else:
            plt.text(0.5, 0.5, "No effect data available yet", 
                    ha="center", va="center", fontsize=12)
        
        # 4. Final results: perplexity reduction by pruning level and strategy
        plt.subplot(2, 2, 4)
        
        if "perplexity_change_from_baseline" in self.results_df.columns:
            # Get baseline and final results
            baseline_df = self.results_df[self.results_df["stage"] == "baseline"].copy()
            baseline_df = baseline_df[["model", "strategy", "pruning_level", "perplexity"]]
            baseline_df = baseline_df.rename(columns={"perplexity": "baseline_perplexity"})
            
            final_df = self.results_df[self.results_df["stage"] == "fine_tuned"].copy()
            final_df = final_df[["model", "strategy", "pruning_level", "perplexity", "perplexity_change_from_baseline"]]
            final_df = final_df.rename(columns={"perplexity": "final_perplexity"})
            
            # Merge
            final_results = pd.merge(
                baseline_df, final_df,
                on=["model", "strategy", "pruning_level"]
            )
            
            if not final_results.empty:
                # Plot as bar chart
                # Group by pruning level and strategy
                grouped = final_results.groupby(["pruning_level", "strategy"])["perplexity_change_from_baseline"].mean().reset_index()
                
                # Pivot for grouped bar chart
                pivot_df = grouped.pivot(index="pruning_level", columns="strategy", values="perplexity_change_from_baseline")
                
                # Plot
                pivot_df.plot(kind="bar", ax=plt.gca())
                
                plt.axhline(y=0, color="k", linestyle="--", alpha=0.3)
                plt.title("Final Perplexity Change from Baseline")
                plt.xlabel("Pruning Level")
                plt.ylabel("Perplexity Change")
                plt.legend(title="Strategy", fontsize=7)
                plt.grid(True, alpha=0.3, axis="y")
            else:
                plt.text(0.5, 0.5, "No final results available yet", 
                        ha="center", va="center", fontsize=12)
        else:
            plt.text(0.5, 0.5, "No final results available yet", 
                    ha="center", va="center", fontsize=12)
        
        # Apply tight layout to reduce white space
        plt.tight_layout(pad=1.5)
        plt.subplots_adjust(bottom=0.15)
        plt.show()
        
        return fig

## Run the Experiment

Now we can run the full experiment:

In [None]:
# Initialize experiment
experiment = PruningFineTuningExperiment("pruning_finetuning_results")

In [None]:
# Configuration
STRATEGIES = ["random", "magnitude", "entropy"]
PRUNING_LEVELS = [0.1, 0.3, 0.5]
PROMPT = "Artificial intelligence will transform society by"
FINE_TUNING_EPOCHS = 2  # Small number for quick iterations
MAX_RUNTIME = 6 * 3600  # 6 hours

# Start the experiment
results = experiment.run_experiment(
    strategies=STRATEGIES,
    pruning_levels=PRUNING_LEVELS,
    prompt=PROMPT,
    fine_tuning_epochs=FINE_TUNING_EPOCHS,
    max_runtime=MAX_RUNTIME
)

## Longer Overnight Run

For an extended overnight run, uncomment and run this cell:

## Real-time Experiment Monitoring

The cell below can be executed independently while experiments are running to visualize the current state of experiments.

In [ ]:
# This cell can be run at any time to visualize current experiment progress
import os
import glob
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

def visualize_ongoing_experiments(results_dir="pruning_finetuning_results", figsize=(10, 8)):
    """
    Create a real-time visualization of ongoing experiments
    This can be run independently while experiments are in progress
    """
    # Set better default styling for plots to fix layout issues
    plt.rcParams.update({
        'figure.figsize': (10, 8),
        'figure.titlesize': 14,
        'axes.titlesize': 11,
        'axes.labelsize': 10,
        'xtick.labelsize': 9,
        'ytick.labelsize': 9,
        'legend.fontsize': 7,
        'legend.title_fontsize': 8,
        'font.family': 'sans-serif'
    })
    
    # Check if results directory exists
    if not os.path.exists(results_dir):
        print(f"Results directory '{results_dir}' not found")
        return
    
    # List all result files
    result_files = glob.glob(os.path.join(results_dir, "*.json"))
    
    if not result_files:
        print(f"No result files found in '{results_dir}'")
        return
    
    # Load all result files
    results = []
    for file_path in result_files:
        try:
            with open(file_path, 'r') as f:
                result = json.load(f)
                results.append(result)
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
    
    if not results:
        print("No valid result files found")
        return
    
    print(f"Found {len(results)} experiment results")
    
    # Extract data for visualization
    data = []
    
    for result in results:
        # Extract experiment info
        model = result.get("model", "unknown")
        # Shorten model name for better display
        if '/' in model:
            model = model.split('/')[-1]
        strategy = result.get("strategy", "unknown")
        pruning_level = result.get("pruning_level", 0)
        timestamp = result.get("timestamp", "unknown")
        
        # Extract perplexity data from different stages
        stages = result.get("stages", {})
        
        baseline_perplexity = stages.get("baseline", {}).get("perplexity", None)
        pruned_perplexity = stages.get("pruned", {}).get("perplexity", None)
        finetuned_perplexity = stages.get("fine_tuned", {}).get("perplexity", None)
        
        recovery_percentage = stages.get("fine_tuned", {}).get("recovery_percentage", None)
        improvement_percentage = stages.get("fine_tuned", {}).get("improvement_percentage", None)
        
        # Add to dataframe
        if baseline_perplexity is not None:
            data.append({
                "model": model,
                "strategy": strategy,
                "pruning_level": pruning_level,
                "stage": "baseline",
                "perplexity": baseline_perplexity,
                "timestamp": timestamp
            })
        
        if pruned_perplexity is not None:
            data.append({
                "model": model,
                "strategy": strategy,
                "pruning_level": pruning_level,
                "stage": "pruned",
                "perplexity": pruned_perplexity,
                "timestamp": timestamp
            })
        
        if finetuned_perplexity is not None:
            data.append({
                "model": model,
                "strategy": strategy,
                "pruning_level": pruning_level,
                "stage": "fine_tuned",
                "perplexity": finetuned_perplexity,
                "recovery_percentage": recovery_percentage,
                "improvement_percentage": improvement_percentage,
                "timestamp": timestamp
            })
    
    # Convert to dataframe
    df = pd.DataFrame(data)
    
    if df.empty:
        print("No valid data extracted from results")
        return
    
    # Create figure with a more compact layout
    fig = plt.figure(figsize=figsize)
    
    # 1. Perplexity across stages by model and strategy
    plt.subplot(2, 2, 1)
    
    # Get unique values for grouping
    models = df["model"].unique()
    strategies = df["strategy"].unique()
    
    # Filter to main stages
    stages_df = df[df["stage"].isin(["baseline", "pruned", "fine_tuned"])]
    
    # Plot lines connecting stages for each experiment
    for model in models:
        model_df = stages_df[stages_df["model"] == model]
        
        for strategy in strategies:
            strategy_df = model_df[model_df["strategy"] == strategy]
            
            for pruning_level in strategy_df["pruning_level"].unique():
                experiment_df = strategy_df[strategy_df["pruning_level"] == pruning_level]
                
                # Sort by stage
                stage_order = {"baseline": 0, "pruned": 1, "fine_tuned": 2}
                experiment_df = experiment_df.sort_values(by="stage", key=lambda x: x.map(stage_order))
                
                # Only plot if we have at least 2 stages
                if len(experiment_df) >= 2:
                    label = f"{model[:6]}-{strategy[:3]}-{pruning_level:.1f}"
                    plt.plot(experiment_df["stage"], experiment_df["perplexity"], "o-", label=label)
    
    plt.title("Perplexity Across Stages")
    plt.xlabel("Stage")
    plt.ylabel("Perplexity")
    plt.xticks(rotation=45)
    plt.legend(fontsize=7, loc='upper right', ncol=2)
    plt.grid(True, alpha=0.3)
    
    # 2. Pruning level vs perplexity by strategy
    plt.subplot(2, 2, 2)
    
    # Filter to specific stages
    baseline_df = df[df["stage"] == "baseline"]
    pruned_df = df[df["stage"] == "pruned"]
    finetuned_df = df[df["stage"] == "fine_tuned"]
    
    for strategy in strategies:
        # Get strategy data for each stage
        baseline_strategy = baseline_df[baseline_df["strategy"] == strategy]
        pruned_strategy = pruned_df[pruned_df["strategy"] == strategy]
        finetuned_strategy = finetuned_df[finetuned_df["strategy"] == strategy]
        
        # Plot lines for each stage if data exists
        if not baseline_strategy.empty:
            plt.plot(baseline_strategy["pruning_level"], baseline_strategy["perplexity"], 
                    "o--", label=f"Base-{strategy[:3]}", alpha=0.7)
        
        if not pruned_strategy.empty:
            plt.plot(pruned_strategy["pruning_level"], pruned_strategy["perplexity"], 
                    "s--", label=f"Pruned-{strategy[:3]}", alpha=0.7)
        
        if not finetuned_strategy.empty:
            plt.plot(finetuned_strategy["pruning_level"], finetuned_strategy["perplexity"], 
                    "^-", label=f"Tuned-{strategy[:3]}", alpha=0.7)
    
    plt.title("Perplexity vs Pruning Level")
    plt.xlabel("Pruning Level")
    plt.ylabel("Perplexity")
    plt.legend(fontsize=7, loc='best')
    plt.grid(True, alpha=0.3)
    
    # 3. Recovery/improvement percentages
    plt.subplot(2, 2, 3)
    
    # Create dataframe with recovery metrics
    recovery_df = finetuned_df.copy()
    
    if not recovery_df.empty:
        # Create unified recovery column (negative means improvement)
        recovery_df["recovery"] = recovery_df["recovery_percentage"]
        # If recovery is NaN but improvement exists, use negative of improvement
        mask = recovery_df["recovery"].isna() & recovery_df["improvement_percentage"].notna()
        recovery_df.loc[mask, "recovery"] = -recovery_df.loc[mask, "improvement_percentage"]
        
        # Plot by strategy
        for strategy in strategies:
            strategy_recovery = recovery_df[recovery_df["strategy"] == strategy]
            if not strategy_recovery.empty:
                # Sort by pruning level
                strategy_recovery = strategy_recovery.sort_values("pruning_level")
                plt.plot(strategy_recovery["pruning_level"], strategy_recovery["recovery"], 
                        "o-", label=strategy)
        
        plt.axhline(y=0, color="k", linestyle="--", alpha=0.3)
        plt.axhline(y=100, color="g", linestyle="--", alpha=0.3)
        plt.text(0.01, 100, "Full Recovery", color="green", ha="left", va="bottom", fontsize=8)
        plt.text(0.01, -5, "Improvement", color="blue", ha="left", va="top", fontsize=8)
        
        plt.title("Recovery/Improvement %")
        plt.xlabel("Pruning Level")
        plt.ylabel("% (negative = improvement)")
        plt.legend(fontsize=7, loc='best')
        plt.grid(True, alpha=0.3)
    else:
        plt.text(0.5, 0.5, "No recovery data available yet", 
                ha="center", va="center", fontsize=12)
    
    # 4. Status overview
    plt.subplot(2, 2, 4)
    
    # Count experiments by status
    total_exps = len(set([(r["model"], r["strategy"], r["pruning_level"]) for r in results]))
    completed_exps = len(finetuned_df)
    pruned_only = len(set(pruned_df["timestamp"])) - completed_exps
    baseline_only = len(set(baseline_df["timestamp"])) - pruned_only - completed_exps
    
    # Create status labels and counts
    status_labels = ["Completed", "Pruned", "Baseline", "Planned"]
    status_counts = [
        completed_exps,
        pruned_only,
        baseline_only,
        total_exps - completed_exps - pruned_only - baseline_only
    ]
    
    # Create status bar chart
    colors = ["green", "orange", "blue", "gray"]
    plt.bar(status_labels, status_counts, color=colors)
    
    for i, count in enumerate(status_counts):
        plt.text(i, count + 0.1, str(count), ha="center")
    
    plt.title(f"Experiment Status (Total: {total_exps})")
    plt.xlabel("Status")
    plt.ylabel("Count")
    plt.xticks(rotation=45)
    
    # Add timestamp
    plt.figtext(0.5, 0.01, f"Last updated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", 
               ha="center", fontsize=8)
    
    # Apply tight layout to reduce white space
    plt.tight_layout(pad=1.5)
    plt.subplots_adjust(bottom=0.15)
    plt.show()
    
    return df

# Run the visualization - you can run this cell repeatedly to refresh
df = visualize_ongoing_experiments()

# Display success count by strategy if we have data
if df is not None and not df.empty and "fine_tuned" in df["stage"].values:
    finetuned = df[df["stage"] == "fine_tuned"]
    
    # Calculate improvement status
    finetuned["status"] = "No Change"
    finetuned.loc[finetuned["perplexity"] < finetuned["perplexity"], "status"] = "Improved"
    finetuned.loc[finetuned["perplexity"] > finetuned["perplexity"], "status"] = "Degraded"
    
    # Count by strategy and status
    status_by_strategy = pd.crosstab(finetuned["strategy"], finetuned["status"])
    display(status_by_strategy)

## Head Importance Visualization

The cell below can be used to visualize which heads are most important in your model:

In [ ]:
# Visualize attention head importance for different pruning strategies
# This can help identify which heads are most critical for model performance

# Initialize the model
model_name = "gpt2"  # Change to one of the models you're using
pruning_module = PruningModule(model_name)
if not pruning_module.load_model():
    print(f"Failed to load model {model_name}")
else:
    # Get original parameters
    original_params = pruning_module.original_params
    
    # Set up a sample prompt
    prompt = "Artificial intelligence will transform society by"
    
    # Calculate importance for different strategies
    strategies = {}
    
    try:
        # Random strategy (baseline)
        random_strategy = get_strategy("random", pruning_module, prompt)
        random_importance = random_strategy.get_head_importance(original_params)
        strategies["random"] = random_importance
        
        # Magnitude strategy
        magnitude_strategy = get_strategy("magnitude", pruning_module, prompt)
        magnitude_importance = magnitude_strategy.get_head_importance(original_params)
        strategies["magnitude"] = magnitude_importance
        
        # Entropy strategy
        entropy_strategy = get_strategy("entropy", pruning_module, prompt)
        entropy_importance = entropy_strategy.get_head_importance(original_params)
        strategies["entropy"] = entropy_importance
        
        # Set better plot styling
        plt.rcParams.update({
            'figure.figsize': (12, 1.5 * pruning_module.num_layers),
            'figure.titlesize': 14,
            'axes.titlesize': 12,
            'axes.labelsize': 10,
            'xtick.labelsize': 9,
            'ytick.labelsize': 9,
            'legend.fontsize': 8,
            'font.family': 'sans-serif'
        })
        
        # Now visualize the head importance scores
        fig, axes = plt.subplots(pruning_module.num_layers, 3, figsize=(12, 1.5 * pruning_module.num_layers))
        
        # Create title
        fig.suptitle(f"Attention Head Importance by Strategy for {model_name}", fontsize=16)
        
        # Set column titles
        for i, strategy_name in enumerate(["random", "magnitude", "entropy"]):
            axes[0, i].set_title(f"{strategy_name.capitalize()} Strategy")
        
        # Create a heatmap for each strategy showing head importance
        for layer in range(pruning_module.num_layers):
            for i, strategy_name in enumerate(["random", "magnitude", "entropy"]):
                # Extract importance scores for this layer
                layer_scores = [score for l, h, score in strategies[strategy_name] if l == layer]
                
                # Create array for visualization
                scores_array = np.array(layer_scores).reshape(1, -1)
                
                # Create heatmap
                cax = axes[layer, i].imshow(scores_array, cmap="viridis", aspect="auto")
                
                # Add labels
                axes[layer, i].set_yticks([0])
                axes[layer, i].set_yticklabels([f"Layer {layer}"])
                axes[layer, i].set_xticks(range(pruning_module.num_heads))
                axes[layer, i].set_xticklabels([f"H{h}" for h in range(pruning_module.num_heads)], 
                                              rotation=90 if pruning_module.num_heads > 8 else 0)
                
                # Add importance values as text
                for h in range(pruning_module.num_heads):
                    score = scores_array[0, h]
                    if np.isnan(score):
                        text_color = "black"
                    else:
                        text_color = "white" if score > 0.5 else "black"
                    axes[layer, i].text(h, 0, f"{score:.2f}", ha="center", va="center", 
                                       color=text_color, fontsize=8)
        
        # Add a colorbar
        fig.colorbar(cax, ax=axes.ravel().tolist(), shrink=0.6)
        
        plt.tight_layout(rect=[0, 0, 1, 0.96])  # Make room for the title
        plt.show()
        
        # Now show the top 10 most and least important heads according to entropy
        # (usually considered the most accurate measure)
        sorted_entropy = sorted(entropy_importance, key=lambda x: x[2])
        
        print("Top 10 Least Important Heads (candidates for pruning):")
        for i, (layer, head, score) in enumerate(sorted_entropy[:10]):
            print(f"{i+1}. Layer {layer}, Head {head}: {score:.4f}")
            
        print("\nTop 10 Most Important Heads (preserved even with aggressive pruning):")
        for i, (layer, head, score) in enumerate(sorted_entropy[-10:]):
            print(f"{i+1}. Layer {layer}, Head {head}: {score:.4f}")
    
    except Exception as e:
        print(f"Error calculating head importance: {e}")
        import traceback
        traceback.print_exc()

In [ ]:
# Overnight Configuration - More conservative settings
OVERNIGHT_STRATEGIES = ["random", "magnitude", "entropy"]
OVERNIGHT_PRUNING_LEVELS = [0.1, 0.2, 0.3, 0.4, 0.5]  # Skip higher levels to avoid memory issues
OVERNIGHT_PROMPT = "Artificial intelligence will revolutionize industries by"
OVERNIGHT_FINE_TUNING_EPOCHS = 3  # Reduced from 5 to avoid memory issues with larger models
OVERNIGHT_MAX_RUNTIME = 20 * 3600  # 20 hours

# Initialize experiment for overnight run
overnight_experiment = PruningFineTuningExperiment("overnight_results")

# Run overnight experiment (uncomment to run)
overnight_results = overnight_experiment.run_experiment(
    strategies=OVERNIGHT_STRATEGIES,
    pruning_levels=OVERNIGHT_PRUNING_LEVELS,
    prompt=OVERNIGHT_PROMPT,
    fine_tuning_epochs=OVERNIGHT_FINE_TUNING_EPOCHS,
    max_runtime=OVERNIGHT_MAX_RUNTIME
)

## Comprehensive Analysis

After collecting results, run a comprehensive analysis:

In [ ]:
# Plot results with improved sizing
experiment.plot_results(figsize=(10, 8))

In [None]:
# Create a summary table
if not experiment.results_df.empty:
    # Get data for different stages
    baseline_df = experiment.results_df[experiment.results_df["stage"] == "baseline"][["model", "strategy", "pruning_level", "perplexity"]]
    baseline_df = baseline_df.rename(columns={"perplexity": "baseline_perplexity"})
    
    pruned_df = experiment.results_df[experiment.results_df["stage"] == "pruned"][["model", "strategy", "pruning_level", "perplexity"]]
    pruned_df = pruned_df.rename(columns={"perplexity": "pruned_perplexity"})
    
    finetuned_df = experiment.results_df[experiment.results_df["stage"] == "fine_tuned"][["model", "strategy", "pruning_level", "perplexity"]]
    finetuned_df = finetuned_df.rename(columns={"perplexity": "finetuned_perplexity"})
    
    # Merge dataframes
    summary = pd.merge(baseline_df, pruned_df, on=["model", "strategy", "pruning_level"])
    summary = pd.merge(summary, finetuned_df, on=["model", "strategy", "pruning_level"])
    
    # Calculate changes
    summary["pruning_effect"] = summary["pruned_perplexity"] - summary["baseline_perplexity"]
    summary["finetuning_effect"] = summary["finetuned_perplexity"] - summary["pruned_perplexity"]
    summary["net_change"] = summary["finetuned_perplexity"] - summary["baseline_perplexity"]
    
    # Display summary
    summary.head()