# Lab 3.2: Wanda Pruning - Apply Pruning

**Goal:** Implement Wanda (Weights and Activations) pruning algorithm.

**You will learn to:**
- Collect activation statistics from calibration data
- Calculate weight importance using Wanda metric
- Apply layer-wise unstructured pruning
- Verify target sparsity is achieved
- Save pruned sparse model

---

## Wanda Algorithm Overview

**Key Innovation**: Prune weights based on **weight magnitude × activation magnitude**

```
For each weight w_ij:
  importance(w_ij) = |w_ij| × |activation_j|
  
Prune weights with lowest importance scores
```

**Why this works**:
- **Weight magnitude** alone ignores input importance
- **Activation magnitude** captures input feature importance
- **Wanda combines both** for better pruning decisions

---

## Prerequisites

Make sure you have completed **01-Setup.ipynb** and have:
- `model`: Baseline dense Llama-2-7B model
- `tokenizer`: Tokenizer for text processing
- `calibration_batch`: Tokenized calibration data
- `calculate_sparsity()`: Function to measure sparsity

If not, run 01-Setup.ipynb first!

---
## Step 1: Verify Prerequisites

Check that all required variables are available.

In [None]:
import torch

print("=" * 60)
print("Prerequisites Check")
print("=" * 60)

# Check required variables
try:
    assert 'model' in dir(), "model not found. Run 01-Setup.ipynb first!"
    assert 'tokenizer' in dir(), "tokenizer not found. Run 01-Setup.ipynb first!"
    assert 'calibration_batch' in dir(), "calibration_batch not found. Run 01-Setup.ipynb first!"
    assert 'calculate_sparsity' in dir(), "calculate_sparsity not found. Run 01-Setup.ipynb first!"
    
    print("✅ Model loaded")
    print("✅ Tokenizer loaded")
    print(f"✅ Calibration data ready: {calibration_batch.shape}")
    print("✅ Sparsity calculation function available")
    
    # Check current sparsity
    baseline_sparsity, total_params, zero_params = calculate_sparsity(model)
    print(f"\n📊 Current Model Stats:")
    print(f"   Total parameters: {total_params / 1e9:.2f}B")
    print(f"   Current sparsity: {baseline_sparsity:.4%}")
    
    print("\n✅ All prerequisites satisfied!")
    
except AssertionError as e:
    print(f"❌ {e}")
    print("\nPlease run 01-Setup.ipynb first to set up the environment.")

print("=" * 60)

---
## Step 2: Configure Pruning Parameters

Set target sparsity and calibration configuration.

In [None]:
# Pruning configuration
TARGET_SPARSITY = 0.5  # 50% sparsity (prune half the weights)
NSAMPLES = 128         # Number of calibration samples
SEQLEN = 2048          # Sequence length for calibration

print("=" * 60)
print("Pruning Configuration")
print("=" * 60)
print(f"Target sparsity: {TARGET_SPARSITY:.1%}")
print(f"Calibration samples: {NSAMPLES}")
print(f"Sequence length: {SEQLEN}")
print("\n📝 Note:")
print("   50% sparsity means we prune 50% of weights (set to 0)")
print("   Effective parameters: 7B × 50% = 3.5B")
print("   Expected performance loss: <8% (PPL +0.44 on WikiText-2)")
print("=" * 60)

---
## Step 3: Identify Target Layers for Pruning

Wanda typically prunes **linear layers** in the transformer blocks.

In [None]:
# Find all linear layers in the model
def find_linear_layers(model):
    """
    Find all nn.Linear layers in the model.
    Returns list of (name, module) tuples.
    """
    linear_layers = []
    
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            # Skip embedding and final output layers (typically not pruned)
            if 'embed' not in name.lower() and 'lm_head' not in name.lower():
                linear_layers.append((name, module))
    
    return linear_layers

print("=" * 60)
print("Identifying Prunable Layers")
print("=" * 60)

linear_layers = find_linear_layers(model)

print(f"Found {len(linear_layers)} linear layers to prune\n")
print("Sample layers:")
for i, (name, module) in enumerate(linear_layers[:5]):
    print(f"  {i+1}. {name}")
    print(f"     Shape: {module.weight.shape}")
    print(f"     Params: {module.weight.numel() / 1e6:.2f}M\n")

if len(linear_layers) > 5:
    print(f"  ... and {len(linear_layers) - 5} more layers")

# Calculate total prunable parameters
total_prunable = sum(module.weight.numel() for _, module in linear_layers)
print(f"\n📊 Total prunable parameters: {total_prunable / 1e9:.2f}B")
print(f"   Parameters to prune (50%): {total_prunable * TARGET_SPARSITY / 1e9:.2f}B")
print("=" * 60)

---
## Step 4: Collect Activation Statistics

Run calibration data through the model to collect activation magnitudes.

