# Lab 3.2: Wanda Pruning - Performance Benchmarking

**Goal:** Comprehensive performance analysis of pruned sparse models.

**You will learn to:**
- Measure perplexity on WikiText-2 dataset
- Analyze latency distribution (P50/P95/P99)
- Profile memory usage across different batch sizes
- Evaluate throughput scaling
- Understand hardware requirements for sparse acceleration
- Make deployment decisions based on metrics

---

## Why Comprehensive Benchmarking?

**Simple inference tests are not enough because**:
- **Perplexity**: Quantitative measure of model quality
- **Latency Distribution**: P99 latency matters for production SLAs
- **Memory Profiling**: Understand deployment requirements
- **Throughput Scaling**: Optimize for different workloads

**Expected Results** (50% Wanda pruning, Llama-2-7B):
- **Perplexity**: 5.68 → 6.12 (+7.7% degradation)
- **Latency**: Similar without hardware acceleration
- **Memory**: Same (dense format storage)
- **With 2:4 Sparse + A100**: 2x speedup possible

---

## Prerequisites

Make sure you have completed:
- **01-Setup.ipynb**: Environment setup
- **02-Prune.ipynb**: Applied Wanda pruning
- **03-Inference.ipynb**: Initial quality comparison

---
## Step 1: Load Models

Load both dense and pruned models for benchmarking.

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import json

# Paths
MODEL_NAME = "meta-llama/Llama-2-7b-hf"
PRUNED_MODEL_DIR = "./pruned_model"

print("=" * 60)
print("Loading Models for Benchmarking")
print("=" * 60)

# Load tokenizer (shared)
print("⏳ Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
print("✅ Tokenizer loaded\n")

# Load dense model
print("⏳ Loading dense baseline model...")
dense_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)
print("✅ Dense model loaded\n")

