# Fine-tuning Pruned Models Demo

This notebook demonstrates how to fine-tune pruned models to recover accuracy while maintaining the speed benefits of pruning.

## Overview

1. Setup environment and load models
2. Prune a model with different strategies
3. Benchmark the pruned model (speed and accuracy)
4. Fine-tune the pruned model with per-head learning rates
5. Benchmark the fine-tuned model to show accuracy recovery
6. Visualize the improvements

## Setup

First, let's clone the repository and install dependencies:

In [None]:
!git clone https://github.com/CambrianTech/sentinel-ai.git
%cd sentinel-ai
!pip install -r requirements.txt
!pip install torch matplotlib pandas seaborn

In [None]:
import os
import sys
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, TensorDataset
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

# Add the project to path
sys.path.append('.')

from models.loaders.loader import load_baseline_model, load_adaptive_model
from datasets.dataset_loader import load_and_tokenize_dataset
from utils.train_utils import compute_loss, compute_perplexity
from utils.checkpoint import save_checkpoint, load_checkpoint
from utils.head_lr_manager import HeadLRManager
from utils.head_metrics import compute_attention_entropy, compute_head_importance
from utils.generation_wrapper import generate_text

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## Load Model

Let's load a small pre-trained model (distilgpt2) and adapt it with our architecture:

In [None]:
model_name = "distilgpt2"
print(f"Loading baseline model: {model_name}")
baseline_model = load_baseline_model(model_name, device)

print("Creating adaptive model...")
model = load_adaptive_model(model_name, baseline_model, device)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# Count initial parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

## Prepare Dataset

Load a small text dataset for training and evaluation:

In [None]:
# Load tiny_shakespeare dataset
dataset_name = "tiny_shakespeare"
max_length = 128  # Sequence length

print(f"Loading dataset: {dataset_name}")
train_ids, val_ids = load_and_tokenize_dataset(model_name, dataset_name, max_length)

# Create data loaders
batch_size = 8  # Increased batch size for faster training
train_dataset = TensorDataset(torch.tensor(train_ids))
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = TensorDataset(torch.tensor(val_ids))
val_loader = DataLoader(val_dataset, batch_size=batch_size)

print(f"Training examples: {len(train_dataset)}")
print(f"Validation examples: {len(val_dataset)}")

## Evaluate Original Model

Let's establish a baseline by evaluating the original model:

In [None]:
def evaluate_model_metrics(model, val_loader, device):
    """Evaluate a model's perplexity and generate sample text."""
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    # Calculate perplexity
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating", leave=False):
            input_ids = batch[0].to(device)
            targets = input_ids.clone()
            
            # Forward pass
            outputs = model(input_ids)
            
            # Compute loss
            loss = compute_loss(outputs, targets)
            
            # Update totals
            total_loss += loss.item() * input_ids.size(0)
            total_tokens += input_ids.size(0) * input_ids.size(1)
    
    # Calculate perplexity
    avg_loss = total_loss / total_tokens
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    
    # Generate sample text
    prompt = "The meaning of life is"
    start_time = time.time()
    output = generate_text(
        model, tokenizer, prompt,
        max_length=100,
        temperature=0.8,
        device=device
    )
    generation_time = time.time() - start_time
    tokens_generated = len(tokenizer.encode(output)) - len(tokenizer.encode(prompt))
    tokens_per_sec = tokens_generated / generation_time
    
    # Return metrics and sample
    return {
        "perplexity": perplexity,
        "tokens_per_sec": tokens_per_sec,
        "generation_time": generation_time,
        "generated_text": output
    }

# Evaluate original model
print("Evaluating original model...")
original_metrics = evaluate_model_metrics(model, val_loader, device)
print(f"\nOriginal model perplexity: {original_metrics['perplexity']:.2f}")
print(f"Original model speed: {original_metrics['tokens_per_sec']:.2f} tokens/sec")
print(f"\nSample text from original model:\n{original_metrics['generated_text']}")

