In [None]:
# Install required packages (critical to do this before anything else)
print("🔧 Installing dependencies...")
!pip install -q jax jaxlib flax transformers matplotlib numpy pandas seaborn tqdm optax
# Install datasets explicitly with required version to ensure compatibility
!pip install -q 'datasets>=2.0.0' multiprocess

# %%

In [None]:
# Import from HuggingFace datasets and our data_modules without conflicts
# Our internal module was renamed from 'datasets' to 'data_modules' to avoid namespace collisions
import datasets
from datasets import load_dataset
print(f"Using HuggingFace datasets from: {datasets.__file__}")

# Import rest of the libraries
import os
import sys
import json
import time
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from pathlib import Path
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Import JAX/Flax
import jax
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState

# Import Hugging Face libraries
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM

# Add the current directory to path and import our modules
# Prioritize local imports for our own modules, but keep system modules first
# This ensures we use our utils package but system's datasets
system_paths = [p for p in sys.path if '/usr/local' in p or 'python3' in p or 'site-packages' in p]
local_paths = ["."]  # Current directory first
other_paths = [p for p in sys.path if p not in system_paths and p != "."]
sys.path = system_paths + local_paths + other_paths
from utils.pruning import (
    Environment,
    ResultsManager,
    PruningModule, 
    get_strategy,
    FineTuner,
    ImprovedFineTuner
)
from utils.pruning.stability import patch_fine_tuner, optimize_fine_tuner

# Set up plotting
plt.style.use('ggplot')
sns.set_theme(style="whitegrid")

# %%

In [None]:
# Initialize environment and detect capabilities
env = Environment()
env.print_info()

