# Week 11 Lab: Efficiency & Optimization

## Learning Objectives
- Understand model quantization techniques
- Implement knowledge distillation
- Measure inference latency and throughput
- Apply pruning strategies

## Prerequisites
```bash
pip install transformers torch numpy matplotlib
```

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time
from transformers import AutoTokenizer, AutoModel

# Setup
print('Week 11: Efficiency & Optimization')
print('=' * 50)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## Part 1: Model Size Analysis

Let's understand what makes models large and how size affects performance.

In [None]:
def get_model_size(model):
    """Calculate model size in MB"""
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    return (param_size + buffer_size) / 1024 / 1024

def count_parameters(model):
    """Count total and trainable parameters"""
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

# Load a model
model_name = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

total_params, trainable_params = count_parameters(model)
model_size = get_model_size(model)

print(f"Model: {model_name}")
print(f"Total parameters: {total_params:,}")
print(f"Model size: {model_size:.2f} MB")
print(f"Bytes per parameter: {model_size * 1024 * 1024 / total_params:.1f}")

In [None]:
# Visualize model size by component
component_sizes = {}
for name, module in model.named_modules():
    if len(list(module.children())) == 0:  # Leaf modules only
        size = sum(p.numel() * p.element_size() for p in module.parameters())
        component_type = type(module).__name__
        if component_type not in component_sizes:
            component_sizes[component_type] = 0
        component_sizes[component_type] += size

# Convert to MB and sort
component_sizes = {k: v / 1024 / 1024 for k, v in component_sizes.items()}
component_sizes = dict(sorted(component_sizes.items(), key=lambda x: x[1], reverse=True)[:8])

fig, ax = plt.subplots(figsize=(10, 5))
bars = ax.barh(list(component_sizes.keys()), list(component_sizes.values()), color='#3333B2')
ax.set_xlabel('Size (MB)', fontsize=12)
ax.set_title('Model Size by Component Type', fontsize=14, fontweight='bold')

for bar, val in zip(bars, component_sizes.values()):
    ax.text(bar.get_width() + 0.5, bar.get_y() + bar.get_height()/2,
            f'{val:.1f} MB', va='center', fontsize=10)

plt.tight_layout()
plt.show()

## Part 2: Quantization

Quantization reduces model size by using lower-precision numbers.

In [None]:
# Demonstrate quantization concepts
def simulate_quantization(tensor, bits=8):
    """Simulate quantization to lower precision"""
    # Get range
    min_val, max_val = tensor.min(), tensor.max()
    
    # Calculate scale and zero point
    qmin, qmax = 0, 2**bits - 1
    scale = (max_val - min_val) / (qmax - qmin)
    zero_point = qmin - min_val / scale
    
    # Quantize
    quantized = torch.clamp(torch.round(tensor / scale + zero_point), qmin, qmax)
    
    # Dequantize
    dequantized = (quantized - zero_point) * scale
    
    return quantized.to(torch.int8 if bits == 8 else torch.int16), dequantized, scale

# Test quantization
original = torch.randn(1000)

print("Quantization Comparison:")
print("-" * 60)

for bits in [8, 4, 2]:
    quantized, dequantized, scale = simulate_quantization(original, bits)
    error = torch.mean((original - dequantized) ** 2).item()
    compression = 32 / bits
    print(f"{bits}-bit: MSE={error:.6f}, Compression={compression}x")

In [None]:
# Visualize quantization error
original = torch.randn(10000)

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Original distribution
axes[0, 0].hist(original.numpy(), bins=50, color='#3333B2', alpha=0.7)
axes[0, 0].set_title('Original (FP32)', fontsize=12, fontweight='bold')
axes[0, 0].set_xlabel('Value')

# Different quantization levels
for ax, bits in zip([axes[0, 1], axes[1, 0], axes[1, 1]], [8, 4, 2]):
    _, dequantized, _ = simulate_quantization(original, bits)
    ax.hist(dequantized.numpy(), bins=50, color='#FF7F0E', alpha=0.7)
    error = torch.mean((original - dequantized) ** 2).item()
    ax.set_title(f'{bits}-bit Quantized (MSE: {error:.4f})', fontsize=12, fontweight='bold')
    ax.set_xlabel('Value')

