# Notebook 8: Quantization - Number Formats

## Inference Engineering Course

---

### What You'll Learn

LLM weights are typically stored in FP32 (32 bits per parameter) or FP16 (16 bits). A 7B parameter model needs 14-28 GB just for weights. **Quantization** reduces this by using fewer bits per parameter, enabling:
- Smaller model files
- Lower memory usage
- Faster inference (less data to move)

In this notebook, we will:

1. **Understand floating-point formats**: FP32, FP16, BF16, FP8 (E4M3, E5M2)
2. **Visualize bit layouts**: sign, exponent, mantissa
3. **Implement manual quantization/dequantization**
4. **Measure precision loss** across formats
5. **Visualize the dynamic range vs precision tradeoff**
6. **Compare model weights before and after quantization**

### Prerequisites
- Binary number representation basics
- Understanding of floating-point arithmetic

### Runtime
- **No GPU required**

---

## 1. Setup

In [None]:
!pip install torch numpy matplotlib -q

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch
import struct
import torch

np.random.seed(42)
torch.manual_seed(42)

plt.rcParams['figure.figsize'] = (14, 6)
plt.rcParams['font.size'] = 11
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False

print("Setup complete!")

## 2. How Floating-Point Numbers Work

### The IEEE 754 Standard

A floating-point number is represented as:

$$(-1)^{\text{sign}} \times 2^{\text{exponent} - \text{bias}} \times (1 + \text{mantissa})$$

The bits are divided into three fields:

| Field | Purpose | Effect |
|-------|---------|--------|
| **Sign** (1 bit) | Positive or negative | +/- |
| **Exponent** (E bits) | Scale/range | How big or small the number can be |
| **Mantissa** (M bits) | Precision | How many significant digits |

**Key tradeoff**: More exponent bits = larger range but less precision. More mantissa bits = more precision but smaller range.

Let's visualize the bit layouts for common formats.

In [None]:
# Define number formats
formats = {
    'FP32':     {'sign': 1, 'exponent': 8,  'mantissa': 23, 'total': 32},
    'FP16':     {'sign': 1, 'exponent': 5,  'mantissa': 10, 'total': 16},
    'BF16':     {'sign': 1, 'exponent': 8,  'mantissa': 7,  'total': 16},
    'FP8 E4M3': {'sign': 1, 'exponent': 4,  'mantissa': 3,  'total': 8},
    'FP8 E5M2': {'sign': 1, 'exponent': 5,  'mantissa': 2,  'total': 8},
    'INT8':     {'sign': 1, 'exponent': 0,  'mantissa': 7,  'total': 8, 'integer': True},
    'INT4':     {'sign': 1, 'exponent': 0,  'mantissa': 3,  'total': 4, 'integer': True},
}

def draw_bit_layout(ax, name, fmt, y_pos):
    """Draw the bit layout for a number format."""
    total = fmt['total']
    colors = {'sign': '#d62728', 'exponent': '#1f77b4', 'mantissa': '#2ca02c'}
    
    x = 0
    bit_width = 0.6
    height = 0.4
    
    # Sign bit
    for i in range(fmt['sign']):
        rect = FancyBboxPatch((x, y_pos), bit_width, height,
                              boxstyle="round,pad=0.02",
                              facecolor=colors['sign'], edgecolor='white', linewidth=1)
        ax.add_patch(rect)
        x += bit_width
    
    # Exponent bits
    for i in range(fmt['exponent']):
        rect = FancyBboxPatch((x, y_pos), bit_width, height,
                              boxstyle="round,pad=0.02",
                              facecolor=colors['exponent'], edgecolor='white', linewidth=1)
        ax.add_patch(rect)
        x += bit_width
    
    # Mantissa bits
    for i in range(fmt['mantissa']):
        rect = FancyBboxPatch((x, y_pos), bit_width, height,
                              boxstyle="round,pad=0.02",
                              facecolor=colors['mantissa'], edgecolor='white', linewidth=1)
        ax.add_patch(rect)
        x += bit_width
    
    # Label
    ax.text(-0.5, y_pos + height/2, name, ha='right', va='center',
            fontsize=11, fontweight='bold')
    
    # Bit counts
    s_mid = fmt['sign'] * bit_width / 2
    e_mid = fmt['sign'] * bit_width + fmt['exponent'] * bit_width / 2
    m_mid = (fmt['sign'] + fmt['exponent']) * bit_width + fmt['mantissa'] * bit_width / 2
    
    if fmt['exponent'] > 0:
        ax.text(s_mid, y_pos - 0.15, f"S({fmt['sign']})", ha='center', fontsize=7, color=colors['sign'])
        ax.text(e_mid, y_pos - 0.15, f"E({fmt['exponent']})", ha='center', fontsize=7, color=colors['exponent'])
        ax.text(m_mid, y_pos - 0.15, f"M({fmt['mantissa']})", ha='center', fontsize=7, color=colors['mantissa'])
    else:
        ax.text(total * bit_width / 2, y_pos - 0.15, f"{total}-bit integer", ha='center', fontsize=7)

