In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="1udrvT-zDGGZC2pSzOCozq04RTKGDagCg", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/06_00_intro.mp3"))


In [None]:
#@title üéß Code Walkthrough: Setup
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_01_setup.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

# The 5D Grid: Composing All Parallelism Dimensions

*Part 6 of 6 in the Vizuara series on 5D Parallelism from Scratch*
*Estimated time: 40 minutes*

In this final notebook, we bring everything together. We have learned five parallelism strategies individually ‚Äî now we will see how they **compose** into a single, unified system that spans thousands of GPUs. By the end, you will build an interactive 5D parallelism planner that recommends configurations for real-world models like Llama 3 405B and DeepSeek-V3.

# ü§ñ AI Teaching Assistant

Need help with this notebook? Open the **AI Teaching Assistant** ‚Äî it has already read this entire notebook and can help with concepts, code, and exercises.

**[üëâ Open AI Teaching Assistant](https://pods.vizuara.ai/courses/5d-parallelism-from-scratch/practice/6/assistant)**

*Tip: Open it in a separate tab and work through this notebook side-by-side.*


In [None]:
# Setup ‚Äî install and import everything we need
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from mpl_toolkits.mplot3d import Axes3D
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

%matplotlib inline
plt.rcParams['figure.dpi'] = 120
plt.rcParams['font.size'] = 11

print("Setup complete! Ready to compose the 5D grid.")

In [None]:
#@title üéß Listen: Why Matters
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_02_why_matters.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 1. Why Does This Matter?

In practice, no single parallelism strategy is enough to train a frontier model. Real training runs **compose all five dimensions** simultaneously:

| Model | Total GPUs | How They Got There |
|-------|------------|-------------------|
| **Llama 3 405B** | 16,384 H100s | 128 DP x 8 TP x 16 PP |
| **DeepSeek-V3** | 2,048 H800s | Dense + 256 MoE experts |
| **GPT-4** (estimated) | ~10,000-25,000 | Unknown exact config |
| **Gemini Ultra** (estimated) | ~10,000+ TPUs | Multi-dimensional parallelism |

Understanding how these dimensions compose is what separates a practitioner from someone who just read a blog post. Let us build that understanding from the ground up.

In [None]:
# Let us start with a concrete motivating example:
# WHY does Llama 3 405B need 16,384 GPUs?

model_params_B = 405  # billion parameters
bytes_per_param_training = 16  # mixed precision Adam: 2 + 2 + 12 bytes

total_memory_GB = model_params_B * bytes_per_param_training
single_gpu_memory_GB = 80  # H100 80GB

gpus_for_weights_alone = total_memory_GB / single_gpu_memory_GB

print(f"Llama 3 405B Training Memory Requirements")
print(f"=" * 50)
print(f"Parameters:              {model_params_B}B")
print(f"Memory per param:        {bytes_per_param_training} bytes")
print(f"Total training memory:   {total_memory_GB:,} GB")
print(f"Single H100 memory:      {single_gpu_memory_GB} GB")
print(f"Minimum GPUs (weights):  {gpus_for_weights_alone:.0f}")
print(f"\nBut with activations, micro-batches, and overhead,")
print(f"Meta used: 16,384 GPUs")
print(f"That is {16384 * single_gpu_memory_GB / 1000:.0f} TB of aggregate GPU memory!")

In [None]:
#@title üéß Listen: Mega Kitchen
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_03_mega_kitchen.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 2. Building Intuition: The Mega-Kitchen

Let us return to our restaurant analogy one final time and see how all five strategies work together.

Imagine a **mega-restaurant** that needs to serve 16,000 plates per hour. Here is how they organize 16,384 chefs:

- **Groups of 8 chefs** share one large table. Each chef handles a different part of the recipe ‚Äî one chops, one sautes, one seasons. They pass ingredients constantly across the table. This is **Tensor Parallelism** ‚Äî splitting each layer across 8 GPUs connected by NVLink (the "shared table").

- **16 tables** are arranged in an **assembly line**. Table 1 handles appetizers, Table 2 handles the first course, Table 3 the main, and so on. Each table passes its finished dish to the next. This is **Pipeline Parallelism** ‚Äî splitting the model's depth across nodes connected by InfiniBand.

- **128 identical assembly lines** operate in parallel, each working on different customer orders. At the end of a round, they share notes on what they learned. This is **Data Parallelism** ‚Äî processing different data across the entire cluster.

- **Within each table of 8**, the chefs also split long order lists between them. Chef A handles items 1-500, Chef B handles items 501-1000, and so on. This is **Sequence Parallelism** ‚Äî it shares the same group as TP.

- Some tables have **specialist chefs** for different cuisines ‚Äî Italian, Indian, Japanese. A host routes each order to the right specialist. This is **Expert Parallelism** ‚Äî distributing MoE experts with All-to-All communication.

In [None]:
#@title üéß Code Walkthrough: Mega Kitchen Viz
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_04_mega_kitchen_viz.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


Now let us visualize this mega-kitchen hierarchy. We will draw the cluster structure showing DP replicas, PP stages, and TP groups as nested rectangles.

In [None]:
# Set up the figure and draw the outer cluster boundary (DP level)
fig, ax = plt.subplots(1, 1, figsize=(14, 8))
ax.set_xlim(0, 14)
ax.set_ylim(0, 10)
ax.axis('off')
ax.set_title("The Mega-Kitchen: How 16,384 Chefs Organize",
             fontsize=16, fontweight='bold', pad=20)

cluster_rect = plt.Rectangle((0.5, 0.3), 13, 9.2, linewidth=2.5,
                               edgecolor='#2196F3', facecolor='#E3F2FD',
                               linestyle='--', alpha=0.4)
ax.add_patch(cluster_rect)
ax.text(7, 9.2, "Data Parallelism: 128 identical assembly lines",
        ha='center', fontsize=12, fontweight='bold', color='#1565C0')

Next we draw the individual assembly lines (PP stages) and their TP node groups.

In [None]:
# Draw 3 assembly lines (representing 128)
for i, x_start in enumerate([1.0, 5.0, 9.0]):
    label = f"Line {i+1}" if i < 2 else "Line 128"
    pp_rect = plt.Rectangle((x_start, 1.0), 3.5, 7.0, linewidth=2,
                              edgecolor='#FF9800', facecolor='#FFF3E0', alpha=0.5)
    ax.add_patch(pp_rect)
    # Draw 4 nodes (representing 16 PP stages)
    for j, y_pos in enumerate([1.5, 3.2, 4.9, 6.6]):
        stage_label = f"Stage {j+1}" if j < 3 else "Stage 16"
        node_rect = plt.Rectangle((x_start + 0.3, y_pos), 2.9, 1.3,
                                   linewidth=1.5, edgecolor='#4CAF50',
                                   facecolor='#E8F5E9', alpha=0.7)
        ax.add_patch(node_rect)
        ax.text(x_start + 1.75, y_pos + 0.65, f"8 GPUs\n(TP group)",
                ha='center', va='center', fontsize=7, color='#2E7D32')
        ax.text(x_start + 1.75, y_pos + 1.15, stage_label,
                ha='center', fontsize=7, fontweight='bold', color='#1B5E20')
        if j < 3:  # Arrows between stages
            ax.annotate('', xy=(x_start + 1.75, y_pos + 1.3),
                       xytext=(x_start + 1.75, y_pos + 1.6),
                       arrowprops=dict(arrowstyle='->', color='#FF9800', lw=1.5))
    ax.text(x_start + 1.75, 8.3, label, ha='center', fontsize=10,
            fontweight='bold', color='#E65100')
    if i == 1:  # Ellipsis between line 2 and 128
        ax.text(7.5, 4.5, "...", ha='center', fontsize=24,
                fontweight='bold', color='#666')

Finally, add the legend and display the complete mega-kitchen diagram.

In [None]:
# Legend
legend_items = [
    ("Data Parallelism (128x)", '#2196F3'),
    ("Pipeline Parallelism (16x)", '#FF9800'),
    ("Tensor Parallelism (8x per node)", '#4CAF50'),
]
for idx, (text, color) in enumerate(legend_items):
    ax.plot([], [], 's', color=color, markersize=10, label=text)
ax.legend(loc='lower center', ncol=3, fontsize=9, framealpha=0.9)

plt.tight_layout()
# üìä Visualization: display the chart
plt.show()

In [None]:
#@title üéß Listen: Math Composition
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_05_math_composition.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 3. The Mathematics of Composition

The total number of GPUs required is:

$$N_{\text{total}} = N_{\text{DP}} \times N_{\text{TP}} \times N_{\text{PP}} \times N_{\text{EP}}$$

Sequence Parallelism shares the same GPU group as Tensor Parallelism (both operate within a node), so it does **not** add an independent dimension to the GPU count.

Let us plug in some real numbers.

**Llama 3 405B (Meta, 2024):**
- $N_{\text{TP}} = 8$ ‚Äî one full node of 8 H100 GPUs, connected by NVLink
- $N_{\text{PP}} = 16$ ‚Äî model split across 16 pipeline stages
- $N_{\text{DP}} = 128$ ‚Äî 128 data-parallel replicas
- $N_{\text{SP}} = 8$ ‚Äî shares TP group (within the same node)
- $N_{\text{EP}} = 1$ ‚Äî Llama 3 is a dense model (not MoE)
- $N_{\text{total}} = 128 \times 8 \times 16 \times 1 = 16{,}384$ GPUs

**DeepSeek-V3 (DeepSeek, 2024):**
- $N_{\text{TP}} = 1$ ‚Äî no tensor parallelism (clever design choice)
- $N_{\text{PP}} = 8$ ‚Äî 8 pipeline stages
- $N_{\text{EP}} = 32$ ‚Äî 256 experts distributed across 32 EP groups (8 experts per GPU)
- $N_{\text{DP}} = 8$ ‚Äî 8 data-parallel replicas
- $N_{\text{total}} = 8 \times 1 \times 8 \times 32 = 2{,}048$ GPUs

We define a `ParallelismConfig` dataclass to hold the parallelism degrees for any model, along with GPU metadata.

In [None]:
# Define the core configuration dataclass
@dataclass
class ParallelismConfig:
    """Configuration for a 5D parallelism setup."""
    name: str
    n_dp: int        # Data Parallelism degree
    n_tp: int        # Tensor Parallelism degree
    n_pp: int        # Pipeline Parallelism degree
    n_sp: int        # Sequence Parallelism degree (shares TP group)
    n_ep: int        # Expert Parallelism degree
    gpu_type: str
    gpu_memory_gb: int

    @property
    def total_gpus(self) -> int:
        return self.n_dp * self.n_tp * self.n_pp * self.n_ep

    @property
    def total_nodes(self) -> int:
        gpus_per_node = self.n_tp  # TP group = one node
        return self.total_gpus // gpus_per_node

Now let us instantiate the real-world configurations and verify the GPU counts match published numbers.

In [None]:
# Real-world configurations
llama3_405b = ParallelismConfig(
    name="Llama 3 405B",
    n_dp=128, n_tp=8, n_pp=16, n_sp=8, n_ep=1,
    gpu_type="H100", gpu_memory_gb=80
)

deepseek_v3 = ParallelismConfig(
    name="DeepSeek-V3",
    n_dp=8, n_tp=1, n_pp=8, n_sp=1, n_ep=32,
    gpu_type="H800", gpu_memory_gb=80
)

for config in [llama3_405b, deepseek_v3]:
    print(f"\n{'=' * 55}")
    print(f"  {config.name}")
    print(f"{'=' * 55}")
    print(f"  DP = {config.n_dp:>4}  (data-parallel replicas)")
    print(f"  TP = {config.n_tp:>4}  (tensor-parallel within node)")
    print(f"  PP = {config.n_pp:>4}  (pipeline stages)")
    print(f"  SP = {config.n_sp:>4}  (sequence-parallel, shares TP)")
    print(f"  EP = {config.n_ep:>4}  (expert-parallel groups)")
    formula = (f"  {config.n_dp} x {config.n_tp} x "
               f"{config.n_pp} x {config.n_ep}")
    print(f"  Total = {formula} = {config.total_gpus:,} GPUs")
    print(f"  Nodes = {config.total_nodes:,} "
          f"({config.n_tp} {config.gpu_type}s per node)")

In [None]:
#@title üéß Listen: Comm Hierarchy
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_06_comm_hierarchy.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 4. The Communication Hierarchy

Not all communication is equal. The key insight behind 5D parallelism is to **match each strategy to the right level of the hardware hierarchy**:

| Parallelism | Communication Pattern | Placement | Bandwidth | Why Here? |
|-------------|----------------------|-----------|-----------|-----------|
| **TP** | AllReduce every layer | Within node | NVLink: 900 GB/s | Highest frequency ‚Äî needs fastest link |
| **SP** | Reduce-Scatter / All-Gather | Within node | NVLink: 900 GB/s | Shares TP group |
| **PP** | Point-to-point per micro-batch | Across nearby nodes | InfiniBand: 400 Gb/s (~50 GB/s) | Moderate frequency |
| **EP** | All-to-All dispatch + collect | Flexible | InfiniBand: ~50 GB/s | Carefully placed to minimize hops |
| **DP** | AllReduce once per step | Entire cluster | Ethernet/IB: varies | Lowest frequency ‚Äî can tolerate latency |

Let us plug in some numbers to see how communication volume differs across dimensions.

We start by defining the key model and training hyperparameters for Llama 3 405B, then compute per-dimension communication volumes.

In [None]:
# Communication volume comparison for Llama 3 405B
# These are approximate per-step volumes

hidden_dim = 16384       # Llama 3 405B hidden dimension
num_layers = 126         # Llama 3 405B layers
num_heads = 128          # attention heads
seq_len = 8192           # training sequence length
micro_batch_size = 1     # per micro-batch
global_batch_size = 1024 # total batch size (in sequences)
bytes_per_element = 2    # fp16 / bf16
total_params_B = 405     # billion parameters

n_tp = 8
n_pp = 16
n_dp = 128

With these parameters set, we can compute the communication volume for TP, PP, and DP.

In [None]:
# TP: AllReduce per layer (2x for forward + backward, across hidden dim)
# Each AllReduce sends 2 * data_size (reduce-scatter + all-gather)
tp_per_layer = (2 * micro_batch_size * seq_len * hidden_dim
                * bytes_per_element)
tp_total = tp_per_layer * (num_layers // n_pp) * 2  # fwd + bwd
tp_total_GB = tp_total / (1024**3)

# PP: Point-to-point at stage boundaries (activation tensor)
pp_per_microbatch = (micro_batch_size * seq_len * hidden_dim
                     * bytes_per_element)
num_microbatches = global_batch_size // (n_dp * micro_batch_size)
pp_total = pp_per_microbatch * num_microbatches * 2  # fwd + bwd
pp_total_GB = pp_total / (1024**3)

# DP: AllReduce gradients once per step
dp_total = 2 * (total_params_B * 1e9 / (n_tp * n_pp)) * bytes_per_element
dp_total_GB = dp_total / (1024**3)

print("Communication Volume Per Training Step (Llama 3 405B)")
print("=" * 55)
print(f"  TP (AllReduce per layer):   {tp_total_GB:>8.2f} GB")
print(f"  PP (point-to-point):        {pp_total_GB:>8.2f} GB")
print(f"  DP (AllReduce gradients):   {dp_total_GB:>8.2f} GB")

In [None]:
#@title üéß Code Walkthrough: Comm Viz
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_07_comm_viz.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


Now let us visualize the communication hierarchy with two charts: available bandwidth per dimension (left) and communication frequency per step (right).

In [None]:
# Visualization: Communication hierarchy ‚Äî bandwidth comparison

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left panel: bandwidth by parallelism dimension
dimensions = ['TP\n(NVLink)', 'SP\n(NVLink)', 'PP\n(InfiniBand)',
              'EP\n(IB/Ethernet)', 'DP\n(Cluster)']
bandwidths = [900, 900, 50, 50, 25]  # GB/s (bidirectional effective)
colors = ['#4CAF50', '#8BC34A', '#FF9800', '#9C27B0', '#2196F3']

bars = axes[0].bar(dimensions, bandwidths, color=colors, edgecolor='white',
                   linewidth=1.5, width=0.6)
axes[0].set_ylabel("Effective Bandwidth (GB/s)", fontsize=11)
axes[0].set_title("Available Bandwidth by Dimension", fontsize=13,
                  fontweight='bold')
axes[0].set_yscale('log')
axes[0].set_ylim(10, 2000)

for bar, bw in zip(bars, bandwidths):
    axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() * 1.15,
                f'{bw} GB/s', ha='center', va='bottom', fontsize=9,
                fontweight='bold')

The right panel shows how frequently each dimension communicates during a single training step.

In [None]:
# Right panel: communication frequency
freq_labels = ['TP', 'SP', 'PP', 'EP', 'DP']
freq_values = [num_layers // n_pp, num_layers // n_pp,
               num_microbatches, num_microbatches, 1]
freq_desc = ['per layer', 'per layer', 'per micro-batch',
             'per micro-batch', 'per step']

bars2 = axes[1].barh(freq_labels, freq_values, color=colors,
                     edgecolor='white', linewidth=1.5, height=0.5)
axes[1].set_xlabel("Communication Events per Step", fontsize=11)
axes[1].set_title("Communication Frequency by Dimension", fontsize=13,
                  fontweight='bold')
axes[1].set_xscale('log')

for bar, val, desc in zip(bars2, freq_values, freq_desc):
    axes[1].text(bar.get_width() * 1.3, bar.get_y() + bar.get_height()/2,
                f'{val}x ({desc})', va='center', fontsize=9)

plt.tight_layout()
# üìä Visualization: display the chart
plt.show()

In [None]:
#@title üéß Listen: Planner Intro
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_08_planner_intro.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 5. Let Us Build It ‚Äî The 5D Configuration Planner

Now we will build the full planner step by step. We start with the foundational data structures.

### 5.1 GPU and Node Specifications

We define dataclasses for GPU hardware specs and interconnect specs, then populate them with real-world values for A100, H100, and H800 GPUs.

In [None]:
@dataclass
class GPUSpec:
    """Hardware specifications for a GPU type."""
    name: str
    memory_gb: float
    peak_tflops_fp16: float  # fp16/bf16 peak TFLOPS
    nvlink_bw_gbps: float    # NVLink bandwidth (GB/s, bidirectional)
    gpus_per_node: int

@dataclass
class InterconnectSpec:
    """Interconnect specifications between nodes."""
    name: str
    bandwidth_gbps: float  # GB/s per link
    latency_us: float      # microseconds

Now we populate the hardware catalog with real specs and print a summary table.

In [None]:
# Define real GPU specs
GPU_SPECS = {
    "A100_80GB": GPUSpec("A100 80GB", 80.0, 312.0, 600.0, 8),
    "H100_80GB": GPUSpec("H100 80GB", 80.0, 989.0, 900.0, 8),
    "H800_80GB": GPUSpec("H800 80GB", 80.0, 989.0, 400.0, 8),
}

INTERCONNECTS = {
    "InfiniBand_HDR": InterconnectSpec("InfiniBand HDR", 25.0, 1.0),
    "InfiniBand_NDR": InterconnectSpec("InfiniBand NDR", 50.0, 0.8),
    "Ethernet_100G":  InterconnectSpec("100G Ethernet", 12.5, 5.0),
}

# Display available hardware
print("Available GPU Types:")
print(f"{'GPU':<15} {'Memory':<10} {'FP16 TFLOPS':<13} "
      f"{'NVLink BW':<12} {'GPUs/Node'}")
print("-" * 62)
for key, gpu in GPU_SPECS.items():
    print(f"{gpu.name:<15} {gpu.memory_gb:<10.0f} "
          f"{gpu.peak_tflops_fp16:<13.0f} "
          f"{gpu.nvlink_bw_gbps:<12.0f} {gpu.gpus_per_node}")

In [None]:
#@title üéß Code Walkthrough: Memory Calculator
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_09_memory_calculator.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### 5.2 Memory Budget Calculator

This is the heart of the planner. Given a model configuration and a parallelism configuration, we compute how much memory each GPU needs.

The memory on each GPU consists of:
- **Weights**: $\frac{\text{total\_params}}{N_{\text{TP}} \times N_{\text{PP}}} \times \text{bytes\_per\_param}$
- **Gradients**: same size as weights
- **Optimizer states**: $3 \times$ weights in fp32 (Adam: first moment, second moment, fp32 copy)
- **Activations**: depends on batch size, sequence length, hidden dimension, and checkpointing

With ZeRO Stage 1, optimizer states are sharded across $N_{\text{DP}}$.

First, we define the `ModelConfig` dataclass that holds all architecture parameters for any model.

In [None]:
@dataclass
class ModelConfig:
    """Model architecture configuration."""
    name: str
    total_params_B: float    # billions of parameters
    num_layers: int
    hidden_dim: int
    num_heads: int
    num_experts: int         # 1 for dense models
    expert_params_B: float   # params per expert (for MoE)
    seq_len: int
    vocab_size: int

Now the core function: `compute_memory_per_gpu` calculates the weight, gradient, optimizer, and activation memory given a model and parallelism configuration.

In [None]:
def compute_memory_per_gpu(
    model: ModelConfig,
    n_dp: int, n_tp: int, n_pp: int, n_ep: int,
    micro_batch_size: int = 1,
    zero_stage: int = 1,
    activation_checkpointing: bool = True
) -> Dict[str, float]:
    """
    Compute per-GPU memory breakdown in GB.
    Returns a dict with weight, gradient, optimizer, activation memory.
    """
    # Total params on this GPU (sharded by TP and PP)
    if model.num_experts > 1:
        # MoE: shared params split by TP*PP, expert params split by EP
        shared_params = model.total_params_B - (model.expert_params_B
                        * model.num_experts)
        shared_per_gpu = shared_params * 1e9 / (n_tp * n_pp)
        experts_per_gpu = model.num_experts // n_ep
        expert_per_gpu = model.expert_params_B * 1e9 * experts_per_gpu
        params_per_gpu = shared_per_gpu + expert_per_gpu
    else:
        params_per_gpu = model.total_params_B * 1e9 / (n_tp * n_pp)

    # Weight memory (bf16 = 2 bytes)
    weight_mem = params_per_gpu * 2 / (1024**3)

    # Gradient memory (bf16 = 2 bytes)
    grad_mem = params_per_gpu * 2 / (1024**3)

We continue the memory computation with optimizer states (applying ZeRO sharding) and activation memory (with optional checkpointing).

In [None]:
    # Optimizer memory (Adam: fp32 weights + fp32 m + fp32 v = 12 bytes)
    optimizer_mem = params_per_gpu * 12 / (1024**3)

    # Apply ZeRO sharding
    if zero_stage >= 1:
        optimizer_mem /= n_dp
    if zero_stage >= 2:
        grad_mem /= n_dp
    if zero_stage >= 3:
        weight_mem /= n_dp

    # Activation memory (simplified estimate)
    layers_per_gpu = model.num_layers // n_pp
    seq_per_gpu = model.seq_len // n_tp  # SP shares TP group
    act_per_layer = (10 * micro_batch_size * seq_per_gpu
                     * model.hidden_dim * 2 / (1024**3))
    if activation_checkpointing:
        act_mem = act_per_layer * (layers_per_gpu ** 0.5)
    else:
        act_mem = act_per_layer * layers_per_gpu

Assemble the memory breakdown into a dictionary and return it.

In [None]:
    return {
        "weights": weight_mem,
        "gradients": grad_mem,
        "optimizer": optimizer_mem,
        "activations": act_mem,
        "total": weight_mem + grad_mem + optimizer_mem + act_mem,
        "params_per_gpu_B": params_per_gpu / 1e9,
    }

In [None]:
#@title üéß Code Walkthrough: Memory Results
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_10_memory_results.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


Let us define the two real-world model configurations and compute the memory breakdown for Llama 3 405B.

In [None]:
# Define real model configurations
llama3_model = ModelConfig(
    name="Llama 3 405B",
    total_params_B=405.0,
    num_layers=126,
    hidden_dim=16384,
    num_heads=128,
    num_experts=1,
    expert_params_B=0.0,
    seq_len=8192,
    vocab_size=128256
)

deepseek_v3_model = ModelConfig(
    name="DeepSeek-V3",
    total_params_B=671.0,
    num_layers=61,
    hidden_dim=7168,
    num_heads=128,
    num_experts=256,
    expert_params_B=1.6,    # ~1.6B params per expert
    seq_len=4096,
    vocab_size=129280
)

Now we run the memory calculator on Llama 3 405B and check whether the result fits in an H100 GPU.

In [None]:
# Compute memory for Llama 3 405B
mem_llama = compute_memory_per_gpu(
    llama3_model, n_dp=128, n_tp=8, n_pp=16, n_ep=1,
    micro_batch_size=1, zero_stage=1
)

print("Memory Per GPU ‚Äî Llama 3 405B (DP=128, TP=8, PP=16)")
print("=" * 50)
for key, val in mem_llama.items():
    if key == "params_per_gpu_B":
        print(f"  Params per GPU:     {val:.2f} B")
    else:
        print(f"  {key:<20s} {val:>8.2f} GB")
print(f"\n  H100 capacity:         80.00 GB")
print(f"  Fits? {'Yes' if mem_llama['total'] < 80 else 'No'} "
      f"({mem_llama['total']/80*100:.0f}% utilized)")

We also compute DeepSeek-V3's memory and compare both models side by side with pie charts showing per-GPU memory composition.

In [None]:
# Visualization: Memory breakdown pie charts for both models

mem_ds = compute_memory_per_gpu(
    deepseek_v3_model, n_dp=8, n_tp=1, n_pp=8, n_ep=32,
    micro_batch_size=1, zero_stage=1
)

fig, axes = plt.subplots(1, 2, figsize=(13, 5))
categories = ['weights', 'gradients', 'optimizer', 'activations']
pie_colors = ['#2196F3', '#4CAF50', '#FF9800', '#9C27B0']

for ax, mem, title in zip(axes, [mem_llama, mem_ds],
                          ["Llama 3 405B\n(DP=128, TP=8, PP=16)",
                           "DeepSeek-V3\n(DP=8, TP=1, PP=8, EP=32)"]):
    sizes = [mem[c] for c in categories]
    total = sum(sizes)
    labels = [f"{c.capitalize()}\n{s:.1f} GB ({s/total*100:.0f}%)"
              for c, s in zip(categories, sizes)]

    wedges, texts = ax.pie(sizes, labels=labels, colors=pie_colors,
                           startangle=90, textprops={'fontsize': 9})
    ax.set_title(f"{title}\nTotal: {total:.1f} GB / 80 GB",
                 fontsize=11, fontweight='bold')

plt.suptitle("Per-GPU Memory Breakdown", fontsize=14, fontweight='bold',
             y=1.02)
plt.tight_layout()
# üìä Visualization: display the chart
plt.show()

In [None]:
#@title üéß Code Walkthrough: Comm Estimator
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_11_comm_estimator.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### 5.3 Communication Volume Estimator

For each parallelism dimension, let us estimate how much data moves per training step.

The `compute_communication_volume` function estimates bytes transferred for TP, PP, DP, EP, and SP per training step. We start with the function signature and the TP / PP calculations.

In [None]:
def compute_communication_volume(
    model: ModelConfig,
    n_dp: int, n_tp: int, n_pp: int, n_ep: int,
    micro_batch_size: int = 1,
    global_batch_size: int = 1024
) -> Dict[str, Dict[str, float]]:
    """
    Estimate communication volume per training step for each dimension.
    Returns dict of {dimension: {volume_GB, num_ops, pattern}}.
    """
    layers_per_stage = model.num_layers // n_pp
    params_per_gpu = model.total_params_B * 1e9 / (n_tp * n_pp)
    seq_per_tp = model.seq_len  # full seq needed for attention
    num_microbatches = global_batch_size // (n_dp * micro_batch_size)

    results = {}

    # TP: AllReduce per layer (forward + backward)
    if n_tp > 1:
        tp_per_layer = (2 * micro_batch_size * seq_per_tp
                       * model.hidden_dim * 2)  # bytes
        tp_total = tp_per_layer * layers_per_stage * 2  # fwd + bwd
        tp_total *= num_microbatches
        results["TP"] = {
            "volume_GB": tp_total / (1024**3),
            "num_ops": layers_per_stage * 2 * num_microbatches,
            "pattern": "AllReduce"
        }
    else:
        results["TP"] = {"volume_GB": 0, "num_ops": 0, "pattern": "N/A"}

Next we compute the PP and DP communication volumes.

In [None]:
    # PP: Point-to-point at stage boundaries
    if n_pp > 1:
        pp_per_mb = (micro_batch_size * model.seq_len
                    * model.hidden_dim * 2)  # bytes
        pp_total = pp_per_mb * num_microbatches * 2  # fwd + bwd
        results["PP"] = {
            "volume_GB": pp_total / (1024**3),
            "num_ops": num_microbatches * 2,
            "pattern": "Point-to-Point"
        }
    else:
        results["PP"] = {"volume_GB": 0, "num_ops": 0, "pattern": "N/A"}

    # DP: AllReduce gradients
    if n_dp > 1:
        dp_total = 2 * params_per_gpu * 2  # 2x for reduce-scatter + all-gather
        results["DP"] = {
            "volume_GB": dp_total / (1024**3),
            "num_ops": 1,
            "pattern": "AllReduce"
        }
    else:
        results["DP"] = {"volume_GB": 0, "num_ops": 0, "pattern": "N/A"}

Finally, the EP (All-to-All) and SP (Reduce-Scatter) volumes, plus we run it on Llama 3 405B.

In [None]:
    # EP: All-to-All dispatch + collect
    if n_ep > 1 and model.num_experts > 1:
        # Tokens dispatched to experts
        tokens_per_step = (micro_batch_size * model.seq_len
                          * num_microbatches)
        # Each token's hidden state sent to an expert
        ep_volume = (tokens_per_step * model.hidden_dim * 2 * 2)  # dispatch+collect
        results["EP"] = {
            "volume_GB": ep_volume / (1024**3),
            "num_ops": num_microbatches * 2,
            "pattern": "All-to-All"
        }
    else:
        results["EP"] = {"volume_GB": 0, "num_ops": 0, "pattern": "N/A"}

    # SP: Reduce-scatter / All-gather (same group as TP)
    if n_tp > 1:
        sp_per_layer = (micro_batch_size * model.seq_len
                       * model.hidden_dim * 2)  # bytes
        sp_total = sp_per_layer * layers_per_stage * 2  # fwd+bwd
        sp_total *= num_microbatches
        results["SP"] = {
            "volume_GB": sp_total / (1024**3),
            "num_ops": layers_per_stage * 2 * num_microbatches,
            "pattern": "Reduce-Scatter"
        }
    else:
        results["SP"] = {"volume_GB": 0, "num_ops": 0, "pattern": "N/A"}

    return results

Let us call the function for Llama 3 405B and print the per-dimension communication table.

In [None]:
# Compute for Llama 3 405B
comm_llama = compute_communication_volume(
    llama3_model, n_dp=128, n_tp=8, n_pp=16, n_ep=1,
    global_batch_size=2048
)

print("Communication Volume Per Step ‚Äî Llama 3 405B")
print("=" * 60)
print(f"{'Dim':<5} {'Volume (GB)':<14} {'Ops/Step':<12} {'Pattern'}")
print("-" * 60)
for dim, info in comm_llama.items():
    if info['volume_GB'] > 0:
        print(f"{dim:<5} {info['volume_GB']:>10.2f}    "
              f"{info['num_ops']:>8}     {info['pattern']}")
    else:
        print(f"{dim:<5} {'--':>10}    {'--':>8}     {info['pattern']}")

In [None]:
#@title üéß Code Walkthrough: Comm Comparison
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_12_comm_comparison.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


Now we visualize the communication volumes side by side for both models to see how different architectures distribute their communication load.

In [None]:
# Visualization: Communication volume comparison

comm_ds = compute_communication_volume(
    deepseek_v3_model, n_dp=8, n_tp=1, n_pp=8, n_ep=32,
    global_batch_size=512
)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))
dims = ['TP', 'SP', 'PP', 'DP', 'EP']
bar_colors = ['#4CAF50', '#8BC34A', '#FF9800', '#2196F3', '#9C27B0']

