# Lab 3.2.3: FP8 Training and Inference

**Module:** 3.2 - Model Quantization & Optimization  
**Time:** 2 hours  
**Difficulty:** ‚≠ê‚≠ê‚≠ê‚òÜ‚òÜ

---

## üéØ Learning Objectives

By the end of this notebook, you will:
- [ ] Understand FP8 E4M3 (inference) vs E5M2 (training) formats
- [ ] Train a model using FP8 precision with Transformer Engine
- [ ] Compare FP8 vs FP16/BF16 training speed and quality
- [ ] Apply FP8 for inference optimization

---

## üìö Prerequisites

- Completed: Lab 3.2.1 and 3.2.2
- Hardware: DGX Spark (Blackwell for native FP8) or Hopper GPU
- Software: NVIDIA Transformer Engine

---

## üåç Real-World Context

**The Challenge:** Training large models is expensive and slow!
- FP32 training: Maximum precision, but 4 bytes per parameter
- FP16/BF16: 2√ó faster, but can have numerical stability issues

**The Solution:** FP8 training gives you:
- **2√ó compute throughput** vs FP16 on Tensor Cores
- **2√ó memory efficiency** for activations
- **Native Blackwell support** - no emulation overhead!
- **<1% quality loss** with proper scaling

---

## üßí ELI5: Why Two FP8 Formats?

> **Think of it like camera settings for different lighting...**
>
> **E4M3 (4 exponent, 3 mantissa)** - like your indoor camera mode:
> - Higher precision (3 mantissa bits = more detail)
> - Smaller range (¬±448)
> - Best for: **weights and activations during inference**
>
> **E5M2 (5 exponent, 2 mantissa)** - like your outdoor camera mode:
> - Lower precision (2 mantissa bits)
> - Larger range (¬±57344) to handle bright/dark extremes
> - Best for: **gradients during training** (can be very large or small)
>
> **In AI terms:**
> - Inference uses E4M3 (precision matters more)
> - Training gradients use E5M2 (range matters more)
> - Blackwell tensor cores support BOTH natively!

---

## Part 1: Understanding FP8 Formats

Let's explore the two FP8 formats in detail.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
import gc

print("=" * 60)
print("FP8 Training Lab - Environment Check")
print("=" * 60)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    cc = torch.cuda.get_device_capability()
    print(f"Compute Capability: {cc[0]}.{cc[1]}")
    
    # Check FP8 support
    if cc[0] >= 10:
        print("\nNative FP8 support: Yes (Blackwell)")
    elif cc[0] >= 9:
        print("\nNative FP8 support: Yes (Hopper)")
    else:
        print("\nNative FP8 support: No (emulation only)")

# Check for Transformer Engine
try:
    import transformer_engine.pytorch as te
    from transformer_engine.common.recipe import Format, DelayedScaling
    HAS_TE = True
    print(f"Transformer Engine: Available")
except ImportError:
    HAS_TE = False
    print(f"Transformer Engine: Not installed")
    print("  Install with: pip install transformer-engine")

print("=" * 60)

In [None]:
# FP8 format specifications
# E4M3: 1 sign + 4 exponent + 3 mantissa = 8 bits
# E5M2: 1 sign + 5 exponent + 2 mantissa = 8 bits

class FP8Format:
    def __init__(self, name, exp_bits, mant_bits):
        self.name = name
        self.exp_bits = exp_bits
        self.mant_bits = mant_bits
        self.bias = 2 ** (exp_bits - 1) - 1
        
        # Calculate range
        max_exp = 2 ** exp_bits - 2  # Exclude special values
        max_mant = (2 ** mant_bits - 1) / 2 ** mant_bits
        self.max_value = (1 + max_mant) * 2 ** (max_exp - self.bias)
        self.min_positive = 2 ** (1 - self.bias - mant_bits)  # Smallest subnormal

E4M3 = FP8Format("E4M3", 4, 3)
E5M2 = FP8Format("E5M2", 5, 2)