# Load sparse model
print("⏳ Loading pruned sparse model...")
sparse_model = AutoModelForCausalLM.from_pretrained(
    PRUNED_MODEL_DIR,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Load pruning config
config_path = os.path.join(PRUNED_MODEL_DIR, "pruning_config.json")
with open(config_path, 'r') as f:
    pruning_config = json.load(f)

print(f"✅ Sparse model loaded (sparsity: {pruning_config['achieved_sparsity']:.2%})\n")

# GPU memory
if torch.cuda.is_available():
    memory_allocated = torch.cuda.memory_allocated() / 1e9
    print(f"🖥️  GPU Memory: {memory_allocated:.2f} GB")

print("=" * 60)

---
## Step 2: Load WikiText-2 Test Dataset

Use WikiText-2 test set for perplexity evaluation (standard benchmark).

In [None]:
from datasets import load_dataset
from tqdm import tqdm

print("=" * 60)
print("Loading WikiText-2 Test Dataset")
print("=" * 60)

# Load test split
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
print(f"✅ Dataset loaded: {len(dataset)} samples\n")

# Filter empty texts
dataset = dataset.filter(lambda x: len(x['text'].strip()) > 0)
print(f"✅ Filtered dataset: {len(dataset)} non-empty samples\n")

# Concatenate all texts
all_text = "\n\n".join(dataset['text'])
print(f"📊 Total characters: {len(all_text):,}")

# Tokenize
print("⏳ Tokenizing dataset...")
encodings = tokenizer(all_text, return_tensors="pt")
input_ids = encodings['input_ids'][0]
print(f"✅ Tokenized: {len(input_ids):,} tokens\n")

print("=" * 60)

---
## Step 3: Calculate Perplexity (Dense Model)

Measure baseline perplexity on WikiText-2.

In [None]:
import math
import time

def calculate_perplexity(model, input_ids, max_length=2048, stride=512):
    """
    Calculate perplexity using sliding window.
    
    Args:
        model: Language model
        input_ids: Tokenized input
        max_length: Maximum context length
        stride: Stride for sliding window
    
    Returns:
        perplexity, avg_loss
    """
    model.eval()
    
    nlls = []  # Negative log-likelihoods
    
    # Sliding window over the dataset
    for i in tqdm(range(0, len(input_ids), stride), desc="Computing perplexity"):
        begin_loc = max(i + stride - max_length, 0)
        end_loc = min(i + stride, len(input_ids))
        trg_len = end_loc - i  # Target length (labels)
        
        input_chunk = input_ids[begin_loc:end_loc].unsqueeze(0).to(model.device)
        target_ids = input_chunk.clone()
        target_ids[:, :-trg_len] = -100  # Ignore context tokens
        
        with torch.no_grad():
            outputs = model(input_chunk, labels=target_ids)
            neg_log_likelihood = outputs.loss * trg_len
        
        nlls.append(neg_log_likelihood)
        
        if i + stride >= len(input_ids):
            break
    
    # Calculate perplexity
    avg_nll = torch.stack(nlls).sum() / end_loc
    perplexity = torch.exp(avg_nll)
    
    return perplexity.item(), avg_nll.item()

print("=" * 60)
print("Calculating Perplexity: Dense Model")
print("=" * 60)
print("⏳ This may take 5-10 minutes...\n")

start_time = time.time()
dense_ppl, dense_loss = calculate_perplexity(dense_model, input_ids)
dense_time = time.time() - start_time

print(f"\n📊 Dense Model Results:")
print(f"   Perplexity: {dense_ppl:.2f}")
print(f"   Avg Loss: {dense_loss:.4f}")
print(f"   Time: {dense_time:.2f}s")
print("=" * 60)

---
## Step 4: Calculate Perplexity (Sparse Model)

Measure perplexity of pruned model to quantify quality loss.

In [None]:
print("=" * 60)
print("Calculating Perplexity: Sparse Model (50% pruned)")
print("=" * 60)
print("⏳ This may take 5-10 minutes...\n")

start_time = time.time()
sparse_ppl, sparse_loss = calculate_perplexity(sparse_model, input_ids)
sparse_time = time.time() - start_time

print(f"\n📊 Sparse Model Results:")
print(f"   Perplexity: {sparse_ppl:.2f}")
print(f"   Avg Loss: {sparse_loss:.4f}")
print(f"   Time: {sparse_time:.2f}s")

# Comparison
ppl_increase = (sparse_ppl - dense_ppl) / dense_ppl * 100

print(f"\n📊 Perplexity Comparison:")
print(f"   Dense:  {dense_ppl:.2f}")
print(f"   Sparse: {sparse_ppl:.2f}")
print(f"   Increase: {ppl_increase:+.2f}%")

if ppl_increase < 10:
    print(f"   ✅ Quality well preserved (<10% degradation)")
elif ppl_increase < 20:
    print(f"   🟡 Moderate quality loss (10-20% degradation)")
else:
    print(f"   ⚠️  Significant quality loss (>20% degradation)")

print(f"\n📖 Reference (Wanda paper, Llama-2-7B, 50% sparsity):")
print(f"   Expected PPL: 5.68 → 6.12 (+7.7%)")
print("=" * 60)

---
## Step 5: Latency Distribution Analysis

Measure latency across multiple runs to understand variability.

In [None]:
import numpy as np

def measure_latency_distribution(model, tokenizer, prompt, num_runs=50, max_tokens=50):
    """
    Measure latency distribution over multiple runs.
    """
    latencies = []
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Warmup runs
    for _ in range(5):
        with torch.no_grad():
            _ = model.generate(**inputs, max_new_tokens=max_tokens, do_sample=False)
    
    # Measure latency
    for _ in tqdm(range(num_runs), desc="Measuring latency"):
        start = time.time()
        with torch.no_grad():
            _ = model.generate(**inputs, max_new_tokens=max_tokens, do_sample=False)
        latency = time.time() - start
        latencies.append(latency)
    
    return np.array(latencies)

# Test prompt
prompt = "The future of artificial intelligence is"

print("=" * 60)
print("Latency Distribution Analysis")
print("=" * 60)
print(f"Prompt: {prompt}")
print(f"Runs: 50")
print(f"Output tokens: 50\n")

# Dense model
print("[Dense Model]")
dense_latencies = measure_latency_distribution(dense_model, tokenizer, prompt)
print(f"✅ Dense latencies collected\n")

# Sparse model
print("[Sparse Model]")
sparse_latencies = measure_latency_distribution(sparse_model, tokenizer, prompt)
print(f"✅ Sparse latencies collected\n")

# Statistics
print("📊 Latency Statistics:")
print("\nDense Model:")
print(f"   Mean: {np.mean(dense_latencies)*1000:.2f}ms")
print(f"   Std:  {np.std(dense_latencies)*1000:.2f}ms")
print(f"   P50:  {np.percentile(dense_latencies, 50)*1000:.2f}ms")
print(f"   P95:  {np.percentile(dense_latencies, 95)*1000:.2f}ms")
print(f"   P99:  {np.percentile(dense_latencies, 99)*1000:.2f}ms")

print("\nSparse Model (50% pruned):")
print(f"   Mean: {np.mean(sparse_latencies)*1000:.2f}ms")
print(f"   Std:  {np.std(sparse_latencies)*1000:.2f}ms")
print(f"   P50:  {np.percentile(sparse_latencies, 50)*1000:.2f}ms")
print(f"   P95:  {np.percentile(sparse_latencies, 95)*1000:.2f}ms")
print(f"   P99:  {np.percentile(sparse_latencies, 99)*1000:.2f}ms")

# Speedup
p50_speedup = np.percentile(dense_latencies, 50) / np.percentile(sparse_latencies, 50)
p95_speedup = np.percentile(dense_latencies, 95) / np.percentile(sparse_latencies, 95)
p99_speedup = np.percentile(dense_latencies, 99) / np.percentile(sparse_latencies, 99)

print("\n📊 Speedup Analysis:")
print(f"   P50 speedup: {p50_speedup:.2f}x")
print(f"   P95 speedup: {p95_speedup:.2f}x")
print(f"   P99 speedup: {p99_speedup:.2f}x")

print("=" * 60)

---
## Step 6: Throughput Benchmarking

Measure tokens/second at different output lengths.

In [None]:
def measure_throughput(model, tokenizer, prompt, output_lengths=[50, 100, 200]):
    """
    Measure throughput at different output lengths.
    """
    results = []
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    for length in output_lengths:
        # Warmup
        with torch.no_grad():
            _ = model.generate(**inputs, max_new_tokens=length, do_sample=False)
        
        # Measure
        start = time.time()
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=length, do_sample=False)
        latency = time.time() - start
        
        num_tokens = len(outputs[0])
        throughput = num_tokens / latency
        
        results.append({
            'output_length': length,
            'latency': latency,
            'tokens': num_tokens,
            'throughput': throughput
        })
    
    return results

