# Lab 3.2.1: Data Type Exploration

**Module:** 3.2 - Model Quantization & Optimization  
**Time:** 1.5 hours  
**Difficulty:** ‚≠ê‚≠ê‚òÜ‚òÜ‚òÜ

---

## üéØ Learning Objectives

By the end of this notebook, you will:
- [ ] Understand different numerical data types (FP32, FP16, BF16, INT8, INT4, FP8, FP4)
- [ ] Visualize precision loss at each quantization level
- [ ] Calculate memory savings for various model sizes
- [ ] Appreciate why DGX Spark's Blackwell architecture is special for quantization

---

## üìö Prerequisites

- Completed: Module 3.1 (LLM Fine-tuning)
- Knowledge of: PyTorch basics, neural network fundamentals
- Hardware: DGX Spark with 128GB unified memory (ideal) or any CUDA GPU

---

## üåç Real-World Context

**The Problem:** You've built an amazing 70B parameter model, but deploying it is incredibly expensive!

Consider the math:
- **70B parameters √ó 2 bytes (FP16) = 140GB** just for weights
- Plus activations, KV cache, framework overhead...
- That's $25,000+/month for a cloud GPU instance!

**The Solution:** Quantization reduces memory by 2-4√ó with minimal quality loss.

| Company | Use Case | Quantization Win |
|---------|----------|------------------|
| Google | On-device Gemini Nano | 4-bit enables running on phones |
| Meta | Llama deployment | INT8 halves serving costs |
| Apple | CoreML models | 16‚Üí4 bit for Neural Engine |
| **You** | DGX Spark | NVFP4 gives 3.5√ó compression! |

---

## üßí ELI5: What is Quantization?

> **Imagine you're taking notes in class...**
>
> You could write down every single word the teacher says (FP32 - full precision).  
> That's accurate, but your notebook fills up fast and your hand gets tired!
>
> Instead, you could:
> - Write only the key points (FP16 - half precision)
> - Use abbreviations like "b/c" for "because" (INT8 - 8-bit integers)
> - Just draw simple diagrams (INT4 - 4-bit integers)
>
> Each shorthand:
> - Uses less notebook space (memory)
> - Lets you write faster (compute)
> - Might lose some details (precision)
>
> **In AI terms:** Quantization is using fewer bits to store each number in a neural network.  
> Just like your notes, smaller bits = smaller models = faster inference, but potentially less accurate.

---

## üßí ELI5: What's Special About FP4?

> **Think of it like JPEG compression for photos...**
>
> Regular integer quantization (INT4) is like converting a color photo to just 16 shades of gray.  
> It works, but important details can get lost.
>
> FP4 (floating-point 4-bit) is smarter! It's like JPEG - it keeps more detail where it matters  
> and simplifies where you won't notice. The result looks almost like the original.
>
> **The catch?** FP4 needs special hardware to work fast.  
> **The good news?** DGX Spark's Blackwell GPU has native FP4 support - you have that hardware!

---

## Part 1: Understanding Data Types

Neural networks are just massive collections of numbers (weights). How we store those numbers matters!

### The Number Line Analogy

Imagine a ruler that can measure from -1000 to +1000:

| Type | Bits | Tick Marks | Precision | Use Case |
|------|------|------------|-----------|----------|
| FP32 | 32 | 4 billion | Highest | Training (legacy) |
| FP16 | 16 | 65,536 | High | Training/Inference |
| BF16 | 16 | 256 (wider range) | Medium | Training (preferred) |
| FP8 | 8 | ~256 | Medium | Inference (Blackwell) |
| INT8 | 8 | 256 | Medium | Inference |
| INT4 | 4 | 16 | Low | Inference (quantized) |
| FP4 | 4 | 16 (smarter) | Medium-Low | Inference (Blackwell!) |

Let's explore each one!

In [None]:
# First, let's check our environment
import torch
import numpy as np
import matplotlib.pyplot as plt
import struct