plt.suptitle('Effect of Quantization on Value Distribution', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Part 3: Knowledge Distillation

Knowledge distillation transfers knowledge from a large model to a smaller one.

In [None]:
class TeacherModel(nn.Module):
    """Large teacher model"""
    def __init__(self, input_size=768, hidden_size=512, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, x):
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.dropout(self.relu(self.fc2(x)))
        return self.fc3(x)

class StudentModel(nn.Module):
    """Small student model"""
    def __init__(self, input_size=768, hidden_size=128, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, num_classes)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)

# Create models
teacher = TeacherModel()
student = StudentModel()

teacher_params = sum(p.numel() for p in teacher.parameters())
student_params = sum(p.numel() for p in student.parameters())

print("Model Comparison:")
print(f"  Teacher parameters: {teacher_params:,}")
print(f"  Student parameters: {student_params:,}")
print(f"  Compression ratio: {teacher_params/student_params:.1f}x")

In [None]:
def distillation_loss(student_logits, teacher_logits, labels, temperature=2.0, alpha=0.5):
    """Combined distillation and classification loss"""
    # Soft targets from teacher
    soft_targets = nn.functional.softmax(teacher_logits / temperature, dim=-1)
    soft_student = nn.functional.log_softmax(student_logits / temperature, dim=-1)
    
    # KL divergence loss (soft targets)
    distill_loss = nn.functional.kl_div(soft_student, soft_targets, reduction='batchmean') * (temperature ** 2)
    
    # Standard cross-entropy loss (hard targets)
    ce_loss = nn.functional.cross_entropy(student_logits, labels)
    
    # Combined loss
    return alpha * distill_loss + (1 - alpha) * ce_loss

# Demonstrate distillation
batch_size = 32
x = torch.randn(batch_size, 768)
labels = torch.randint(0, 10, (batch_size,))

# Get teacher predictions
teacher.eval()
with torch.no_grad():
    teacher_logits = teacher(x)

# Student forward pass
student.train()
student_logits = student(x)

# Calculate losses
loss = distillation_loss(student_logits, teacher_logits, labels)
print(f"Distillation loss: {loss.item():.4f}")

In [None]:
# Visualize effect of temperature on soft targets
logits = torch.tensor([2.0, 1.0, 0.1, -0.5, -1.0])
temperatures = [0.5, 1.0, 2.0, 5.0, 10.0]

fig, axes = plt.subplots(1, 5, figsize=(15, 3))

for ax, temp in zip(axes, temperatures):
    probs = torch.softmax(logits / temp, dim=0).numpy()
    ax.bar(range(5), probs, color='#3333B2')
    ax.set_title(f'T = {temp}', fontsize=12, fontweight='bold')
    ax.set_ylim(0, 1)
    ax.set_xlabel('Class')
    if ax == axes[0]:
        ax.set_ylabel('Probability')