print("=" * 60)
print("Throughput Benchmarking")
print("=" * 60)

output_lengths = [50, 100, 200]
prompt = "Artificial intelligence is transforming"

print(f"Prompt: {prompt}")
print(f"Output lengths: {output_lengths}\n")

# Dense model
print("[Dense Model]")
dense_throughput = measure_throughput(dense_model, tokenizer, prompt, output_lengths)
print("✅ Dense throughput measured\n")

# Sparse model
print("[Sparse Model]")
sparse_throughput = measure_throughput(sparse_model, tokenizer, prompt, output_lengths)
print("✅ Sparse throughput measured\n")

# Display results
print("📊 Throughput Results:\n")
print(f"{'Length':<10} {'Dense (tok/s)':<15} {'Sparse (tok/s)':<15} {'Speedup':<10}")
print("─" * 60)

for d, s in zip(dense_throughput, sparse_throughput):
    speedup = s['throughput'] / d['throughput']
    print(f"{d['output_length']:<10} {d['throughput']:<15.2f} {s['throughput']:<15.2f} {speedup:<10.2f}x")

print("=" * 60)

---
## Step 7: Memory Profiling

Analyze GPU memory usage during inference.

In [None]:
def profile_memory(model, tokenizer, prompt, max_tokens=100):
    """
    Profile GPU memory usage during inference.
    """
    if not torch.cuda.is_available():
        return None
    
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    
    # Measure baseline
    baseline = torch.cuda.memory_allocated() / 1e9
    
    # Run inference
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=max_tokens, do_sample=False)
    
    # Measure peak
    peak = torch.cuda.max_memory_allocated() / 1e9
    current = torch.cuda.memory_allocated() / 1e9
    
    return {
        'baseline': baseline,
        'peak': peak,
        'current': current,
        'inference_overhead': peak - baseline
    }

print("=" * 60)
print("Memory Profiling")
print("=" * 60)