## Compute Head Metrics

To decide which heads to prune, we need to compute importance metrics:

In [None]:
# Compute entropy-based metrics for heads
print("Computing entropy metrics...")
entropy_dict = compute_attention_entropy(model, device=device)

# Compute gradient-based importance
print("Computing importance metrics...")
importance_dict = compute_head_importance(model, val_loader, compute_loss, device=device)

# Visualize head metrics
def plot_head_metrics(entropy_dict, importance_dict):
    num_layers = len(entropy_dict)
    num_heads = len(entropy_dict[0])
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Prepare data for heatmaps
    entropy_data = np.zeros((num_layers, num_heads))
    importance_data = np.zeros((num_layers, num_heads))
    
    for layer_idx in range(num_layers):
        for head_idx in range(num_heads):
            entropy_data[layer_idx, head_idx] = entropy_dict[layer_idx][head_idx].item()
            importance_data[layer_idx, head_idx] = importance_dict[layer_idx][head_idx].item()
    
    # Plot entropy heatmap
    sns.heatmap(entropy_data, ax=ax1, cmap="viridis", annot=True, fmt=".2f")
    ax1.set_title("Attention Entropy by Head")
    ax1.set_xlabel("Head Index")
    ax1.set_ylabel("Layer Index")
    
    # Plot importance heatmap
    sns.heatmap(importance_data, ax=ax2, cmap="plasma", annot=True, fmt=".2f")
    ax2.set_title("Head Importance")
    ax2.set_xlabel("Head Index")
    ax2.set_ylabel("Layer Index")
    
    plt.tight_layout()
    plt.show()
    
    return entropy_data, importance_data

entropy_data, importance_data = plot_head_metrics(entropy_dict, importance_dict)

## Prune the Model

Now let's prune the model using entropy-based metrics:

In [None]:
def prune_model_by_metrics(model, prune_ratio=0.5, strategy="entropy"):
    """Prune a percentage of attention heads based on metrics."""
    num_layers = len(model.blocks)
    num_heads = model.blocks[0]["attn"].num_heads
    total_heads = num_layers * num_heads
    heads_to_prune = int(total_heads * prune_ratio)
    
    print(f"Pruning {heads_to_prune} of {total_heads} heads ({prune_ratio*100:.1f}%)")
    
    # Flatten metrics for all heads
    all_heads = []
    for layer_idx in range(num_layers):
        for head_idx in range(num_heads):
            # Choose metric based on strategy
            if strategy == "entropy":
                # Lower entropy is less informative -> prune
                metric = entropy_dict[layer_idx][head_idx].item()
            elif strategy == "importance":
                # Lower importance -> prune
                metric = importance_dict[layer_idx][head_idx].item()
            elif strategy == "random":
                metric = np.random.random()
            else:
                raise ValueError(f"Unknown pruning strategy: {strategy}")
            
            all_heads.append((layer_idx, head_idx, metric))
    
    # Sort heads by metric (lower values first for pruning)
    if strategy == "entropy":
        # For entropy, we prune heads with lowest entropy (least informative)
        sorted_heads = sorted(all_heads, key=lambda x: x[2])
    elif strategy == "importance":
        # For importance, we prune heads with lowest importance
        sorted_heads = sorted(all_heads, key=lambda x: x[2])
    else:
        # For random, we just use the random values
        sorted_heads = sorted(all_heads, key=lambda x: x[2])
    
    # Heads to prune
    heads_to_prune = sorted_heads[:heads_to_prune]
    
    # Actually prune the heads
    with torch.no_grad():
        for layer_idx, head_idx, _ in heads_to_prune:
            # Set gate value to 0 (pruned)
            model.blocks[layer_idx]["attn"].gate[head_idx] = 0.0
            print(f"Pruned layer {layer_idx}, head {head_idx}")
    
    # Count active heads after pruning
    active_heads = 0
    for layer_idx in range(num_layers):
        for head_idx in range(num_heads):
            if model.blocks[layer_idx]["attn"].gate[head_idx] > 0.01:
                active_heads += 1
    
    print(f"Active heads after pruning: {active_heads}/{total_heads} ({active_heads/total_heads*100:.1f}%)")
    return model

