# Notebook 9: Post-Training Quantization

## Inference Engineering Course

---

### What You'll Learn

**Post-Training Quantization (PTQ)** converts a pre-trained model from higher precision (FP32/FP16) to lower precision (INT8/INT4) **without retraining**. This is the most practical approach for deploying LLMs because it requires no training data or GPU-intensive fine-tuning.

In this notebook, we will:

1. **Understand PTQ approaches**: absmax, zero-point, and GPTQ concepts
2. **Implement absmax quantization** and apply to real weights
3. **Implement zero-point quantization** for asymmetric distributions
4. **Apply quantization to a small model** (GPT-2)
5. **Measure quality impact** using perplexity
6. **Compare INT8 vs INT4** quantized outputs

### Prerequisites
- Notebook 08 (Quantization Formats)
- Basic understanding of language model evaluation

### Runtime
- **No GPU required** (using GPT-2 small, which runs on CPU)
- Some cells may take 1-2 minutes on CPU

---

## 1. Setup

In [None]:
!pip install torch transformers datasets matplotlib numpy tqdm -q

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import copy
import time
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

torch.manual_seed(42)
np.random.seed(42)

plt.rcParams['figure.figsize'] = (14, 5)
plt.rcParams['font.size'] = 11
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## 2. Load a Pre-trained Model

We'll use GPT-2 small (124M parameters) as our test model. It's small enough to run on CPU while still being a real Transformer model with meaningful weights.

In [None]:
print("Loading GPT-2 model and tokenizer...")
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
model.eval()

# Model statistics
total_params = sum(p.numel() for p in model.parameters())
model_size_fp32 = sum(p.numel() * 4 for p in model.parameters()) / (1024 ** 2)  # MB
model_size_fp16 = model_size_fp32 / 2

print(f"\nModel: GPT-2 Small")
print(f"Parameters: {total_params:,}")
print(f"Size (FP32): {model_size_fp32:.1f} MB")
print(f"Size (FP16): {model_size_fp16:.1f} MB")
print(f"\nLayer structure:")
for name, param in model.named_parameters():
    if 'transformer.h.0' in name:  # Just show first layer
        print(f"  {name}: {param.shape}")

## 3. Analyzing Weight Distributions

Before quantizing, let's understand what the model's weights look like. The distribution of weights tells us how well different quantization schemes will work.

In [None]:
# Collect all weights and analyze their distributions
all_weights = []
layer_stats = []

for name, param in model.named_parameters():
    if 'weight' in name and param.dim() >= 2:  # Only weight matrices
        w = param.data.detach().cpu().flatten()
        all_weights.append(w)
        layer_stats.append({
            'name': name.replace('transformer.', ''),
            'shape': tuple(param.shape),
            'mean': w.mean().item(),
            'std': w.std().item(),
            'min': w.min().item(),
            'max': w.max().item(),
            'abs_max': w.abs().max().item(),
            'outlier_frac': (w.abs() > 3 * w.std()).float().mean().item(),
        })

all_weights_flat = torch.cat(all_weights)

# Print statistics
print("Weight Statistics by Layer:")
print(f"{'Layer':<35s} | {'Shape':>15s} | {'Mean':>8s} | {'Std':>8s} | {'AbsMax':>8s} | {'Outlier%':>8s}")
print("=" * 100)
for s in layer_stats[:12]:  # First 12 layers
    print(f"{s['name']:<35s} | {str(s['shape']):>15s} | {s['mean']:>8.4f} | {s['std']:>8.4f} | "
          f"{s['abs_max']:>8.4f} | {s['outlier_frac']*100:>7.2f}%")

In [None]:
# Visualize weight distributions
fig, axes = plt.subplots(2, 2, figsize=(16, 10))

# Plot 1: Overall weight distribution
ax = axes[0][0]
ax.hist(all_weights_flat.numpy(), bins=500, alpha=0.7, color='steelblue', density=True)
ax.set_xlabel('Weight Value')
ax.set_ylabel('Density')
ax.set_title(f'All Weight Values ({len(all_weights_flat):,} total)')
mean, std = all_weights_flat.mean().item(), all_weights_flat.std().item()
ax.axvline(x=mean, color='red', linestyle='--', label=f'Mean={mean:.4f}')
ax.axvline(x=mean + 3*std, color='orange', linestyle=':', label=f'+3std={mean+3*std:.4f}')
ax.axvline(x=mean - 3*std, color='orange', linestyle=':')
ax.legend()