if torch.cuda.is_available():
    # Dense model
    print("\n[Dense Model]")
    dense_mem = profile_memory(dense_model, tokenizer, "The impact of AI on society")
    print(f"   Baseline:  {dense_mem['baseline']:.2f} GB")
    print(f"   Peak:      {dense_mem['peak']:.2f} GB")
    print(f"   Current:   {dense_mem['current']:.2f} GB")
    print(f"   Inference overhead: {dense_mem['inference_overhead']:.2f} GB")
    
    # Sparse model
    print("\n[Sparse Model (50% pruned)]")
    sparse_mem = profile_memory(sparse_model, tokenizer, "The impact of AI on society")
    print(f"   Baseline:  {sparse_mem['baseline']:.2f} GB")
    print(f"   Peak:      {sparse_mem['peak']:.2f} GB")
    print(f"   Current:   {sparse_mem['current']:.2f} GB")
    print(f"   Inference overhead: {sparse_mem['inference_overhead']:.2f} GB")
    
    # Comparison
    print("\n📊 Memory Comparison:")
    print(f"   Baseline reduction: {(dense_mem['baseline'] - sparse_mem['baseline']) / dense_mem['baseline']:.1%}")
    print(f"   Peak reduction: {(dense_mem['peak'] - sparse_mem['peak']) / dense_mem['peak']:.1%}")
    print("\n⚠️  Note:")
    print("   Memory is similar because sparse model is stored in dense format.")
    print("   For actual memory reduction, export to sparse format (CSR/COO).")
else:
    print("❌ CUDA not available. Skipping memory profiling.")

print("=" * 60)

---
## Step 8: Comprehensive Visualization

Create publication-quality visualizations of all metrics.

In [None]:
import matplotlib.pyplot as plt

print("=" * 60)
print("Creating Comprehensive Visualizations")
print("=" * 60)

fig = plt.figure(figsize=(16, 12))
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

# Plot 1: Perplexity Comparison
ax1 = fig.add_subplot(gs[0, 0])
models = ['Dense', 'Sparse\n(50%)']
ppls = [dense_ppl, sparse_ppl]
colors = ['green', 'blue']
bars = ax1.bar(models, ppls, color=colors, alpha=0.7)
ax1.set_ylabel('Perplexity', fontsize=11)
ax1.set_title('Perplexity on WikiText-2', fontsize=12, fontweight='bold')
ax1.grid(axis='y', alpha=0.3)
for bar, ppl in zip(bars, ppls):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height,
             f'{ppl:.2f}', ha='center', va='bottom', fontsize=10)

# Plot 2: Perplexity Increase
ax2 = fig.add_subplot(gs[0, 1])
ax2.bar(['Increase'], [ppl_increase], color='coral', alpha=0.7)
ax2.axhline(y=10, color='orange', linestyle='--', label='10% threshold', linewidth=2)
ax2.set_ylabel('Increase (%)', fontsize=11)
ax2.set_title('Perplexity Degradation', fontsize=12, fontweight='bold')
ax2.legend()
ax2.grid(axis='y', alpha=0.3)

# Plot 3: Latency Distribution (Box plot)
ax3 = fig.add_subplot(gs[0, 2])
box_data = [dense_latencies * 1000, sparse_latencies * 1000]
bp = ax3.boxplot(box_data, labels=['Dense', 'Sparse'], patch_artist=True)
bp['boxes'][0].set_facecolor('green')
bp['boxes'][1].set_facecolor('blue')
for patch in bp['boxes']:
    patch.set_alpha(0.6)
ax3.set_ylabel('Latency (ms)', fontsize=11)
ax3.set_title('Latency Distribution (50 tokens)', fontsize=12, fontweight='bold')
ax3.grid(axis='y', alpha=0.3)

# Plot 4: P50/P95/P99 Latency Comparison
ax4 = fig.add_subplot(gs[1, 0])
percentiles = ['P50', 'P95', 'P99']
dense_p = [np.percentile(dense_latencies, p)*1000 for p in [50, 95, 99]]
sparse_p = [np.percentile(sparse_latencies, p)*1000 for p in [50, 95, 99]]
x = np.arange(len(percentiles))
width = 0.35
ax4.bar(x - width/2, dense_p, width, label='Dense', color='green', alpha=0.7)
ax4.bar(x + width/2, sparse_p, width, label='Sparse', color='blue', alpha=0.7)
ax4.set_ylabel('Latency (ms)', fontsize=11)
ax4.set_title('Percentile Latency Comparison', fontsize=12, fontweight='bold')
ax4.set_xticks(x)
ax4.set_xticklabels(percentiles)
ax4.legend()
ax4.grid(axis='y', alpha=0.3)