# Define pruning parameters
pruning_ratio = 0.5  # Prune 50% of heads
pruning_strategy = "entropy"  # Options: "entropy", "importance", "random"

# Create a copy of the model for pruning
pruned_model = load_adaptive_model(model_name, baseline_model, device)
pruned_model.load_state_dict(model.state_dict())

# Apply pruning
pruned_model = prune_model_by_metrics(pruned_model, pruning_ratio, pruning_strategy)

# Save the pruned model
checkpoint_dir = "./checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
pruned_checkpoint_path = os.path.join(checkpoint_dir, "pruned_model.pth")

optimizer = torch.optim.AdamW(pruned_model.parameters())
save_checkpoint(pruned_checkpoint_path, pruned_model, optimizer, {}, 0, 0)
print(f"Saved pruned model to {pruned_checkpoint_path}")

## Evaluate Pruned Model

Let's see how pruning affects performance:

In [None]:
# Evaluate pruned model
print("Evaluating pruned model...")
pruned_metrics = evaluate_model_metrics(pruned_model, val_loader, device)

print(f"\nPruned model perplexity: {pruned_metrics['perplexity']:.2f} (Original: {original_metrics['perplexity']:.2f})")
print(f"Pruned model speed: {pruned_metrics['tokens_per_sec']:.2f} tokens/sec (Original: {original_metrics['tokens_per_sec']:.2f})")
print(f"Speed improvement: {pruned_metrics['tokens_per_sec'] / original_metrics['tokens_per_sec']:.2f}x")

print(f"\nSample text from pruned model:\n{pruned_metrics['generated_text']}")

## Fine-tune the Pruned Model

Now, let's fine-tune the pruned model to recover accuracy. We'll use an aggressive training approach to show improvement quickly:

In [None]:
# Import the fine-tuning function
from scripts.finetune_pruned_model import create_optimizer_with_head_params

# Fine-tuning parameters - more aggressive for faster convergence
epochs = 4  # More epochs to ensure meaningful recovery
learning_rate = 1e-4  # Higher learning rate for faster convergence 
boost_factor = 10.0  # Much higher boost for active heads
warmup_steps = 50  # Shorter warmup for faster initial progress
cooldown_steps = 300  # Shorter cooldown cycle for more high-LR training
eval_interval = 50  # Evaluate more frequently to track progress

# Create optimizer with per-head parameters
optimizer = create_optimizer_with_head_params(pruned_model, learning_rate)

# Create HeadLRManager
head_lr_manager = HeadLRManager(
    model=pruned_model,
    optimizer=optimizer,
    base_lr=learning_rate,
    boost_factor=boost_factor,
    decay_factor=0.9,
    warmup_steps=warmup_steps,
    cooldown_steps=cooldown_steps
)

# Create dummy gate status to initialize head status
with torch.no_grad():
    dummy_gates = torch.zeros((len(pruned_model.blocks), pruned_model.blocks[0]["attn"].num_heads))
    
    # Set gates based on actual model gates
    for layer_idx in range(len(pruned_model.blocks)):
        for head_idx in range(pruned_model.blocks[0]["attn"].num_heads):
            gate_value = pruned_model.blocks[layer_idx]["attn"].gate[head_idx].item()
            dummy_gates[layer_idx, head_idx] = 1.0 if gate_value > 0.01 else 0.0

# Initialize head status with current gates
head_lr_manager.update_head_status(dummy_gates)