We plot each model's communication volume bar chart, annotating active dimensions with their GB values and inactive ones with "N/A".

In [None]:
for ax, comm, title in zip(axes, [comm_llama, comm_ds],
                           ["Llama 3 405B", "DeepSeek-V3"]):
    volumes = [comm[d]['volume_GB'] for d in dims]
    volumes_plot = [max(v, 0.001) for v in volumes]  # for log scale
    active = [v > 0 for v in volumes]
    bars = ax.bar(dims, volumes_plot, color=[c if a else '#E0E0E0'
                  for c, a in zip(bar_colors, active)],
                  edgecolor='white', linewidth=1.5)
    for bar, vol, a in zip(bars, volumes, active):
        if a and vol > 0:
            ax.text(bar.get_x() + bar.get_width()/2,
                   bar.get_height() * 1.1,
                   f'{vol:.1f} GB', ha='center', va='bottom',
                   fontsize=9, fontweight='bold')
        elif not a:
            ax.text(bar.get_x() + bar.get_width()/2,
                   bar.get_height() * 1.1, 'N/A',
                   ha='center', va='bottom', fontsize=9, color='gray')
    ax.set_ylabel("Communication Volume (GB)", fontsize=11)
    ax.set_title(title, fontsize=13, fontweight='bold')
    ax.set_yscale('log')
    ax.set_ylim(0.001, max(volumes_plot) * 5)