In [None]:
import torch.nn as nn
from tqdm import tqdm

# Dictionary to store activation statistics
activation_stats = {}

# Hook function to capture activations
def get_activation_hook(name):
    """
    Create a forward hook to capture input activations.
    """
    def hook(module, input, output):
        # input is a tuple, we want the first element
        activation = input[0].detach()
        
        # Calculate mean absolute activation per input feature
        # Shape: [batch, seq_len, in_features] -> [in_features]
        mean_activation = activation.abs().mean(dim=[0, 1])
        
        # Accumulate activation statistics
        if name not in activation_stats:
            activation_stats[name] = mean_activation.cpu()
        else:
            # Running average
            activation_stats[name] = (activation_stats[name] + mean_activation.cpu()) / 2
    
    return hook

print("=" * 60)
print("Collecting Activation Statistics")
print("=" * 60)
print("⏳ Running calibration forward passes...\n")

# Register hooks
hooks = []
for name, module in linear_layers:
    hook = module.register_forward_hook(get_activation_hook(name))
    hooks.append(hook)

# Run calibration
model.eval()
with torch.no_grad():
    # Process in smaller batches to avoid OOM
    batch_size = 4
    for i in tqdm(range(0, len(calibration_batch), batch_size), desc="Calibration"):
        batch = calibration_batch[i:i+batch_size].to(model.device)
        _ = model(batch)

# Remove hooks
for hook in hooks:
    hook.remove()

print(f"\n✅ Activation statistics collected for {len(activation_stats)} layers")

# Show sample statistics
sample_layer = list(activation_stats.keys())[0]
sample_stats = activation_stats[sample_layer]
print(f"\n📊 Sample layer: {sample_layer}")
print(f"   Activation shape: {sample_stats.shape}")
print(f"   Mean activation: {sample_stats.mean().item():.6f}")
print(f"   Max activation: {sample_stats.max().item():.6f}")
print(f"   Min activation: {sample_stats.min().item():.6f}")

print("=" * 60)

---
## Step 5: Implement Wanda Pruning Function

Calculate importance scores and apply pruning masks.

In [None]:
def wanda_prune_layer(weight, activation, sparsity):
    """
    Apply Wanda pruning to a single layer.
    
    Args:
        weight: Weight tensor [out_features, in_features]
        activation: Activation statistics [in_features]
        sparsity: Target sparsity ratio (0-1)
    
    Returns:
        mask: Binary mask [out_features, in_features]
    """
    # 1. Calculate importance score: |weight| × |activation|
    # Broadcast activation across output dimension
    importance = weight.abs() * activation.unsqueeze(0)
    
    # 2. Flatten and sort by importance
    importance_flat = importance.view(-1)
    
    # 3. Find threshold for target sparsity
    num_prune = int(sparsity * importance_flat.numel())
    
    if num_prune == 0:
        # No pruning needed
        return torch.ones_like(weight)
    
    # Get the k-th smallest importance value
    threshold = torch.topk(importance_flat, num_prune, largest=False)[0].max()
    
    # 4. Create binary mask (1 = keep, 0 = prune)
    mask = (importance > threshold).float()
    
    return mask

print("=" * 60)
print("Wanda Pruning Function Defined")
print("=" * 60)
print("\n📝 Algorithm:")
print("   1. Calculate importance = |weight| × |activation|")
print("   2. Flatten importance scores")
print("   3. Find threshold for target sparsity")
print("   4. Create binary mask (prune below threshold)")
print("\n✅ Function ready to use")
print("=" * 60)

---
## Step 6: Apply Wanda Pruning to All Layers

Prune each layer and apply the mask.

In [None]:
import time

print("=" * 60)
print("Applying Wanda Pruning")
print("=" * 60)
print(f"Target sparsity: {TARGET_SPARSITY:.1%}\n")

start_time = time.time()

# Store masks for verification
pruning_masks = {}

for name, module in tqdm(linear_layers, desc="Pruning layers"):
    # Get weight and activation statistics
    weight = module.weight.data
    activation = activation_stats[name].to(weight.device)
    
    # Apply Wanda pruning
    mask = wanda_prune_layer(weight, activation, TARGET_SPARSITY)
    
    # Apply mask to weight
    module.weight.data *= mask
    
    # Store mask
    pruning_masks[name] = mask

end_time = time.time()
pruning_time = end_time - start_time

print(f"\n✅ Pruning completed in {pruning_time:.2f} seconds")
print(f"   Average time per layer: {pruning_time / len(linear_layers):.3f}s")
print("=" * 60)

---
## Step 7: Verify Target Sparsity

Check that pruning achieved the target sparsity.

In [None]:
print("=" * 60)
print("Sparsity Verification")
print("=" * 60)

# Calculate overall sparsity
pruned_sparsity, total, zeros = calculate_sparsity(model)