# Create learning rate scheduler - use OneCycleLR for faster convergence
from torch.optim.lr_scheduler import OneCycleLR
steps_per_epoch = len(train_loader)
lr_scheduler = OneCycleLR(
    optimizer,
    max_lr=learning_rate * 3,  # Peak learning rate 3x the base rate
    total_steps=epochs * steps_per_epoch,
    pct_start=0.3,  # Spend 30% of steps in warmup
    div_factor=25.0,  # Initial LR is max_lr/25
    final_div_factor=10000.0  # Final LR is max_lr/10000
)

# Training loop
pruned_model.train()
step = 0
history = {
    "loss": [],
    "perplexity": [],
    "lr": [],
    "eval_steps": []
}

print(f"Fine-tuning pruned model for {epochs} epochs...")
for epoch in range(epochs):
    epoch_loss = 0.0
    
    # Training
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
    for batch_idx, batch in enumerate(progress_bar):
        step += 1
        
        # Get batch
        input_ids = batch[0].to(device)
        targets = input_ids.clone()
        
        # Forward pass
        outputs = pruned_model(input_ids)
        
        # Compute loss
        loss = compute_loss(outputs, targets)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update head learning rates every 5 steps
        if step % 5 == 0:
            lr_info = head_lr_manager.update_learning_rates()
        
        # Update learning rate scheduler
        lr_scheduler.step()
        
        # Update progress
        epoch_loss += loss.item()
        progress_bar.set_postfix({
            "loss": loss.item(),
            "avg_loss": epoch_loss / (batch_idx + 1),
            "lr": optimizer.param_groups[0]["lr"]
        })
        
        # Save metrics
        history["loss"].append(loss.item())
        history["lr"].append(optimizer.param_groups[0]["lr"])
        
        # Evaluate periodically
        if step % eval_interval == 0 or batch_idx == len(train_loader) - 1:
            pruned_model.eval()
            val_loss = 0.0
            val_tokens = 0
            
            with torch.no_grad():
                # Only evaluate on a subset of validation data to save time
                eval_subset = list(val_loader)[:10]  # Just use first 10 batches
                for val_batch in tqdm(eval_subset, desc="Evaluating", leave=False):
                    val_input_ids = val_batch[0].to(device)
                    val_targets = val_input_ids.clone()
                    
                    # Forward pass
                    val_outputs = pruned_model(val_input_ids)
                    
                    # Compute loss
                    val_loss_batch = compute_loss(val_outputs, val_targets)
                    
                    # Update totals
                    val_loss += val_loss_batch.item() * val_input_ids.size(0)
                    val_tokens += val_input_ids.size(0) * val_input_ids.size(1)
            
            # Calculate perplexity
            val_avg_loss = val_loss / val_tokens
            val_perplexity = torch.exp(torch.tensor(val_avg_loss)).item()
            
            print(f"\nStep {step}, Validation perplexity: {val_perplexity:.2f}")
            history["perplexity"].append(val_perplexity)
            history["eval_steps"].append(step)
            
            # Return to training mode
            pruned_model.train()
            
            # Early stopping check - if perplexity is at or better than original model, we can stop
            if val_perplexity <= original_metrics["perplexity"] * 1.05:  # Within 5% of original
                print(f"Perplexity recovered to within 5% of original model! Early stopping.")
                break
    
    # End of epoch
    print(f"Epoch {epoch+1}/{epochs} completed. Avg loss: {epoch_loss / len(train_loader):.4f}")
    
    # Early stopping check at epoch level
    if len(history["perplexity"]) > 0 and history["perplexity"][-1] <= original_metrics["perplexity"] * 1.05:
        print("Early stopping after sufficient accuracy recovery.")
        break

# Save fine-tuned model
finetuned_checkpoint_path = os.path.join(checkpoint_dir, "finetuned_model.pth")
save_checkpoint(
    finetuned_checkpoint_path,
    pruned_model,
    optimizer,
    head_lr_manager.save_state_dict(),
    epochs,
    step
)
print(f"Saved fine-tuned model to {finetuned_checkpoint_path}")