plt.suptitle("Communication Volume Per Training Step",
             fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
# üìä Visualization: display the chart
plt.show()

In [None]:
#@title üéß Code Walkthrough: Validator
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_13_validator.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### 5.4 The Composition Validator

Before we use a configuration, we must verify it actually works. There are hard constraints that must be satisfied.

**Common pitfall**: Blindly picking parallelism numbers without checking divisibility constraints will result in silent correctness bugs or crashes. Always validate before launching a training run.

The validator checks divisibility constraints (TP divides heads, PP divides layers, EP divides experts), memory capacity, and batch size alignment. It returns a list of errors and warnings.

In [None]:
def validate_config(
    model: ModelConfig,
    n_dp: int, n_tp: int, n_pp: int, n_ep: int,
    gpu_spec: GPUSpec,
    micro_batch_size: int = 1,
    global_batch_size: int = 1024,
    zero_stage: int = 1
) -> Tuple[bool, List[str]]:
    """
    Validate a parallelism configuration. Returns (is_valid, messages).
    """
    errors = []
    warnings_list = []

    # Check: TP must divide num_heads
    if model.num_heads % n_tp != 0:
        errors.append(
            f"TP={n_tp} does not divide num_heads={model.num_heads}")

    # Check: PP must divide num_layers
    if model.num_layers % n_pp != 0:
        errors.append(
            f"PP={n_pp} does not divide num_layers={model.num_layers}")

    # Check: TP should not exceed GPUs per node
    if n_tp > gpu_spec.gpus_per_node:
        errors.append(
            f"TP={n_tp} exceeds GPUs per node={gpu_spec.gpus_per_node}")

We continue the validator with EP divisibility, memory capacity checks, and batch size alignment.

In [None]:
    # Check: EP must divide num_experts (for MoE)
    if model.num_experts > 1 and model.num_experts % n_ep != 0:
        errors.append(
            f"EP={n_ep} does not divide num_experts={model.num_experts}")

    # Check: total GPUs is reasonable
    total_gpus = n_dp * n_tp * n_pp * n_ep
    if total_gpus > 100000:
        warnings_list.append(
            f"Total GPUs={total_gpus:,} is extremely large")

    # Check: memory fits
    mem = compute_memory_per_gpu(
        model, n_dp, n_tp, n_pp, n_ep,
        micro_batch_size, zero_stage
    )
    if mem['total'] > gpu_spec.memory_gb:
        errors.append(
            f"Memory {mem['total']:.1f} GB exceeds GPU capacity "
            f"{gpu_spec.memory_gb} GB")
    elif mem['total'] > gpu_spec.memory_gb * 0.9:
        warnings_list.append(
            f"Memory {mem['total']:.1f} GB is >90% of GPU capacity")

    # Check: global batch size is divisible
    effective_dp = n_dp
    if global_batch_size % effective_dp != 0:
        warnings_list.append(
            f"Global batch {global_batch_size} not divisible by DP={n_dp}")

Finally we aggregate errors and warnings, then run the validator on both real-world configs.

In [None]:
    is_valid = len(errors) == 0
    messages = ([f"ERROR: {e}" for e in errors] +
                [f"WARNING: {w}" for w in warnings_list])

    if is_valid and not warnings_list:
        messages.append("All checks passed!")

    return is_valid, messages

# Validate Llama 3 405B
print("Validating Llama 3 405B configuration...")
valid, msgs = validate_config(
    llama3_model, n_dp=128, n_tp=8, n_pp=16, n_ep=1,
    gpu_spec=GPU_SPECS["H100_80GB"], global_batch_size=2048
)
for msg in msgs:
    print(f"  {msg}")
print(f"  Valid: {valid}")

print()

# Validate DeepSeek-V3
print("Validating DeepSeek-V3 configuration...")
valid2, msgs2 = validate_config(
    deepseek_v3_model, n_dp=8, n_tp=1, n_pp=8, n_ep=32,
    gpu_spec=GPU_SPECS["H800_80GB"], global_batch_size=512
)
for msg in msgs2:
    print(f"  {msg}")
print(f"  Valid: {valid2}")

In [None]:
#@title üéß Code Walkthrough: 3d Grid
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_14_3d_grid.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### 5.5 3D GPU Grid Visualization

This is where it gets visually satisfying. Let us render a 3D grid showing how GPUs are organized across the three primary spatial dimensions: DP, TP, and PP.

The `visualize_3d_gpu_grid` function creates a 3D scatter plot where each dot is a GPU. TP maps to the x-axis (within-node), PP to y (across stages), and DP to z (replicas). We draw node boundaries and pipeline connections as line overlays.

In [None]:
def visualize_3d_gpu_grid(
    n_dp: int, n_tp: int, n_pp: int,
    title: str = "5D Parallelism GPU Grid",
    max_dp_show: int = 8,
    max_pp_show: int = 8
):
    """
    Visualize the GPU layout as a 3D grid.
    TP = x-axis (within node), PP = y-axis (across stages),
    DP = z-axis (replicas).
    """
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='3d')

    # Limit display for readability
    dp_show = min(n_dp, max_dp_show)
    pp_show = min(n_pp, max_pp_show)
    tp_show = n_tp

    # Create grid coordinates
    tp_coords = np.arange(tp_show)
    pp_coords = np.arange(pp_show)
    dp_coords = np.arange(dp_show)

