# Window Pruning Debug Notebook

This notebook tests the window pruning approach step by step for debugging.

In [None]:
import os

# Device selection
os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
i = 0  # device number to use
os.environ["CUDA_VISIBLE_DEVICES"] = f'{i}'

import torch
import sys
sys.path.append('../..')

from src.pruninghealing import Trainer, DatasetLoader, WindowPruner
from src.pruninghealing.utils import load_model_and_tokenizer, calculate_perplexity, get_model_layers
from src.pruninghealing.logger import Logger

print(f'Using GPU device {i}: {torch.cuda.get_device_name(0)}' if torch.cuda.is_available() else 'Using CPU')

## Load Model and Dataset

In [None]:
# Configuration
MODEL_PATH = "../checkpoints/tinyllama"  # Change to your model
WORKSPACE = "../../workspace/window_debug"
WINDOW_SIZE = 3

# Load model and tokenizer
print("Loading model...")
model, tokenizer = load_model_and_tokenizer(MODEL_PATH)

print(f"Model loaded: {get_model_layers(model)} layers")
print(f"Model type: {model.config.model_type}")

In [None]:
# Load dataset
print("Loading dataset...")
dataset_loader = DatasetLoader(tokenizer)
dataset_loader.load_wikitext(train_size=500, eval_size=50)  # Small for debugging

print(f"Dataset loaded: {len(dataset_loader.train_dataset)} train, {len(dataset_loader.eval_dataset)} eval")

## Calculate Baseline Perplexity

In [None]:
# Calculate baseline perplexity
print("Calculating baseline perplexity...")
baseline_ppl = calculate_perplexity(model, tokenizer, max_samples=20)
print(f"Baseline perplexity: {baseline_ppl:.3f}")

## Initialize Components

In [None]:
# Initialize pruner, trainer, and logger
pruner = WindowPruner(model, tokenizer, WORKSPACE)
trainer = Trainer(model, tokenizer, WORKSPACE)
logger = Logger(WORKSPACE)

# Log baseline
logger.log_step({
    "step": 0,
    "action": "baseline",
    "layers_total": get_model_layers(model),
    "perplexity": baseline_ppl
})

print("Components initialized")

## Test Question Setup

In [None]:
# Fixed test question for quality evaluation
TEST_PROMPT = "What is the capital of France?"