print("=" * 60)
print("DGX Spark Environment Check")
print("=" * 60)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    props = torch.cuda.get_device_properties(0)
    print(f"Total Memory: {props.total_memory / 1e9:.1f} GB")
    cc = torch.cuda.get_device_capability()
    print(f"Compute Capability: {cc[0]}.{cc[1]}")
    
    # Check for special features
    print(f"\nBF16 supported: {torch.cuda.is_bf16_supported()}")
    
    if cc[0] >= 10:
        print("FP8 native: Yes (Blackwell!)")
        print("FP4 native: Yes (Blackwell exclusive!)")
    elif cc[0] >= 9:
        print("FP8 native: Yes (Hopper)")
        print("FP4 native: No (emulation only)")
    else:
        print("FP8 native: No (emulation only)")
        print("FP4 native: No (emulation only)")

print("=" * 60)

### Understanding Floating-Point Representation

A floating-point number is stored as: **(-1)^sign √ó mantissa √ó 2^exponent**

```
FP32: [1 sign][8 exponent][23 mantissa] = 32 bits
FP16: [1 sign][5 exponent][10 mantissa] = 16 bits
BF16: [1 sign][8 exponent][ 7 mantissa] = 16 bits (Google's format!)
FP8:  [1 sign][4 exponent][ 3 mantissa] = 8 bits (E4M3) or
      [1 sign][5 exponent][ 2 mantissa] = 8 bits (E5M2)
FP4:  [1 sign][1 exponent][ 2 mantissa] = 4 bits (with scaling)
```

Let's visualize what each format can represent!

In [None]:
# Visualize the precision of different data types

def analyze_dtype_precision(dtype_name, dtype):
    """Analyze precision of a PyTorch dtype."""
    # Create test value
    pi = 3.141592653589793
    tensor = torch.tensor([pi], dtype=dtype)
    stored = tensor.item()
    
    # Calculate error
    error = abs(pi - stored)
    relative_error = error / pi * 100
    
    # Get info
    info = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)
    
    return {
        'name': dtype_name,
        'bits': tensor.element_size() * 8,
        'stored_pi': stored,
        'error': error,
        'relative_error': relative_error,
        'min': info.min if hasattr(info, 'min') else info.min,
        'max': info.max,
    }

# Test common dtypes
dtypes = [
    ('FP64 (double)', torch.float64),
    ('FP32 (float)', torch.float32),
    ('FP16 (half)', torch.float16),
    ('BF16 (bfloat)', torch.bfloat16),
]

print("How different precisions store œÄ = 3.141592653589793")
print("=" * 70)
print(f"{'Type':<15} {'Bits':>6} {'Stored Value':>20} {'Error':>15}")
print("-" * 70)

for name, dtype in dtypes:
    result = analyze_dtype_precision(name, dtype)
    print(f"{result['name']:<15} {result['bits']:>6} {result['stored_pi']:>20.15f} {result['error']:>15.2e}")

print("=" * 70)

In [None]:
# Let's visualize precision loss with a distribution of random values

def compare_precision(original, dtype):
    """Compare original FP32 tensor with converted version."""
    converted = original.to(dtype).to(torch.float32)  # Round trip
    error = (original - converted).abs()
    return error

# Create random weights (like from a neural network)
torch.manual_seed(42)
original_weights = torch.randn(10000)  # 10k weights

# Compare precision for each dtype
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
fig.suptitle('Precision Loss by Data Type (Error Distribution)', fontsize=14)

dtypes_to_compare = [
    ('FP16', torch.float16),
    ('BF16', torch.bfloat16),
]