# Plot 2: Distribution per layer type
ax = axes[0][1]
attn_weights = []
mlp_weights = []
for s, w in zip(layer_stats, all_weights):
    if 'attn' in s['name']:
        attn_weights.append(w)
    elif 'mlp' in s['name']:
        mlp_weights.append(w)
if attn_weights:
    ax.hist(torch.cat(attn_weights).numpy(), bins=300, alpha=0.5, label='Attention', density=True, color='blue')
if mlp_weights:
    ax.hist(torch.cat(mlp_weights).numpy(), bins=300, alpha=0.5, label='MLP', density=True, color='red')
ax.set_xlabel('Weight Value')
ax.set_ylabel('Density')
ax.set_title('Weight Distribution by Layer Type')
ax.legend()

# Plot 3: AbsMax per layer
ax = axes[1][0]
abs_maxes = [s['abs_max'] for s in layer_stats]
ax.bar(range(len(abs_maxes)), abs_maxes, color='steelblue', alpha=0.7)
ax.set_xlabel('Layer Index')
ax.set_ylabel('Absolute Max Weight')
ax.set_title('Maximum Weight Magnitude per Layer')
ax.axhline(y=np.mean(abs_maxes), color='red', linestyle='--', label=f'Mean={np.mean(abs_maxes):.3f}')
ax.legend()

# Plot 4: Standard deviation per layer
ax = axes[1][1]
stds = [s['std'] for s in layer_stats]
ax.bar(range(len(stds)), stds, color='green', alpha=0.7)
ax.set_xlabel('Layer Index')
ax.set_ylabel('Standard Deviation')
ax.set_title('Weight Std Dev per Layer')