## Visualize Learning Progress

Let's plot the learning curve to see how fine-tuning improves perplexity:

In [None]:
# Plot learning curves
plt.figure(figsize=(12, 5))

# Plot perplexity
plt.subplot(1, 2, 1)
plt.plot(history["eval_steps"], history["perplexity"], marker='o')
plt.axhline(y=original_metrics["perplexity"], color='r', linestyle='--', label="Original Model")
plt.axhline(y=pruned_metrics["perplexity"], color='g', linestyle='--', label="Pruned Model")
plt.xlabel("Step")
plt.ylabel("Perplexity (lower is better)")
plt.title("Perplexity During Fine-tuning")
plt.legend()

# Plot loss
plt.subplot(1, 2, 2)
plt.plot(history["loss"], alpha=0.3, label="Raw")
window_size = 20
smoothed_loss = [sum(history["loss"][max(0, i-window_size):i]) / min(i, window_size) for i in range(1, len(history["loss"])+1)]
plt.plot(smoothed_loss, label="Smoothed")
plt.xlabel("Step")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.legend()

plt.tight_layout()
plt.show()

## Evaluate Fine-tuned Model

Now let's see how much accuracy we've recovered:

In [None]:
# Evaluate fine-tuned model
print("Evaluating fine-tuned model...")
finetuned_metrics = evaluate_model_metrics(pruned_model, val_loader, device)

print(f"\nFine-tuned model perplexity: {finetuned_metrics['perplexity']:.2f}")
print(f"Fine-tuned model speed: {finetuned_metrics['tokens_per_sec']:.2f} tokens/sec")

print(f"\nSample text from fine-tuned model:\n{finetuned_metrics['generated_text']}")

# Compare all models
models = ["Original", "Pruned", "Fine-tuned"]
perplexities = [
    original_metrics["perplexity"],
    pruned_metrics["perplexity"],
    finetuned_metrics["perplexity"]
]
speeds = [
    original_metrics["tokens_per_sec"],
    pruned_metrics["tokens_per_sec"],
    finetuned_metrics["tokens_per_sec"]
]

# Calculate relative improvements
perplexity_change_pruned = (pruned_metrics["perplexity"] - original_metrics["perplexity"]) / original_metrics["perplexity"] * 100
perplexity_change_finetuned = (finetuned_metrics["perplexity"] - original_metrics["perplexity"]) / original_metrics["perplexity"] * 100
perplexity_recovery = (pruned_metrics["perplexity"] - finetuned_metrics["perplexity"]) / (pruned_metrics["perplexity"] - original_metrics["perplexity"]) * 100 if pruned_metrics["perplexity"] != original_metrics["perplexity"] else 0

speed_change_pruned = (pruned_metrics["tokens_per_sec"] - original_metrics["tokens_per_sec"]) / original_metrics["tokens_per_sec"] * 100
speed_change_finetuned = (finetuned_metrics["tokens_per_sec"] - original_metrics["tokens_per_sec"]) / original_metrics["tokens_per_sec"] * 100

print("\nModel Comparison:")
print(f"Pruned model perplexity change: {perplexity_change_pruned:.1f}% (higher is worse)")
print(f"Fine-tuned model perplexity change: {perplexity_change_finetuned:.1f}% (higher is worse)")
print(f"Accuracy recovery: {perplexity_recovery:.1f}% of the gap closed")
print(f"\nPruned model speed change: +{speed_change_pruned:.1f}%")
print(f"Fine-tuned model speed change: +{speed_change_finetuned:.1f}%")

## Results Visualization

Let's create some visualizations to summarize our results:

In [None]:
# Create a comparison plot
plt.figure(figsize=(15, 6))

# Perplexity comparison
plt.subplot(1, 2, 1)
colors = ['blue', 'red', 'green']
plt.bar(models, perplexities, color=colors)
plt.ylabel("Perplexity (lower is better)")
plt.title("Model Performance Comparison")
for i, perplexity in enumerate(perplexities):
    plt.text(i, perplexity + 1, f"{perplexity:.2f}", ha='center')