# Plot 5: Throughput vs Output Length
ax5 = fig.add_subplot(gs[1, 1])
lengths = [d['output_length'] for d in dense_throughput]
dense_thr = [d['throughput'] for d in dense_throughput]
sparse_thr = [s['throughput'] for s in sparse_throughput]
ax5.plot(lengths, dense_thr, 'o-', label='Dense', color='green', linewidth=2, markersize=8)
ax5.plot(lengths, sparse_thr, 's-', label='Sparse', color='blue', linewidth=2, markersize=8)
ax5.set_xlabel('Output Length (tokens)', fontsize=11)
ax5.set_ylabel('Throughput (tokens/sec)', fontsize=11)
ax5.set_title('Throughput Scaling', fontsize=12, fontweight='bold')
ax5.legend()
ax5.grid(True, alpha=0.3)

# Plot 6: Speedup by Output Length
ax6 = fig.add_subplot(gs[1, 2])
speedups_by_length = [s['throughput'] / d['throughput'] 
                      for d, s in zip(dense_throughput, sparse_throughput)]
bars = ax6.bar(lengths, speedups_by_length, color=['green' if s > 1 else 'orange' 
                                                     for s in speedups_by_length], alpha=0.7)
ax6.axhline(y=1.0, color='red', linestyle='--', label='Baseline', linewidth=2)
ax6.set_xlabel('Output Length (tokens)', fontsize=11)
ax6.set_ylabel('Speedup (sparse/dense)', fontsize=11)
ax6.set_title('Speedup by Output Length', fontsize=12, fontweight='bold')
ax6.legend()
ax6.grid(axis='y', alpha=0.3)

# Plot 7: Memory Usage (if available)
ax7 = fig.add_subplot(gs[2, 0])
if torch.cuda.is_available() and dense_mem and sparse_mem:
    memory_metrics = ['Baseline', 'Peak', 'Current']
    dense_mems = [dense_mem['baseline'], dense_mem['peak'], dense_mem['current']]
    sparse_mems = [sparse_mem['baseline'], sparse_mem['peak'], sparse_mem['current']]
    x = np.arange(len(memory_metrics))
    ax7.bar(x - width/2, dense_mems, width, label='Dense', color='green', alpha=0.7)
    ax7.bar(x + width/2, sparse_mems, width, label='Sparse', color='blue', alpha=0.7)
    ax7.set_ylabel('Memory (GB)', fontsize=11)
    ax7.set_title('GPU Memory Usage', fontsize=12, fontweight='bold')
    ax7.set_xticks(x)
    ax7.set_xticklabels(memory_metrics)
    ax7.legend()
    ax7.grid(axis='y', alpha=0.3)
else:
    ax7.text(0.5, 0.5, 'CUDA not available', ha='center', va='center', fontsize=12)
    ax7.axis('off')

# Plot 8: Latency Histogram
ax8 = fig.add_subplot(gs[2, 1])
ax8.hist(dense_latencies * 1000, bins=20, alpha=0.5, label='Dense', color='green')
ax8.hist(sparse_latencies * 1000, bins=20, alpha=0.5, label='Sparse', color='blue')
ax8.axvline(np.mean(dense_latencies) * 1000, color='green', linestyle='--', linewidth=2)
ax8.axvline(np.mean(sparse_latencies) * 1000, color='blue', linestyle='--', linewidth=2)
ax8.set_xlabel('Latency (ms)', fontsize=11)
ax8.set_ylabel('Frequency', fontsize=11)
ax8.set_title('Latency Distribution', fontsize=12, fontweight='bold')
ax8.legend()
ax8.grid(axis='y', alpha=0.3)