# Comparison table
print("FP8 Format Comparison")
print("=" * 70)
print(f"{'Property':<25} {'E4M3 (Inference)':>20} {'E5M2 (Training)':>20}")
print("-" * 70)
print(f"{'Exponent bits':<25} {E4M3.exp_bits:>20} {E5M2.exp_bits:>20}")
print(f"{'Mantissa bits':<25} {E4M3.mant_bits:>20} {E5M2.mant_bits:>20}")
print(f"{'Exponent bias':<25} {E4M3.bias:>20} {E5M2.bias:>20}")
print(f"{'Max value':<25} {E4M3.max_value:>20.1f} {E5M2.max_value:>20.1f}")
print(f"{'Min positive':<25} {E4M3.min_positive:>20.2e} {E5M2.min_positive:>20.2e}")
print(f"{'Dynamic range':<25} {np.log10(E4M3.max_value/E4M3.min_positive):>18.1f}x {np.log10(E5M2.max_value/E5M2.min_positive):>18.1f}x")
print(f"{'Best for':<25} {'weights, activations':>20} {'gradients':>20}")
print("=" * 70)

In [None]:
# Visualize FP8 precision at different magnitudes

def simulate_fp8(value, fp8_format):
    """Simulate FP8 quantization of a value."""
    # Clamp to range
    clamped = np.clip(value, -fp8_format.max_value, fp8_format.max_value)
    
    # Simulate precision loss
    # In real FP8, this is done by the hardware
    multiplier = 2 ** fp8_format.mant_bits
    quantized = np.round(clamped * multiplier) / multiplier
    
    return quantized


# Test precision across magnitude range
test_values = np.logspace(-3, 2, 100)  # 0.001 to 100

e4m3_errors = []
e5m2_errors = []

for v in test_values:
    e4m3_q = simulate_fp8(v, E4M3)
    e5m2_q = simulate_fp8(v, E5M2)
    
    e4m3_errors.append(abs(v - e4m3_q) / v * 100)  # Relative error %
    e5m2_errors.append(abs(v - e5m2_q) / v * 100)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ax = axes[0]
ax.loglog(test_values, e4m3_errors, 'b-', label='E4M3', linewidth=2)
ax.loglog(test_values, e5m2_errors, 'r-', label='E5M2', linewidth=2)
ax.set_xlabel('Value Magnitude')
ax.set_ylabel('Relative Error (%)')
ax.set_title('FP8 Precision by Magnitude')
ax.legend()
ax.grid(True, alpha=0.3)
ax.axvline(x=1.0, color='gray', linestyle='--', alpha=0.5)
ax.text(1.1, ax.get_ylim()[1]*0.5, 'Typical\nweight\nrange', fontsize=9)

ax = axes[1]
# Show representable values near 1.0
e4m3_vals = [i/8 for i in range(1, 16)]  # 0.125 to 1.875 in steps of 0.125
e5m2_vals = [i/4 for i in range(1, 8)]   # 0.25 to 1.75 in steps of 0.25

ax.scatter(e4m3_vals, [1]*len(e4m3_vals), s=100, label='E4M3', marker='o')
ax.scatter(e5m2_vals, [0.5]*len(e5m2_vals), s=100, label='E5M2', marker='s')
ax.set_xlabel('Value')
ax.set_yticks([0.5, 1])
ax.set_yticklabels(['E5M2', 'E4M3'])
ax.set_title('Representable Values Near 1.0')
ax.set_xlim(0, 2)
ax.legend()
ax.grid(True, alpha=0.3, axis='x')

plt.tight_layout()
plt.show()

print("Key insight: E4M3 has 8 values in [1,2), E5M2 has only 4.")
print("This is why E4M3 is preferred for inference (more precision).")

---

## Part 2: FP8 Training with Transformer Engine

NVIDIA's Transformer Engine makes FP8 training seamless. Let's see how!

In [None]:
# Simple model for demonstration
import torch.nn as nn
import torch.nn.functional as F