fig, ax = plt.subplots(1, 1, figsize=(16, 7))

y_positions = list(range(len(formats) - 1, -1, -1))
for (name, fmt), y in zip(formats.items(), y_positions):
    draw_bit_layout(ax, name, fmt, y * 0.7)

ax.set_xlim(-3, 22)
ax.set_ylim(-0.5, len(formats) * 0.7 + 0.3)
ax.axis('off')

# Legend
legend_elements = [
    mpatches.Patch(facecolor='#d62728', label='Sign (1 bit)'),
    mpatches.Patch(facecolor='#1f77b4', label='Exponent (range)'),
    mpatches.Patch(facecolor='#2ca02c', label='Mantissa (precision)'),
]
ax.legend(handles=legend_elements, loc='upper right', fontsize=11)

ax.set_title('Bit Layouts of Common Number Formats', fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.show()

## 3. Properties of Each Format

Let's compute the key properties - dynamic range, precision (smallest representable gap), and number of unique values.

In [None]:
def format_properties(name, exp_bits, man_bits, is_integer=False):
    """Calculate key properties of a number format."""
    if is_integer:
        total_bits = 1 + exp_bits + man_bits  # sign + value bits
        n_values = 2 ** total_bits
        max_val = 2 ** (total_bits - 1) - 1
        min_positive = 1 
        precision_at_1 = 1  # Integer precision
        return {
            'name': name,
            'bits': total_bits,
            'max_value': max_val,
            'min_positive': min_positive,
            'precision_at_1': precision_at_1,
            'n_unique_values': n_values,
            'dynamic_range_db': 20 * np.log10(max_val / 1),
        }
    
    bias = 2 ** (exp_bits - 1) - 1
    max_exp = 2 ** exp_bits - 2 - bias  # Exclude inf/nan
    min_exp = 1 - bias  # Smallest normal exponent
    
    # Maximum representable value
    max_mantissa = 1 + sum(2**(-i) for i in range(1, man_bits + 1))
    max_val = 2 ** max_exp * max_mantissa
    
    # Minimum positive normal value
    min_normal = 2 ** min_exp
    
    # Minimum subnormal
    min_subnormal = 2 ** (min_exp - man_bits)
    
    # Precision at 1.0 (machine epsilon)
    epsilon = 2 ** (-man_bits)
    
    # Number of unique values
    n_values = 2 ** (1 + exp_bits + man_bits)
    
    return {
        'name': name,
        'bits': 1 + exp_bits + man_bits,
        'max_value': max_val,
        'min_positive': min_subnormal,
        'precision_at_1': epsilon,
        'n_unique_values': n_values,
        'dynamic_range_db': 20 * np.log10(max_val / min_subnormal) if min_subnormal > 0 else float('inf'),
    }

# Calculate properties
props = [
    format_properties('FP32', 8, 23),
    format_properties('FP16', 5, 10),
    format_properties('BF16', 8, 7),
    format_properties('FP8 E4M3', 4, 3),
    format_properties('FP8 E5M2', 5, 2),
    format_properties('INT8', 0, 7, is_integer=True),
    format_properties('INT4', 0, 3, is_integer=True),
]

print(f"{'Format':<12s} | {'Bits':>4s} | {'Max Value':>12s} | {'Min Positive':>14s} | {'Precision@1':>12s} | {'Unique Values':>14s}")
print("=" * 90)
for p in props:
    print(f"{p['name']:<12s} | {p['bits']:>4d} | {p['max_value']:>12.1f} | {p['min_positive']:>14.2e} | "
          f"{p['precision_at_1']:>12.2e} | {p['n_unique_values']:>14,}")

## 4. Visualizing Dynamic Range vs Precision

This is the fundamental tradeoff in quantization: **range** (how large/small numbers can be) vs **precision** (how fine-grained the representation is).

In [None]:
# Visualize representable numbers on a number line
fig, axes = plt.subplots(4, 1, figsize=(16, 10))

def plot_representable_numbers(ax, name, exp_bits, man_bits, x_range=(-4, 4)):
    """Plot all representable numbers for a format in a range."""
    bias = 2 ** (exp_bits - 1) - 1
    values = set()
    
    # Generate all representable positive values
    for e in range(2 ** exp_bits):
        for m in range(2 ** man_bits):
            if e == 0:  # Subnormal
                val = 2 ** (1 - bias) * (m / 2 ** man_bits)
            elif e == 2 ** exp_bits - 1:  # Inf/NaN - skip
                continue
            else:  # Normal
                val = 2 ** (e - bias) * (1 + m / 2 ** man_bits)
            
            if x_range[0] <= val <= x_range[1]:
                values.add(val)
            if x_range[0] <= -val <= x_range[1]:
                values.add(-val)
    
    values = sorted(values)
    
    # Plot
    ax.scatter(values, [0] * len(values), marker='|', s=50, c='steelblue', alpha=0.6, linewidth=1)
    ax.axhline(y=0, color='black', linewidth=0.5)
    ax.set_xlim(x_range)
    ax.set_ylim(-0.5, 0.5)
    ax.set_yticks([])
    ax.set_title(f'{name}: {len(values)} values in [{x_range[0]}, {x_range[1]}]', fontsize=12, fontweight='bold')
    ax.set_xlabel('Value')
    
    return len(values)

x_range = (-4, 4)
n_fp16 = plot_representable_numbers(axes[0], 'FP16 (E5M10)', 5, 10, x_range)
n_bf16 = plot_representable_numbers(axes[1], 'BF16 (E8M7)', 8, 7, x_range)
n_e4m3 = plot_representable_numbers(axes[2], 'FP8 E4M3', 4, 3, x_range)
n_e5m2 = plot_representable_numbers(axes[3], 'FP8 E5M2', 5, 2, x_range)

plt.suptitle('Representable Numbers on the Number Line\n(density = precision)', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print("Notice how:")
print("- Numbers are denser near zero (more precision for small values)")
print("- More mantissa bits = more marks between powers of 2")
print("- More exponent bits = larger range of powers of 2 covered")

## 5. BF16 vs FP16: A Critical Comparison

Both are 16-bit, but they make **very different tradeoffs**:

- **FP16**: 5 exponent bits, 10 mantissa bits - good precision, limited range (max ~65,504)
- **BF16**: 8 exponent bits, 7 mantissa bits - same range as FP32, less precision

BF16 was introduced by Google specifically for deep learning, where **range matters more than precision** (gradients and activations can have large dynamic range).

In [None]:
# Compare FP16 and BF16 precision
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Generate values across a wide range
fp32_values = np.logspace(-6, 5, 1000, dtype=np.float32)

# Simulate FP16 and BF16 quantization
fp16_values = fp32_values.astype(np.float16).astype(np.float32)
# BF16: truncate mantissa to 7 bits (approximate)
bf16_values = np.array([float(torch.tensor(v, dtype=torch.float32).to(torch.bfloat16).to(torch.float32))
                         for v in fp32_values])

# Relative error
fp16_rel_error = np.abs(fp16_values - fp32_values) / (np.abs(fp32_values) + 1e-30)
bf16_rel_error = np.abs(bf16_values - fp32_values) / (np.abs(fp32_values) + 1e-30)

# Handle overflow (FP16 overflows above ~65504)
fp16_overflow = fp32_values > 65504

# Plot 1: Relative error vs value
ax = axes[0]
ax.semilogy(np.log10(fp32_values[~fp16_overflow]), fp16_rel_error[~fp16_overflow],
            'b-', alpha=0.6, label='FP16', linewidth=1.5)
ax.semilogy(np.log10(fp32_values), bf16_rel_error,
            'r-', alpha=0.6, label='BF16', linewidth=1.5)
ax.axvspan(np.log10(65504), 5, alpha=0.2, color='blue', label='FP16 overflow zone')
ax.set_xlabel('log10(Value)')
ax.set_ylabel('Relative Error')
ax.set_title('Quantization Error: FP16 vs BF16')
ax.legend()

# Plot 2: What happens at the boundaries
ax = axes[1]
test_values = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0, 10000.0, 60000.0]
fp16_errors = []
bf16_errors = []

for v in test_values:
    v32 = np.float32(v)
    v16 = np.float32(np.float16(v))
    vb16 = float(torch.tensor(v, dtype=torch.float32).to(torch.bfloat16).to(torch.float32))
    
    fp16_err = abs(v16 - v32) / abs(v32)
    bf16_err = abs(vb16 - v32) / abs(v32)
    fp16_errors.append(fp16_err)
    bf16_errors.append(bf16_err)

x_pos = np.arange(len(test_values))
width = 0.35
ax.bar(x_pos - width/2, fp16_errors, width, label='FP16', color='#1f77b4', alpha=0.8)
ax.bar(x_pos + width/2, bf16_errors, width, label='BF16', color='#d62728', alpha=0.8)
ax.set_xticks(x_pos)
ax.set_xticklabels([str(v) for v in test_values], rotation=45, fontsize=9)
ax.set_ylabel('Relative Error')
ax.set_xlabel('Original Value')
ax.set_title('Error at Specific Values')
ax.legend()
ax.set_yscale('log')

plt.tight_layout()
plt.show()

print("Key insight:")
print("- FP16 has BETTER precision (lower error) in its valid range")
print("- BF16 handles LARGER values without overflow")
print("- For LLM weights (typically in [-2, 2]), FP16 precision is usually sufficient")
print("- For activations/gradients with outliers, BF16's range is valuable")

## 6. FP8 Formats: The Frontier of Low-Precision

FP8 (8-bit floating point) comes in two variants:

- **E4M3**: 4 exponent bits, 3 mantissa bits - better precision, smaller range
  - Used for: **weights and activations** (forward pass)
- **E5M2**: 5 exponent bits, 2 mantissa bits - larger range, less precision
  - Used for: **gradients** (backward pass, need larger range)

Let's implement these from scratch.

In [None]:
def float_to_fp8_e4m3(value):
    """Convert a float to FP8 E4M3 format.
    
    E4M3 has: 1 sign, 4 exponent, 3 mantissa bits
    Bias = 2^(4-1) - 1 = 7
    Max value = 448 (special: no inf, uses 1111.111 for max rather than NaN/inf)
    """
    if value == 0:
        return 0.0
    
    sign = 1 if value < 0 else 0
    abs_val = abs(value)
    
    # Clamp to max representable
    max_val = 448.0  # E4M3 max
    abs_val = min(abs_val, max_val)
    
    # Find exponent and mantissa
    bias = 7
    
    if abs_val < 2**(1 - bias - 3):  # Smaller than min subnormal
        return 0.0
    
    # Normal numbers
    exp = int(np.floor(np.log2(abs_val)))
    exp = max(exp, 1 - bias)  # Clamp to min normal exponent
    exp = min(exp, 2**4 - 2 - bias)  # Clamp to max exponent
    
    # Quantize mantissa to 3 bits
    mantissa = abs_val / (2 ** exp) - 1.0
    mantissa = np.clip(mantissa, 0, 1 - 2**(-3))
    # Round to nearest 3-bit value
    mantissa = round(mantissa * 8) / 8
    
    result = (2 ** exp) * (1 + mantissa)
    return -result if sign else result

def float_to_fp8_e5m2(value):
    """Convert a float to FP8 E5M2 format.
    
    E5M2 has: 1 sign, 5 exponent, 2 mantissa bits
    Bias = 2^(5-1) - 1 = 15
    Max value = 57344
    """
    if value == 0:
        return 0.0
    
    sign = 1 if value < 0 else 0
    abs_val = abs(value)
    
    max_val = 57344.0
    abs_val = min(abs_val, max_val)
    
    bias = 15
    
    if abs_val < 2**(1 - bias - 2):
        return 0.0
    
    exp = int(np.floor(np.log2(abs_val)))
    exp = max(exp, 1 - bias)
    exp = min(exp, 2**5 - 2 - bias)
    
    mantissa = abs_val / (2 ** exp) - 1.0
    mantissa = np.clip(mantissa, 0, 1 - 2**(-2))
    mantissa = round(mantissa * 4) / 4
    
    result = (2 ** exp) * (1 + mantissa)
    return -result if sign else result

# Test FP8 conversion
test_values = [0.1, 0.5, 1.0, 1.5, 3.14159, 10.0, 100.0, 0.001, -2.5]

print(f"{'Original':>10s} | {'E4M3':>10s} | {'E4M3 Err':>10s} | {'E5M2':>10s} | {'E5M2 Err':>10s}")
print("=" * 60)
for v in test_values:
    e4m3 = float_to_fp8_e4m3(v)
    e5m2 = float_to_fp8_e5m2(v)
    e4m3_err = abs(e4m3 - v) / abs(v) * 100
    e5m2_err = abs(e5m2 - v) / abs(v) * 100
    print(f"{v:>10.5f} | {e4m3:>10.5f} | {e4m3_err:>8.2f}% | {e5m2:>10.5f} | {e5m2_err:>8.2f}%")

## 7. Integer Quantization: INT8 and INT4

Unlike floating-point formats, **integer quantization** maps continuous float values to discrete integer levels. This requires a **scale** and optionally a **zero-point**.

### Absmax (Symmetric) Quantization

$$x_{int} = \text{round}\left(\frac{x}{\text{scale}}\right), \quad \text{scale} = \frac{\max(|x|)}{2^{b-1} - 1}$$

### Zero-Point (Asymmetric) Quantization

$$x_{int} = \text{round}\left(\frac{x}{\text{scale}}\right) + \text{zero\_point}$$

In [None]:
def quantize_absmax(tensor, n_bits=8):
    """Absmax (symmetric) quantization.
    
    Maps the full range [-max, max] to [-2^(b-1)+1, 2^(b-1)-1].
    """
    qmax = 2 ** (n_bits - 1) - 1  # e.g., 127 for INT8
    scale = tensor.abs().max() / qmax
    
    # Quantize
    quantized = torch.round(tensor / scale).clamp(-qmax, qmax).to(torch.int8 if n_bits == 8 else torch.int32)
    
    return quantized, scale

def dequantize_absmax(quantized, scale):
    """Dequantize absmax values back to float."""
    return quantized.float() * scale

def quantize_zeropoint(tensor, n_bits=8):
    """Zero-point (asymmetric) quantization.
    
    Maps [min, max] to [0, 2^b - 1] (unsigned range).
    """
    qmin = 0
    qmax = 2 ** n_bits - 1  # e.g., 255 for INT8
    
    min_val = tensor.min()
    max_val = tensor.max()
    
    scale = (max_val - min_val) / (qmax - qmin)
    zero_point = torch.round(qmin - min_val / scale).clamp(qmin, qmax)
    
    quantized = torch.round(tensor / scale + zero_point).clamp(qmin, qmax).to(torch.uint8)
    
    return quantized, scale, zero_point

def dequantize_zeropoint(quantized, scale, zero_point):
    """Dequantize zero-point values back to float."""
    return (quantized.float() - zero_point) * scale

# Demonstrate on sample data
torch.manual_seed(42)
# Simulate model weights (typically small values, roughly Gaussian)
weights = torch.randn(1000) * 0.5

# Quantize with both methods
q_abs8, scale_abs8 = quantize_absmax(weights, n_bits=8)
q_abs4, scale_abs4 = quantize_absmax(weights, n_bits=4)
q_zp8, scale_zp8, zp8 = quantize_zeropoint(weights, n_bits=8)

# Dequantize
deq_abs8 = dequantize_absmax(q_abs8, scale_abs8)
deq_abs4 = dequantize_absmax(q_abs4, scale_abs4)
deq_zp8 = dequantize_zeropoint(q_zp8, scale_zp8, zp8)

# Measure error
print("Quantization Error Analysis")
print("=" * 50)
for name, deq in [('Absmax INT8', deq_abs8), ('Absmax INT4', deq_abs4), ('ZeroPoint INT8', deq_zp8)]:
    mse = ((deq - weights) ** 2).mean().item()
    max_err = (deq - weights).abs().max().item()
    rel_err = ((deq - weights).abs() / (weights.abs() + 1e-8)).mean().item()
    print(f"{name:>15s}: MSE={mse:.6f}, Max Error={max_err:.6f}, Mean Rel Error={rel_err:.4f}")

In [None]:
# Visualize quantization effects
fig, axes = plt.subplots(2, 2, figsize=(16, 10))

# Plot 1: Original vs quantized values
ax = axes[0][0]
idx = torch.argsort(weights)[:100]  # First 100 sorted values
ax.plot(weights[idx].numpy(), 'b-', label='Original (FP32)', linewidth=1.5)
ax.plot(deq_abs8[idx].numpy(), 'r--', label='INT8 (absmax)', alpha=0.7)
ax.plot(deq_abs4[idx].numpy(), 'g:', label='INT4 (absmax)', alpha=0.7, linewidth=2)
ax.set_xlabel('Index (sorted)')
ax.set_ylabel('Value')
ax.set_title('Original vs Quantized Values')
ax.legend()

# Plot 2: Error distribution
ax = axes[0][1]
errors_int8 = (deq_abs8 - weights).numpy()
errors_int4 = (deq_abs4 - weights).numpy()
ax.hist(errors_int8, bins=50, alpha=0.5, label='INT8 error', color='red')
ax.hist(errors_int4, bins=50, alpha=0.5, label='INT4 error', color='green')
ax.set_xlabel('Quantization Error')
ax.set_ylabel('Count')
ax.set_title('Error Distribution')
ax.legend()

# Plot 3: Weight distribution with quantization levels
ax = axes[1][0]
ax.hist(weights.numpy(), bins=100, alpha=0.5, color='blue', label='Original weights', density=True)

# Show INT4 quantization levels
int4_levels = torch.arange(-7, 8) * scale_abs4
for level in int4_levels:
    ax.axvline(x=level.item(), color='green', alpha=0.3, linewidth=0.5)
ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.set_title('Weight Distribution with INT4 Quantization Levels')
ax.legend()

# Plot 4: Scatter plot - original vs dequantized
ax = axes[1][1]
sample = torch.randperm(len(weights))[:200]
ax.scatter(weights[sample].numpy(), deq_abs8[sample].numpy(), s=10, alpha=0.5, label='INT8', color='red')
ax.scatter(weights[sample].numpy(), deq_abs4[sample].numpy(), s=10, alpha=0.5, label='INT4', color='green')
lims = [weights.min().item() - 0.1, weights.max().item() + 0.1]
ax.plot(lims, lims, 'k--', alpha=0.3, label='Perfect (y=x)')
ax.set_xlabel('Original Value (FP32)')
ax.set_ylabel('Dequantized Value')
ax.set_title('Original vs Dequantized (closer to diagonal = better)')
ax.legend()
ax.set_aspect('equal')

plt.suptitle('Quantization Effects on Model Weights', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## 8. Per-Channel vs Per-Tensor Quantization

**Per-tensor quantization** uses a single scale for the entire weight tensor. This is simple but wasteful - if one channel has much larger values, it determines the scale for all channels.

**Per-channel quantization** uses a separate scale per output channel (row of weight matrix). This is much more accurate because each channel gets its own optimal scale.

In [None]:
# Create a weight matrix with varying scales per channel
torch.manual_seed(42)
n_out, n_in = 64, 256
# Different channels have different magnitudes (realistic for neural nets)
channel_scales = torch.rand(n_out) * 5 + 0.1  # Scales from 0.1 to 5.1
weight_matrix = torch.randn(n_out, n_in) * channel_scales.unsqueeze(1)

def quantize_per_tensor(tensor, n_bits=8):
    """Per-tensor absmax quantization."""
    qmax = 2 ** (n_bits - 1) - 1
    scale = tensor.abs().max() / qmax
    quantized = torch.round(tensor / scale).clamp(-qmax, qmax)
    dequantized = quantized * scale
    return dequantized, scale

def quantize_per_channel(tensor, n_bits=8):
    """Per-channel (per-row) absmax quantization."""
    qmax = 2 ** (n_bits - 1) - 1
    # Scale per row
    scales = tensor.abs().max(dim=1, keepdim=True).values / qmax
    quantized = torch.round(tensor / scales).clamp(-qmax, qmax)
    dequantized = quantized * scales
    return dequantized, scales

# Compare
deq_per_tensor, _ = quantize_per_tensor(weight_matrix, n_bits=8)
deq_per_channel, _ = quantize_per_channel(weight_matrix, n_bits=8)

# Per-channel errors
errors_per_tensor = ((deq_per_tensor - weight_matrix) ** 2).mean(dim=1).numpy()
errors_per_channel = ((deq_per_channel - weight_matrix) ** 2).mean(dim=1).numpy()

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: MSE per channel
ax1.bar(range(n_out), errors_per_tensor, alpha=0.5, label='Per-Tensor', color='red')
ax1.bar(range(n_out), errors_per_channel, alpha=0.5, label='Per-Channel', color='green')
ax1.set_xlabel('Channel Index')
ax1.set_ylabel('MSE')
ax1.set_title('Quantization Error by Channel (INT8)')
ax1.legend()

# Plot 2: Channel scale vs error improvement
improvement = errors_per_tensor / (errors_per_channel + 1e-10)
ax2.scatter(channel_scales.numpy(), improvement, alpha=0.7, c='steelblue')
ax2.set_xlabel('Channel Scale (magnitude)')
ax2.set_ylabel('Error Improvement Factor')
ax2.set_title('Per-Channel Helps Most for Small Channels')
ax2.axhline(y=1, color='gray', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()

total_mse_tensor = ((deq_per_tensor - weight_matrix) ** 2).mean().item()
total_mse_channel = ((deq_per_channel - weight_matrix) ** 2).mean().item()
print(f"Overall MSE - Per-Tensor: {total_mse_tensor:.6f}, Per-Channel: {total_mse_channel:.6f}")
print(f"Per-Channel is {total_mse_tensor/total_mse_channel:.1f}x more accurate")

## 9. Memory Savings from Quantization

The primary motivation for quantization is **reducing memory**. Let's calculate the savings for real models.

In [None]:
def model_memory_gb(n_params_billions, bits_per_param):
    """Calculate model memory in GB."""
    return n_params_billions * 1e9 * bits_per_param / 8 / (1024 ** 3)

model_sizes = [1.5, 7, 13, 34, 70, 175]
quantization_levels = {
    'FP32 (32-bit)': 32,
    'FP16 (16-bit)': 16,
    'INT8 (8-bit)': 8,
    'INT4 (4-bit)': 4,
    'INT3 (3-bit)': 3,
    'INT2 (2-bit)': 2,
}

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

colors = plt.cm.viridis(np.linspace(0.2, 0.9, len(quantization_levels)))

# Plot 1: Memory by model size and quantization
x_pos = np.arange(len(model_sizes))
width = 0.13

for i, (qname, bits) in enumerate(quantization_levels.items()):
    mems = [model_memory_gb(s, bits) for s in model_sizes]
    ax1.bar(x_pos + i * width, mems, width, label=qname, color=colors[i], alpha=0.8)

ax1.set_xticks(x_pos + width * 2.5)
ax1.set_xticklabels([f'{s}B' for s in model_sizes])
ax1.set_ylabel('Memory (GB)')
ax1.set_xlabel('Model Size')
ax1.set_title('Model Memory by Quantization Level')
ax1.legend(fontsize=8)

# Add GPU memory lines
for gpu_mem, gpu_name, color in [(8, '8GB (RTX 3070)', 'red'),
                                   (24, '24GB (RTX 4090)', 'orange'),
                                   (80, '80GB (A100)', 'green')]:
    ax1.axhline(y=gpu_mem, color=color, linestyle=':', alpha=0.4)
    ax1.text(len(model_sizes) - 0.5, gpu_mem + 1, gpu_name, fontsize=8, color=color)

# Plot 2: Which models fit on which GPUs?
gpus = {'RTX 3070 (8GB)': 8, 'RTX 3090 (24GB)': 24, 'RTX 4090 (24GB)': 24,
        'A100 (40GB)': 40, 'A100 (80GB)': 80}

data = []
for gpu_name, gpu_mem in gpus.items():
    for model_size in model_sizes:
        for qname, bits in quantization_levels.items():
            mem = model_memory_gb(model_size, bits)
            if mem <= gpu_mem * 0.85:  # 85% utilization
                data.append({'gpu': gpu_name, 'model': f'{model_size}B', 'quant': qname, 'mem': mem})

# Create a heatmap: can this model fit on this GPU?
fit_matrix = np.zeros((len(model_sizes), len(quantization_levels)))
gpu_mem = 24  # RTX 4090

for i, model_size in enumerate(model_sizes):
    for j, (qname, bits) in enumerate(quantization_levels.items()):
        mem = model_memory_gb(model_size, bits)
        if mem <= gpu_mem * 0.85:
            fit_matrix[i, j] = 1  # Fits
        elif mem <= gpu_mem:
            fit_matrix[i, j] = 0.5  # Tight fit
        else:
            fit_matrix[i, j] = 0  # Doesn't fit

im = ax2.imshow(fit_matrix, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)
ax2.set_xticks(range(len(quantization_levels)))
ax2.set_xticklabels(quantization_levels.keys(), rotation=45, ha='right', fontsize=9)
ax2.set_yticks(range(len(model_sizes)))
ax2.set_yticklabels([f'{s}B' for s in model_sizes])
ax2.set_title(f'Can It Fit? (RTX 4090, 24GB)')
ax2.set_xlabel('Quantization')
ax2.set_ylabel('Model Size')

# Add memory values to cells
for i in range(len(model_sizes)):
    for j, (qname, bits) in enumerate(quantization_levels.items()):
        mem = model_memory_gb(model_sizes[i], bits)
        color = 'white' if fit_matrix[i, j] < 0.5 else 'black'
        ax2.text(j, i, f'{mem:.1f}GB', ha='center', va='center', fontsize=7, color=color)

plt.tight_layout()
plt.show()

## 10. Quantization Error vs Number of Bits

Let's systematically measure how quantization error scales with the number of bits.

In [None]:
# Sweep across bit widths
torch.manual_seed(42)
weights = torch.randn(10000) * 0.5  # Typical weight distribution

bit_widths = [2, 3, 4, 5, 6, 7, 8, 10, 12, 16]
mse_values = []
max_errors = []
snr_values = []  # Signal-to-noise ratio

for bits in bit_widths:
    qmax = 2 ** (bits - 1) - 1
    scale = weights.abs().max() / qmax
    quantized = torch.round(weights / scale).clamp(-qmax, qmax)
    dequantized = quantized * scale
    
    error = dequantized - weights
    mse = (error ** 2).mean().item()
    max_err = error.abs().max().item()
    signal_power = (weights ** 2).mean().item()
    snr = 10 * np.log10(signal_power / mse) if mse > 0 else float('inf')
    
    mse_values.append(mse)
    max_errors.append(max_err)
    snr_values.append(snr)

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))

# MSE vs bits
ax1.semilogy(bit_widths, mse_values, 'b-o', linewidth=2, markersize=8)
ax1.set_xlabel('Bit Width')
ax1.set_ylabel('MSE (log scale)')
ax1.set_title('Mean Squared Error vs Bit Width')
ax1.grid(True, alpha=0.3)

# SNR vs bits
ax2.plot(bit_widths, snr_values, 'g-o', linewidth=2, markersize=8)
ax2.set_xlabel('Bit Width')
ax2.set_ylabel('SNR (dB)')
ax2.set_title('Signal-to-Noise Ratio vs Bit Width')
ax2.grid(True, alpha=0.3)

# Theoretical: each additional bit should give ~6dB SNR improvement
ax2.plot(bit_widths, [6.02 * b - 1.76 for b in bit_widths], 'r--', alpha=0.5, label='Theoretical (6 dB/bit)')
ax2.legend()

# Memory savings vs quality loss
memory_savings = [32 / b for b in bit_widths]  # Compression ratio vs FP32
ax3.scatter(memory_savings, snr_values, c=bit_widths, cmap='viridis', s=100, zorder=5)
for b, ms, snr in zip(bit_widths, memory_savings, snr_values):
    ax3.annotate(f'{b}-bit', (ms, snr), textcoords='offset points',
                 xytext=(5, 5), fontsize=9)
ax3.set_xlabel('Memory Compression (x vs FP32)')
ax3.set_ylabel('SNR (dB) - higher is better')
ax3.set_title('The Quantization Tradeoff')
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nKey observations:")
print("- Each additional bit roughly doubles precision (halves MSE)")
print("- INT8 provides ~50 dB SNR - more than enough for most applications")
print("- INT4 is the practical lower limit for acceptable quality")
print("- Below 4 bits, quality degrades rapidly")

## 11. Visualizing the Outlier Problem

A key challenge in LLM quantization is **outliers** - a few weight values that are much larger than the rest. These outliers force the quantization scale to be large, wasting resolution on the majority of small values.

In [None]:
# Create weights with outliers (realistic for LLMs)
torch.manual_seed(42)
normal_weights = torch.randn(10000) * 0.02  # Most weights are small
# Add outliers (about 0.1% of weights are much larger)
outlier_mask = torch.rand(10000) < 0.001
outlier_weights = normal_weights.clone()
outlier_weights[outlier_mask] = torch.randn(outlier_mask.sum()) * 2.0  # 100x larger!

fig, axes = plt.subplots(2, 2, figsize=(16, 10))

# Plot 1: Distribution with outliers
ax = axes[0][0]
ax.hist(outlier_weights.numpy(), bins=200, color='steelblue', alpha=0.7)
ax.set_xlabel('Weight Value')
ax.set_ylabel('Count')
ax.set_title('Weight Distribution with Outliers')
outlier_vals = outlier_weights[outlier_mask]
for v in outlier_vals:
    ax.axvline(x=v.item(), color='red', alpha=0.5, linewidth=1)
ax.annotate(f'{outlier_mask.sum().item()} outliers', xy=(outlier_vals.max().item(), 0),
            xytext=(0.7, 0.9), textcoords='axes fraction',
            arrowprops=dict(arrowstyle='->', color='red'),
            fontsize=11, color='red')

# Plot 2: Quantization error WITH outliers
ax = axes[0][1]
for bits, color, label in [(8, 'blue', 'INT8'), (4, 'red', 'INT4')]:
    deq, scale = quantize_absmax(outlier_weights, n_bits=bits)
    deq = dequantize_absmax(deq, scale)
    errors = (deq - outlier_weights).abs().numpy()
    ax.scatter(outlier_weights.numpy(), errors, s=1, alpha=0.3, color=color, label=label)

ax.set_xlabel('Original Weight Value')
ax.set_ylabel('Absolute Error')
ax.set_title('Quantization Error (with outliers)')
ax.legend()

# Plot 3: Same analysis WITHOUT outliers
ax = axes[1][0]
ax.hist(normal_weights.numpy(), bins=200, color='steelblue', alpha=0.7)
ax.set_xlabel('Weight Value')
ax.set_ylabel('Count')
ax.set_title('Weight Distribution WITHOUT Outliers')

ax = axes[1][1]
for bits, color, label in [(8, 'blue', 'INT8'), (4, 'red', 'INT4')]:
    deq, scale = quantize_absmax(normal_weights, n_bits=bits)
    deq = dequantize_absmax(deq, scale)
    errors = (deq - normal_weights).abs().numpy()
    ax.scatter(normal_weights.numpy(), errors, s=1, alpha=0.3, color=color, label=label)

ax.set_xlabel('Original Weight Value')
ax.set_ylabel('Absolute Error')
ax.set_title('Quantization Error (without outliers)')
ax.legend()

plt.suptitle('The Outlier Problem in Quantization', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

# Compare MSE
for name, w in [('With outliers', outlier_weights), ('Without outliers', normal_weights)]:
    for bits in [8, 4]:
        q, s = quantize_absmax(w, n_bits=bits)
        deq = dequantize_absmax(q, s)
        mse = ((deq - w) ** 2).mean().item()
        print(f"{name:>20s}, INT{bits}: MSE = {mse:.8f}")

## 12. Key Takeaways

### Format Summary

| Format | Bits | Best For | Range | Precision | Memory/Param |
|--------|------|----------|-------|-----------|-------------|
| FP32 | 32 | Training (gold standard) | Huge | Very high | 4 bytes |
| FP16 | 16 | Inference, fine-tuning | 65,504 | Good | 2 bytes |
| BF16 | 16 | Training, large activations | Same as FP32 | Moderate | 2 bytes |
| FP8 E4M3 | 8 | Forward pass | 448 | Low | 1 byte |
| FP8 E5M2 | 8 | Backward pass (gradients) | 57,344 | Very low | 1 byte |
| INT8 | 8 | Weight quantization | 127 | Uniform | 1 byte |
| INT4 | 4 | Aggressive quantization | 7 | Minimal | 0.5 bytes |

### Practical Guidelines

1. **FP16/BF16** is the default for inference - 2x smaller than FP32 with negligible quality loss
2. **INT8** quantization works well for most models - 4x compression with minimal quality loss
3. **INT4** is aggressive but enables running 70B models on consumer GPUs
4. **Outliers are the enemy** - they waste quantization resolution
5. **Per-channel > per-tensor** quantization for better accuracy
6. **FP8** is emerging as the standard for training (H100 GPUs have native FP8 support)

---

## Exercises

### Exercise 1: Implement Block-wise Quantization
Instead of per-tensor or per-channel, quantize in blocks of 128 values. Each block gets its own scale.

In [None]:
def quantize_blockwise(tensor, block_size=128, n_bits=4):
    """Block-wise quantization.
    
    TODO: Implement quantization where each block of block_size
    values gets its own scale factor.
    
    Hint: Reshape the tensor into blocks, compute per-block scales,
    quantize each block, then reshape back.
    """
    pass

# Test: compare blockwise INT4 vs per-tensor INT4

### Exercise 2: Outlier-Aware Quantization
Implement a quantization scheme that handles outliers specially: keep outlier values in FP16 and quantize the rest to INT4.

In [None]:
def mixed_precision_quantize(tensor, outlier_threshold=3.0, n_bits=4):
    """Mixed-precision quantization: outliers in FP16, rest in INT4.
    
    TODO: Implement this
    1. Identify outliers (values > threshold * std)
    2. Store outliers as FP16 with their indices
    3. Quantize the rest to INT4
    4. Calculate total memory and compare to pure INT4
    """
    pass

### Exercise 3: Format Conversion Roundtrip
Implement a chain: FP32 -> FP16 -> FP8 -> INT8 -> FP32. Measure the accumulated error at each step.

In [None]:
# TODO: Implement the format conversion chain
# Track and plot error accumulation at each conversion step

---

**Next up: Notebook 09 - Post-Training Quantization** where we'll apply these quantization concepts to actual model weights and measure the impact on model quality.