for idx, (name, dtype) in enumerate(dtypes_to_compare):
    ax = axes[idx // 2, idx % 2]
    errors = compare_precision(original_weights, dtype)
    
    ax.hist(errors.numpy(), bins=50, alpha=0.7, color='steelblue', edgecolor='black')
    ax.axvline(errors.mean(), color='red', linestyle='--', label=f'Mean: {errors.mean():.2e}')
    ax.set_xlabel('Absolute Error')
    ax.set_ylabel('Count')
    ax.set_title(f'{name} Precision Loss')
    ax.legend()
    ax.set_yscale('log')

# Add summary statistics
ax = axes[1, 0]
ax.clear()
ax.axis('off')

summary_text = "Summary Statistics\n" + "="*30 + "\n\n"
for name, dtype in dtypes_to_compare:
    errors = compare_precision(original_weights, dtype)
    summary_text += f"{name}:\n"
    summary_text += f"  Mean Error: {errors.mean():.2e}\n"
    summary_text += f"  Max Error:  {errors.max():.2e}\n"
    summary_text += f"  Std Dev:    {errors.std():.2e}\n\n"

ax.text(0.1, 0.9, summary_text, transform=ax.transAxes, fontsize=11,
        verticalalignment='top', fontfamily='monospace')

# Memory comparison
ax = axes[1, 1]
ax.clear()
sizes = ['FP32', 'FP16/BF16', 'INT8/FP8', 'INT4/FP4']
memory = [4, 2, 1, 0.5]
colors = ['#ff6b6b', '#feca57', '#48dbfb', '#1dd1a1']

bars = ax.bar(sizes, memory, color=colors, edgecolor='black')
ax.set_ylabel('Bytes per Parameter')
ax.set_title('Memory per Parameter')

for bar, mem in zip(bars, memory):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
            f'{mem}B', ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

### üîç What Just Happened?

We compared how accurately different data types can represent the same values:

1. **FP16** has good precision but limited range (can overflow with large values)
2. **BF16** has the same range as FP32 but less precision (great for training!)
3. Both are **2√ó smaller** than FP32, which means **2√ó more model capacity**!

---

## Part 2: Integer Quantization (INT8 and INT4)

Now let's explore integer quantization - the foundation of most model compression.

### The Quantization Formula

**Symmetric Quantization:**
```
quantized = round(float_value / scale)
scale = max(|weights|) / 127  (for INT8)
```

**Asymmetric Quantization:**
```
quantized = round(float_value / scale) + zero_point
scale = (max - min) / 255
zero_point = round(-min / scale)
```

In [None]:
# Implement and visualize INT8 and INT4 quantization

def symmetric_quantize(tensor, bits):
    """Symmetric quantization to specified bit-width."""
    qmax = 2 ** (bits - 1) - 1
    qmin = -2 ** (bits - 1)
    
    scale = tensor.abs().max() / qmax
    scale = max(scale, 1e-10)  # Avoid division by zero
    
    quantized = torch.round(tensor / scale).clamp(qmin, qmax)
    dequantized = quantized * scale
    
    return quantized, scale, dequantized


# Test with sample weights
torch.manual_seed(42)
weights = torch.randn(16)  # 16 sample weights for visibility

print("Original Weights (FP32):")
print(weights.numpy().round(4))
print()

# Quantize to different bit-widths
for bits in [8, 4]:
    q, scale, deq = symmetric_quantize(weights, bits)
    error = (weights - deq).abs()
    
    print(f"INT{bits} Quantization:")
    print(f"  Scale: {scale:.6f}")
    print(f"  Quantized values: {q.int().numpy()}")
    print(f"  Dequantized: {deq.numpy().round(4)}")
    print(f"  Mean Error: {error.mean():.6f}")
    print(f"  Max Error: {error.max():.6f}")
    print()

In [None]:
# Visualize the quantization effect

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Original weights
ax = axes[0]
x = np.arange(len(weights))
ax.bar(x, weights.numpy(), alpha=0.7, label='Original', color='steelblue')
ax.set_xlabel('Weight Index')
ax.set_ylabel('Value')
ax.set_title('Original FP32 Weights')
ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)

# INT8
ax = axes[1]
q8, s8, deq8 = symmetric_quantize(weights, 8)
ax.bar(x - 0.2, weights.numpy(), width=0.4, alpha=0.7, label='Original', color='steelblue')
ax.bar(x + 0.2, deq8.numpy(), width=0.4, alpha=0.7, label='INT8', color='coral')
ax.set_xlabel('Weight Index')
ax.set_ylabel('Value')
ax.set_title('INT8 Quantization (256 levels)')
ax.legend()

# INT4
ax = axes[2]
q4, s4, deq4 = symmetric_quantize(weights, 4)
ax.bar(x - 0.2, weights.numpy(), width=0.4, alpha=0.7, label='Original', color='steelblue')
ax.bar(x + 0.2, deq4.numpy(), width=0.4, alpha=0.7, label='INT4', color='gold')
ax.set_xlabel('Weight Index')
ax.set_ylabel('Value')
ax.set_title('INT4 Quantization (16 levels)')
ax.legend()

plt.tight_layout()
plt.show()

print(f"\nQuantization Summary:")
print(f"  INT8: 256 possible values, error: {(weights - deq8).abs().mean():.6f}")
print(f"  INT4: 16 possible values, error:  {(weights - deq4).abs().mean():.6f}")
print(f"  INT4 has {(weights - deq4).abs().mean() / (weights - deq8).abs().mean():.1f}x more error than INT8")