class SimpleMLP(nn.Module):
    """Simple MLP for demonstrating FP8 training."""
    
    def __init__(self, input_dim=768, hidden_dim=3072, output_dim=768):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.gelu = nn.GELU()
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.fc2(x)
        return x


# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_bf16 = SimpleMLP().to(device).to(torch.bfloat16)

param_count = sum(p.numel() for p in model_bf16.parameters())
print(f"Model parameters: {param_count / 1e6:.1f}M")
print(f"Model dtype: {next(model_bf16.parameters()).dtype}")

In [None]:
# If Transformer Engine is available, create an FP8 version

if HAS_TE:
    class SimpleMLP_FP8(nn.Module):
        """FP8-enabled MLP using Transformer Engine."""
        
        def __init__(self, input_dim=768, hidden_dim=3072, output_dim=768):
            super().__init__()
            # Use TE's FP8-aware Linear layers
            self.fc1 = te.Linear(input_dim, hidden_dim)
            self.fc2 = te.Linear(hidden_dim, output_dim)
        
        def forward(self, x):
            x = self.fc1(x)
            x = F.gelu(x)
            x = self.fc2(x)
            return x
    
    # Create FP8 model
    model_fp8 = SimpleMLP_FP8().to(device)
    print(f"FP8 model created with Transformer Engine!")
    print(f"Parameters: {sum(p.numel() for p in model_fp8.parameters()) / 1e6:.1f}M")
else:
    print("Transformer Engine not available. Simulating FP8 behavior.")
    model_fp8 = None

In [None]:
# Training loop comparison: BF16 vs FP8

def train_step(model, optimizer, x, y, use_fp8=False):
    """Single training step."""
    optimizer.zero_grad()
    
    if use_fp8 and HAS_TE:
        # Use FP8 autocast context
        fp8_recipe = DelayedScaling(
            fp8_format=Format.HYBRID,  # E4M3 for forward, E5M2 for backward
            amax_history_len=16,
            amax_compute_algo="most_recent",
        )
        
        with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
            output = model(x)
            loss = F.mse_loss(output, y)
    else:
        output = model(x)
        loss = F.mse_loss(output, y)
    
    loss.backward()
    optimizer.step()
    
    return loss.item()


def benchmark_training(model, num_steps=100, batch_size=32, use_fp8=False):
    """Benchmark training speed."""
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    
    # Generate dummy data
    x = torch.randn(batch_size, 768, device=device, dtype=torch.bfloat16)
    y = torch.randn(batch_size, 768, device=device, dtype=torch.bfloat16)
    
    # Warmup
    for _ in range(10):
        _ = train_step(model, optimizer, x, y, use_fp8=use_fp8)
    
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    # Benchmark
    losses = []
    start = time.perf_counter()
    
    for i in range(num_steps):
        loss = train_step(model, optimizer, x, y, use_fp8=use_fp8)
        losses.append(loss)
    
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    elapsed = time.perf_counter() - start
    
    return {
        'time_per_step_ms': elapsed / num_steps * 1000,
        'steps_per_second': num_steps / elapsed,
        'final_loss': losses[-1],
        'avg_loss': sum(losses) / len(losses),
    }


# Benchmark BF16
print("Benchmarking BF16 training...")
bf16_results = benchmark_training(model_bf16, num_steps=100, use_fp8=False)

print(f"\nBF16 Results:")
print(f"  Time per step: {bf16_results['time_per_step_ms']:.2f} ms")
print(f"  Steps/second: {bf16_results['steps_per_second']:.1f}")
print(f"  Final loss: {bf16_results['final_loss']:.4f}")

In [None]:
# Benchmark FP8 if available