plt.suptitle('GPT-2 Weight Distribution Analysis', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## 4. Implementing Absmax Quantization

**Absmax (symmetric) quantization** is the simplest PTQ method:

$$\text{scale} = \frac{\max(|W|)}{2^{b-1} - 1}$$
$$W_{int} = \text{round}\left(\frac{W}{\text{scale}}\right)$$
$$\hat{W} = W_{int} \times \text{scale}$$

The quantized weight matrix $\hat{W}$ approximates the original $W$.

In [None]:
class AbsmaxQuantizer:
    """Absmax (symmetric) quantization for weight matrices.
    
    Supports both per-tensor and per-channel quantization.
    """
    def __init__(self, n_bits=8, per_channel=True):
        self.n_bits = n_bits
        self.per_channel = per_channel
        self.qmax = 2 ** (n_bits - 1) - 1
    
    def quantize(self, tensor):
        """Quantize a weight tensor.
        
        Returns:
            quantized: Integer tensor
            scale: Scale factor(s) for dequantization
        """
        if self.per_channel and tensor.dim() >= 2:
            # Per-channel: separate scale for each output channel (row)
            scale = tensor.abs().amax(dim=list(range(1, tensor.dim())), keepdim=True) / self.qmax
            scale = scale.clamp(min=1e-8)  # Avoid division by zero
        else:
            # Per-tensor: single scale for entire tensor
            scale = tensor.abs().max() / self.qmax
            scale = max(scale, 1e-8)
        
        quantized = torch.round(tensor / scale).clamp(-self.qmax, self.qmax)
        return quantized.to(torch.int8 if self.n_bits <= 8 else torch.int16), scale
    
    def dequantize(self, quantized, scale):
        """Dequantize back to floating point."""
        return quantized.float() * scale
    
    def quantize_dequantize(self, tensor):
        """Quantize then immediately dequantize (simulated quantization)."""
        q, s = self.quantize(tensor)
        return self.dequantize(q, s)

# Test on a sample layer
sample_weight = model.transformer.h[0].attn.c_attn.weight.data.clone()
print(f"Sample weight shape: {sample_weight.shape}")
print(f"Original - min: {sample_weight.min():.4f}, max: {sample_weight.max():.4f}")

for bits in [8, 4]:
    quantizer = AbsmaxQuantizer(n_bits=bits, per_channel=True)
    deq = quantizer.quantize_dequantize(sample_weight)
    mse = ((deq - sample_weight) ** 2).mean().item()
    max_err = (deq - sample_weight).abs().max().item()
    print(f"\nINT{bits} (per-channel): MSE={mse:.8f}, Max Error={max_err:.6f}")
    
    quantizer_pt = AbsmaxQuantizer(n_bits=bits, per_channel=False)
    deq_pt = quantizer_pt.quantize_dequantize(sample_weight)
    mse_pt = ((deq_pt - sample_weight) ** 2).mean().item()
    print(f"INT{bits} (per-tensor): MSE={mse_pt:.8f} ({mse_pt/mse:.1f}x worse)")

## 5. Implementing Zero-Point Quantization

**Zero-point (asymmetric) quantization** handles distributions that aren't centered around zero. It maps the range $[\min, \max]$ to $[0, 2^b - 1]$:

$$\text{scale} = \frac{\max(W) - \min(W)}{2^b - 1}$$
$$\text{zero\_point} = \text{round}\left(-\frac{\min(W)}{\text{scale}}\right)$$
$$W_{int} = \text{round}\left(\frac{W}{\text{scale}}\right) + \text{zero\_point}$$

In [None]:
class ZeroPointQuantizer:
    """Zero-point (asymmetric) quantization.
    
    Better for distributions not centered at zero.
    Uses unsigned integer representation.
    """
    def __init__(self, n_bits=8, per_channel=True):
        self.n_bits = n_bits
        self.per_channel = per_channel
        self.qmin = 0
        self.qmax = 2 ** n_bits - 1
    
    def quantize(self, tensor):
        if self.per_channel and tensor.dim() >= 2:
            reduce_dims = list(range(1, tensor.dim()))
            min_val = tensor.amin(dim=reduce_dims, keepdim=True)
            max_val = tensor.amax(dim=reduce_dims, keepdim=True)
        else:
            min_val = tensor.min()
            max_val = tensor.max()
        
        scale = (max_val - min_val) / (self.qmax - self.qmin)
        scale = torch.where(scale == 0, torch.ones_like(scale) * 1e-8, scale)
        zero_point = torch.round(self.qmin - min_val / scale).clamp(self.qmin, self.qmax)
        
        quantized = torch.round(tensor / scale + zero_point).clamp(self.qmin, self.qmax)
        return quantized.to(torch.uint8), scale, zero_point
    
    def dequantize(self, quantized, scale, zero_point):
        return (quantized.float() - zero_point) * scale
    
    def quantize_dequantize(self, tensor):
        q, s, zp = self.quantize(tensor)
        return self.dequantize(q, s, zp)

# Compare absmax vs zero-point on different distributions
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Test distributions
distributions = [
    ('Symmetric (centered at 0)', torch.randn(10000) * 0.5),
    ('Asymmetric (shifted)', torch.randn(10000) * 0.5 + 1.0),
    ('One-sided (ReLU-like)', F.relu(torch.randn(10000) * 0.5)),
]

for ax, (name, data) in zip(axes, distributions):
    abs_q = AbsmaxQuantizer(n_bits=4)
    zp_q = ZeroPointQuantizer(n_bits=4)
    
    deq_abs = abs_q.quantize_dequantize(data)
    deq_zp = zp_q.quantize_dequantize(data)
    
    mse_abs = ((deq_abs - data) ** 2).mean().item()
    mse_zp = ((deq_zp - data) ** 2).mean().item()
    
    ax.hist(data.numpy(), bins=100, alpha=0.3, label='Original', color='blue', density=True)
    ax.hist(deq_abs.numpy(), bins=30, alpha=0.5, label=f'Absmax (MSE={mse_abs:.5f})', color='red', density=True)
    ax.hist(deq_zp.numpy(), bins=30, alpha=0.5, label=f'ZeroPoint (MSE={mse_zp:.5f})', color='green', density=True)
    ax.set_title(f'{name}\n(INT4)', fontsize=11)
    ax.legend(fontsize=8)
    ax.set_xlabel('Value')

plt.suptitle('Absmax vs Zero-Point Quantization on Different Distributions', fontsize=13, y=1.02)
plt.tight_layout()
plt.show()

print("Key insight: Zero-point quantization is better for asymmetric distributions,")
print("but absmax is simpler and works well when weights are roughly symmetric (which they usually are).")

## 6. Quantizing the Full Model

Now let's apply quantization to the entire GPT-2 model. We'll replace each linear layer's weights with quantized versions.

In [None]:
def quantize_model(model, n_bits=8, method='absmax', per_channel=True):
    """Quantize all linear layer weights in a model.
    
    Uses simulated quantization: quantize then dequantize,
    so the model still operates in FP32 but with quantized weight values.
    
    In production, you'd use actual integer arithmetic for speed.
    """
    quantized_model = copy.deepcopy(model)
    
    if method == 'absmax':
        quantizer = AbsmaxQuantizer(n_bits=n_bits, per_channel=per_channel)
    else:
        quantizer = ZeroPointQuantizer(n_bits=n_bits, per_channel=per_channel)
    
    n_quantized = 0
    total_mse = 0
    
    for name, module in quantized_model.named_modules():
        if isinstance(module, (nn.Linear, type(model.transformer.h[0].attn.c_attn))):
            if hasattr(module, 'weight'):
                original_weight = module.weight.data
                quantized_weight = quantizer.quantize_dequantize(original_weight)
                
                mse = ((quantized_weight - original_weight) ** 2).mean().item()
                total_mse += mse
                n_quantized += 1
                
                module.weight.data = quantized_weight
    
    avg_mse = total_mse / max(n_quantized, 1)
    print(f"Quantized {n_quantized} layers to INT{n_bits} ({method}, {'per-channel' if per_channel else 'per-tensor'})")
    print(f"Average layer MSE: {avg_mse:.8f}")
    
    return quantized_model

# Quantize model at different bit widths
print("Quantizing GPT-2...\n")

model_int8 = quantize_model(model, n_bits=8, method='absmax')
print()
model_int4 = quantize_model(model, n_bits=4, method='absmax')
print()
model_int4_zp = quantize_model(model, n_bits=4, method='zeropoint')

## 7. Measuring Perplexity

**Perplexity** is the standard metric for evaluating language model quality. It measures how surprised the model is by the test text:

$$\text{Perplexity} = \exp\left(-\frac{1}{N}\sum_{i=1}^{N}\log P(x_i | x_{<i})\right)$$

Lower perplexity = better model. A good quantization should barely increase perplexity.

In [None]:
def compute_perplexity(model, text, tokenizer, max_length=256, stride=128):
    """Compute perplexity on a text string.
    
    Uses a sliding window approach for efficiency.
    """
    model.eval()
    encodings = tokenizer(text, return_tensors='pt', truncation=True, max_length=1024)
    input_ids = encodings.input_ids.to(device)
    
    nlls = []
    n_tokens = 0
    
    for i in range(0, input_ids.size(1) - 1, stride):
        begin = max(i + stride - max_length, 0)
        end = min(i + stride, input_ids.size(1))
        
        target_begin = max(i, begin)
        target_end = end
        
        input_chunk = input_ids[:, begin:end]
        target_chunk = input_ids[:, begin:end].clone()
        target_chunk[:, :target_begin - begin] = -100
        
        with torch.no_grad():
            outputs = model(input_chunk, labels=target_chunk)
            nll = outputs.loss.item() * (target_end - target_begin)
            nlls.append(nll)
            n_tokens += (target_end - target_begin)
        
        if end >= input_ids.size(1):
            break
    
    avg_nll = sum(nlls) / max(n_tokens, 1)
    perplexity = np.exp(avg_nll)
    return perplexity

# Test text for evaluation
test_text = """The history of artificial intelligence began in antiquity, with myths, stories and rumors of 
artificial beings endowed with intelligence or consciousness by master craftsmen. The seeds of modern AI 
were planted by philosophers who attempted to describe the process of human thinking as the mechanical 
manipulation of symbols. This work culminated in the invention of the programmable digital computer in 
the 1940s, a machine based on the abstract essence of mathematical reasoning. This device and the ideas 
behind it inspired a handful of scientists to begin seriously discussing the possibility of building an 
electronic brain. The field of AI research was founded at a workshop held on the campus of Dartmouth 
College during the summer of 1956. Those who attended would become the leaders of AI research for decades. 
Many of them predicted that a machine as intelligent as a human being would exist in no more than a 
generation, and they were given millions of dollars to make this vision come true. Eventually, it became 
obvious that commercial developers and researchers had grossly underestimated the difficulty of the project."""

print("Computing perplexity for each model variant...")
print("(This may take a minute on CPU)\n")

results = {}

for name, m in [
    ('FP32 (original)', model),
    ('INT8 (absmax)', model_int8),
    ('INT4 (absmax)', model_int4),
    ('INT4 (zero-point)', model_int4_zp),
]:
    start = time.time()
    ppl = compute_perplexity(m, test_text, tokenizer)
    elapsed = time.time() - start
    results[name] = ppl
    print(f"{name:>25s}: Perplexity = {ppl:>8.2f} (computed in {elapsed:.1f}s)")

print(f"\nPerplexity increase:")
base_ppl = results['FP32 (original)']
for name, ppl in results.items():
    if name != 'FP32 (original)':
        increase = (ppl - base_ppl) / base_ppl * 100
        print(f"  {name}: +{increase:.1f}%")

In [None]:
# Visualize perplexity comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

names = list(results.keys())
ppls = list(results.values())
colors = ['#2ca02c', '#1f77b4', '#ff7f0e', '#d62728']

bars = ax1.bar(range(len(names)), ppls, color=colors, alpha=0.8)
ax1.set_xticks(range(len(names)))
ax1.set_xticklabels(names, rotation=15, fontsize=9)
ax1.set_ylabel('Perplexity (lower = better)')
ax1.set_title('Model Quality After Quantization')

for bar, ppl in zip(bars, ppls):
    ax1.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.5,
             f'{ppl:.1f}', ha='center', fontsize=10, fontweight='bold')