# Speed comparison
plt.subplot(1, 2, 2)
plt.bar(models, speeds, color=colors)
plt.ylabel("Generation Speed (tokens/sec)")
plt.title("Model Speed Comparison")
for i, speed in enumerate(speeds):
    plt.text(i, speed + 0.5, f"{speed:.2f}", ha='center')

plt.tight_layout()
plt.show()

# Create a head usage visualization
def plot_head_usage(model):
    num_layers = len(model.blocks)
    num_heads = model.blocks[0]["attn"].num_heads
    
    # Get gate values
    gate_values = np.zeros((num_layers, num_heads))
    for layer_idx in range(num_layers):
        for head_idx in range(num_heads):
            gate_values[layer_idx, head_idx] = model.blocks[layer_idx]["attn"].gate[head_idx].item()
    
    # Plot heatmap
    plt.figure(figsize=(10, 6))
    sns.heatmap(gate_values, cmap="RdYlGn", annot=True, fmt=".2f")
    plt.title("Attention Head Gate Values (0 = pruned)")
    plt.xlabel("Head Index")
    plt.ylabel("Layer Index")
    plt.show()
    
    # Count active heads
    active_heads = np.sum(gate_values > 0.01)
    total_heads = num_layers * num_heads
    print(f"Active heads: {active_heads}/{total_heads} ({active_heads/total_heads*100:.1f}%)")

print("\nFine-tuned model head usage:")
plot_head_usage(pruned_model)

## Side-by-Side Text Generation Comparison

Let's compare text generated by all three models with the same prompts:

In [None]:
# Compare generation results
prompts = [
    "The meaning of life is",
    "In a world where technology",
    "Once upon a time in a distant kingdom"
]

# Reload original model
original_model = load_adaptive_model(model_name, baseline_model, device)

# Save pruned outputs separately for later comparison
pruned_outputs = []
for i, prompt in enumerate(prompts):
    pruned_output = generate_text(pruned_model, tokenizer, prompt, max_length=100, temperature=0.8, device=device)
    pruned_outputs.append(pruned_output)

# Now compare side by side
for i, prompt in enumerate(prompts):
    print(f"\nPrompt {i+1}: '{prompt}'\n")
    
    # Generate text with original model
    original_text = generate_text(original_model, tokenizer, prompt, max_length=100, temperature=0.8, device=device)
    
    # Retrieve previously generated pruned text
    pruned_text = pruned_outputs[i]
    
    # Generate text with fine-tuned model
    finetuned_text = generate_text(pruned_model, tokenizer, prompt, max_length=100, temperature=0.8, device=device)
    
    # Print side by side
    print("Original model:")
    print(original_text)
    print("\nPruned model (before fine-tuning):")
    print(pruned_text)
    print("\nFine-tuned model:")
    print(finetuned_text)
    print("\n" + "-"*80)

## Conclusion

This notebook demonstrated how to fine-tune pruned transformer models to recover accuracy while maintaining the speed benefits. We showed that:

1. Pruning models can significantly improve inference speed (typically 1.5-2x faster)
2. Pruning initially degrades accuracy (higher perplexity)
3. With specialized fine-tuning using per-head learning rates, we can recover much of the lost accuracy
4. The fine-tuned model retains the speed benefits while producing higher quality outputs

This approach is particularly important for deployment to resource-constrained environments where both speed and quality matter.

Key optimizations in this implementation:
1. Higher boost factor (10.0) for active heads to accelerate learning
2. OneCycleLR scheduler for faster convergence
3. Early stopping when quality is sufficiently recovered
4. Aggressive parameter tuning for fast demonstration

In production settings, you might want to use more epochs and a less aggressive learning rate schedule for maximum quality recovery, especially for higher pruning ratios.