---

## Part 3: FP8 - The Blackwell/Hopper Sweet Spot

FP8 is a floating-point format with 8 bits. There are two variants:

| Format | Exponent | Mantissa | Best For | Range |
|--------|----------|----------|----------|-------|
| E4M3 | 4 bits | 3 bits | **Inference** | ¬±448 |
| E5M2 | 5 bits | 2 bits | **Training** | ¬±57344 |

E4M3 has more precision (3 mantissa bits), while E5M2 has larger range (5 exponent bits).

In [None]:
# Simulate FP8 quantization

class FP8Format:
    """Configuration for FP8 format."""
    def __init__(self, name, exp_bits, mant_bits):
        self.name = name
        self.exp_bits = exp_bits
        self.mant_bits = mant_bits
        self.bias = 2 ** (exp_bits - 1) - 1
        
        # Calculate max representable value
        max_exp = 2 ** exp_bits - 2  # Exclude infinity
        max_mant = (2 ** mant_bits - 1) / 2 ** mant_bits
        self.max_value = (1 + max_mant) * 2 ** (max_exp - self.bias)

E4M3 = FP8Format("E4M3", 4, 3)  # Inference
E5M2 = FP8Format("E5M2", 5, 2)  # Training

def quantize_to_fp8(tensor, fp8_format):
    """Simulate FP8 quantization."""
    # Compute scaling factor to fit in FP8 range
    max_val = tensor.abs().max()
    scale = max_val / fp8_format.max_value
    scale = max(scale.item(), 1e-10)
    
    # Scale to FP8 range
    scaled = tensor / scale
    
    # Clip to range
    clipped = torch.clamp(scaled, -fp8_format.max_value, fp8_format.max_value)
    
    # Simulate reduced precision
    mantissa_mult = 2 ** fp8_format.mant_bits
    quantized = torch.round(clipped * mantissa_mult) / mantissa_mult
    
    # Scale back
    dequantized = quantized * scale
    
    return quantized, scale, dequantized


# Compare FP8 formats
print("FP8 Format Comparison")
print("=" * 50)
print(f"{'Format':<10} {'Exp Bits':>10} {'Mant Bits':>10} {'Max Value':>12}")
print("-" * 50)
print(f"{'E4M3':<10} {4:>10} {3:>10} {E4M3.max_value:>12.1f}")
print(f"{'E5M2':<10} {5:>10} {2:>10} {E5M2.max_value:>12.1f}")
print("=" * 50)

# Test with weights
print("\nQuantization Comparison:")
torch.manual_seed(42)
test_weights = torch.randn(10000)

for fp8_format in [E4M3, E5M2]:
    _, scale, deq = quantize_to_fp8(test_weights, fp8_format)
    error = (test_weights - deq).abs()
    print(f"  {fp8_format.name}: Mean Error = {error.mean():.6f}, Max Error = {error.max():.6f}")

---

## Part 4: FP4 - The Blackwell Exclusive

FP4 (NVFP4) is NVIDIA's 4-bit floating-point format, **exclusive to Blackwell GPUs**.

### What Makes FP4 Special?

Unlike INT4 which uses linear quantization, FP4 uses:
1. **Micro-block scaling**: Each small block of weights has its own scale
2. **Dual-level scaling**: Both per-tensor and per-block scales
3. **Non-linear values**: 8 positive values (0, 0.5, 1, 1.5, 2, 3, 4, 6)

This results in **3.5√ó memory reduction** with **<0.1% accuracy loss** on benchmarks like MMLU!

In [None]:
# Simulate NVFP4 quantization

# FP4 representable values (normalized)
FP4_VALUES = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0])