if model_fp8 is not None:
    print("Benchmarking FP8 training...")
    fp8_results = benchmark_training(model_fp8, num_steps=100, use_fp8=True)
    
    print(f"\nFP8 Results:")
    print(f"  Time per step: {fp8_results['time_per_step_ms']:.2f} ms")
    print(f"  Steps/second: {fp8_results['steps_per_second']:.1f}")
    print(f"  Final loss: {fp8_results['final_loss']:.4f}")
    
    # Comparison
    speedup = bf16_results['time_per_step_ms'] / fp8_results['time_per_step_ms']
    print(f"\n{'='*50}")
    print(f"Speedup: {speedup:.2f}x faster with FP8!")
    print(f"Loss difference: {abs(fp8_results['final_loss'] - bf16_results['final_loss']):.6f}")
    print(f"{'='*50}")
else:
    print("\nFP8 benchmark skipped (Transformer Engine not available).")
    print("\nExpected FP8 performance on Blackwell:")
    print("  - 1.5-2x speedup vs BF16")
    print("  - 2x reduced activation memory")
    print("  - <1% loss difference")

---

## Part 3: FP8 Inference Optimization

FP8 also shines for inference with E4M3 format.

In [None]:
# Convert model weights to FP8 for inference

def simulate_fp8_inference(model, x, fp8_format=E4M3):
    """
    Simulate FP8 inference by quantizing weights.
    
    In practice, this is done by TensorRT or the hardware.
    """
    model.eval()
    
    # Store original weights
    original_weights = {}
    
    with torch.no_grad():
        # Quantize each weight tensor
        for name, param in model.named_parameters():
            original_weights[name] = param.data.clone()
            
            # Simulate FP8 quantization
            max_val = param.abs().max()
            scale = max_val / fp8_format.max_value
            scale = max(scale.item(), 1e-10)
            
            # Quantize and dequantize
            quantized = torch.round(param / scale * (2 ** fp8_format.mant_bits)) / (2 ** fp8_format.mant_bits) * scale
            param.data = quantized
        
        # Run inference
        output = model(x)
        
        # Restore original weights
        for name, param in model.named_parameters():
            param.data = original_weights[name]
    
    return output


# Test FP8 vs BF16 inference
test_input = torch.randn(32, 768, device=device, dtype=torch.bfloat16)

model_bf16.eval()
with torch.no_grad():
    bf16_output = model_bf16(test_input)
    fp8_output = simulate_fp8_inference(model_bf16, test_input)

# Compare outputs
diff = (bf16_output - fp8_output).abs()
print("FP8 Inference Quality (Simulated)")
print("=" * 50)
print(f"Mean absolute difference: {diff.mean():.6f}")
print(f"Max absolute difference: {diff.max():.6f}")
print(f"Relative error: {(diff.mean() / bf16_output.abs().mean() * 100):.4f}%")

In [None]:
# Benchmark FP8 inference performance

def benchmark_inference(model, num_runs=100, batch_size=32, input_dim=768):
    """Benchmark inference speed."""
    model.eval()
    x = torch.randn(batch_size, input_dim, device=device, dtype=torch.bfloat16)
    
    # Warmup
    with torch.no_grad():
        for _ in range(10):
            _ = model(x)
    
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    # Benchmark
    start = time.perf_counter()
    with torch.no_grad():
        for _ in range(num_runs):
            _ = model(x)
    
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    elapsed = time.perf_counter() - start
    
    return {
        'time_per_inference_ms': elapsed / num_runs * 1000,
        'inferences_per_second': num_runs / elapsed,
        'throughput_samples_per_second': batch_size * num_runs / elapsed,
    }


print("Benchmarking inference...")
inf_results = benchmark_inference(model_bf16, num_runs=100)

print(f"\nBF16 Inference Results:")
print(f"  Time per inference: {inf_results['time_per_inference_ms']:.3f} ms")
print(f"  Inferences/second: {inf_results['inferences_per_second']:.1f}")
print(f"  Throughput: {inf_results['throughput_samples_per_second']:.0f} samples/s")

print("\nExpected FP8 inference on Blackwell:")
print(f"  Time per inference: ~{inf_results['time_per_inference_ms'] / 2:.3f} ms (2x faster)")
print(f"  Memory: 50% reduction in activation memory")