# Plot 9: Summary Table
ax9 = fig.add_subplot(gs[2, 2])
ax9.axis('off')
summary_data = [
    ['Metric', 'Dense', 'Sparse', 'Change'],
    ['Perplexity', f'{dense_ppl:.2f}', f'{sparse_ppl:.2f}', f'+{ppl_increase:.1f}%'],
    ['P50 Latency', f'{np.percentile(dense_latencies, 50)*1000:.0f}ms', 
     f'{np.percentile(sparse_latencies, 50)*1000:.0f}ms', f'{p50_speedup:.2f}x'],
    ['Throughput', f'{dense_thr[0]:.1f}', f'{sparse_thr[0]:.1f}', 
     f'{sparse_thr[0]/dense_thr[0]:.2f}x'],
    ['Parameters', '7.0B', '7.0B', '50% sparse']
]
table = ax9.table(cellText=summary_data, loc='center', cellLoc='center')
table.auto_set_font_size(False)
table.set_fontsize(9)
table.scale(1, 2)
for i in range(len(summary_data[0])):
    table[(0, i)].set_facecolor('#40466e')
    table[(0, i)].set_text_props(weight='bold', color='white')
ax9.set_title('Performance Summary', fontsize=12, fontweight='bold', pad=20)

plt.savefig(os.path.join(PRUNED_MODEL_DIR, 'comprehensive_benchmark.png'), 
            dpi=150, bbox_inches='tight')
plt.show()

print(f"\n✅ Visualization saved to {PRUNED_MODEL_DIR}/comprehensive_benchmark.png")
print("=" * 60)

---
## Step 9: Production Deployment Recommendations

Provide actionable insights based on benchmark results.

In [None]:
print("=" * 80)
print("PRODUCTION DEPLOYMENT RECOMMENDATIONS")
print("=" * 80)

print("\n📊 Benchmark Summary:")
print(f"   Sparsity: {pruning_config['achieved_sparsity']:.1%}")
print(f"   Perplexity: {dense_ppl:.2f} → {sparse_ppl:.2f} (+{ppl_increase:.1f}%)")
print(f"   P50 Latency: {np.percentile(dense_latencies, 50)*1000:.0f}ms → {np.percentile(sparse_latencies, 50)*1000:.0f}ms ({p50_speedup:.2f}x)")
print(f"   Throughput: {dense_thr[0]:.1f} → {sparse_thr[0]:.1f} tok/s")

print("\n" + "="*80)
print("✅ WHEN TO USE SPARSE MODEL (50% Wanda Pruning):")
print("="*80)

print("\n1. DEPLOYMENT WITH 2:4 SPARSE HARDWARE (NVIDIA A100+):")
print("   ✅ Use Case: High-throughput inference services")
print("   ✅ Benefits:")
print("      - 2x speedup with Sparse Tensor Cores")
print("      - Same memory footprint (in CSR format)")
print("      - <10% quality loss (acceptable for most tasks)")
print("   ✅ Requirements:")
print("      - NVIDIA A100 or newer GPU")
print("      - PyTorch with sparse kernel support")
print("      - Export model to 2:4 semi-structured sparse format")

print("\n2. DEPLOYMENT WITHOUT SPARSE ACCELERATION:")
if abs(p50_speedup - 1.0) < 0.1:
    print("   ⚠️  Performance: Similar to dense model")
    print("   💡 Recommendation: Use dense model for now")
    print("   📌 Future: Wait for hardware acceleration support")
else:
    print("   ✅ Performance: Still provides speedup")
    print("   💡 Recommendation: Deploy if quality acceptable")

print("\n3. MEMORY-CONSTRAINED DEPLOYMENT:")
print("   ⚠️  Current: Dense format (no memory reduction)")
print("   💡 Action Required: Export to sparse format (CSR/COO)")
print("   ✅ Expected Savings: ~50% memory reduction")
print("   📌 Trade-off: Need sparse-aware inference engine")

print("\n" + "="*80)
print("⚠️  WHEN NOT TO USE SPARSE MODEL:")
print("="*80)

if ppl_increase > 10:
    print("\n❌ QUALITY-CRITICAL APPLICATIONS:")
    print(f"   Perplexity increase: {ppl_increase:.1f}% (>10% threshold)")
    print("   Use cases to avoid: Medical diagnosis, legal advice, financial analysis")
    print("   💡 Consider: Lower sparsity (30-40%) or dense model")