We place each GPU as a scatter point, colored by pipeline stage, and draw NVLink / pipeline connection lines.

In [None]:
    # Color by pipeline stage
    for dp_idx in dp_coords:
        for pp_idx in pp_coords:
            for tp_idx in tp_coords:
                # Color: blend based on all three dimensions
                color_val = pp_idx / max(pp_show - 1, 1)
                alpha = 0.4 + 0.5 * (1 - dp_idx / max(dp_show - 1, 1))

                ax.scatter(tp_idx, pp_idx, dp_idx,
                          c=[plt.cm.viridis(color_val)],
                          s=120, alpha=alpha, edgecolors='black',
                          linewidth=0.5, zorder=5)

    # Draw node boundaries (TP groups)
    for dp_idx in dp_coords:
        for pp_idx in pp_coords:
            xs = [0, tp_show - 1]
            ys = [pp_idx, pp_idx]
            zs = [dp_idx, dp_idx]
            ax.plot(xs, ys, zs, color='#4CAF50', linewidth=1.5,
                   alpha=0.3)

    # Draw pipeline connections
    for dp_idx in dp_coords:
        for tp_idx in tp_coords:
            xs = [tp_idx, tp_idx]
            ys = [0, pp_show - 1]
            zs = [dp_idx, dp_idx]
            ax.plot(xs, ys, zs, color='#FF9800', linewidth=0.8,
                   alpha=0.2)