def quantize_to_fp4(tensor, block_size=16):
    """
    Simulate NVFP4 quantization with micro-block scaling.
    
    Args:
        tensor: Input tensor
        block_size: Size of micro-blocks for fine-grained scaling
    """
    original_shape = tensor.shape
    flat = tensor.flatten()
    
    # Pad to multiple of block_size
    n = flat.numel()
    pad = (block_size - n % block_size) % block_size
    if pad > 0:
        flat = torch.nn.functional.pad(flat, (0, pad))
    
    # Reshape into blocks
    n_blocks = flat.numel() // block_size
    blocks = flat.view(n_blocks, block_size)
    
    # Tensor-level scale (coarse)
    tensor_scale = blocks.abs().max()
    blocks_normalized = blocks / max(tensor_scale.item(), 1e-10)
    
    # Block-level scales (fine)
    block_max = blocks_normalized.abs().amax(dim=1, keepdim=True)
    block_scales = block_max / 6.0  # 6.0 is max FP4 value
    block_scales = torch.clamp(block_scales, min=1e-10)
    
    # Normalize by block scale
    normalized = blocks_normalized / block_scales
    
    # Quantize to nearest FP4 value
    signs = torch.sign(normalized)
    abs_vals = normalized.abs()
    
    # Find nearest FP4 value
    distances = (abs_vals.unsqueeze(-1) - FP4_VALUES.unsqueeze(0).unsqueeze(0)).abs()
    indices = distances.argmin(dim=-1)
    quantized = signs * FP4_VALUES[indices]
    
    # Dequantize
    dequantized = quantized * block_scales * tensor_scale
    dequantized = dequantized.flatten()[:n].view(original_shape)
    
    return quantized, block_scales, tensor_scale, dequantized


# Compare FP4 with INT4
print("FP4 vs INT4 Comparison")
print("=" * 60)

torch.manual_seed(42)
weights = torch.randn(1024)  # 1024 weights

# INT4
_, _, deq_int4 = symmetric_quantize(weights, 4)
error_int4 = (weights - deq_int4).abs()

# FP4
_, _, _, deq_fp4 = quantize_to_fp4(weights)
error_fp4 = (weights - deq_fp4).abs()

print(f"{'Method':<15} {'Mean Error':>15} {'Max Error':>15} {'RMSE':>15}")
print("-" * 60)
print(f"{'INT4':<15} {error_int4.mean():>15.6f} {error_int4.max():>15.6f} {error_int4.pow(2).mean().sqrt():>15.6f}")
print(f"{'FP4 (NVFP4)':<15} {error_fp4.mean():>15.6f} {error_fp4.max():>15.6f} {error_fp4.pow(2).mean().sqrt():>15.6f}")
print("=" * 60)

improvement = (error_int4.mean() - error_fp4.mean()) / error_int4.mean() * 100
print(f"\nFP4 reduces mean error by {improvement:.1f}% compared to INT4!")

In [None]:
# Visualize the FP4 vs INT4 difference

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Error distributions
ax = axes[0]
ax.hist(error_int4.numpy(), bins=50, alpha=0.7, label='INT4', color='coral')
ax.hist(error_fp4.numpy(), bins=50, alpha=0.7, label='FP4', color='steelblue')
ax.set_xlabel('Absolute Error')
ax.set_ylabel('Count')
ax.set_title('Error Distribution: INT4 vs FP4')
ax.legend()

# FP4 representable values
ax = axes[1]
all_fp4 = torch.cat([-FP4_VALUES.flip(0)[:-1], FP4_VALUES])
ax.stem(all_fp4.numpy(), [1]*len(all_fp4), linefmt='steelblue', markerfmt='o')
ax.set_xlabel('Normalized Value')
ax.set_ylabel('Representable')
ax.set_title('FP4 Representable Values (16 total)')
ax.set_ylim(0, 1.5)

# INT4 vs FP4 coverage
ax = axes[2]
int4_vals = torch.linspace(-7, 7, 15)  # -7 to 7 for INT4
int4_normalized = int4_vals / 7  # Normalized to [-1, 1]

ax.stem(int4_normalized.numpy() * 6, [0.8]*len(int4_vals), linefmt='coral', 
        markerfmt='s', label='INT4 (linear)')
ax.stem(all_fp4.numpy(), [1.2]*len(all_fp4), linefmt='steelblue', 
        markerfmt='o', label='FP4 (non-linear)')
ax.set_xlabel('Value (scaled)')
ax.set_title('INT4 vs FP4 Value Distribution')
ax.legend()
ax.set_ylim(0, 1.5)

plt.tight_layout()
plt.show()

print("\nKey Insight: FP4's non-linear values cluster near zero where most weights are!")

---

## Part 5: Memory Savings Calculator

Let's calculate exactly how much memory you save with each quantization method.