# Plot 2: Size vs Quality tradeoff
sizes_mb = [model_size_fp32, model_size_fp32/4, model_size_fp32/8, model_size_fp32/8]
ax2.scatter(sizes_mb, ppls, c=colors, s=200, zorder=5)
for name, size, ppl, color in zip(names, sizes_mb, ppls, colors):
    ax2.annotate(name, (size, ppl), textcoords='offset points',
                 xytext=(10, 5), fontsize=9, color=color)

ax2.set_xlabel('Model Size (MB)')
ax2.set_ylabel('Perplexity')
ax2.set_title('Size vs Quality Tradeoff')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 8. Comparing Generated Text Quality

Beyond perplexity, let's actually generate text and see the qualitative differences.

In [None]:
def generate_text(model, prompt, max_tokens=50, temperature=0.7, top_p=0.9):
    """Generate text from a model."""
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    return tokenizer.decode(output[0], skip_special_tokens=True)

prompts = [
    "The future of artificial intelligence is",
    "In a surprising discovery, scientists found that",
]

models_to_compare = [
    ('FP32', model),
    ('INT8', model_int8),
    ('INT4', model_int4),
]

for prompt in prompts:
    print(f"\nPrompt: \"{prompt}\"")
    print("=" * 80)
    
    for name, m in models_to_compare:
        # Fix seed for comparable outputs
        torch.manual_seed(42)
        text = generate_text(m, prompt, max_tokens=40)
        print(f"\n[{name}]: {text}")
    print()