Finally we label the axes, add a title showing total vs. displayed GPUs, and render the plot.

In [None]:
    ax.set_xlabel(f'Tensor Parallel ({n_tp}x)\n[NVLink, within node]',
                  fontsize=10, labelpad=10)
    ax.set_ylabel(f'Pipeline Parallel ({n_pp}x)\n[InfiniBand, across nodes]',
                  fontsize=10, labelpad=10)
    ax.set_zlabel(f'Data Parallel ({n_dp}x)\n[Cluster-wide]',
                  fontsize=10, labelpad=10)

    # Annotate totals
    total = n_dp * n_tp * n_pp
    shown = dp_show * tp_show * pp_show
    subtitle = (f"Showing {shown} of {total:,} GPUs "
                f"({n_dp}x DP x {n_tp}x TP x {n_pp}x PP)")

    ax.set_title(f"{title}\n{subtitle}", fontsize=13, fontweight='bold',
                 pad=20)

    ax.view_init(elev=25, azim=135)

    plt.tight_layout()
    # üìä Visualization: display the chart
    plt.show()

# Visualize Llama 3 405B grid
visualize_3d_gpu_grid(
    n_dp=128, n_tp=8, n_pp=16,
    title="Llama 3 405B ‚Äî 16,384 GPU Grid"
)

Let us also see the DeepSeek-V3 grid shape, which looks very different because it uses EP instead of TP.

In [None]:
# Visualize DeepSeek-V3 grid (different shape ‚Äî EP replaces TP)
visualize_3d_gpu_grid(
    n_dp=8, n_tp=8, n_pp=8,
    title="DeepSeek-V3 ‚Äî 2,048 GPU Grid\n(EP=32 maps onto node groups)",
    max_dp_show=8, max_pp_show=8
)

In [None]:
#@title üéß Before You Start: Todo1 Intro
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_15_todo1_intro.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 6. Your Turn ‚Äî TODO Sections

Now it is your turn. We have built the building blocks ‚Äî let us see if you can put them together.

### TODO 1: Implement the Parallelism Config Optimizer

Given a model size and GPU count, find the optimal DP / TP / PP split. The heuristic is:
1. Set TP to the maximum that fits within a node (usually 8)
2. Set PP to the minimum needed so the model fits in memory
3. Maximize DP with the remaining GPUs
4. Verify the configuration is valid

In [None]:
def recommend_parallelism(
    model: ModelConfig,
    total_gpus: int,
    gpu_spec: GPUSpec,
    max_tp: int = 8
) -> Optional[ParallelismConfig]:
    """
    Recommend a parallelism configuration.
    Heuristic: Set TP to node size, find min PP for memory,
    maximize DP with remaining GPUs, then validate.

    Args:
        model: Model architecture config
        total_gpus: Total available GPUs
        gpu_spec: GPU hardware specs
        max_tp: Maximum tensor parallelism degree
    Returns:
        ParallelismConfig or None if no valid config found.
    """
    # ============ TODO ============
    # Step 1: Set n_tp = min(max_tp, gpus_per_node)
    # Step 2: Set n_ep = model.num_experts if MoE, else 1
    # Step 3: Try n_pp from 1 upward. For each n_pp:
    #         a) Check n_pp divides model.num_layers
    #         b) Compute n_dp = total_gpus // (n_tp * n_pp * n_ep)
    #         c) Check n_dp >= 1 and memory fits
    # Step 4: Return the ParallelismConfig
    # ==============================

    n_tp = min(max_tp, gpu_spec.gpus_per_node)

Complete the MoE expert parallelism setup and the search loop over PP values.