---

## Part 4: Delayed Scaling for Numerical Stability

FP8's limited range requires dynamic scaling. Let's understand how it works.

In [None]:
# Demonstrate delayed scaling

class DelayedScalingSimulator:
    """
    Simulate the delayed scaling algorithm used in FP8 training.
    
    The idea:
    1. Track the maximum absolute value (amax) over recent steps
    2. Compute scale factor based on amax history
    3. This allows FP8 to adapt to changing tensor magnitudes
    """
    
    def __init__(self, history_len=16, fp8_format=E4M3):
        self.history_len = history_len
        self.fp8_format = fp8_format
        self.amax_history = []
    
    def update_and_get_scale(self, tensor):
        """Update history and compute scale factor."""
        # Get current amax
        current_amax = tensor.abs().max().item()
        
        # Update history
        self.amax_history.append(current_amax)
        if len(self.amax_history) > self.history_len:
            self.amax_history.pop(0)
        
        # Compute scale from max of history
        amax = max(self.amax_history)
        scale = amax / self.fp8_format.max_value
        
        return max(scale, 1e-10)
    
    def quantize(self, tensor):
        """Quantize tensor using delayed scaling."""
        scale = self.update_and_get_scale(tensor)
        
        # Quantize
        scaled = tensor / scale
        clipped = torch.clamp(scaled, -self.fp8_format.max_value, self.fp8_format.max_value)
        quantized = torch.round(clipped * (2 ** self.fp8_format.mant_bits)) / (2 ** self.fp8_format.mant_bits)
        
        # Dequantize
        return quantized * scale, scale


# Simulate training with varying tensor magnitudes
scaler = DelayedScalingSimulator(history_len=8)

# Simulate gradients that change magnitude over time
print("Delayed Scaling Simulation")
print("=" * 60)
print(f"{'Step':<8} {'Tensor Max':>15} {'Scale':>15} {'Quantization Error':>20}")
print("-" * 60)

torch.manual_seed(42)
for step in range(16):
    # Simulate gradients with varying magnitude
    magnitude = 0.1 + step * 0.1  # Gradually increasing
    tensor = torch.randn(1000) * magnitude
    
    quantized, scale = scaler.quantize(tensor)
    error = (tensor - quantized).abs().mean().item()
    
    print(f"{step:<8} {tensor.abs().max().item():>15.4f} {scale:>15.6f} {error:>20.6f}")

print("\nNote: Scale adapts to track the increasing tensor magnitude!")

---

## ‚úã Try It Yourself

### Exercise 1: Custom FP8 Recipe

Experiment with different delayed scaling parameters:
- `history_len`: How many steps to track (try 4, 8, 16, 32)
- What happens with rapidly changing tensor magnitudes?

### Exercise 2: Compare E4M3 vs E5M2

Quantize the same gradients with both formats. Which has lower error?

In [None]:
# Helper function for quantizing tensors with specific FP8 formats
# ================================================================
#
# For the exercises below, you'll need to quantize tensors with different
# FP8 formats (E4M3 vs E5M2) and compare the results.

def quantize_tensor_fp8(tensor, fp8_format):
    """
    Quantize a tensor using a specific FP8 format.
    
    Args:
        tensor: PyTorch tensor to quantize
        fp8_format: FP8Format object (E4M3 or E5M2)
    
    Returns:
        tuple: (dequantized_tensor, quantization_error)
    """
    # Compute scale to fit tensor in FP8 range
    max_val = tensor.abs().max()
    scale = max_val / fp8_format.max_value
    scale = max(scale.item(), 1e-10)
    
    # Scale tensor to FP8 range
    scaled = tensor / scale
    
    # Clip to FP8 representable range
    clipped = torch.clamp(scaled, -fp8_format.max_value, fp8_format.max_value)
    
    # Simulate reduced mantissa precision (round to FP8 grid)
    multiplier = 2 ** fp8_format.mant_bits
    quantized = torch.round(clipped * multiplier) / multiplier
    
    # Scale back (dequantize)
    dequantized = quantized * scale
    
    # Calculate error
    error = (tensor - dequantized).abs()
    
    return dequantized, error