plt.suptitle('Effect of Temperature on Softmax Distribution', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Part 4: Pruning

Pruning removes unnecessary weights to reduce model size.

In [None]:
def magnitude_prune(weight, sparsity=0.5):
    """Prune weights by magnitude"""
    threshold = torch.quantile(torch.abs(weight.flatten()), sparsity)
    mask = torch.abs(weight) > threshold
    return weight * mask, mask

# Create a weight matrix
weight = torch.randn(256, 256)

print("Pruning Results:")
print("-" * 50)

sparsities = [0.3, 0.5, 0.7, 0.9]
for sparsity in sparsities:
    pruned, mask = magnitude_prune(weight, sparsity)
    actual_sparsity = 1 - mask.float().mean().item()
    frobenius_ratio = torch.norm(pruned) / torch.norm(weight)
    print(f"Target sparsity {sparsity*100:.0f}%: "
          f"Actual {actual_sparsity*100:.1f}%, "
          f"Frobenius norm ratio: {frobenius_ratio:.3f}")

In [None]:
# Visualize weight distribution and pruning
weight = torch.randn(1000, 1000)

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Original weights
axes[0, 0].hist(weight.flatten().numpy(), bins=50, color='#3333B2', alpha=0.7)
axes[0, 0].set_title('Original Weight Distribution', fontsize=12, fontweight='bold')
axes[0, 0].axvline(0, color='red', linestyle='--', alpha=0.5)

# Different pruning levels
for ax, sparsity in zip([axes[0, 1], axes[1, 0], axes[1, 1]], [0.5, 0.7, 0.9]):
    pruned, _ = magnitude_prune(weight, sparsity)
    non_zero = pruned[pruned != 0].flatten().numpy()
    ax.hist(non_zero, bins=50, color='#2CA02C', alpha=0.7)
    ax.set_title(f'{sparsity*100:.0f}% Pruned ({len(non_zero):,} remaining)', 
                 fontsize=12, fontweight='bold')

plt.suptitle('Effect of Pruning on Weight Distribution', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Part 5: Inference Speed Benchmarking

Let's measure and compare inference speeds.

In [None]:
def benchmark_inference(model, input_data, num_runs=10, warmup=3):
    """Benchmark model inference time"""
    model.eval()

    # Warmup
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(input_data)

    # Benchmark
    times = []
    with torch.no_grad():
        for _ in range(num_runs):
            start = time.perf_counter()
            _ = model(input_data)
            end = time.perf_counter()
            times.append((end - start) * 1000)  # Convert to ms

    return {
        'mean': np.mean(times),
        'std': np.std(times),
        'min': np.min(times),
        'max': np.max(times)
    }

# Benchmark different batch sizes
batch_sizes = [1, 4, 8, 16]
seq_length = 64

results = []
print("Inference Benchmark Results:")
print("-" * 60)

for batch_size in batch_sizes:
    # Create dummy input
    dummy_input = tokenizer(["Hello world"] * batch_size,
                           padding='max_length',
                           max_length=seq_length,
                           return_tensors='pt')

    stats = benchmark_inference(model, dummy_input['input_ids'], num_runs=10)
    results.append({'batch_size': batch_size, **stats})

    throughput = batch_size / (stats['mean'] / 1000)
    print(f"Batch {batch_size:2d}: {stats['mean']:.2f} +/- {stats['std']:.2f} ms, "
          f"Throughput: {throughput:.1f} samples/sec")

In [None]:
# Visualize benchmark results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

batch_sizes = [r['batch_size'] for r in results]
latencies = [r['mean'] for r in results]
throughputs = [b / (l / 1000) for b, l in zip(batch_sizes, latencies)]

# Latency
ax1.plot(batch_sizes, latencies, 'b-o', linewidth=2, markersize=8)
ax1.set_xlabel('Batch Size', fontsize=12)
ax1.set_ylabel('Latency (ms)', fontsize=12)
ax1.set_title('Inference Latency vs Batch Size', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)

# Throughput
ax2.bar(range(len(batch_sizes)), throughputs, color='#2CA02C')
ax2.set_xticks(range(len(batch_sizes)))
ax2.set_xticklabels(batch_sizes)
ax2.set_xlabel('Batch Size', fontsize=12)
ax2.set_ylabel('Throughput (samples/sec)', fontsize=12)
ax2.set_title('Inference Throughput vs Batch Size', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

## Exercises

1. **Quantization**: Implement a more sophisticated quantization scheme with per-channel scaling
2. **Distillation**: Train a student model to match the teacher on a classification task
3. **Pruning**: Implement structured pruning (remove entire neurons/heads)
4. **Benchmarking**: Compare inference speed of different optimization techniques

In [None]:
# Exercise starter: Structured pruning
def structured_prune_neurons(layer, sparsity=0.5):
    """
    Prune entire neurons based on their weight magnitude.
    
    Args:
        layer: Linear layer to prune
        sparsity: Fraction of neurons to remove
    
    Returns:
        Indices of kept neurons
    """
    with torch.no_grad():
        # Calculate importance of each neuron (L2 norm of weights)
        importance = torch.norm(layer.weight, dim=1)
        
        # Find threshold
        k = int(len(importance) * (1 - sparsity))
        threshold = torch.topk(importance, k).values[-1]
        
        # Get indices of neurons to keep
        keep_indices = importance >= threshold
        
        return keep_indices

# Test
test_layer = nn.Linear(100, 50)
kept = structured_prune_neurons(test_layer, sparsity=0.5)
print(f"Original neurons: 50")
print(f"Kept neurons: {kept.sum().item()}")

## Summary

In this lab, we explored:

1. **Model size analysis**: Understanding what makes models large
2. **Quantization**: Reducing precision to compress models
3. **Knowledge distillation**: Training smaller models from larger ones
4. **Pruning**: Removing unnecessary weights
5. **Benchmarking**: Measuring inference performance

**Key Takeaways**:
- 8-bit quantization typically has minimal accuracy loss
- Knowledge distillation can achieve 3-10x compression
- Pruning can remove 50-90% of weights with careful tuning
- Batch size significantly affects throughput