In [None]:
# Memory calculator for model sizes

def calculate_model_memory(params_billions, precision):
    """
    Calculate memory for a model in GB.
    
    Args:
        params_billions: Number of parameters in billions
        precision: 'fp32', 'fp16', 'bf16', 'int8', 'fp8', 'int4', 'fp4'
    """
    bytes_per_param = {
        'fp32': 4,
        'fp16': 2,
        'bf16': 2,
        'int8': 1,
        'fp8': 1,
        'int4': 0.5,
        'fp4': 0.5,
        'nvfp4': 0.5,
    }
    
    bpp = bytes_per_param.get(precision.lower(), 2)
    return params_billions * 1e9 * bpp / 1e9  # GB


# Calculate for common model sizes
model_sizes = [1, 3, 7, 13, 34, 70, 100, 200]
precisions = ['FP32', 'FP16', 'INT8', 'INT4']

print("Model Memory Requirements (GB)")
print("=" * 70)
print(f"{'Model':<10}", end="")
for p in precisions:
    print(f"{p:>12}", end="")
print(f"{'Fits Spark?':>14}")
print("-" * 70)

for size in model_sizes:
    print(f"{size}B{'':<7}", end="")
    for p in precisions:
        mem = calculate_model_memory(size, p)
        print(f"{mem:>12.1f}", end="")
    
    # Check if fits in DGX Spark (128GB)
    fp4_mem = calculate_model_memory(size, 'fp4')
    fits = " Yes (FP4)" if fp4_mem < 120 else " No"
    print(f"{fits:>14}")

print("=" * 70)
print("\nDGX Spark has 128GB unified memory.")
print("With FP4/INT4, you can run models up to ~200B parameters!")

In [None]:
# Visualize the memory savings

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Memory by model size
ax = axes[0]
x = np.arange(len(model_sizes))
width = 0.2

colors = ['#ff6b6b', '#feca57', '#48dbfb', '#1dd1a1']
for i, p in enumerate(precisions):
    memories = [calculate_model_memory(s, p) for s in model_sizes]
    ax.bar(x + i*width, memories, width, label=p, color=colors[i])

ax.axhline(y=128, color='red', linestyle='--', linewidth=2, label='DGX Spark (128GB)')
ax.set_xlabel('Model Size (Billions)')
ax.set_ylabel('Memory (GB)')
ax.set_title('Model Memory by Precision')
ax.set_xticks(x + width * 1.5)
ax.set_xticklabels([f'{s}B' for s in model_sizes])
ax.legend(loc='upper left')
ax.set_ylim(0, 400)

# Compression ratio
ax = axes[1]
compression = {
    'FP32 ‚Üí FP16': 2.0,
    'FP32 ‚Üí INT8': 4.0,
    'FP32 ‚Üí INT4': 8.0,
    'FP16 ‚Üí INT8': 2.0,
    'FP16 ‚Üí INT4/FP4': 4.0,
}

bars = ax.barh(list(compression.keys()), list(compression.values()), color='steelblue')
ax.set_xlabel('Compression Ratio')
ax.set_title('Memory Compression Ratios')

for bar in bars:
    width = bar.get_width()
    ax.text(width + 0.1, bar.get_y() + bar.get_height()/2,
            f'{width:.0f}√ó', va='center', fontweight='bold')

ax.set_xlim(0, 10)

plt.tight_layout()
plt.show()

---

## ‚úã Try It Yourself

### Exercise 1: Custom Quantization Analysis

Pick a model size relevant to your work and calculate:
1. Memory at each precision level
2. Whether it fits on DGX Spark at each level
3. The recommended precision for your use case

<details>
<summary>Hint</summary>

Use the `calculate_model_memory()` function we defined. Consider:
- For training: you need ~3-4√ó more memory for optimizer states and gradients
- For inference: just the model weights plus ~10-20% for KV cache
</details>

In [None]:
# Exercise 1: Your analysis here

# Example: Analyze a 34B model
my_model_size = 34  # billion parameters

# TODO: Calculate memory at each precision
# TODO: Determine if it fits on DGX Spark (128GB)
# TODO: What's your recommendation?

# Your code here...

### Exercise 2: Signal-to-Noise Analysis

Calculate the Signal-to-Noise Ratio (SNR) for each quantization method.