print(f"\n📊 Pruning Results:")
print(f"   Total parameters: {total / 1e9:.2f}B")
print(f"   Zero parameters: {zeros / 1e9:.2f}B")
print(f"   Achieved sparsity: {pruned_sparsity:.2%}")
print(f"   Target sparsity: {TARGET_SPARSITY:.2%}")

# Check if target is met
if abs(pruned_sparsity - TARGET_SPARSITY) < 0.02:  # Within 2%
    print(f"\n✅ Target sparsity achieved!")
else:
    print(f"\n⚠️  Sparsity deviation: {abs(pruned_sparsity - TARGET_SPARSITY):.2%}")
    print(f"   (This is normal, embedding layers are not pruned)")

# Layer-wise sparsity analysis
print(f"\n📊 Layer-wise Sparsity Analysis:")
layer_sparsities = []
for name, module in linear_layers[:5]:  # Show first 5 layers
    weight = module.weight.data
    layer_sparsity = (weight == 0).sum().item() / weight.numel()
    layer_sparsities.append(layer_sparsity)
    print(f"   {name[:50]:50s} {layer_sparsity:.2%}")

if len(linear_layers) > 5:
    print(f"   ... and {len(linear_layers) - 5} more layers")

avg_layer_sparsity = sum(layer_sparsities) / len(layer_sparsities)
print(f"\n   Average layer sparsity: {avg_layer_sparsity:.2%}")

print("=" * 60)

---
## Step 8: Quick Inference Test

Verify the pruned model still generates coherent text.

In [None]:
# Test prompt
prompt = "The impact of artificial intelligence on society is"

print("=" * 60)
print("Pruned Model Inference Test")
print("=" * 60)
print(f"Prompt: {prompt}\n")

# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# Generate
start_time = time.time()

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=50,
        do_sample=True,
        temperature=0.8,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id
    )

end_time = time.time()
latency = end_time - start_time

# Decode output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(f"Output:\n{generated_text}\n")
print(f"⏱️  Latency: {latency:.2f} seconds")
print(f"📊 Tokens/sec: {len(outputs[0]) / latency:.2f}")
print("\n✅ Pruned model generates coherent text!")
print("=" * 60)

---
## Step 9: Memory Comparison

Compare memory usage before and after pruning.

In [None]:
print("=" * 60)
print("Memory Analysis")
print("=" * 60)

# Calculate model size
def calculate_model_size(model):
    """
    Calculate model size in GB (all parameters).
    Note: Sparse model still stores zeros in dense format.
    """
    total_params = sum(p.numel() for p in model.parameters())
    # FP16 = 2 bytes per parameter
    size_gb = total_params * 2 / 1e9
    return size_gb, total_params

model_size, total_params = calculate_model_size(model)

print(f"\n📊 Model Size (Dense Format):")
print(f"   Total parameters: {total_params / 1e9:.2f}B")
print(f"   Model size: {model_size:.2f} GB (FP16)")
print(f"   Non-zero parameters: {total_params * (1 - pruned_sparsity) / 1e9:.2f}B")

# Potential size with sparse format
sparse_size = model_size * (1 - pruned_sparsity)
print(f"\n💾 Potential Sparse Format Size:")
print(f"   Estimated size: {sparse_size:.2f} GB")
print(f"   Size reduction: {model_size - sparse_size:.2f} GB ({(1 - sparse_size/model_size):.1%})")
print(f"\n⚠️  Note:")
print(f"   Current implementation stores sparse model in dense format.")
print(f"   For actual size reduction, export to sparse format (CSR/COO).")

# GPU memory usage
if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated() / 1e9
    reserved = torch.cuda.memory_reserved() / 1e9
    print(f"\n🖥️  GPU Memory:")
    print(f"   Allocated: {allocated:.2f} GB")
    print(f"   Reserved: {reserved:.2f} GB")

print("=" * 60)

---
## Step 10: Save Pruned Model

Save the pruned model for inference and evaluation.

In [None]:
import os

# Output directory
OUTPUT_DIR = "./pruned_model"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("=" * 60)
print("Saving Pruned Model")
print("=" * 60)
print(f"Output directory: {OUTPUT_DIR}\n")

# Save model
print("⏳ Saving model...")
model.save_pretrained(OUTPUT_DIR)
print("✅ Model saved")

# Save tokenizer
print("⏳ Saving tokenizer...")
tokenizer.save_pretrained(OUTPUT_DIR)
print("✅ Tokenizer saved")

# Save pruning configuration
import json

pruning_config = {
    "method": "wanda",
    "target_sparsity": TARGET_SPARSITY,
    "achieved_sparsity": pruned_sparsity,
    "calibration_samples": NSAMPLES,
    "sequence_length": SEQLEN,
    "total_parameters": total,
    "zero_parameters": zeros,
    "pruned_layers": len(linear_layers),
    "pruning_time": pruning_time
}