if p50_speedup < 1.0:
    print("\n❌ LATENCY-CRITICAL APPLICATIONS (without sparse hardware):")
    print(f"   Sparse model is {1/p50_speedup:.2f}x slower")
    print("   Use cases to avoid: Real-time chat, low-latency APIs")
    print("   💡 Recommendation: Use dense model or upgrade to A100+")

print("\n" + "="*80)
print("📋 DEPLOYMENT CHECKLIST:")
print("="*80)

print("\n☑️  Model Validation:")
print("   [ ] Perplexity within acceptable range")
print("   [ ] Side-by-side quality testing on production data")
print("   [ ] A/B testing with 5% traffic")

print("\n☑️  Infrastructure:")
print("   [ ] GPU supports sparse operations (check SM compute capability)")
print("   [ ] PyTorch/TensorRT sparse kernels installed")
print("   [ ] Export model to sparse format (CSR for 2:4, COO for unstructured)")

print("\n☑️  Performance Monitoring:")
print("   [ ] Set up P95/P99 latency monitoring")
print("   [ ] Track quality metrics (perplexity, task accuracy)")
print("   [ ] Compare dense vs sparse in production")

print("\n☑️  Rollback Plan:")
print("   [ ] Keep dense model as fallback")
print("   [ ] Define quality degradation thresholds")
print("   [ ] Automated rollback if metrics degrade")

print("\n" + "="*80)
print("🎯 FINAL RECOMMENDATION:")
print("="*80)

if ppl_increase < 10 and torch.cuda.is_available():
    gpu_props = torch.cuda.get_device_properties(0)
    if gpu_props.major >= 8:  # Ampere or newer
        print("\n✅ DEPLOY SPARSE MODEL")
        print(f"   Your GPU (SM {gpu_props.major}.{gpu_props.minor}) supports sparse acceleration")
        print(f"   Quality loss is acceptable ({ppl_increase:.1f}%)")
        print("   Expected benefits: 2x speedup with proper sparse kernel setup")
        print("\n📌 Next Steps:")
        print("   1. Export model to 2:4 semi-structured sparse format")
        print("   2. Integrate with TensorRT or vLLM sparse backend")
        print("   3. Run production load testing")
    else:
        print("\n⚠️  WAIT FOR HARDWARE UPGRADE")
        print(f"   Your GPU (SM {gpu_props.major}.{gpu_props.minor}) lacks sparse acceleration")
        print("   Current performance gain is minimal")
        print("   💡 Recommendation: Stick with dense model until A100+ available")
else:
    print("\n⚠️  FURTHER TUNING NEEDED")
    if ppl_increase >= 10:
        print(f"   Quality loss ({ppl_increase:.1f}%) exceeds acceptable threshold")
        print("   💡 Options:")
        print("      - Reduce sparsity to 30-40%")
        print("      - Use better calibration data")
        print("      - Try structured pruning (2:4) instead of unstructured")
    if not torch.cuda.is_available():
        print("   CUDA not available - need GPU for deployment")

print("\n" + "="*80)

---
## ✅ Benchmarking Complete!

**Summary**:
- ✅ Measured perplexity on WikiText-2 (quantitative quality metric)
- ✅ Analyzed latency distribution (P50/P95/P99)
- ✅ Benchmarked throughput at different output lengths
- ✅ Profiled GPU memory usage
- ✅ Created comprehensive visualizations
- ✅ Provided production deployment recommendations

**Key Findings**:
- **Quality**: Perplexity increased by ~{ppl_increase:.1f}%
- **Performance**: Speedup depends on hardware (2x with A100 sparse support)
- **Memory**: Similar in dense format, ~50% reduction with sparse export
- **Production**: Viable for non-critical tasks with sparse hardware

**Files Generated**:
- `pruned_model/comprehensive_benchmark.png`: All visualizations
- Performance metrics logged above

**Next Steps**:
1. Export model to sparse format (CSR/COO) for size reduction
2. Integrate with sparse inference engine (TensorRT, vLLM)
3. Deploy with A/B testing on 5% production traffic
4. Monitor quality and performance metrics closely

---

**🎉 Congratulations!** You have completed Lab-3.2: Wanda Pruning!

You now understand:
- How activation-aware pruning works (weight × activation importance)
- Trade-offs between sparsity and quality
- Hardware requirements for sparse acceleration
- When to use sparse models in production

**Next Lab**: Lab-3.3 Knowledge Distillation (coming soon)