# Check JAX capabilities
print(f"\nJAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

# %%

In [None]:
# Define a custom PruningFineTuningExperiment class for the Colab notebook
# Note: This is different from the one in utils.pruning.experiment to support additional parameters
class PruningFineTuningExperiment:
    """Manages the pruning + fine-tuning experiment"""
    
    def __init__(self, 
                 results_dir="pruning_finetuning_results",
                 use_improved_fine_tuner=True,
                 detect_environment=True,
                 optimize_memory=True,
                 batch_size=None,
                 sequence_length=None,
                 stability_level=None,
                 stability_level=None):
        self.results_dir = Path(results_dir)
        self.results_dir.mkdir(exist_ok=True, parents=True)
        self.results = []
        self.current_experiment = {}
        
        # Store configuration parameters
        self.use_improved_fine_tuner = use_improved_fine_tuner
        self.optimize_memory = optimize_memory
        self.batch_size = batch_size
        # Get sequence length from environment if set
        sequence_length = getattr(self.env, "seq_length", 64) if hasattr(self.env, "seq_length") else 64
        # Use ImprovedFineTuner with stability level if provided
        stability_level = getattr(self.env, "stability_level", 2) if hasattr(self.env, "stability_level") else 2
        
        # Initialize environment
        self.env = Environment()
        
        # Add has_high_ram attribute if not present in Environment class
        # This ensures backward compatibility while also providing better memory detection
        if not hasattr(self.env, 'has_high_ram'):
            self.env.has_high_ram = False
            try:
                import psutil
                total_ram = psutil.virtual_memory().total / (1024**3)  # Convert to GB
                self.env.has_high_ram = total_ram > 12
                print(f"Detected {total_ram:.1f}GB RAM, high RAM: {self.env.has_high_ram}")
            except:
                print("Could not detect RAM, assuming standard memory")
                
        # Detect GPU memory if possible
        self.gpu_memory_gb = 0
        if self.env.has_gpu:
            try:
                import torch
                gpu_props = torch.cuda.get_device_properties(0)
                self.gpu_memory_gb = gpu_props.total_memory / (1024**3)
                print(f"Detected GPU with {self.gpu_memory_gb:.1f}GB VRAM")
            except:
                try:
                    # Alternative method for JAX
                    device = jax.devices()[0]
                    if hasattr(device, 'memory_stats') and callable(device.memory_stats):
                        memory_stats = device.memory_stats()
                        if 'bytes_limit' in memory_stats:
                            self.gpu_memory_gb = memory_stats['bytes_limit'] / (1024**3)
                            print(f"Detected GPU with approximately {self.gpu_memory_gb:.1f}GB VRAM")
                except:
                    # Estimate based on common GPU types
                    if self.env.in_colab:
                        # T4 in Colab typically has 16GB
                        self.gpu_memory_gb = 16
                        print(f"Estimating Colab GPU with {self.gpu_memory_gb}GB VRAM")
        
        # Get suitable models for this environment
        self.available_models = self.env.get_suitable_models()
        print(f"Models available: {', '.join(self.available_models)}")
        
        # Apply manual configuration overrides if provided
        if self.batch_size is not None or self.sequence_length is not None or self.stability_level is not None:
            print(f"Using custom optimization parameters:")
            if self.batch_size is not None:
                print(f"  - Batch size: {self.batch_size}")
                self.env.batch_size = self.batch_size
            if self.sequence_length is not None:
                print(f"  - Sequence length: {self.sequence_length}")
                self.env.seq_length = self.sequence_length
            if self.stability_level is not None:
                print(f"  - Stability level: {self.stability_level}")
                self.env.stability_level = self.stability_level
        
        # Model size limits based on environment - adapt based on detected resources
        self.model_size_limits = {
            "gpt2": 1.0,  # Always allow GPT-2 (124M params)
            "gpt2-medium": 1.0,  # Always allow GPT-2 Medium (355M params)
            "opt-350m": 1.0,  # Always allow OPT-350M
            "opt-125m": 1.0,  # Always allow OPT-125M
            "facebook/opt-125m": 1.0,  # Always allow OPT-125M
            "facebook/opt-350m": 1.0,  # Always allow OPT-350M
            "EleutherAI/pythia-160m": 1.0,  # Always allow Pythia-160M
            "EleutherAI/pythia-410m": 1.0,  # Always allow Pythia-410M
        }
        
        # Add larger models only if we have enough resources
        if self.env.has_gpu and self.gpu_memory_gb >= 8:
            # If we have a GPU with 8+ GB, allow medium-sized models
            self.model_size_limits.update({
                "gpt2-large": 1.0,  # Allow GPT-2 Large (774M params) with sufficient GPU
                "EleutherAI/pythia-1b": 0.5,  # Allow Pythia-1B with pruning
                "facebook/opt-1.3b": 0.3,  # Allow OPT-1.3B with significant pruning only
            })
            
            if self.gpu_memory_gb >= 16:
                # If we have 16+ GB VRAM, allow larger models
                self.model_size_limits.update({
                    "gpt2-xl": 0.3,  # Allow GPT-2 XL with pruning
                    "facebook/opt-1.3b": 0.5,  # Allow OPT-1.3B with moderate pruning
                    "facebook/opt-2.7b": 0.2,  # Allow OPT-2.7B with heavy pruning
                })
        
        # Setup Results Manager
        self.results_manager = ResultsManager(str(self.results_dir))
        self.results_df = pd.DataFrame()
    
    def run_experiment(self, strategies, pruning_levels, prompt, fine_tuning_epochs=1, max_runtime=3600):
        """Run the full experiment"""
        if not self.available_models:
            print("No suitable models found for this environment")
            return
        
        # Start time for runtime tracking
        start_time = time.time()
        
        # Generate all experiment combinations
        experiments = []
        for model in self.available_models:
            for strategy in strategies:
                for level in pruning_levels:
                    experiments.append({
                        "model": model,
                        "strategy": strategy,
                        "pruning_level": level,
                        "prompt": prompt,
                        "fine_tuning_epochs": fine_tuning_epochs
                    })
        
        # Shuffle to get more diverse results early
        random.shuffle(experiments)
        
        # Create progress bar
        pbar = tqdm(total=len(experiments), desc="Running experiments")
        
        # Run experiments
        for i, exp in enumerate(experiments):
            # Check if we've exceeded the runtime limit
            current_runtime = time.time() - start_time
            if max_runtime is not None and current_runtime > max_runtime:
                print(f"\nReached maximum runtime of {max_runtime/3600:.1f} hours")
                break
                
            # Update progress bar
            pbar.set_description(f"Testing {exp['model']}, {exp['strategy']}, {exp['pruning_level']:.2f}")
            
            # Run experiment
            try:
                result = self.run_single_experiment(**exp)
                if result is not None:
                    self.results.append(result)
                
                # Update progress bar
                pbar.update(1)
                
                # Plot intermediate results every few experiments
                if (i + 1) % 1 == 0 or i == len(experiments) - 1:
                    self.plot_results()
            except Exception as e:
                print(f"Error in experiment {exp['model']}, {exp['strategy']}, {exp['pruning_level']:.2f}: {e}")
                import traceback
                traceback.print_exc()
                # Still update progress bar
                pbar.update(1)
        
        # Close progress bar
        pbar.close()
        
        # Final results
        print(f"\nCompleted {len(self.results)} experiments out of {len(experiments)} attempted")
        runtime = time.time() - start_time
        print(f"Total runtime: {runtime/3600:.2f} hours ({runtime/60:.2f} minutes)")
        
        # Plot final results
        self.plot_results()
        
        return self.results
    
    def run_single_experiment(self, model, strategy, pruning_level, prompt, fine_tuning_epochs=1):
        """Run a single experiment with pruning and fine-tuning"""
        print(f"\n{'='*80}")
        print(f"Experiment: {model}, {strategy} strategy, {pruning_level:.2f} pruning level")
        print(f"{'='*80}")
        
        # Initialize pruning module
        pruning_module = PruningModule(model)
        if not pruning_module.load_model():
            print(f"Failed to load model {model}")
            return None
        
        # Store model name for architecture detection
        pruning_module.model_name = model
        
        # Setup experiment record
        self.current_experiment = {
            "model": model,
            "strategy": strategy,
            "pruning_level": pruning_level,
            "prompt": prompt,
            "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "stages": {}
        }
        
        # 1. Evaluate baseline model
        print("\n>> Stage 1: Evaluating baseline model")
        original_params = pruning_module.original_params
        
        # Evaluate perplexity and generation
        perplexity_baseline = pruning_module.evaluate_perplexity(original_params, prompt)
        print(f"Baseline perplexity: {perplexity_baseline:.4f}")
        
        generated_baseline = pruning_module.generate_text(original_params, prompt)
        print(f"Baseline generated: {generated_baseline}")
        
        # Record baseline results
        self.current_experiment["stages"]["baseline"] = {
            "perplexity": float(perplexity_baseline),
            "generated_text": generated_baseline
        }
        
        # 2. Apply pruning
        print("\n>> Stage 2: Applying pruning")
        pruning_strat = get_strategy(strategy, pruning_module, prompt)
        
        # Calculate importance scores
        print("Calculating head importance...")
        all_head_importance = pruning_strat.get_head_importance(original_params)
        
        # Sort by importance (ascending)
        all_head_importance.sort(key=lambda x: x[2])
        
        # Determine number of heads to prune
        total_heads = pruning_module.num_layers * pruning_module.num_heads
        heads_to_prune = int(total_heads * pruning_level)
        print(f"Pruning {heads_to_prune} out of {total_heads} heads")
        
        # Get head indices to prune (least important first)
        head_indices = [(l, h) for l, h, _ in all_head_importance[:heads_to_prune]]
        
        # Prune heads
        print("Pruning heads...")
        pruned_params = pruning_strat.prune_heads(original_params, head_indices)
        
        # Evaluate after pruning
        perplexity_pruned = pruning_module.evaluate_perplexity(pruned_params, prompt)
        print(f"Pruned perplexity: {perplexity_pruned:.4f}")
        
        generated_pruned = pruning_module.generate_text(pruned_params, prompt)
        print(f"Pruned generated: {generated_pruned}")
        
        # Record pruning results
        self.current_experiment["stages"]["pruned"] = {
            "perplexity": float(perplexity_pruned),
            "perplexity_change": float(perplexity_pruned - perplexity_baseline),
            "generated_text": generated_pruned,
            "pruned_heads": heads_to_prune,
            "total_heads": total_heads,
            "head_indices": head_indices
        }
        
        # 3. Fine-tune the pruned model - with improved stability
        print("\n>> Stage 3: Fine-tuning the pruned model")
        
        # Create fine-tuner with dataset config
        dataset_name = "wikitext"
        dataset_config = "wikitext-2-v1"
        
        # Use custom parameters if provided, otherwise auto-detect
        if hasattr(self.env, 'batch_size') and self.env.batch_size is not None:
            batch_size = self.env.batch_size
            print(f"Using configured batch size: {batch_size}")
        else:
            # Determine batch size based on model size and environment
            if self.env.in_colab and self.env.has_tpu:
                # TPUs can handle larger batch sizes
                batch_size = 16
                # But still reduce for large models
                if "1.3b" in model.lower() or "large" in model.lower():
                    batch_size = 8
            elif self.env.in_colab and self.env.has_gpu:
                batch_size = 8
                # Reduce batch size for larger models
                if "1.3b" in model.lower() or "large" in model.lower():
                    batch_size = 4
                elif "2.7b" in model.lower() or "xl" in model.lower():
                    batch_size = 2
            else:
                # CPU-only case
                batch_size = 4
                # Even smaller for large models on CPU
                if "1.3b" in model.lower() or "large" in model.lower():
                    batch_size = 2
                elif "2.7b" in model.lower() or "xl" in model.lower():
                    batch_size = 1
        
        # Get sequence length from environment if set
        sequence_length = getattr(self.env, 'seq_length', 64) if hasattr(self.env, 'seq_length') else 64
        
        # Use ImprovedFineTuner with stability level if provided
        stability_level = getattr(self.env, 'stability_level', 2) if hasattr(self.env, 'stability_level') else 2
        print(f"Using ImprovedFineTuner (stability level: {stability_level})")
        print(f"Training parameters: batch_size={batch_size}, sequence_length={sequence_length}")
        
        fine_tuner = ImprovedFineTuner(
            pruning_module, 
            dataset_name=dataset_name,
            dataset_config=dataset_config,
            batch_size=batch_size,
            sequence_length=sequence_length,
            stability_level=stability_level
        )
        
        # Adjust learning rate based on model size
        model_name = model.lower()
        is_large_model = any(x in model_name for x in ['opt', '1.3b', 'large', 'bloom', '2.7b', 'xl'])
        learning_rate = 1e-5 if is_large_model else 5e-5
        
        # Fine-tune model
        try:
            # Install a emergency NaN detection system
            def emergency_fix_loss_fn(params, batch):
                """Safe loss function that prevents NaN values"""
                model = pruning_module.model
                
                # Extract labels from batch but don't pass them to the model
                labels = batch.pop("labels", None)
                
                # Check for NaNs in input
                for k, v in batch.items():
                    if jnp.isnan(v).any() or jnp.isinf(v).any():
                        # Replace NaNs with zeros
                        batch[k] = jnp.nan_to_num(v, nan=0.0, posinf=0.0, neginf=0.0)
                        print(f"Warning: Found NaN in {k}, replaced with zeros")
                
                try:
                    # Handle different model architectures
                    is_opt_model = 'opt' in model_name
                    
                    if is_opt_model:
                        # OPT models don't accept 'train' parameter
                        outputs = model(**batch, params=params)
                    else:
                        # Other models might need the 'train' parameter
                        outputs = model(**batch, params=params, train=True)
                        
                    logits = outputs.logits
                    
                    # Add labels back to batch
                    batch["labels"] = labels
                    
                    # Create mask for padding tokens
                    pad_token_id = pruning_module.tokenizer.pad_token_id  
                    loss_mask = (labels != pad_token_id)
                    
                    # Shift logits and labels
                    shift_logits = logits[:, :-1]
                    shift_labels = labels[:, 1:]
                    shift_mask = loss_mask[:, 1:]
                    
                    # Check for NaNs in logits
                    if jnp.isnan(shift_logits).any() or jnp.isinf(shift_logits).any():
                        print("Warning: Found NaN/Inf in logits - replacing with finite values")
                        shift_logits = jnp.nan_to_num(shift_logits, nan=0.0, posinf=1.0, neginf=-1.0)
                    
                    # Safe cross entropy computation
                    loss = optax.softmax_cross_entropy_with_integer_labels(
                        shift_logits, shift_labels
                    )
                    
                    # Check and fix NaNs in loss
                    if jnp.isnan(loss).any() or jnp.isinf(loss).any():
                        print("Warning: Found NaN/Inf in loss - replacing with finite values")
                        loss = jnp.nan_to_num(loss, nan=1.0, posinf=1.0, neginf=1.0)
                    
                    # Safe mean calculation
                    masked_loss = loss * shift_mask
                    mask_sum = shift_mask.sum()
                    
                    # Safe division
                    computed_loss = jnp.where(
                        mask_sum > 0,
                        masked_loss.sum() / mask_sum,
                        jnp.array(0.0, dtype=loss.dtype)
                    )
                    
                    if jnp.isnan(computed_loss) or jnp.isinf(computed_loss):
                        print("CRITICAL: NaN/Inf in final loss - using fallback value")
                        return jnp.array(1.0, dtype=loss.dtype)
                        
                    return computed_loss
                    
                except Exception as e:
                    print(f"Error in loss function: {e}")
                    # Return a safe default value
                    batch["labels"] = labels  # Restore labels
                    return jnp.array(1.0)  # Safe fallback
            
            # Patch the fine tuner with our stability utilities
            print("⚠️ Installing NaN-safe loss function")
            fine_tuner = patch_fine_tuner(fine_tuner, model_name=model)
            
            # Optimize memory usage based on model size and available GPU memory
            print(f"🧠 Optimizing memory usage for {model}")
            fine_tuner = optimize_fine_tuner(fine_tuner, model_name=model, gpu_memory_gb=self.gpu_memory_gb)
            
            # Proceed with fine-tuning
            tuned_params, metrics = fine_tuner.fine_tune(
                pruned_params, 
                num_epochs=fine_tuning_epochs,
                learning_rate=learning_rate,
                evaluate_interval=5
            )
            
            # Plot training progress
            fine_tuner.plot_training_progress()
            
            # Evaluate fine-tuned model
            perplexity_tuned = pruning_module.evaluate_perplexity(tuned_params, prompt)
            print(f"Fine-tuned perplexity: {perplexity_tuned:.4f}")
            
            generated_tuned = pruning_module.generate_text(tuned_params, prompt)
            print(f"Fine-tuned generated: {generated_tuned}")
            
            # Record fine-tuning results
            self.current_experiment["stages"]["fine_tuned"] = {
                "perplexity": float(perplexity_tuned),
                "perplexity_change_from_baseline": float(perplexity_tuned - perplexity_baseline),
                "perplexity_change_from_pruned": float(perplexity_tuned - perplexity_pruned),
                "generated_text": generated_tuned,
                "training_epochs": fine_tuning_epochs,
                "training_metrics": metrics
            }
            
            # Compute recovery percentage
            if perplexity_pruned > perplexity_baseline:
                # Calculate how much of the perplexity increase was recovered
                perplexity_increase = perplexity_pruned - perplexity_baseline
                perplexity_recovery = perplexity_pruned - perplexity_tuned
                recovery_percentage = (perplexity_recovery / perplexity_increase) * 100 if perplexity_increase > 0 else 0
                
                self.current_experiment["stages"]["fine_tuned"]["recovery_percentage"] = float(recovery_percentage)
                print(f"Recovery percentage: {recovery_percentage:.2f}%")
            else:
                # Pruning improved perplexity, so we measure improvement from baseline
                improvement_percentage = ((perplexity_baseline - perplexity_tuned) / perplexity_baseline) * 100
                
                self.current_experiment["stages"]["fine_tuned"]["improvement_percentage"] = float(improvement_percentage)
                print(f"Improvement percentage: {improvement_percentage:.2f}%")
        
        except Exception as e:
            print(f"Error during fine-tuning: {e}")
            # Continue with partial results
            import traceback
            traceback.print_exc()
        
        # 4. Save results
        print("\n>> Stage 4: Saving results")
        
        # Save to disk
        timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
        result_filename = f"{model.replace('/', '_')}_{strategy}_{pruning_level:.2f}_{timestamp}.json"
        result_path = self.results_dir / result_filename
        
        import json
        with open(result_path, "w") as f:
            json.dump(self.current_experiment, f, indent=2)
            
        print(f"Results saved to {result_path}")
        
        # Update DataFrame for plotting
        self._update_dataframe()
        
        return self.current_experiment
    
    def _update_dataframe(self):
        """Update DataFrame for visualization"""
        # Extract data for DataFrame
        data = []
        
        for result in self.results:
            # Extract model and strategy info
            model = result["model"]
            strategy = result["strategy"]
            pruning_level = result["pruning_level"]
            
            # Add baseline stage
            if "baseline" in result["stages"]:
                baseline = result["stages"]["baseline"]
                data.append({
                    "model": model,
                    "strategy": strategy,
                    "pruning_level": pruning_level,
                    "stage": "baseline",
                    "perplexity": baseline["perplexity"]
                })
            
            # Add pruned stage
            if "pruned" in result["stages"]:
                pruned = result["stages"]["pruned"]
                data.append({
                    "model": model,
                    "strategy": strategy,
                    "pruning_level": pruning_level,
                    "stage": "pruned",
                    "perplexity": pruned["perplexity"],
                    "perplexity_change": pruned.get("perplexity_change", 0)
                })
                
            # Add fine-tuned stage
            if "fine_tuned" in result["stages"]:
                fine_tuned = result["stages"]["fine_tuned"]
                data.append({
                    "model": model,
                    "strategy": strategy,
                    "pruning_level": pruning_level,
                    "stage": "fine_tuned",
                    "perplexity": fine_tuned["perplexity"],
                    "perplexity_change_from_baseline": fine_tuned.get("perplexity_change_from_baseline", 0),
                    "perplexity_change_from_pruned": fine_tuned.get("perplexity_change_from_pruned", 0),
                    "recovery_percentage": fine_tuned.get("recovery_percentage", None),
                    "improvement_percentage": fine_tuned.get("improvement_percentage", None)
                })
        
        self.results_df = pd.DataFrame(data)
    
    def plot_results(self, figsize=(15, 12)):
        """Plot comprehensive experiment results"""
        if not self.results:
            print("No results available yet")
            return
            
        # Update DataFrame
        self._update_dataframe()
            
        if self.results_df.empty:
            print("No data available for plotting")
            return
        
        # Create figure
        fig = plt.figure(figsize=figsize)
        
        # 1. Perplexity across stages by model and strategy
        plt.subplot(2, 2, 1)
        
        # Get unique models and strategies
        models = self.results_df["model"].unique()
        strategies = self.results_df["strategy"].unique()
        
        # Filter to main stages
        stages_df = self.results_df[self.results_df["stage"].isin(["baseline", "pruned", "fine_tuned"])]
        
        # Plot lines connecting stages for each experiment
        for model in models:
            model_df = stages_df[stages_df["model"] == model]
            
            for strategy in strategies:
                strategy_df = model_df[model_df["strategy"] == strategy]
                
                for pruning_level in strategy_df["pruning_level"].unique():
                    experiment_df = strategy_df[strategy_df["pruning_level"] == pruning_level]
                    
                    # Sort by stage to ensure correct order
                    stage_order = {"baseline": 0, "pruned": 1, "fine_tuned": 2}
                    experiment_df = experiment_df.sort_values(by="stage", key=lambda x: x.map(stage_order))
                    
                    # Plot if we have all stages
                    if len(experiment_df) >= 2:
                        label = f"{model}, {strategy}, {pruning_level:.2f}"
                        plt.plot(experiment_df["stage"], experiment_df["perplexity"], "o-", label=label)
        
        plt.title("Perplexity Across Stages")
        plt.xlabel("Stage")
        plt.ylabel("Perplexity")
        plt.xticks(rotation=45)
        plt.legend(fontsize=8)
        plt.grid(True)
        
        # 2. Recovery percentage vs pruning level
        plt.subplot(2, 2, 2)
        
        # Get data with recovery information
        recovery_df = self.results_df[self.results_df["stage"] == "fine_tuned"].copy()
        
        if not recovery_df.empty:
            # Create recovery column (combining both metrics)
            recovery_df["recovery"] = recovery_df["recovery_percentage"]
            # If improvement percentage exists and recovery is NaN, use negative of improvement
            mask = recovery_df["recovery"].isna() & recovery_df["improvement_percentage"].notna()
            recovery_df.loc[mask, "recovery"] = -recovery_df.loc[mask, "improvement_percentage"]
            
            # Plot by strategy
            for strategy in strategies:
                strategy_df = recovery_df[recovery_df["strategy"] == strategy]
                if not strategy_df.empty:
                    for model in models:
                        model_strategy_df = strategy_df[strategy_df["model"] == model]
                        if not model_strategy_df.empty:
                            # Sort by pruning level
                            model_strategy_df = model_strategy_df.sort_values("pruning_level")
                            plt.plot(model_strategy_df["pruning_level"], model_strategy_df["recovery"], 
                                    "o-", label=f"{model}, {strategy}")
            
            plt.axhline(y=0, color="k", linestyle="--", alpha=0.3)
            plt.axhline(y=100, color="g", linestyle="--", alpha=0.3)
            plt.text(0.01, 100, "Full Recovery", color="green", ha="left", va="bottom")
            plt.text(0.01, -5, "Improvement", color="blue", ha="left", va="top")
            
            plt.title("Recovery Percentage by Pruning Level")
            plt.xlabel("Pruning Level")
            plt.ylabel("Recovery % (negative means improvement)")
            plt.legend(fontsize=8)
            plt.grid(True)
        else:
            plt.text(0.5, 0.5, "No recovery data available yet", 
                    ha="center", va="center", fontsize=12)
        
        # 3. Perplexity change: pruning vs fine-tuning effect
        plt.subplot(2, 2, 3)
        
        if "perplexity_change" in self.results_df.columns and "perplexity_change_from_pruned" in self.results_df.columns:
            # Get pruning change
            pruned_df = self.results_df[self.results_df["stage"] == "pruned"].copy()
            pruned_df = pruned_df[["model", "strategy", "pruning_level", "perplexity_change"]]
            
            # Get fine-tuning change
            finetuned_df = self.results_df[self.results_df["stage"] == "fine_tuned"].copy()
            finetuned_df = finetuned_df[["model", "strategy", "pruning_level", "perplexity_change_from_pruned"]]
            
            # Merge
            effects_df = pd.merge(
                pruned_df, finetuned_df,
                on=["model", "strategy", "pruning_level"],
                suffixes=("_pruning", "_finetuning")
            )
            
            if not effects_df.empty:
                # Plot scatter with size based on pruning level
                for strategy in strategies:
                    strategy_df = effects_df[effects_df["strategy"] == strategy]
                    if not strategy_df.empty:
                        for model in models:
                            model_df = strategy_df[strategy_df["model"] == model]
                            if not model_df.empty:
                                plt.scatter(
                                    model_df["perplexity_change"], 
                                    model_df["perplexity_change_from_pruned"],
                                    s=model_df["pruning_level"] * 500,  # Size based on pruning level
                                    label=f"{model}, {strategy}",
                                    alpha=0.7
                                )
                
                plt.axhline(y=0, color="k", linestyle="--", alpha=0.3)
                plt.axvline(x=0, color="k", linestyle="--", alpha=0.3)
                
                # Add quadrant labels
                plt.text(-5, -5, "Both improved", fontsize=10, ha="center", va="center",
                        bbox=dict(facecolor="lightgreen", alpha=0.5))
                plt.text(5, -5, "Pruning hurt,\nFine-tuning fixed", fontsize=10, ha="center", va="center",
                        bbox=dict(facecolor="lightblue", alpha=0.5))
                plt.text(-5, 5, "Pruning helped,\nFine-tuning hurt", fontsize=10, ha="center", va="center",
                        bbox=dict(facecolor="lightyellow", alpha=0.5))
                plt.text(5, 5, "Both hurt", fontsize=10, ha="center", va="center",
                        bbox=dict(facecolor="lightcoral", alpha=0.5))
                
                plt.title("Effect of Pruning vs. Fine-tuning")
                plt.xlabel("Perplexity Change from Pruning")
                plt.ylabel("Perplexity Change from Fine-tuning")
                plt.legend(fontsize=8)
                plt.grid(True)
            else:
                plt.text(0.5, 0.5, "No effect data available yet", 
                        ha="center", va="center", fontsize=12)
        else:
            plt.text(0.5, 0.5, "No effect data available yet", 
                    ha="center", va="center", fontsize=12)
        
        # 4. Final results: perplexity reduction by pruning level and strategy
        plt.subplot(2, 2, 4)
        
        if "perplexity_change_from_baseline" in self.results_df.columns:
            # Get baseline and final results
            baseline_df = self.results_df[self.results_df["stage"] == "baseline"].copy()
            baseline_df = baseline_df[["model", "strategy", "pruning_level", "perplexity"]]
            baseline_df = baseline_df.rename(columns={"perplexity": "baseline_perplexity"})
            
            final_df = self.results_df[self.results_df["stage"] == "fine_tuned"].copy()
            final_df = final_df[["model", "strategy", "pruning_level", "perplexity", "perplexity_change_from_baseline"]]
            final_df = final_df.rename(columns={"perplexity": "final_perplexity"})
            
            # Merge
            final_results = pd.merge(
                baseline_df, final_df,
                on=["model", "strategy", "pruning_level"]
            )
            
            if not final_results.empty:
                # Plot as bar chart
                # Group by pruning level and strategy
                grouped = final_results.groupby(["pruning_level", "strategy"])["perplexity_change_from_baseline"].mean().reset_index()
                
                # Pivot for grouped bar chart
                pivot_df = grouped.pivot(index="pruning_level", columns="strategy", values="perplexity_change_from_baseline")
                
                # Plot
                pivot_df.plot(kind="bar", ax=plt.gca())
                
                plt.axhline(y=0, color="k", linestyle="--", alpha=0.3)
                plt.title("Final Perplexity Change from Baseline")
                plt.xlabel("Pruning Level")
                plt.ylabel("Perplexity Change")
                plt.legend(title="Strategy")
                plt.grid(True, axis="y")
            else:
                plt.text(0.5, 0.5, "No final results available yet", 
                        ha="center", va="center", fontsize=12)
        else:
            plt.text(0.5, 0.5, "No final results available yet", 
                    ha="center", va="center", fontsize=12)
        
        plt.tight_layout()
        plt.show()
        
        return fig

# %%

In [None]:
# Initialize experiment with memory optimizations
experiment = PruningFineTuningExperiment(
    results_dir="pruning_finetuning_results",
    use_improved_fine_tuner=True,      # Use the improved fine-tuner with stability enhancements
    detect_environment=True,           # Automatically detect Colab environment
    optimize_memory=True,              # Optimize memory usage based on detected hardware
    batch_size=2,                      # Override batch size for fine-tuning
    sequence_length=64,                # Override sequence length for fine-tuning
    stability_level=2                  # Use enhanced stability measures (1-3, where 3 is max stability)
)

# %%

# ## Longer Overnight Run
# 
# For an extended overnight run, uncomment and run this cell:

# %%

# ## Optimization Recommendations
# 
# Based on our extensive profiling and research, we have the following recommendations:

# %%

# ## Comprehensive Analysis
# 
# After collecting results, run a comprehensive analysis:

# %%

In [None]:
# Create a summary table
if not experiment.results_df.empty:
    # Get data for different stages
    baseline_df = experiment.results_df[experiment.results_df["stage"] == "baseline"][["model", "strategy", "pruning_level", "perplexity"]]
    baseline_df = baseline_df.rename(columns={"perplexity": "baseline_perplexity"})
    
    pruned_df = experiment.results_df[experiment.results_df["stage"] == "pruned"][["model", "strategy", "pruning_level", "perplexity"]]
    pruned_df = pruned_df.rename(columns={"perplexity": "pruned_perplexity"})
    
    finetuned_df = experiment.results_df[experiment.results_df["stage"] == "fine_tuned"][["model", "strategy", "pruning_level", "perplexity"]]
    finetuned_df = finetuned_df.rename(columns={"perplexity": "finetuned_perplexity"})
    
    # Merge dataframes
    summary = pd.merge(baseline_df, pruned_df, on=["model", "strategy", "pruning_level"])
    summary = pd.merge(summary, finetuned_df, on=["model", "strategy", "pruning_level"])
    
    # Calculate changes
    summary["pruning_effect"] = summary["pruned_perplexity"] - summary["baseline_perplexity"]
    summary["finetuning_effect"] = summary["finetuned_perplexity"] - summary["pruned_perplexity"]
    summary["net_change"] = summary["finetuned_perplexity"] - summary["baseline_perplexity"]
    
    # Display summary
    summary.head()