# Day 24: GPipe - Efficient Training of Giant Neural Networks

> Huang et al. (2018/2019) - [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965)

### What You'll Learn:
1. **Model Partitioning**: How to logically split a sequential model across $K$ stages.
2. **Micro-batching**: The mechanics of slicing mini-batches to minimize the 'Pipeline Bubble'.
3. **Pipeline Scheduling**: Visualizing the flow of data through stages (Section 3.1).
4. **Efficiency Math**: Calculating theoretical bubble overhead vs. empirical timing.
5. **Re-materialization**: Implementing activation checkpointing to break the Memory Wall (Section 3.2).
6. **Gradient Equivalence**: Proving that synchronous updates match standard SGD precisely.
7. **Memory Scaling**: Benchmarking peak memory reduction on simulated 'giant' networks.

## 1. Setup & Environment
We use PyTorch for the core logic and Matplotlib for visualizing the pipeline schedule.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import time
from implementation import GPipe, get_peak_memory, summarize_results
from visualization import plot_pipeline_schedule

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
print(f"Experiment running on: {device.type.upper()}")

## 2. The Micro-batching Logic (Section 3.1)
In GPipe, we don't pass the whole batch to a partition. We slice it into $M$ micro-batches. 
Note how `torch.chunk` handles the slicing across the batch dimension (Dim 0).

In [None]:
batch_size = 128
n_microbatches = 8
hidden_dim = 512

mini_batch = torch.randn(batch_size, hidden_dim)
micro_batches = torch.chunk(mini_batch, n_microbatches, dim=0)

print(f"Full Batch:  {list(mini_batch.shape)}")
print(f"Micro-batch: {list(micro_batches[0].shape)}")
print(f"Total MBs:   {len(micro_batches)}")

## 3. Visualizing the Pipeline Schedule
Wait time (the 'bubble') occurs at the beginning and end of the pipeline. 
We want $M$ (micro-batches) to be much larger than $K$ (partitions).

In [None]:
K, M = 4, 16
print(f"Theoretical Efficiency: {M / (M + K - 1):.2%}")

plot_pipeline_schedule(n_partitions=K, n_microbatches=M, output_dir='plots')
# The resulting Gantt chart shows how devices spend 90% of their time active.

## 4. Re-materialization: Breaking the Memory Wall
Standard training stores all intermediate activations ($O(L)$). 
GPipe with Re-materialization stores only inputs to partitions ($O(L/K)$).

We build a 'Giant' 60-layer model to test this.

In [None]:
def build_giant_model(layers=60, dim=1024):
    return nn.Sequential(*[nn.Linear(dim, dim) for _ in range(layers)])

giant_model = build_giant_model()

# 1. GPipe WITHOUT Checkpointing
model_no_ckpt = GPipe(giant_model, n_partitions=4, n_microbatches=4, use_checkpoint=False).to(device)

# 2. GPipe WITH Checkpointing (The Paper's Proposed Method)
model_with_ckpt = GPipe(giant_model, n_partitions=4, n_microbatches=4, use_checkpoint=True).to(device)

print("Models ready for benchmarking.")

## 5. Empirical Memory Benchmark
We measure peak memory during a forward+backward pass. 
*Note: Results are most visible on GPU devices.*

In [None]:
x = torch.randn(64, 1024).to(device)
y = torch.randn(64, 1024).to(device)
criterion = nn.MSELoss()

def run_step(model, data, target):
    if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats()
    out = model(data)
    loss = criterion(out, target)
    loss.backward()
    return get_peak_memory()

mem_no = run_step(model_no_ckpt, x, y)
mem_yes = run_step(model_with_ckpt, x, y)

print(f"Peak Memory (No Checkpoint):   {mem_no:.2f} MB")
print(f"Peak Memory (With Checkpoint): {mem_yes:.2f} MB")
print(f"Reduction Factor:               {mem_no/mem_yes if mem_yes > 0 else 0:.2f}x")

## 6. Proving Gradient Equivalence
One of GPipe's strongest claims is that it is mathematically identical to sequential training. 
We verify this by comparing gradients of a GPipe model vs a standard Sequential model.

In [None]:
# Create two identical models
base_model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10))
seq_model = base_model.to(device)
piped_model = GPipe(base_model, n_partitions=2, n_microbatches=4).to(device)

test_input = torch.randn(16, 10).to(device)

# Forward / Backward on Sequential
seq_out = seq_model(test_input)
seq_out.sum().backward()
seq_grads = [p.grad.clone() for p in seq_model.parameters()]

# Reset
for p in seq_model.parameters(): p.grad.zero_() 

# Forward / Backward on GPipe
piped_out = piped_model(test_input)
piped_out.sum().backward()
piped_grads = [p.grad.clone() for p in piped_model.parameters()]

# Compare
diff = sum(torch.norm(s - p) for s, p in zip(seq_grads, piped_grads))
print(f"Total Gradient Difference: {diff.item():.2e}")
assert diff < 1e-5, "Gradients do not match!"

## 7. Operational Summary & Best Practices

1. **Partition Balancing**: If Stage 2 is 2x slower than Stage 1, the whole pipeline throttles to Stage 2's speed.
2. **Communication Frequency**: Increasing $M$ helps utilization but increases the number of cross-device syncs. 
3. **Synchronous Pipelining**: GPipe's synchronous nature makes it reliable for research, unlike asynchronous methods that might introduce 'Stale Gradient' noise.