# Demonstrate with a sample tensor
torch.manual_seed(42)
sample_data = torch.randn(1000) * 5  # Some test data

# Quantize with both formats
deq_e4m3, err_e4m3 = quantize_tensor_fp8(sample_data, E4M3)
deq_e5m2, err_e5m2 = quantize_tensor_fp8(sample_data, E5M2)

print("quantize_tensor_fp8() Demo")
print("=" * 50)
print(f"Input tensor range: [{sample_data.min():.3f}, {sample_data.max():.3f}]")
print(f"\nE4M3 (3 mantissa bits - more precision):")
print(f"  Mean error: {err_e4m3.mean():.6f}")
print(f"  Max error:  {err_e4m3.max():.6f}")
print(f"\nE5M2 (2 mantissa bits - larger range):")
print(f"  Mean error: {err_e5m2.mean():.6f}")
print(f"  Max error:  {err_e5m2.max():.6f}")
print(f"\nE4M3 error is {err_e5m2.mean()/err_e4m3.mean():.1f}x lower (better for typical values)")

In [None]:
# Exercise 1: Experiment with history lengths

# TODO: Create scalers with different history lengths
# TODO: Simulate rapidly changing tensor magnitudes
# TODO: Compare quantization errors

# Your code here...

In [None]:
# Exercise 2: Compare E4M3 vs E5M2 for gradients

# Simulate gradients (can have large range)
torch.manual_seed(42)
gradients = torch.randn(10000) * 10  # Larger magnitude gradients

# TODO: Quantize with E4M3
# TODO: Quantize with E5M2
# TODO: Compare errors
# TODO: Which format is better for gradients? Why?

# Your code here...

---

## Common Mistakes

### Mistake 1: Using Wrong FP8 Format

```python
# Wrong: Using E5M2 for inference weights
weights = quantize_to_fp8(weights, format="E5M2")  # Less precision!

# Right: Use E4M3 for inference (more precision)
weights = quantize_to_fp8(weights, format="E4M3")
```

### Mistake 2: Forgetting to Scale Gradients

```python
# Wrong: Directly clipping gradients to FP8 range
grad_fp8 = grad.clamp(-448, 448)  # May lose large gradients!

# Right: Use delayed scaling
scale = amax_history.max() / 448
grad_fp8 = (grad / scale).clamp(-448, 448)
```

### Mistake 3: Not Using FP8 Autocast

```python
# Wrong: Manual FP8 conversion everywhere
x_fp8 = x.to(torch.float8_e4m3fn)
y = model(x_fp8).to(torch.bfloat16)

# Right: Use Transformer Engine's autocast
with te.fp8_autocast(enabled=True):
    y = model(x)  # Automatic conversion!
```

---

## Checkpoint

You've learned:

- **E4M3 vs E5M2**: Precision vs range trade-off
- **FP8 training**: 2√ó speedup with Transformer Engine
- **Delayed scaling**: How FP8 adapts to changing magnitudes
- **Blackwell advantage**: Native FP8 tensor cores!

---

## Further Reading

- [NVIDIA Transformer Engine](https://github.com/NVIDIA/TransformerEngine)
- [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433)
- [Mixed Precision Training](https://arxiv.org/abs/1710.03740)

---

## Cleanup

In [None]:
# Clean up
del model_bf16
if model_fp8 is not None:
    del model_fp8

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("Notebook complete! Ready for Lab 3.2.4: GPTQ Quantization")

---

## Next Steps

In the next notebook, we'll explore **GPTQ Quantization** - the most popular 4-bit quantization method for GPU inference!

‚û°Ô∏è Continue to: [Lab 3.2.4: GPTQ Quantization](lab-3.2.4-gptq-quantization.ipynb)