## 9. Layer-wise Sensitivity Analysis

Not all layers are equally sensitive to quantization. Some layers can tolerate aggressive quantization while others need higher precision. Let's identify the sensitive layers.

In [None]:
def measure_layer_sensitivity(model, test_text, tokenizer, n_bits=4):
    """Quantize one layer at a time and measure perplexity impact."""
    base_ppl = compute_perplexity(model, test_text, tokenizer)
    sensitivities = []
    
    quantizer = AbsmaxQuantizer(n_bits=n_bits, per_channel=True)
    
    for name, module in model.named_modules():
        if hasattr(module, 'weight') and module.weight.dim() >= 2:
            # Save original weight
            original_weight = module.weight.data.clone()
            
            # Quantize just this layer
            module.weight.data = quantizer.quantize_dequantize(original_weight)
            
            # Measure perplexity
            ppl = compute_perplexity(model, test_text, tokenizer)
            
            # Restore original weight
            module.weight.data = original_weight
            
            sensitivity = ppl - base_ppl
            sensitivities.append({
                'name': name,
                'base_ppl': base_ppl,
                'quantized_ppl': ppl,
                'ppl_increase': sensitivity,
                'ppl_increase_pct': sensitivity / base_ppl * 100,
            })
    
    return sensitivities

print("Measuring layer sensitivity (this takes a few minutes)...")
sensitivities = measure_layer_sensitivity(model, test_text, tokenizer, n_bits=4)

# Sort by sensitivity
sensitivities.sort(key=lambda x: x['ppl_increase'], reverse=True)

print(f"\nLayer Sensitivity to INT4 Quantization:")
print(f"{'Layer':<45s} | {'PPL Increase':>12s} | {'% Increase':>10s}")
print("=" * 75)
for s in sensitivities[:15]:  # Top 15 most sensitive
    print(f"{s['name']:<45s} | {s['ppl_increase']:>+12.2f} | {s['ppl_increase_pct']:>+9.1f}%")