In [None]:
    # For MoE, set EP based on expert count
    if model.num_experts > 1:
        # Try to place experts across available GPUs
        n_ep = min(model.num_experts,
                   total_gpus // (n_tp * 2))  # leave room for PP,DP
        # Ensure n_ep divides num_experts
        while n_ep > 1 and model.num_experts % n_ep != 0:
            n_ep -= 1
    else:
        n_ep = 1

    # YOUR CODE: find the right n_pp and n_dp
    n_pp = ???  # Try values from 1 upward
    n_dp = ???  # Remaining GPUs

    return None  # Replace with your ParallelismConfig

In [None]:
#@title üéß Code Walkthrough: Todo1 Solution
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_16_todo1_solution.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


Here is the reference solution. It searches for the minimum PP that satisfies memory constraints, then maximizes DP with the remaining GPUs.

In [None]:
# Verification cell ‚Äî Reference solution (hidden)
# Uncomment the function below to check your answer

def recommend_parallelism_solution(
    model: ModelConfig,
    total_gpus: int,
    gpu_spec: GPUSpec,
    max_tp: int = 8
) -> Optional[ParallelismConfig]:
    """Reference solution for the config optimizer."""
    n_tp = min(max_tp, gpu_spec.gpus_per_node)

    if model.num_experts > 1:
        n_ep = min(model.num_experts, total_gpus // (n_tp * 2))
        while n_ep > 1 and model.num_experts % n_ep != 0:
            n_ep -= 1
    else:
        n_ep = 1

    best_config = None

    for n_pp in range(1, model.num_layers + 1):
        if model.num_layers % n_pp != 0:
            continue

        remaining = total_gpus // (n_tp * n_pp * n_ep)
        if remaining < 1:
            break

        n_dp = remaining

We check memory and validity for each candidate PP, then return the first valid config (which minimizes PP and maximizes DP).

In [None]:
        # Check memory
        mem = compute_memory_per_gpu(
            model, n_dp, n_tp, n_pp, n_ep,
            micro_batch_size=1, zero_stage=1
        )

        if mem['total'] <= gpu_spec.memory_gb * 0.9:
            # Validate
            valid, _ = validate_config(
                model, n_dp, n_tp, n_pp, n_ep, gpu_spec
            )
            if valid:
                best_config = ParallelismConfig(
                    name=f"{model.name} (recommended)",
                    n_dp=n_dp, n_tp=n_tp, n_pp=n_pp,
                    n_sp=n_tp, n_ep=n_ep,
                    gpu_type=gpu_spec.name,
                    gpu_memory_gb=int(gpu_spec.memory_gb)
                )
                break  # Take first valid (minimizes PP)

    return best_config

Let us test: does our optimizer recover the actual Llama 3 405B configuration?

In [None]:
# Test: does our optimizer recover the Llama 3 405B config?
recommended = recommend_parallelism_solution(
    llama3_model,
    total_gpus=16384,
    gpu_spec=GPU_SPECS["H100_80GB"]
)

if recommended:
    print("Recommended config for Llama 3 405B on 16,384 H100s:")
    print(f"  DP={recommended.n_dp}, TP={recommended.n_tp}, "
          f"PP={recommended.n_pp}, EP={recommended.n_ep}")
    print(f"  Total GPUs: {recommended.total_gpus:,}")

    # Compare to actual
    actual_dp, actual_tp, actual_pp = 128, 8, 16
    match = (recommended.n_dp == actual_dp and
             recommended.n_tp == actual_tp and
             recommended.n_pp == actual_pp)
    if match:
        print("\n  Matches Meta's actual configuration!")
    else:
        print(f"\n  Actual config: DP={actual_dp}, TP={actual_tp}, "
              f"PP={actual_pp}")
        print("  (Heuristic may differ from actual ‚Äî that is OK)")
else:
    print("No valid configuration found.")

In [None]:
#@title üéß Listen: Think About It
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_17_think_about_it.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### Think About This

Before moving to TODO 2, consider:
- Why did we set TP first and then find PP?
- What would happen if we maximized PP instead of DP?
- Why is TP always limited to within a single node?

**Key insight**: The answer lies in the communication hierarchy. TP requires the most bandwidth (AllReduce every layer), so it **must** sit on the fastest link (NVLink). PP has moderate requirements, and DP can tolerate the most latency. By setting TP first to the node boundary, we automatically respect the hardware hierarchy.

In [None]:
#@title üéß Before You Start: Todo2 Mfu
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_18_todo2_mfu.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### TODO 2: Compute Model FLOPS Utilization (MFU)

MFU measures what fraction of theoretical peak FLOPS the training run actually achieves. Llama 3 reports 38-43% MFU. Let us compute it.

The formula:
$$\text{MFU} = \frac{\text{Achieved FLOPS per GPU}}{\text{Peak FLOPS per GPU}}$$

Where achieved FLOPS per GPU depends on:
- Model FLOPS per token: approximately $6 \times P$ (forward + backward, for $P$ parameters)
- Tokens processed per second per GPU
- Pipeline bubble overhead reduces effective throughput

In [None]:
def compute_mfu(
    model: ModelConfig,
    config: ParallelismConfig,
    gpu_spec: GPUSpec,
    tokens_per_second_per_gpu: float,
    bubble_fraction: float = 0.0
) -> float:
    """
    Compute Model FLOPS Utilization.
    Args: model config, parallelism config, GPU specs,
          achieved tokens/sec/GPU, bubble fraction (0-1).
    Returns: MFU as a fraction (0 to 1).
    """
    # ============ TODO ============
    # Step 1: flops_per_token = 6 * total_params
    # Step 2: achieved_flops = flops_per_token_per_gpu * tokens/sec
    #         (divide by TP*PP since each GPU computes its shard)
    # Step 3: effective_flops = achieved * (1 - bubble_fraction)
    # Step 4: MFU = effective_flops / peak_flops
    # ==============================

    mfu = ???  # YOUR CODE HERE
    return mfu

In [None]:
#@title üéß Code Walkthrough: Todo2 Solution
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_19_todo2_solution.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


Here is the reference solution. The key subtlety is that each GPU only computes its shard of the model (divided by TP and PP), so `flops_per_token` must be divided accordingly.

In [None]:
# Verification ‚Äî Reference solution and check

def compute_mfu_solution(
    model: ModelConfig,
    config: ParallelismConfig,
    gpu_spec: GPUSpec,
    tokens_per_second_per_gpu: float,
    bubble_fraction: float = 0.0
) -> float:
    """Reference solution for MFU computation."""
    # FLOPS per token for the full model
    flops_per_token = 6 * model.total_params_B * 1e9

    # Each GPU processes its share (already counted in tokens_per_second)
    # But FLOPS per token is for the FULL model ‚Äî each GPU only computes
    # its shard, so we divide by (TP * PP)
    flops_per_token_per_gpu = flops_per_token / (config.n_tp * config.n_pp)

    # Achieved FLOPS per GPU
    achieved_flops = flops_per_token_per_gpu * tokens_per_second_per_gpu

    # Account for pipeline bubble
    effective_flops = achieved_flops * (1.0 - bubble_fraction)

    # Peak FLOPS (convert TFLOPS to FLOPS)
    peak_flops = gpu_spec.peak_tflops_fp16 * 1e12

    mfu = effective_flops / peak_flops
    return mfu

Let us test with Llama 3 405B's reported numbers and see if we land near Meta's 38-43% MFU.

In [None]:
# Llama 3 405B: Meta reports ~380 tokens/sec/GPU and 38-43% MFU
n_pp_llama = 16
n_microbatches_llama = 16
bubble_frac_llama = (n_pp_llama - 1) / (n_pp_llama - 1 + n_microbatches_llama)

# Tokens per second per GPU (estimated from Meta's paper)
# Total throughput: ~16M tokens/sec across 16,384 GPUs
tps_per_gpu = 16e6 / 16384  # ~977 tokens/sec/gpu

mfu = compute_mfu_solution(
    llama3_model, llama3_405b, GPU_SPECS["H100_80GB"],
    tokens_per_second_per_gpu=tps_per_gpu,
    bubble_fraction=bubble_frac_llama
)

print(f"Llama 3 405B ‚Äî MFU Estimation")
print(f"=" * 45)
print(f"  Pipeline bubble fraction:  {bubble_frac_llama:.1%}")
print(f"  Tokens/sec/GPU:            {tps_per_gpu:.0f}")
print(f"  Estimated MFU:             {mfu:.1%}")
print(f"  Meta's reported MFU:       38-43%")

if 0.30 <= mfu <= 0.50:
    print(f"\n  Our estimate is in the right ballpark!")
else:
    print(f"\n  Our estimate differs from reported ‚Äî that is expected")
    print(f"  (real MFU depends on many factors we simplified)")

In [None]:
#@title üéß Listen: Cheat Sheet
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_20_cheat_sheet.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 7. The Big Picture: Summary Table

Let us step back and see all five dimensions in one place. This is the table you should commit to memory.

In [None]:
# Print a comprehensive summary table
print("=" * 95)
print("  THE 5D PARALLELISM CHEAT SHEET")
print("=" * 95)

headers = ["Dimension", "Splits", "Why Needed", "Communication",
           "Placement", "Bandwidth"]
rows = [
    ["DP (Data)", "Training data", "Throughput",
     "AllReduce (grads)", "Entire cluster", "Low OK"],
    ["TP (Tensor)", "Weight matrices", "Layer too large",
     "AllReduce (activs)", "Within node", "NVLink 900 GB/s"],
    ["PP (Pipeline)", "Model depth", "Too many layers",
     "Point-to-Point", "Across nodes", "InfiniBand 50 GB/s"],
    ["SP (Sequence)", "Sequence length", "Long contexts",
     "Reduce-Scatter", "Within node", "NVLink 900 GB/s"],
    ["EP (Expert)", "MoE experts", "Specialist nets",
     "All-to-All", "Flexible", "Medium ~50 GB/s"],
]

Format and print the cheat sheet table along with the key insight.

In [None]:
col_widths = [14, 16, 16, 19, 15, 17]
header_line = " | ".join(h.ljust(w) for h, w in zip(headers, col_widths))
print(f"  {header_line}")
print(f"  {'-' * len(header_line)}")
for row in rows:
    row_line = " | ".join(val.ljust(w) for val, w in zip(row, col_widths))
    print(f"  {row_line}")

print("=" * 95)
print()
print("  Key Insight: Map communication-hungry dimensions to fast links.")
print("  TP (every layer) -> NVLink | PP (per micro-batch) -> InfiniBand")
print("  DP (once per step) -> Cluster-wide")

Now let us build a radar chart comparing the five dimensions across communication frequency, bandwidth requirements, memory savings, implementation complexity, and scalability.

In [None]:
# Visualization: Radar chart comparing the 5 dimensions

fig, ax = plt.subplots(1, 1, figsize=(8, 8),
                        subplot_kw=dict(polar=True))

categories = ['Comm\nFrequency', 'Bandwidth\nNeeded', 'Memory\nSaved',
              'Impl\nComplexity', 'Scalability']
N = len(categories)

# Scores for each dimension (1-5 scale)
dimension_scores = {
    'DP': [1, 1, 2, 1, 5],
    'TP': [5, 5, 4, 3, 2],
    'PP': [3, 3, 4, 4, 3],
    'SP': [5, 5, 3, 3, 2],
    'EP': [3, 3, 2, 4, 4],
}

dim_colors = {
    'DP': '#2196F3', 'TP': '#4CAF50', 'PP': '#FF9800',
    'SP': '#8BC34A', 'EP': '#9C27B0'
}

angles = [n / float(N) * 2 * np.pi for n in range(N)]
angles += angles[:1]  # close the polygon

We plot each dimension as a filled polygon on the radar chart.

In [None]:
for dim, scores in dimension_scores.items():
    values = scores + scores[:1]
    ax.plot(angles, values, 'o-', linewidth=2, label=dim,
           color=dim_colors[dim], markersize=6)
    ax.fill(angles, values, alpha=0.1, color=dim_colors[dim])

ax.set_xticks(angles[:-1])
ax.set_xticklabels(categories, fontsize=10)
ax.set_ylim(0, 5.5)
ax.set_yticks([1, 2, 3, 4, 5])
ax.set_yticklabels(['1', '2', '3', '4', '5'], fontsize=8)
ax.set_title("5D Parallelism ‚Äî Dimension Characteristics", fontsize=14,
             fontweight='bold', pad=25)
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1), fontsize=10)