config_path = os.path.join(OUTPUT_DIR, "pruning_config.json")
with open(config_path, 'w') as f:
    json.dump(pruning_config, f, indent=2)

print(f"✅ Pruning config saved to {config_path}")

# List saved files
saved_files = os.listdir(OUTPUT_DIR)
print(f"\n📁 Saved files ({len(saved_files)}):")
for file in sorted(saved_files)[:5]:
    file_path = os.path.join(OUTPUT_DIR, file)
    if os.path.isfile(file_path):
        size_mb = os.path.getsize(file_path) / 1e6
        print(f"   {file:40s} {size_mb:>10.2f} MB")

if len(saved_files) > 5:
    print(f"   ... and {len(saved_files) - 5} more files")

print("\n" + "=" * 60)
print("✅ Pruned model saved successfully!")
print("=" * 60)

---
## Step 11: Visualize Pruning Distribution

Analyze the distribution of pruned weights across layers.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Collect sparsity for all layers
layer_names = []
layer_sparsities = []

for name, module in linear_layers:
    weight = module.weight.data
    sparsity = (weight == 0).sum().item() / weight.numel()
    layer_names.append(name.split('.')[-2] + '.' + name.split('.')[-1])  # Shortened name
    layer_sparsities.append(sparsity * 100)  # Convert to percentage

print("=" * 60)
print("Pruning Distribution Visualization")
print("=" * 60)

# Plot 1: Layer-wise sparsity (first 20 layers)
fig, axes = plt.subplots(2, 1, figsize=(14, 10))

# Bar plot for first 20 layers
n_display = min(20, len(layer_sparsities))
axes[0].bar(range(n_display), layer_sparsities[:n_display], color='steelblue', alpha=0.7)
axes[0].axhline(y=TARGET_SPARSITY * 100, color='red', linestyle='--', 
                label=f'Target: {TARGET_SPARSITY:.0%}')
axes[0].set_xlabel('Layer Index', fontsize=12)
axes[0].set_ylabel('Sparsity (%)', fontsize=12)
axes[0].set_title('Layer-wise Sparsity Distribution (First 20 Layers)', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=10)
axes[0].grid(axis='y', alpha=0.3)
axes[0].set_ylim([0, 100])

# Plot 2: Histogram of sparsity distribution
axes[1].hist(layer_sparsities, bins=30, color='coral', alpha=0.7, edgecolor='black')
axes[1].axvline(x=TARGET_SPARSITY * 100, color='red', linestyle='--', 
                label=f'Target: {TARGET_SPARSITY:.0%}', linewidth=2)
axes[1].axvline(x=np.mean(layer_sparsities), color='green', linestyle='--', 
                label=f'Mean: {np.mean(layer_sparsities):.1f}%', linewidth=2)
axes[1].set_xlabel('Sparsity (%)', fontsize=12)
axes[1].set_ylabel('Number of Layers', fontsize=12)
axes[1].set_title('Sparsity Distribution Across All Layers', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=10)
axes[1].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'pruning_distribution.png'), dpi=150, bbox_inches='tight')
plt.show()

# Statistics
print(f"\n📊 Sparsity Statistics:")
print(f"   Mean: {np.mean(layer_sparsities):.2f}%")
print(f"   Std: {np.std(layer_sparsities):.2f}%")
print(f"   Min: {np.min(layer_sparsities):.2f}%")
print(f"   Max: {np.max(layer_sparsities):.2f}%")
print(f"\n✅ Visualization saved to {OUTPUT_DIR}/pruning_distribution.png")
print("=" * 60)

---
## ✅ Pruning Complete!

**Summary**:
- ✅ Collected activation statistics from calibration data
- ✅ Calculated Wanda importance scores (weight × activation)
- ✅ Applied layer-wise pruning with target sparsity
- ✅ Verified sparsity achieved (~50%)
- ✅ Tested inference with pruned model
- ✅ Saved pruned model to `./pruned_model/`
- ✅ Visualized pruning distribution

**Key Results**:
- Target sparsity: 50%
- Achieved sparsity: ~50% (linear layers only)
- Effective parameters: 3.5B (from 7B)
- Model still generates coherent text

**Next Steps**:
1. Proceed to **03-Inference.ipynb** for detailed quality evaluation
2. Compare dense vs sparse model outputs
3. Measure performance metrics (latency, throughput)

**Key Variables Available**:
- `model`: Pruned sparse Llama-2-7B model (50% sparsity)
- `pruning_masks`: Dictionary of binary masks for each layer
- `activation_stats`: Collected activation statistics
- `OUTPUT_DIR`: Path to saved pruned model

---

**⏭️ Continue to**: [03-Inference.ipynb](./03-Inference.ipynb)