In [None]:
# Visualize layer sensitivity
fig, ax = plt.subplots(1, 1, figsize=(16, 6))

# Sort by position in model (approximate)
sensitivities_sorted = sorted(sensitivities, key=lambda x: x['name'])

names = [s['name'].replace('transformer.', '').replace('.weight', '') for s in sensitivities_sorted]
increases = [s['ppl_increase_pct'] for s in sensitivities_sorted]

colors = ['#d62728' if inc > 1.0 else '#ff7f0e' if inc > 0.1 else '#2ca02c' for inc in increases]

bars = ax.bar(range(len(names)), increases, color=colors, alpha=0.8)
ax.set_xticks(range(len(names)))
ax.set_xticklabels(names, rotation=90, fontsize=7)
ax.set_ylabel('Perplexity Increase (%)')
ax.set_title('Layer Sensitivity to INT4 Quantization\n(Red = high sensitivity, Green = low sensitivity)')
ax.axhline(y=0, color='black', linewidth=0.5)

plt.tight_layout()
plt.show()

print("\nInsight: We can use mixed-precision quantization:")
print("- Sensitive layers: Keep at INT8 or even FP16")
print("- Insensitive layers: Quantize aggressively to INT4")

## 10. Mixed-Precision Quantization

Based on the sensitivity analysis, we can quantize different layers at different precision levels.

In [None]:
def mixed_precision_quantize(model, sensitivities, int4_threshold=0.5):
    """Apply mixed-precision quantization based on sensitivity.
    
    Layers with PPL increase > threshold get INT8.
    Other layers get INT4.
    """
    mixed_model = copy.deepcopy(model)
    
    sensitive_layers = {s['name'] for s in sensitivities if s['ppl_increase_pct'] > int4_threshold}
    
    quantizer_8 = AbsmaxQuantizer(n_bits=8, per_channel=True)
    quantizer_4 = AbsmaxQuantizer(n_bits=4, per_channel=True)
    
    n_int8 = 0
    n_int4 = 0
    
    for name, module in mixed_model.named_modules():
        if hasattr(module, 'weight') and module.weight.dim() >= 2:
            if name in sensitive_layers:
                module.weight.data = quantizer_8.quantize_dequantize(module.weight.data)
                n_int8 += 1
            else:
                module.weight.data = quantizer_4.quantize_dequantize(module.weight.data)
                n_int4 += 1
    
    print(f"Mixed precision: {n_int8} layers at INT8, {n_int4} layers at INT4")
    return mixed_model

# Apply mixed precision
model_mixed = mixed_precision_quantize(model, sensitivities, int4_threshold=0.5)

# Measure quality
ppl_mixed = compute_perplexity(model_mixed, test_text, tokenizer)

print(f"\nResults Comparison:")
print(f"  FP32 (original):  {results['FP32 (original)']:.2f}")
print(f"  INT8 (all):       {results['INT8 (absmax)']:.2f}")
print(f"  INT4 (all):       {results['INT4 (absmax)']:.2f}")
print(f"  Mixed (INT4+INT8): {ppl_mixed:.2f}")
print(f"\nMixed precision achieves near-INT8 quality with mostly-INT4 compression!")

## 11. Introduction to GPTQ Concepts

**GPTQ** (Frantar et al., 2022) is a more sophisticated PTQ method that goes beyond simple rounding. It uses calibration data to minimize the quantization error in a layerwise manner.

### Key Ideas:

1. **Optimal Brain Quantization (OBQ)**: Instead of rounding each weight independently, consider how rounding one weight should adjust others to minimize output error.

2. **Hessian-based correction**: Uses the second-order information (Hessian of the loss) to decide the optimal adjustment.

3. **Layerwise quantization**: Process one layer at a time using calibration data.

Let's implement a simplified version of this idea.