plt.tight_layout()
# üìä Visualization: display the chart
plt.show()

In [None]:
#@title üéß Code Walkthrough: Full Planner
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_21_full_planner.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 8. Putting It All Together

We have now studied each parallelism dimension in isolation and seen the math behind composing them. In this section, we bring all five together into a single unified planner that takes any model and hardware setup and produces a complete training plan.

The key principle: **composition is multiplicative for GPUs but hierarchical for communication**. The total GPU count is simply $N_{DP} \times N_{TP} \times N_{PP} \times N_{EP}$, but the communication patterns nest according to hardware topology ‚Äî TP and SP ride the fastest NVLink bus within a node, PP uses InfiniBand between nearby nodes, and DP spans the entire cluster where latency tolerance is highest.

This is also where we see the power of the heuristic we built in TODO 1: set TP first (to fill the node), then find the minimum PP that fits memory, and give everything left to DP for maximum throughput. The planner below automates this entire workflow, adds communication and efficiency estimation, and presents a complete training recommendation.

The `full_5d_planner` function ties together every component: it calls `recommend_parallelism_solution` for the config, `compute_memory_per_gpu` for the memory breakdown, and `compute_communication_volume` for the comm analysis.

In [None]:
def full_5d_planner(
    model: ModelConfig,
    gpu_spec: GPUSpec,
    total_gpus: int,
    global_batch_size: int = 1024,
    zero_stage: int = 1,
    micro_batch_size: int = 1,
    verbose: bool = True
) -> Dict:
    """
    The complete 5D parallelism planner. Recommends optimal config
    and reports memory, communication, and efficiency estimates.
    """
    # Step 1: Recommend parallelism config
    config = recommend_parallelism_solution(model, total_gpus, gpu_spec)
    if config is None:
        if verbose:
            print(f"Could not find valid config for {model.name} "
                  f"on {total_gpus} {gpu_spec.name} GPUs")
        return {}

    # Step 2: Compute memory breakdown
    mem = compute_memory_per_gpu(
        model, config.n_dp, config.n_tp, config.n_pp, config.n_ep,
        micro_batch_size, zero_stage
    )

We continue with communication volume estimation, pipeline bubble analysis, and MFU projection.

In [None]:
    # Step 3: Compute communication volumes
    comm = compute_communication_volume(
        model, config.n_dp, config.n_tp, config.n_pp, config.n_ep,
        micro_batch_size, global_batch_size
    )

    # Step 4: Compute bubble fraction and MFU estimate
    n_microbatches = max(1, global_batch_size //
                        (config.n_dp * micro_batch_size))
    bubble_frac = ((config.n_pp - 1) /
                   (config.n_pp - 1 + n_microbatches)
                   if config.n_pp > 1 else 0.0)

    # Rough MFU estimate (35-45% is typical for well-optimized runs)
    estimated_mfu = 0.42 * (1 - bubble_frac) * (1 - 0.05)  # 5% comm overhead

    # Step 5: Estimate training time
    # Total FLOPS for training = 6 * P * T (P=params, T=total tokens)
    total_tokens = 15e12  # 15T tokens (typical for frontier models)
    total_flops = 6 * model.total_params_B * 1e9 * total_tokens

    achieved_flops_total = (total_gpus * gpu_spec.peak_tflops_fp16
                           * 1e12 * estimated_mfu)
    training_time_seconds = total_flops / achieved_flops_total
    training_days = training_time_seconds / 86400

Finally, we assemble the result dictionary and print the complete planner report.

In [None]:
    result = {
        "config": config,
        "memory": mem,
        "communication": comm,
        "bubble_fraction": bubble_frac,
        "estimated_mfu": estimated_mfu,
        "training_days": training_days,
    }

    if verbose:
        print(f"\n{'=' * 65}")
        print(f"  5D PARALLELISM PLANNER ‚Äî {model.name}")
        print(f"{'=' * 65}")

        print(f"\n  Model: {model.total_params_B:.0f}B params, "
              f"{model.num_layers} layers, "
              f"hidden={model.hidden_dim}, "
              f"heads={model.num_heads}")
        if model.num_experts > 1:
            print(f"  MoE: {model.num_experts} experts, "
                  f"{model.expert_params_B:.1f}B params each")
        print(f"  Hardware: {total_gpus:,} x {gpu_spec.name}")

Print the recommended parallelism split, per-GPU memory breakdown, and efficiency metrics.

In [None]:
        print(f"\n  Recommended Parallelism:")
        print(f"    DP={config.n_dp}, TP={config.n_tp}, PP={config.n_pp}, "
              f"SP={config.n_sp}, EP={config.n_ep}")
        print(f"    Total = {config.total_gpus:,} GPUs "
              f"({config.total_nodes:,} nodes)")

        print(f"\n  Memory Per GPU:")
        print(f"    Weights:      {mem['weights']:>7.2f} GB")
        print(f"    Gradients:    {mem['gradients']:>7.2f} GB")
        print(f"    Optimizer:    {mem['optimizer']:>7.2f} GB")
        print(f"    Activations:  {mem['activations']:>7.2f} GB")
        print(f"    Total:        {mem['total']:>7.2f} GB / "
              f"{gpu_spec.memory_gb:.0f} GB "
              f"({mem['total']/gpu_spec.memory_gb*100:.0f}%)")

        print(f"\n  Efficiency:")
        print(f"    Pipeline bubble:  {bubble_frac:.1%}")
        print(f"    Estimated MFU:    {estimated_mfu:.1%}")
        print(f"    Training time:    ~{training_days:.0f} days "
              f"(on 15T tokens)")
        print(f"\n{'=' * 65}")

    return result

In [None]:
#@title üéß Code Walkthrough: Planner Results
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_22_planner_results.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### 8.1 Running the Planner on Real Models

Now let us run the planner on our two reference models and a custom model to see the recommendations in action.

In [None]:
# Plan for Llama 3 405B
print("=" * 70)
print("  CASE 1: Llama 3 405B on 16,384 H100 GPUs")
print("=" * 70)
result_llama = full_5d_planner(
    llama3_model, GPU_SPECS["H100_80GB"],
    total_gpus=16384, global_batch_size=2048
)

In [None]:
# Plan for DeepSeek-V3
print("=" * 70)
print("  CASE 2: DeepSeek-V3 on 2,048 H800 GPUs")
print("=" * 70)
result_ds = full_5d_planner(
    deepseek_v3_model, GPU_SPECS["H800_80GB"],
    total_gpus=2048, global_batch_size=512
)

Now a custom 70B dense model on a smaller cluster, to show how the planner adapts to different scales.

In [None]:
# Plan for a custom model ‚Äî 70B dense on 256 GPUs
custom_model = ModelConfig(
    name="Custom 70B Dense",
    total_params_B=70.0,
    num_layers=80,
    hidden_dim=8192,
    num_heads=64,
    num_experts=1,
    expert_params_B=0.0,
    seq_len=8192,
    vocab_size=128000
)

print("=" * 70)
print("  CASE 3: Custom 70B Dense on 256 A100 GPUs")
print("=" * 70)
result_custom = full_5d_planner(
    custom_model, GPU_SPECS["A100_80GB"],
    total_gpus=256, global_batch_size=512
)

In [None]:
#@title üéß Code Walkthrough: Final Viz
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_23_final_viz.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### 8.2 Final Comparison Visualization

With all three planner results in hand, we create a four-panel comparison: GPU counts, memory utilization, parallelism breakdown, and efficiency metrics.

In [None]:
# Final visualization: Side-by-side comparison
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
models = ['Llama 3\n405B', 'DeepSeek\nV3', 'Custom\n70B']

# --- Panel 1: GPU count comparison ---
ax = axes[0, 0]
gpu_counts = [16384, 2048, 256]
bar_colors_models = ['#2196F3', '#4CAF50', '#FF9800']
bars = ax.bar(models, gpu_counts, color=bar_colors_models,
              edgecolor='white', linewidth=2)
for bar, count in zip(bars, gpu_counts):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 200,
           f'{count:,}', ha='center', fontsize=12, fontweight='bold')
ax.set_ylabel("Total GPUs", fontsize=12)
ax.set_title("GPU Count", fontsize=13, fontweight='bold')

Panel 2 shows per-GPU memory utilization (used vs. available) for each model.

In [None]:
# --- Panel 2: Memory utilization ---
ax = axes[0, 1]
mem_totals = [result_llama['memory']['total'],
              result_ds['memory']['total'],
              result_custom['memory']['total']]