def test_model_quality(model, tokenizer, prompt=TEST_PROMPT):
    """Test model response quality"""
    inputs = tokenizer(prompt, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=20,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

# Test baseline
print(f"Test prompt: {TEST_PROMPT}")
baseline_response = test_model_quality(model, tokenizer)
print(f"Baseline response: {baseline_response}")

## Find Unimportant Window

In [None]:
# Find least important window
print(f"Finding least important window of size {WINDOW_SIZE}...")
best_window, best_score = pruner.find_unimportant_window(WINDOW_SIZE)

print(f"Best window to remove: {best_window}")
print(f"Window score: {best_score}")

# Log window selection
logger.log_step({
    "step": 1,
    "action": "window_selection",
    "window_size": WINDOW_SIZE,
    "selected_window": best_window,
    "window_score": best_score
})

## Test Window Importance Evaluation

In [None]:
# Test different windows for comparison
num_layers = get_model_layers(model)
print(f"\nTesting all possible windows of size {WINDOW_SIZE}:")

window_scores = []
for start_idx in range(num_layers - WINDOW_SIZE + 1):
    window = list(range(start_idx, start_idx + WINDOW_SIZE))
    score = pruner._evaluate_window_importance(window)
    window_scores.append((window, score))
    print(f"Window {window}: score = {score}")

# Sort by score
window_scores.sort(key=lambda x: x[1])
print(f"\nBest (lowest score) window: {window_scores[0][0]} with score {window_scores[0][1]}")
print(f"Worst (highest score) window: {window_scores[-1][0]} with score {window_scores[-1][1]}")

## Prune Window

In [None]:
# Prune the selected window
print(f"\nPruning window {best_window}...")
layers_before = get_model_layers(model)

try:
    pruned_model = pruner.prune_window(best_window)
    layers_after = get_model_layers(pruned_model)
    
    print(f"Layers before pruning: {layers_before}")
    print(f"Layers after pruning: {layers_after}")
    print(f"Layers removed: {layers_before - layers_after}")
    
except Exception as e:
    print(f"Error during pruning: {e}")
    pruned_model = model

## Calculate Perplexity After Pruning

In [None]:
# Calculate perplexity after pruning
print("Calculating perplexity after pruning...")
ppl_after_prune = calculate_perplexity(pruned_model, tokenizer, max_samples=20)
response_after_prune = test_model_quality(pruned_model, tokenizer)

print(f"Perplexity after pruning: {ppl_after_prune:.3f}")
print(f"Perplexity change: {ppl_after_prune - baseline_ppl:.3f}")
print(f"Response after pruning: {response_after_prune}")

# Log pruning results
logger.log_step({
    "step": 2,
    "action": "window_pruning",
    "pruned_window": best_window,
    "layers_before": layers_before,
    "layers_after": layers_after,
    "ppl_before_prune": baseline_ppl,
    "ppl_after_prune": ppl_after_prune,
    "ppl_degradation": ppl_after_prune - baseline_ppl,
    "response_after_prune": response_after_prune
})

## Fine-tune Pruned Model

In [None]:
# Fine-tune the pruned model
print("\nFine-tuning pruned model...")
trainer.model = pruned_model
finetuned_model = trainer.train(dataset_loader, max_steps=200)  # Small for debugging

print("Fine-tuning completed")

## Calculate Final Perplexity

In [None]:
# Calculate final perplexity
print("Calculating final perplexity...")
final_ppl = calculate_perplexity(finetuned_model, tokenizer, max_samples=20)
final_response = test_model_quality(finetuned_model, tokenizer)

print(f"Final perplexity: {final_ppl:.3f}")
print(f"Final response: {final_response}")

# Calculate improvements
healing_improvement = ppl_after_prune - final_ppl
total_change = final_ppl - baseline_ppl

print(f"\n=== RESULTS ===")
print(f"Baseline: {baseline_ppl:.3f} | {baseline_response}")
print(f"After pruning: {ppl_after_prune:.3f} (change: {ppl_after_prune - baseline_ppl:+.3f}) | {response_after_prune}")
print(f"After fine-tuning: {final_ppl:.3f} (healing: {healing_improvement:+.3f}) | {final_response}")
print(f"Total change: {total_change:+.3f}")

# Log final results
logger.log_step({
    "step": 3,
    "action": "fine_tuning",
    "ppl_before_finetune": ppl_after_prune,
    "ppl_after_finetune": final_ppl,
    "healing_improvement": healing_improvement,
    "total_change": total_change,
    "final_response": final_response
})

## Results Analysis

In [None]:
# Show all logs
import pandas as pd
df = pd.DataFrame(logger.logs)
print("\n=== DETAILED LOGS ===")
print(df)

# Plot if possible
try:
    import matplotlib.pyplot as plt
    
    perplexities = [baseline_ppl, ppl_after_prune, final_ppl]
    stages = ['Baseline', 'After Pruning', 'After Fine-tuning']
    
    plt.figure(figsize=(10, 6))
    plt.plot(stages, perplexities, 'bo-', linewidth=2, markersize=8)
    plt.ylabel('Perplexity')
    plt.title('Window Pruning and Healing Process')
    plt.grid(True, alpha=0.3)
    
    # Add value labels
    for i, (stage, ppl) in enumerate(zip(stages, perplexities)):
        plt.annotate(f'{ppl:.3f}', (i, ppl), textcoords="offset points", 
                    xytext=(0,10), ha='center')
    
    plt.tight_layout()
    plt.show()
    
except ImportError:
    print("Matplotlib not available for plotting")