In [None]:
def gptq_simplified(weight, n_bits=4, n_columns_per_step=1):
    """Simplified GPTQ-style quantization.
    
    Instead of independently rounding each weight, we:
    1. Quantize column by column
    2. After quantizing each column, adjust remaining columns
       to compensate for the rounding error
    
    This is a simplified version of the full GPTQ algorithm
    (full version uses Hessian information from calibration data).
    """
    W = weight.clone().float()
    n_rows, n_cols = W.shape
    
    qmax = 2 ** (n_bits - 1) - 1
    scale = W.abs().amax(dim=1, keepdim=True) / qmax
    scale = scale.clamp(min=1e-8)
    
    Q = torch.zeros_like(W)  # Quantized weights
    
    for col in range(n_cols):
        # Quantize this column
        w_col = W[:, col]
        q_col = torch.round(w_col / scale.squeeze()).clamp(-qmax, qmax)
        Q[:, col] = q_col
        
        # Compute quantization error for this column
        error = w_col - q_col * scale.squeeze()
        
        # Distribute error to remaining columns (simplified)
        # In full GPTQ, this uses the Hessian inverse
        if col < n_cols - 1:
            remaining_cols = n_cols - col - 1
            error_per_col = error / remaining_cols
            W[:, col+1:] += error_per_col.unsqueeze(1)
    
    # Dequantize
    return Q * scale

# Compare simple rounding vs GPTQ-style
sample_weight = model.transformer.h[0].attn.c_attn.weight.data.clone()

# Method 1: Simple absmax rounding
quantizer = AbsmaxQuantizer(n_bits=4, per_channel=True)
simple_deq = quantizer.quantize_dequantize(sample_weight)
simple_mse = ((simple_deq - sample_weight) ** 2).mean().item()

# Method 2: GPTQ-style
gptq_deq = gptq_simplified(sample_weight, n_bits=4)
gptq_mse = ((gptq_deq - sample_weight) ** 2).mean().item()

print(f"Weight shape: {sample_weight.shape}")
print(f"Simple absmax INT4: MSE = {simple_mse:.8f}")
print(f"GPTQ-style INT4:   MSE = {gptq_mse:.8f}")
print(f"Improvement: {(1 - gptq_mse/simple_mse)*100:.1f}%")

## 12. Key Takeaways

### PTQ Methods Summary

| Method | Complexity | Quality | Speed |
|--------|-----------|---------|-------|
| **Absmax** (per-tensor) | Simplest | Lowest | Fastest |
| **Absmax** (per-channel) | Simple | Good | Fast |
| **Zero-point** | Medium | Good for asymmetric | Fast |
| **GPTQ** | High | Best | Slow (needs calibration) |
| **AWQ** | High | Very good | Slow |

### Practical Recommendations

1. **INT8 absmax per-channel** is the easiest win: negligible quality loss, 4x compression
2. **INT4 GPTQ** is the standard for aggressive quantization: runs 70B models on consumer GPUs
3. **Mixed precision** gives the best quality-size tradeoff
4. **Layer sensitivity varies**: first and last layers tend to be most sensitive
5. **Calibration data helps**: GPTQ-style methods significantly outperform naive rounding at INT4

### The Quantization Quality Ladder

```
FP32 (baseline) -----> FP16 (nearly lossless) -----> INT8 (minimal loss)
     |                                                     |
     |                                                     v
     |                                               INT4-GPTQ (small loss)
     |                                                     |
     +---- Growing quality loss <----- INT4-naive (moderate loss)
```

---

## Exercises

### Exercise 1: Calibration-Based Scale Selection
Instead of using absmax to determine the scale, use a small calibration dataset to find the scale that minimizes MSE on actual model outputs.

In [None]:
def calibrated_quantize(weight, calibration_inputs, n_bits=4):
    """Find optimal quantization scale using calibration data.
    
    TODO: Implement this
    1. Try different scales (e.g., 0.5x to 1.5x of absmax scale)
    2. For each scale, compute output error on calibration inputs
    3. Select the scale that minimizes output error
    """
    pass

### Exercise 2: Dynamic Quantization
Implement dynamic quantization where the scale is computed per-batch at inference time (based on input activations) rather than fixed at quantization time.

In [None]:
# TODO: Implement dynamic quantization
# Create a DynamicQuantizedLinear module that quantizes activations
# on-the-fly during forward pass

### Exercise 3: Perplexity vs Bits Sweep
Quantize the model at every bit width from 2 to 16 and plot the perplexity curve. Find the "sweet spot" where quality starts degrading rapidly.

In [None]:
# TODO: Sweep bit widths and plot perplexity
# bit_widths = [2, 3, 4, 5, 6, 7, 8, 10, 12, 16]
# For each, quantize the model and compute perplexity

---

**Next up: Notebook 10 - Mixture of Experts Routing** where we'll explore how sparse architectures achieve massive model capacity with limited compute.