bar_width = 0.35
x = np.arange(3)
ax.bar(x - bar_width/2, mem_totals, bar_width, label='Used',
       color='#F44336', alpha=0.8, edgecolor='white')
ax.bar(x + bar_width/2, [80, 80, 80], bar_width, label='Available',
       color='#E0E0E0', edgecolor='white')
ax.set_xticks(x)
ax.set_xticklabels(models, fontsize=10)
ax.set_ylabel("Memory (GB)", fontsize=12)
ax.set_title("Per-GPU Memory", fontsize=13, fontweight='bold')
ax.legend(fontsize=10)

Panels 3 and 4: parallelism degree breakdown (grouped bars on log scale) and efficiency metrics (bubble fraction vs. estimated MFU).

In [None]:
# --- Panel 3: Parallelism breakdown (grouped bar) ---
ax = axes[1, 0]
configs = [
    ('Llama 3 405B', 128, 8, 16, 1),
    ('DeepSeek-V3', 8, 1, 8, 32),
    ('Custom 70B', result_custom['config'].n_dp,
     result_custom['config'].n_tp,
     result_custom['config'].n_pp,
     result_custom['config'].n_ep),
]
x = np.arange(3)
dp_vals = [c[1] for c in configs]
tp_vals = [c[2] for c in configs]
pp_vals = [c[3] for c in configs]
ep_vals = [c[4] for c in configs]

w = 0.2
ax.bar(x - 1.5*w, dp_vals, w, label='DP', color='#2196F3')
ax.bar(x - 0.5*w, tp_vals, w, label='TP', color='#4CAF50')
ax.bar(x + 0.5*w, pp_vals, w, label='PP', color='#FF9800')
ax.bar(x + 1.5*w, ep_vals, w, label='EP', color='#9C27B0')
ax.set_xticks(x)
ax.set_xticklabels([c[0] for c in configs], fontsize=10)
ax.set_ylabel("Parallelism Degree", fontsize=12)
ax.set_title("Parallelism Breakdown", fontsize=13, fontweight='bold')
ax.set_yscale('log')
ax.legend(fontsize=10)

The efficiency panel shows how pipeline bubbles and MFU trade off across the three models.

In [None]:
# --- Panel 4: Efficiency metrics ---
ax = axes[1, 1]
results = [result_llama, result_ds, result_custom]
bubbles = [r['bubble_fraction'] for r in results]
mfus = [r['estimated_mfu'] for r in results]

x = np.arange(3)
ax.bar(x - 0.15, [b * 100 for b in bubbles], 0.3,
       label='Bubble %', color='#F44336', alpha=0.7)
ax.bar(x + 0.15, [m * 100 for m in mfus], 0.3,
       label='Est. MFU %', color='#4CAF50', alpha=0.7)
ax.set_xticks(x)
ax.set_xticklabels(models, fontsize=10)
ax.set_ylabel("Percentage (%)", fontsize=12)
ax.set_title("Efficiency Metrics", fontsize=13, fontweight='bold')
ax.legend(fontsize=10)

plt.suptitle("5D Parallelism Planner ‚Äî Model Comparison",
             fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
# üìä Visualization: display the chart
plt.show()

Let us also visualize the GPU grid for our custom 70B model configuration.

In [None]:
# Final 3D grid visualization for the custom model
if result_custom.get('config'):
    c = result_custom['config']
    visualize_3d_gpu_grid(
        n_dp=c.n_dp, n_tp=c.n_tp, n_pp=c.n_pp,
        title=f"Custom 70B Dense ‚Äî {c.total_gpus} GPU Grid"
    )

In [None]:
#@title üéß Wrap-Up: Closing
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_24_closing.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


And now, the moment we have been building toward for 6 notebooks:

## üéØ Final Output

This is the culmination of the entire 6-notebook series. We summarize the key numbers, the communication hierarchy, and the golden rule of 5D parallelism.

In [None]:
# The grand finale print
print()
print("=" * 70)
print()
print("  Congratulations!")
print()
print("  You now understand how modern LLMs are trained")
print("  across thousands of GPUs.")
print()
print("  You have built every parallelism dimension from scratch ‚Äî")
print("  from Data Parallelism to the full 5D grid.")
print()
print("  The next time you read that a model was trained on")
print("  16,000 GPUs, you know exactly what is happening.")
print()
print("=" * 70)
print()
print("  Series Complete: 5D Parallelism from Scratch")
print("  Notebooks 1-6 by Vizuara")
print()
print("=" * 70)

## 9. Reflection and Next Steps

### Reflection Questions

1. **Why is TP always used within a node and not across nodes?**
   Because TP requires AllReduce communication *within every single layer*. If the interconnect were slow (like Ethernet across nodes), the communication overhead would completely dominate compute time. NVLink provides 900 GB/s within a node ‚Äî roughly 18x faster than InfiniBand. That speed difference is why TP is always confined to a node.

2. **If you had 256 GPUs and a 70B dense model (not MoE), what parallelism config would you choose?**
   Run the planner above! The likely answer: TP=8 (one node), PP=2 or PP=4 (split the 80 layers), DP = 256 / (8 x PP). With PP=4, DP=8, giving 256 GPUs total. The exact split depends on your memory constraints and desired batch size.

3. **What are the trade-offs between using more PP stages vs more DP replicas?**
   More PP stages mean: lower per-GPU memory (fewer layers per GPU), but higher pipeline bubble overhead ($\frac{P-1}{P-1+M}$). More DP replicas mean: higher throughput and larger effective batch size, but each GPU must still hold enough layers to fit in memory. The art is finding the sweet spot where memory fits and the bubble is small.

### Optional Challenges

1. **Gradient Checkpointing**: Modify the `compute_memory_per_gpu` function to support a `checkpointing_ratio` parameter that controls what fraction of layers are checkpointed. How does this change the recommended PP?

2. **Cost Estimator**: Add a cloud cost estimator to the planner. H100 instances cost approximately $3.50/GPU/hour on-demand. How much would training Llama 3 405B cost? (Spoiler: it is measured in millions of dollars.)

3. **Research Deep Dive**: Compare 3D parallelism between Megatron-LM and DeepSpeed. What are the key differences in their PP schedules? How does Megatron's interleaved 1F1B compare to DeepSpeed's zero-bubble pipeline?

In [None]:
# Quick cost estimate for Llama 3 405B
cost_per_gpu_hour = 3.50  # USD, H100 on-demand
num_gpus = 16384
training_days = 54  # approximate
training_hours = training_days * 24

total_cost = cost_per_gpu_hour * num_gpus * training_hours
print(f"Estimated training cost for Llama 3 405B:")
print(f"  {num_gpus:,} GPUs x {training_hours:,} hours x "
      f"${cost_per_gpu_hour}/GPU/hr")
print(f"  = ${total_cost:,.0f}")
print(f"  = ~${total_cost / 1e6:.1f} million")
print(f"\nFor context, DeepSeek-V3 reportedly cost ~$5.5M ‚Äî")
print(f"roughly {total_cost / 5.5e6:.0f}x less than Llama 3 405B.")

### Series Recap: 5D Parallelism from Scratch

Congratulations on completing all 6 notebooks! Here is what we covered:

| Notebook | Topic | Key Takeaway |
|----------|-------|-------------|
| **1** | Why Parallelism | A 7B model needs 112 GB for training ‚Äî a single GPU is not enough |
| **2** | Data Parallelism & ZeRO | Replicate model, split data, AllReduce gradients. ZeRO eliminates redundancy |
| **3** | Tensor Parallelism | Split weight matrices column-wise and row-wise. Needs NVLink |
| **4** | Pipeline Parallelism | Split layers into stages. Micro-batching reduces the bubble |
| **5** | Sequence & Expert Parallelism | SP splits the sequence (shares TP group). EP routes tokens to specialist experts |
| **6** | The 5D Grid | All 5 compose: N = DP x TP x PP x EP. Match dimensions to the hardware hierarchy |

The core lesson: **every parallelism dimension solves a specific bottleneck**, and the art of distributed training is composing them to match the communication hierarchy of your hardware.

In [None]:
# Final summary ‚Äî one last look at the numbers

print("5D PARALLELISM ‚Äî THE NUMBERS THAT MATTER")
print("=" * 60)
print()
print("  Llama 3 405B:")
print("    128 DP x 8 TP x 16 PP = 16,384 H100 GPUs")
print("    ~38-43% MFU | ~54 days training | ~$40M+ estimated")
print()
print("  DeepSeek-V3 (671B, 256 experts):")
print("    8 DP x 1 TP x 8 PP x 32 EP = 2,048 H800 GPUs")
print("    ~$5.5M training cost (remarkably efficient)")
print()
print("  Communication Hierarchy:")
print("    NVLink  900 GB/s  ->  TP, SP  (within node)")
print("    IB NDR   50 GB/s  ->  PP, EP  (across nodes)")
print("    Cluster  ~25 GB/s ->  DP      (everywhere)")
print()
print("  The Golden Rule:")
print("  Map communication-hungry ops to fast links.")
print("  TP every layer -> NVLink")
print("  PP every micro-batch -> InfiniBand")
print("  DP once per step -> anything")
print()
print("=" * 60)
print("  Thank you for learning with Vizuara!")
print("=" * 60)