**Formula:** `SNR_dB = 10 * log10(signal_power / noise_power)`

Where:
- signal_power = mean(original^2)
- noise_power = mean(error^2)

<details>
<summary>Hint</summary>

Higher SNR = better quality preservation. Good quantization should have SNR > 20 dB.
</details>

In [None]:
# Exercise 2: Calculate SNR for each method

def calculate_snr(original, quantized):
    """Calculate Signal-to-Noise Ratio in dB."""
    # TODO: Implement this function
    pass

# Test with our weights
torch.manual_seed(42)
test_weights = torch.randn(10000)

# TODO: Calculate SNR for INT8, INT4, FP8, FP4
# TODO: Which method has the best quality preservation?

# Your code here...

---

## Common Mistakes

### Mistake 1: Confusing Bits with Precision

```python
# Wrong: Assuming fewer bits always means lower quality
# INT8 has 256 levels, FP8 E4M3 has ~240 representable values
# BUT FP8's non-linear distribution often gives BETTER results!

# Right: Choose format based on your data distribution
# - For weights with normal distribution: FP formats often win
# - For activations with outliers: INT with scaling may be better
```

### Mistake 2: Ignoring Scale Overhead

```python
# Wrong: Claiming 8√ó compression for INT4 vs FP32
actual_compression = 4 / 0.5  # = 8√ó (ignoring scales)

# Right: Account for scale factors
# With group_size=128, scales add ~1.5% overhead
actual_compression = 4 / (0.5 + 0.015)  # ‚âà 7.8√ó
```

### Mistake 3: Using Wrong FP8 Format

```python
# Wrong: Using E5M2 for inference weights
model.to(torch.float8_e5m2)  # Larger range, less precision

# Right: E4M3 for inference, E5M2 for gradients/training
weights.to(torch.float8_e4m3fn)  # More precision for weights
gradients.to(torch.float8_e5m2)   # Larger range for gradients
```

---

## Checkpoint

You've learned:

- **Data types matter**: FP32‚ÜíFP16‚ÜíINT8‚ÜíINT4 each halve memory
- **Floating-point vs Integer**: FP formats preserve precision better for normally-distributed weights
- **FP8 is special**: Native Blackwell support gives 2√ó compute with 50% memory
- **FP4 is Blackwell-exclusive**: 3.5√ó compression with <0.1% quality loss
- **DGX Spark advantage**: 128GB lets you run 70B models in FP16, 200B+ in FP4!

---

## Challenge (Optional)

**Build a Precision Advisor**

Create a function that recommends the optimal precision for a given model and hardware:

```python
def recommend_precision(
    model_params_billions: float,
    memory_budget_gb: float = 128,
    task: str = 'inference',  # or 'training'
    quality_priority: str = 'balanced'  # 'quality', 'balanced', 'memory'
) -> dict:
    """
    Recommend optimal precision for deployment.
    
    Returns:
        dict with 'precision', 'memory_gb', 'fits', 'reasoning'
    """
    # Your implementation here
    pass
```

In [None]:
# Challenge: Build the precision advisor

def recommend_precision(
    model_params_billions: float,
    memory_budget_gb: float = 128,
    task: str = 'inference',
    quality_priority: str = 'balanced'
) -> dict:
    """
    Recommend optimal precision for deployment.
    """
    # TODO: Implement this advisor
    pass

# Test your advisor
# result = recommend_precision(70, memory_budget_gb=128, task='inference')
# print(result)

---

## Further Reading

- [NVIDIA FP8 Format Specification](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html)
- [A Survey of Quantization Methods](https://arxiv.org/abs/2103.13630)
- [Blackwell Architecture Whitepaper](https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/)
- [TensorRT Model Optimizer](https://developer.nvidia.com/tensorrt)

---

## Cleanup

In [None]:
# Clear any GPU memory
import gc

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"GPU memory after cleanup: {torch.cuda.memory_allocated()/1e9:.2f} GB")

print("\nNotebook complete! Ready for Lab 3.2.2: NVFP4 Quantization")

---

## Next Steps

In the next notebook, we'll dive deep into **NVFP4 Quantization** - applying FP4 to real 70B models on DGX Spark!

‚û°Ô∏è Continue to: [Lab 3.2.2: NVFP4 Quantization](lab-3.2.2-nvfp4-quantization.ipynb)