## Introduction

This tutorial is going to show you how to use PyTorch's fully-sharded data parallel (FSDP) implementation. We're going to train two models: (i) a shallow model, and (ii) a deep model. We will first train both models with local (no distributed/FSDP) training on a standard benchmark dataset (we chose MNIST for simplicity). Then, we will show the necessary code changes in order to use FSDP in PyTorch. In both local and distributed training, we log the time and memory, as well as losses. We will observe how FSDP training both decreases training time, and decreases memory consumption compared to local training. 

## Local training

We start off with local training. Please make sure you have installed the following before proceeding:
- `torch` >= 2.0 with CUDA
- `tqdm`

This should be installed for you if you use the conda environment from `environment_312.yml` (instructions in the README). 

Additionally, the implementation of distributed training shown in this tutorial will require you to have at least 2 NVIDIA GPU available on your machine. You can check this with `nvidia-smi`. 

In [1]:
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms # for datasets
from tqdm import tqdm # for progress bar

First, we define our models. We will be using
- a shallow model, ~1.2M params
- a deep model, ~94M params

In [2]:
class ShallowNet(nn.Module):
  def __init__(self):
    super(ShallowNet, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 64, 3, 1)
    self.dropout1 = nn.Dropout(0.25)
    self.dropout2 = nn.Dropout(0.5)
    self.fc1 = nn.Linear(9216, 128)
    self.fc2 = nn.Linear(128, 10)

  def forward(self, x):
    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.relu(x)
    x = F.max_pool2d(x, 2)
    x = self.dropout1(x)
    x = torch.flatten(x, 1)
    x = self.fc1(x)
    x = F.relu(x)
    x = self.dropout2(x)
    x = self.fc2(x)
    output = F.log_softmax(x, dim=1)
    return output
  
class DeepNet(nn.Module):
  def __init__(self):
    super(DeepNet, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 64, 3, 1)
    self.dropout1 = nn.Dropout(0.25)
    self.dropout2 = nn.Dropout(0.5)
    self.fc1 = nn.Linear(9216, 9000)
    self.fc1a = nn.Linear(9000, 1000)
    self.fc1b = nn.Linear(1000, 1000)
    self.fc1c = nn.Linear(1000, 1000)
    self.fc1d = nn.Linear(1000, 128)
    self.fc2 = nn.Linear(128, 10)

  def forward(self, x):
    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.relu(x)
    x = F.max_pool2d(x, 2)
    x = self.dropout1(x)
    x = torch.flatten(x, 1)
    x = self.fc1(x)
    x = F.relu(x)
    x = self.dropout2(x)
    x = self.fc1a(x)
    x = self.fc1b(x)
    x = self.fc1c(x)
    x = self.fc1d(x)
    x = self.fc2(x)
    output = F.log_softmax(x, dim=1)
    return output
  
# helper function
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
  
print("# params:")
print(f"  shallow: {count_params(ShallowNet())}")
print(f"  deep: {count_params(DeepNet())}")

# params:
  shallow: 1199882
  deep: 94104234


We will use the standard MNIST benchmark dataset. This choice is arbitrary and made for simplicity.

In [3]:
_transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.1307,), (0.3081,))
  ])

train_data = datasets.MNIST('../data', train=True, download=True, transform=_transform)
test_data = datasets.MNIST('../data', train=False, transform=_transform)

train_data, test_data 

(Dataset MNIST
     Number of datapoints: 60000
     Root location: ../data
     Split: Train
     StandardTransform
 Transform: Compose(
                ToTensor()
                Normalize(mean=(0.1307,), std=(0.3081,))
            ),
 Dataset MNIST
     Number of datapoints: 10000
     Root location: ../data
     Split: Test
     StandardTransform
 Transform: Compose(
                ToTensor()
                Normalize(mean=(0.1307,), std=(0.3081,))
            ))

Now we construct our dataloaders. Notice how we set `'persistent_workers': True`. If we set this to `False`, our training time massively slows down on Allen HPC.

Additionally, one can optimize their per-epoch training time by tuning `num_workers`, but we do not go into that in this tutorial.

In [4]:
train_batch_size = 256
test_batch_size = 1000


loader_kwargs = {
    'num_workers': 2,
    'pin_memory': True,
    'shuffle': False,
    'drop_last': True,
    'persistent_workers': True, # warning: on some SLURM clusters, disabling this massively slows down training
}

train_loader = DataLoader(train_data, batch_size=train_batch_size, **loader_kwargs)
test_loader = DataLoader(test_data, batch_size=test_batch_size, **loader_kwargs)

next(iter(train_loader)), next(iter(test_loader))

([tensor([[[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            ...,
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242]]],
  
  
          [[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            ...,
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242]]],
  
  
          [[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242

Finally, out `train()` and `test()` functions. These functions will compute the per-sample loss for each batch of data. The `train()` function will additionally compute the memory allocated, if training on a NVIDIA GPU. Lastly, we use a `tqdm` progress bar during training to monitor our training time.

In [5]:
def train(model, train_loader, opt, device):
  total_loss = 0
  num_batches = 0

  model.train()
  for (data, target) in tqdm(train_loader, total=len(train_loader)):
    data, target = data.to(device), target.to(device)
    opt.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target, reduction='sum')
    memory = torch.cuda.memory_allocated(device) / 1e6 if torch.cuda.is_available() else 'N/A' # memory in MB
    loss.backward()
    opt.step()
    total_loss += loss.item()
    num_batches += 1

  avg_loss = total_loss / (num_batches * train_loader.batch_size) # average loss per datapoint
  return {
    'avg_loss': avg_loss,
    'memory': memory
  }

def test(model, test_loader, device):
  total_loss = 0
  num_batches = 0
  total_correct = 0
  num_datapoints = 0

  model.eval()
  with torch.no_grad():
    for (data, target) in tqdm(test_loader, total=len(test_loader)):
      data, target = data.to(device), target.to(device)
      output = model(data)
      total_loss += F.nll_loss(output, target, reduction='sum').item()
      pred = output.argmax(dim=1, keepdim=True)
      total_correct += pred.eq(target.view_as(pred)).sum().item()
      num_datapoints += len(data)
      num_batches += 1
  
  avg_loss = total_loss / (num_batches * test_loader.batch_size) # average loss per datapoint
  accuracy = total_correct / num_datapoints
  return {
    'avg_loss': avg_loss, 
    'total_correct': total_correct, 
    'num_datapoints': num_datapoints, 
    'accuracy': accuracy
    }

First, we train the shallow model.

In [6]:
epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
lr = 0.001

shallow_model = ShallowNet().to(device)
shallow_optimizer = optim.Adam(shallow_model.parameters(), lr=lr)

print(f"> training shallow model on {device}\n")
for epoch in range(epochs):
  start_time = time.time()
  train_losses = train(shallow_model, train_loader, shallow_optimizer, device)
  test_losses = test(shallow_model, test_loader, device)
  total_time = time.time() - start_time
  print(f"epoch: {epoch} | train loss: {train_losses['avg_loss']:.4f} | test loss: {test_losses['avg_loss']:.4f} | test acc: {test_losses['total_correct']}/{test_losses['num_datapoints']} ({test_losses['accuracy']:.4f}) | memory: {train_losses['memory']:.2f} MB | time: {total_time:.2f}s")
  print('-' * 100)

> training shallow model on cuda



100%|██████████| 234/234 [00:04<00:00, 53.89it/s]
100%|██████████| 10/10 [00:00<00:00, 11.56it/s]


epoch: 0 | train loss: 0.2755 | test loss: 0.0539 | test acc: 9827/10000 (0.9827) | memory: 124.96 MB | time: 5.21s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:04<00:00, 56.46it/s]
100%|██████████| 10/10 [00:00<00:00, 11.74it/s]


epoch: 1 | train loss: 0.0869 | test loss: 0.0427 | test acc: 9855/10000 (0.9855) | memory: 124.96 MB | time: 5.00s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:04<00:00, 56.69it/s]
100%|██████████| 10/10 [00:00<00:00, 11.53it/s]


epoch: 2 | train loss: 0.0672 | test loss: 0.0326 | test acc: 9892/10000 (0.9892) | memory: 124.96 MB | time: 5.00s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:04<00:00, 56.31it/s]
100%|██████████| 10/10 [00:00<00:00, 11.44it/s]


epoch: 3 | train loss: 0.0548 | test loss: 0.0365 | test acc: 9884/10000 (0.9884) | memory: 124.96 MB | time: 5.03s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:04<00:00, 56.31it/s]
100%|██████████| 10/10 [00:00<00:00, 11.33it/s]


epoch: 4 | train loss: 0.0456 | test loss: 0.0324 | test acc: 9886/10000 (0.9886) | memory: 124.96 MB | time: 5.04s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:04<00:00, 56.18it/s]
100%|██████████| 10/10 [00:00<00:00, 11.26it/s]


epoch: 5 | train loss: 0.0418 | test loss: 0.0338 | test acc: 9896/10000 (0.9896) | memory: 124.96 MB | time: 5.06s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:04<00:00, 56.56it/s]
100%|██████████| 10/10 [00:00<00:00, 11.03it/s]


epoch: 6 | train loss: 0.0361 | test loss: 0.0344 | test acc: 9898/10000 (0.9898) | memory: 124.96 MB | time: 5.05s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:04<00:00, 55.89it/s]
100%|██████████| 10/10 [00:00<00:00, 11.31it/s]


epoch: 7 | train loss: 0.0340 | test loss: 0.0316 | test acc: 9901/10000 (0.9901) | memory: 124.96 MB | time: 5.07s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:04<00:00, 56.13it/s]
100%|██████████| 10/10 [00:00<00:00, 11.26it/s]


epoch: 8 | train loss: 0.0278 | test loss: 0.0395 | test acc: 9878/10000 (0.9878) | memory: 124.96 MB | time: 5.06s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:04<00:00, 55.64it/s]
100%|██████████| 10/10 [00:00<00:00, 11.38it/s]

epoch: 9 | train loss: 0.0272 | test loss: 0.0308 | test acc: 9918/10000 (0.9918) | memory: 124.96 MB | time: 5.09s
----------------------------------------------------------------------------------------------------





Next, we train the deep model.

In [7]:
deep_model = DeepNet().to(device)
deep_optimizer = optim.Adam(deep_model.parameters(), lr=lr)

print(f"> training deep model on {device}\n")
for epoch in range(epochs):
  start_time = time.time()
  train_losses = train(deep_model, train_loader, deep_optimizer, device)
  test_losses = test(deep_model, test_loader, device)
  total_time = time.time() - start_time
  print(f"epoch: {epoch} | train loss: {train_losses['avg_loss']:.4f} | test loss: {test_losses['avg_loss']:.4f} | test acc: {test_losses['total_correct']}/{test_losses['num_datapoints']} ({test_losses['accuracy']:.4f}) | memory: {train_losses['memory']:.2f} MB | time: {total_time:.2f}s")
  print('-' * 100)

del train_loader, test_loader # cleanup resources

> training deep model on cuda



100%|██████████| 234/234 [00:05<00:00, 42.26it/s]
100%|██████████| 10/10 [00:00<00:00, 11.49it/s]


epoch: 0 | train loss: 0.2750 | test loss: 0.0673 | test acc: 9809/10000 (0.9809) | memory: 1283.19 MB | time: 6.41s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:05<00:00, 42.91it/s]
100%|██████████| 10/10 [00:00<00:00, 11.43it/s]


epoch: 1 | train loss: 0.0632 | test loss: 0.0515 | test acc: 9856/10000 (0.9856) | memory: 1283.19 MB | time: 6.33s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:05<00:00, 42.89it/s]
100%|██████████| 10/10 [00:00<00:00, 11.43it/s]


epoch: 2 | train loss: 0.0426 | test loss: 0.0542 | test acc: 9856/10000 (0.9856) | memory: 1283.19 MB | time: 6.33s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:05<00:00, 42.72it/s]
100%|██████████| 10/10 [00:00<00:00, 14.31it/s]


epoch: 3 | train loss: 0.0389 | test loss: 0.0664 | test acc: 9832/10000 (0.9832) | memory: 1283.19 MB | time: 6.18s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:05<00:00, 42.67it/s]
100%|██████████| 10/10 [00:00<00:00, 14.35it/s]


epoch: 4 | train loss: 0.0316 | test loss: 0.0538 | test acc: 9878/10000 (0.9878) | memory: 1283.19 MB | time: 6.18s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:05<00:00, 42.69it/s]
100%|██████████| 10/10 [00:00<00:00, 14.12it/s]


epoch: 5 | train loss: 0.0291 | test loss: 0.0595 | test acc: 9856/10000 (0.9856) | memory: 1283.19 MB | time: 6.19s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:05<00:00, 42.78it/s]
100%|██████████| 10/10 [00:00<00:00, 11.06it/s]


epoch: 6 | train loss: 0.0288 | test loss: 0.0572 | test acc: 9875/10000 (0.9875) | memory: 1283.19 MB | time: 6.38s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:05<00:00, 42.79it/s]
100%|██████████| 10/10 [00:00<00:00, 14.04it/s]


epoch: 7 | train loss: 0.2260 | test loss: 0.2105 | test acc: 9449/10000 (0.9449) | memory: 1283.19 MB | time: 6.18s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:05<00:00, 42.81it/s]
100%|██████████| 10/10 [00:00<00:00, 14.34it/s]


epoch: 8 | train loss: 0.1073 | test loss: 0.0414 | test acc: 9884/10000 (0.9884) | memory: 1283.19 MB | time: 6.17s
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:05<00:00, 42.78it/s]
100%|██████████| 10/10 [00:00<00:00, 14.27it/s]

epoch: 9 | train loss: 0.0417 | test loss: 0.0514 | test acc: 9864/10000 (0.9864) | memory: 1283.19 MB | time: 6.17s
----------------------------------------------------------------------------------------------------





## Distributed training

We now move onto distributed training with FSDP. Note, this code currently is setup for training for training on multiple NVIDIA GPUs. Please make sure you have multiple NVIDIA GPUs available on your device (try `nvidia-smi` in the shell) before proceeding.

In [8]:
!nvidia-smi 

Wed Aug 21 11:03:56 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  |   00000000:61:00.0 Off |                    0 |
| N/A   49C    P0             61W /  300W |    2676MiB /  32768MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla V100-SXM2-32GB           On  |   00

First, we have our additional imports for FSDP.

In [9]:
import socket
import os
from functools import partial

import torch.distributed as dist # for distributed communication
from torch.utils.data.distributed import DistributedSampler # for distributed data across GPUs
import torch.multiprocessing as mp # for spawning processes on each GPU
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP, # FSDP constructor for sharding model parameters
)
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy # for FSDP configuration

assert torch.cuda.is_available(), "CUDA is not available. You must have CUDA enabled to use distributed training."

Before we proceed, some basic vocabulary:

`world_size` is the number of processes you will spawn for distributed training. This should be the GPUs on your device. So if you have 4 GPUs on your node/machine, you want to set `world_size` to 4.

Then, `rank` is the ID of the current process. In order to use distributed rather than local training, we will spawn a separate Python process on each GPU using PyTorch's multiprocessing API. Then, we number processes by the GPU they are spawned on. For example, if you have 4 GPUs, one can think of their GPUs as `[GPU0, GPU1, GPU2, GPU3]`. Then, on each of `GPU0`, `GPU1`, `GPU2`, `GPU3`, one spawns a separate process. These processes are ID'd by their `rank`. So the process on GPU0 is called "rank 0", the process on GPU1 is called "rank 1", etc.

We will use "world size", "`world_size`", and "number of GPUs" interchangeably in this tutorial. Additionally, we will use "`rank`", "rank", "process" and "spawned process" interchangeably.

After initializing PyTorch's distributed backend using `torch.distributed.init_process_group()`, one can always compute the rank of the current process via `torch.distributed.get_rank()`. Similarly, one can compute the world size (that is, the number of processes spawned, which should be the number of GPUs available on one's device) via `torch.distributed.get_world_size()`. The world size is the same across ranks. 

Finally, before the program exits, one should call `torch.distributed.destroy_process_group()` in order to cleanup any remaining resources. This is simply what PyTorch requires us to go through in order to use distributed training with `torch.distributed`

Unfortunately, distributed training in PyTorch requires a nontrivial amount of ceremony. The next cell contains some utility functions to initialize PyTorch's distributed backend. There is more than one way to initialize the distributed backend for distributed training in PyTorch. We chose this implementation for simplicity. The most important parts are: setting necessary environmental variables (`MASTER_ADDR` and `MASTER_PORT`) before initializing a process group for distributed training, calling `torch.distributed.init_process_group()` to initialize distributed training after spawning multiple processes, and calling `torch.distributed.destroy_process_group()` after distributed training to cleanup resources.

Notice, we have some additional code to handle if our GPU is an NVIDIA A100. When testing FSDP on some SLURM HPC clusters, we found we needed to set the environment flag `NCCL_P2P_LVL` to value `NVL` in order for PyTorch's distributed backend to not timeout when using NVIDIA A100's. When using A100's on other machines and cloud GPU providers (such as Lambda Labs), setting the `NCCL_P2P_LVL` environmental variable was not necessary. Additionally, setting this environmental variable was not necessary when using other NVIDIA GPUs (V100, TitanXP, etc.) on SLURM HPC clusters.

In [10]:
# utility functions to initialize distributed training

def get_free_addr():
  return socket.gethostbyname_ex(socket.gethostname())[2][0]
    
def get_free_port(addr):
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
    s.bind((addr, 0))
    s.listen(1)
    port = s.getsockname()[1]
  return port

# initialize PyTorch's distributed backend
def setup_distributed(rank, world_size, addr, port):
  os.environ['MASTER_ADDR'] = str(addr) # address that ranks will use to communicate
  os.environ['MASTER_PORT'] = str(port) # port that ranks will use to communicate
  if 'a100' in torch.cuda.get_device_name().lower(): # needed for training with A100's on some SLURM clusters
    os.environ['NCCL_P2P_LVL'] = 'NVL'
  dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)

# cleanup PyTorch's distributed backend
def cleanup_distributed():
  dist.destroy_process_group()

# we only want to print on the master rank (rank 0) to avoid duplicate logging. This is a convention commonly followed
# when doing distributed training to reduce logging clutter
def is_master():
  return dist.get_rank() == 0

We then modify our `train()` and `test()` functions. Our modified `train()` function is called `train_dist()`. Similarly, `test()` is now `test_dist()`.

We now store our losses as `torch.Tensor` objects on each rank. This is so we can sum our losses across ranks using `torch.distributed.all_reduce()` to get the average losses-per-sample across all samples across all ranks. The `all_reduce()` function is simply an element-wise sum of tensors. For example, suppose you have 4 ranks: rank 0, rank 1, rank 2, rank 3. And a tensor `t` on each rank. If `t` is `torch.Tensor([2])` on rank 0, `torch.Tensor([3])` on rank 1, `torch.Tensor([9])` on rank 2, and `torch.tensor([1])` on rank 3, then after calling `torch.distributed.all_reduce(t)`, then `t` will be `torch.Tensor([2 + 3 + 1 + 9 = 15])` on rank 0, `torch.Tensor([15])` on rank 1, and so on.

Additionally, we now only display the progress bar on rank 0 to reduce clutter.

You may notice this is not the most concise or succinct code possible for distributed training. Our goal here was readability. You are more than welcome to alter this implementation when doing your distributed training as you see fit.

In [11]:
def train_dist(model, train_loader, opt):
  rank = dist.get_rank() # new
  
  total_loss = torch.zeros(1).to(rank) # updated
  num_batches = torch.zeros(1).to(rank) # updated

  model.train()
  bar = tqdm(train_loader, total=len(train_loader)) if is_master() else train_loader # updated: we only want a progress bar on the master rank
  for (data, target) in bar: # updated
    data, target = data.to(rank), target.to(rank) # updated
    opt.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target, reduction='sum')
    memory = torch.cuda.memory_allocated(rank) / 1e6 if torch.cuda.is_available() else 'N/A'
    loss.backward()
    opt.step()
    total_loss += loss
    num_batches += 1
  
  dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) # new
  dist.all_reduce(num_batches, op=dist.ReduceOp.SUM) # new
  avg_loss = total_loss.item() / (num_batches.item() * train_loader.batch_size) # updated
  return {
    'avg_loss': avg_loss,
    'memory': memory
  }

def test_dist(model, test_loader):
  rank = dist.get_rank() # new

  total_loss = torch.zeros(1).to(rank) # updated
  num_batches = torch.zeros(1).to(rank) # updated
  total_correct = torch.zeros(1).to(rank) # updated
  num_datapoints = torch.zeros(1).to(rank) # updated

  model.eval()
  with torch.no_grad():
    bar = tqdm(test_loader, total=len(test_loader)) if is_master() else test_loader
    for (data, target) in bar: # updated: we only want a progress bar on the master rank
      data, target = data.to(rank), target.to(rank)
      output = model(data)
      total_loss += F.nll_loss(output, target, reduction='sum')
      pred = output.argmax(dim=1, keepdim=True)
      total_correct += pred.eq(target.view_as(pred)).sum()
      num_datapoints += len(data)
      num_batches += 1

  dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  dist.all_reduce(num_batches, op=dist.ReduceOp.SUM)
  dist.all_reduce(total_correct, op=dist.ReduceOp.SUM)
  dist.all_reduce(num_datapoints, op=dist.ReduceOp.SUM)
  avg_loss = total_loss.item() / (num_batches.item() * test_loader.batch_size)
  accuracy = total_correct.item() / num_datapoints.item()
  return {
    'avg_loss': avg_loss,
    'total_correct': total_correct.item(),
    'num_datapoints': num_datapoints.item(),
    'accuracy': accuracy
  }

Next, we have now added a `main()` function. We now have to encapsulate all of our training logic in a single function in order to do distributed training in PyTorch. This is the entry-point for each spawned process/rank.

The most significant changes in `main()` from the training code earlier in the notebook are:
- initializing and cleaning up the distributed process group (line 13, line 75)
- constructing a `DistributedSampler` object to distribute data across ranks for data-parallelism (lines 24-25, lines 36-37)
- sharding the model across devices with the `FSDP` constructor (line 42, line 60)

Using a `DistributedSampler` allows us to incorporate data-parallelism in our training loop, in addition to sharding model parameters across ranks. Data-parallelism speeds up training, similar to using `DistributedDataParallelism`. Sharding model parameters across ranks decreases memory consumption.

Pay attention to lines 9 and 10. We must tell PyTorch how to shard our model parameters across ranks via a "wrapping policy." In addition to this wrapping policy, the `FullyShardedDataParallelism` implementation in PyTorch has *many* hyperparameters you can tune to optimize the performance of the FSDP algorithm for your model's particular architecture and training loop. In addition, `FullyShardedDataParallelism` in PyTorch allows you to combine FSDP with other large-scale model training techniques, such as mixed-precision training, multi-dimensional parallelism (so combining with training strategies such as *tensor parallelism* or *pipeline parallelism*), and JIT compilation with `torch.compile()`. We do not go into any of these advanced techniques in this tutorial. If you have questions on using these techniques, you're more than welcome to reach out to me (email at the end of this notebook).

Additionally, there are some other small code changes in this block, not mentioned in this explanation. Please pay close attention to anything labeled `new` or `updated`.

In [12]:
def main(rank, world_size, addr, port):
  # training config, same as before
  train_batch_size = 256
  test_batch_size = 1000
  lr = 0.001
  epochs = 10
  
  # new: some config for FSDP
  min_num_params = 20000
  auto_wrap_policy = partial(size_based_auto_wrap_policy, min_num_params=min_num_params)
  
  
  setup_distributed(rank, world_size, addr, port) # new: initialize PyTorch's distributed backend

  torch.cuda.set_device(rank) # new: set the device to the current rank

  _transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.1307,), (0.3081,))
  ])
  train_data = datasets.MNIST('../data', train=True, download=True, transform=_transform)
  test_data = datasets.MNIST('../data', train=False, transform=_transform)

  train_sampler = DistributedSampler(train_data, num_replicas=world_size, rank=rank) # new: use DistributedSampler to distribute data
                                                                                     #      across all ranks
  test_sampler = DistributedSampler(test_data, num_replicas=world_size, rank=rank)

  loader_kwargs = {
      'num_workers': 2,
      'pin_memory': True,
      'shuffle': False,
      'drop_last': True,
      'persistent_workers': True # warning: on Allen HPC, disabling this massively slows down training
  }

  train_loader_dist = DataLoader(train_data, batch_size=train_batch_size, sampler=train_sampler, **loader_kwargs) # updated: make sure to pass in the sampler
  test_loader_dist = DataLoader(test_data, batch_size=test_batch_size, sampler=test_sampler, **loader_kwargs)

  # First, we train the shallow model

  shallow_model = ShallowNet().to(rank)
  shallow_model = FSDP(shallow_model, auto_wrap_policy=auto_wrap_policy) # new: wrap the model with FSDP. This will shard the model's parameters across ranks during training
  shallow_optimizer = optim.Adam(shallow_model.parameters(), lr=lr) # make sure to construct the optimizer AFTER wrapping the model with FSDP, as we want to update the 
                                                                    # parameters of the sharded model, not  original model
  if is_master():
    print("> training shallow model\n") # updated: only print on the master rank
  for epoch in range(epochs):
    start_time = time.time()
    train_sampler.set_epoch(epoch) # new: we must ensure each rank works on a different partition of the same batch of data
    train_losses = train_dist(shallow_model, train_loader_dist, shallow_optimizer)

    test_sampler.set_epoch(epoch) # new: we must ensure each rank works on a different partition of the same batch of data
    test_losses = test_dist(shallow_model, test_loader_dist)
    total_time = time.time() - start_time
    if is_master(): # updated: only print on the master rank
      print(f"epoch: {epoch} | train loss: {train_losses['avg_loss']:.4f} | test loss: {test_losses['avg_loss']:.4f} | test acc: {test_losses['total_correct']}/{test_losses['num_datapoints']} ({test_losses['accuracy']:.4f}) | memory: {train_losses['memory']:.2f} MB | time: {total_time:.2f}s")
      print('-' * 100)
    
  # Next, we train the deep model. Most of what applied to the shallow model applies here, too.

  deep_model = DeepNet().to(rank)
  deep_model = FSDP(deep_model, auto_wrap_policy=auto_wrap_policy)
  deep_optimizer = optim.Adam(deep_model.parameters(), lr=lr)
  if is_master():
    print("> training deep model\n")
  for epoch in range(epochs):
    start_time = time.time()
    train_sampler.set_epoch(epoch)
    train_losses = train_dist(deep_model, train_loader_dist, deep_optimizer)

    test_sampler.set_epoch(epoch)
    test_losses = test_dist(deep_model, test_loader_dist)
    total_time = time.time() - start_time
    if is_master():
      print(f"epoch: {epoch} | train loss: {train_losses['avg_loss']:.4f} | test loss: {test_losses['avg_loss']:.4f} | test acc: {test_losses['total_correct']}/{test_losses['num_datapoints']} ({test_losses['accuracy']:.4f}) | memory: {train_losses['memory']:.2f} MB | time: {total_time:.2f}s")
      print('-' * 100)

  cleanup_distributed() # new: cleanup PyTorch's distributed backend to release resources

Finally, we spawn a separate process for each GPU using PyTorch's multiprocessing API.

We define our master address and master port once, and pass the same address and port to each process as an argument to `main()`.

We call `fsdp_main()` from the `fsdp_tutorial.py` (notice the extension) Python script rather than the `main()` function in this notebook as a workaround to use `multiprocesing` in a Jupyter notebook. We must have the function we want to attach processes to with `multiprocessing.spawn()` be defined a separate file, which we then import into this notebook. If we try calling `mp.spawn(main, ...)`, you will notice that we get an error. Worry not, the code in `fsdp_tutorial.py` is exactly the same as in this notebook.

Additionally, on some SLURM HPC clusters, the first epoch of distributed training may be very slow compared to the other epochs. This is due to to overhead of spawning multiple workers in the `Dataloader`. This will not affect the training time of the rest of your epochs, though. The rest of your epochs should train much more quickly than local training.

In [13]:
from fsdp_tutorial import main as fsdp_main

world_size = torch.cuda.device_count() # new: number of GPUs
addr = get_free_addr()
port = get_free_port(addr)
mp.spawn(fsdp_main, args=(world_size, addr, port), nprocs=world_size, join=True) # new: spawn a process on each GPU

> training shallow model



100%|██████████| 58/58 [00:17<00:00,  3.40it/s]
100%|██████████| 2/2 [00:11<00:00,  5.68s/it]
  2%|▏         | 1/58 [00:00<00:06,  8.89it/s]

epoch: 0 | train loss: 0.5351 | test loss: 0.1191 | test acc: 7712.0/8000.0 (0.9640) | memory: 115.16 MB | time: 28.41s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:01<00:00, 40.93it/s]
100%|██████████| 2/2 [00:00<00:00, 11.80it/s]
  0%|          | 0/58 [00:00<?, ?it/s]

epoch: 1 | train loss: 0.1504 | test loss: 0.0589 | test acc: 7846.0/8000.0 (0.9808) | memory: 115.16 MB | time: 1.59s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:01<00:00, 36.50it/s]
100%|██████████| 2/2 [00:00<00:00,  9.75it/s]
  2%|▏         | 1/58 [00:00<00:08,  6.73it/s]

epoch: 2 | train loss: 0.0989 | test loss: 0.0475 | test acc: 7877.0/8000.0 (0.9846) | memory: 115.16 MB | time: 1.80s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:01<00:00, 38.82it/s]
100%|██████████| 2/2 [00:00<00:00, 11.42it/s]
  2%|▏         | 1/58 [00:00<00:07,  7.42it/s]

epoch: 3 | train loss: 0.0775 | test loss: 0.0388 | test acc: 7884.0/8000.0 (0.9855) | memory: 115.16 MB | time: 1.67s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:01<00:00, 42.71it/s]
100%|██████████| 2/2 [00:00<00:00, 11.38it/s]
  0%|          | 0/58 [00:00<?, ?it/s]

epoch: 4 | train loss: 0.0671 | test loss: 0.0363 | test acc: 7901.0/8000.0 (0.9876) | memory: 115.16 MB | time: 1.55s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:01<00:00, 39.54it/s]
100%|██████████| 2/2 [00:00<00:00,  9.81it/s]
  2%|▏         | 1/58 [00:00<00:07,  7.19it/s]

epoch: 5 | train loss: 0.0574 | test loss: 0.0374 | test acc: 7906.0/8000.0 (0.9882) | memory: 115.16 MB | time: 1.68s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:01<00:00, 42.19it/s]
100%|██████████| 2/2 [00:00<00:00,  9.88it/s]
  0%|          | 0/58 [00:00<?, ?it/s]

epoch: 6 | train loss: 0.0521 | test loss: 0.0279 | test acc: 7920.0/8000.0 (0.9900) | memory: 115.16 MB | time: 1.58s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:01<00:00, 33.41it/s]
100%|██████████| 2/2 [00:00<00:00, 12.76it/s]
  0%|          | 0/58 [00:00<?, ?it/s]

epoch: 7 | train loss: 0.0459 | test loss: 0.0319 | test acc: 7916.0/8000.0 (0.9895) | memory: 115.16 MB | time: 1.91s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:01<00:00, 41.43it/s]
100%|██████████| 2/2 [00:00<00:00, 12.57it/s]
  2%|▏         | 1/58 [00:00<00:08,  7.06it/s]

epoch: 8 | train loss: 0.0427 | test loss: 0.0299 | test acc: 7922.0/8000.0 (0.9902) | memory: 115.16 MB | time: 1.60s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:01<00:00, 39.11it/s]
100%|██████████| 2/2 [00:00<00:00, 11.83it/s]


epoch: 9 | train loss: 0.0397 | test loss: 0.0293 | test acc: 7928.0/8000.0 (0.9910) | memory: 115.16 MB | time: 1.69s
----------------------------------------------------------------------------------------------------
> training deep model



100%|██████████| 58/58 [00:02<00:00, 24.78it/s]
100%|██████████| 2/2 [00:00<00:00, 10.12it/s]
  0%|          | 0/58 [00:00<?, ?it/s]

epoch: 0 | train loss: 0.9412 | test loss: 0.0637 | test acc: 7848.0/8000.0 (0.9810) | memory: 422.25 MB | time: 2.56s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:02<00:00, 25.36it/s]
100%|██████████| 2/2 [00:00<00:00, 10.21it/s]
  0%|          | 0/58 [00:00<?, ?it/s]

epoch: 1 | train loss: 0.0641 | test loss: 0.0376 | test acc: 7902.0/8000.0 (0.9878) | memory: 422.25 MB | time: 2.50s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:02<00:00, 25.13it/s]
100%|██████████| 2/2 [00:00<00:00, 10.46it/s]
  0%|          | 0/58 [00:00<?, ?it/s]

epoch: 2 | train loss: 0.0345 | test loss: 0.0372 | test acc: 7901.0/8000.0 (0.9876) | memory: 422.25 MB | time: 2.52s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:02<00:00, 25.00it/s]
100%|██████████| 2/2 [00:00<00:00, 10.39it/s]
  0%|          | 0/58 [00:00<?, ?it/s]

epoch: 3 | train loss: 0.0226 | test loss: 0.0331 | test acc: 7906.0/8000.0 (0.9882) | memory: 422.25 MB | time: 2.53s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:02<00:00, 25.02it/s]
100%|██████████| 2/2 [00:00<00:00,  9.88it/s]
  0%|          | 0/58 [00:00<?, ?it/s]

epoch: 4 | train loss: 0.0149 | test loss: 0.0368 | test acc: 7915.0/8000.0 (0.9894) | memory: 422.25 MB | time: 2.54s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:02<00:00, 24.99it/s]
100%|██████████| 2/2 [00:00<00:00, 10.20it/s]
  0%|          | 0/58 [00:00<?, ?it/s]

epoch: 5 | train loss: 0.0137 | test loss: 0.0329 | test acc: 7927.0/8000.0 (0.9909) | memory: 422.25 MB | time: 2.54s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:02<00:00, 25.17it/s]
100%|██████████| 2/2 [00:00<00:00, 10.46it/s]
  0%|          | 0/58 [00:00<?, ?it/s]

epoch: 6 | train loss: 0.0106 | test loss: 0.0357 | test acc: 7923.0/8000.0 (0.9904) | memory: 422.25 MB | time: 2.52s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:02<00:00, 25.20it/s]
100%|██████████| 2/2 [00:00<00:00, 10.62it/s]
  0%|          | 0/58 [00:00<?, ?it/s]

epoch: 7 | train loss: 0.0101 | test loss: 0.0355 | test acc: 7921.0/8000.0 (0.9901) | memory: 422.25 MB | time: 2.51s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:02<00:00, 25.16it/s]
100%|██████████| 2/2 [00:00<00:00, 10.08it/s]
  0%|          | 0/58 [00:00<?, ?it/s]

epoch: 8 | train loss: 0.0072 | test loss: 0.0360 | test acc: 7925.0/8000.0 (0.9906) | memory: 422.25 MB | time: 2.52s
----------------------------------------------------------------------------------------------------


100%|██████████| 58/58 [00:02<00:00, 25.05it/s]
100%|██████████| 2/2 [00:00<00:00, 10.56it/s]


epoch: 9 | train loss: 0.0090 | test loss: 0.0355 | test acc: 7923.0/8000.0 (0.9904) | memory: 422.25 MB | time: 2.53s
----------------------------------------------------------------------------------------------------


## Conclusion

And we're done! Hopefully you've now learned how to use FSDP training with PyTorch, including how to setup PyTorch's distributed backend, and how you can shard your model's parameters across GPU devices with PyTorch's FSDP API.

If you want to learn more, PyTorch's FSDP tutorials are pretty good:
- https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
- https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html

The FSDP whitepaper is also pretty good. It includes a lot of implementation details you can use to speedup your training:
- https://arxiv.org/pdf/2304.11277

The documentation for FSDP can answer a lot of your questions:
- https://pytorch.org/docs/stable/fsdp.html

For anything the above resources can't answer, the FSDP source is publicly available:
- https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/fully_sharded_data_parallel.py/

Note, that in PyTorch 2.4, the PyTorch team is updating FSDP to "FSDP 2.0". So everything shown in this tutorial (as well as all the above resources) are for "FSDP 1.0". As of August 2024, the FSDP 2.0 API is experimental. However, if you want to learn more about it, you may find these resources useful:
- https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md (FSDP 2.0 overview)
- https://github.com/pytorch/pytorch/blob/main/torch/distributed/_composable/fully_shard.py (FSDP 2.0 source)

If you have any questions, please email me at `hilal.mufti@alleninstitute.org` or `hmufti@cs.washington.edu`. I'm more than happy to help with any distributed training questions you may still have.

### References

This tutorial expands on the ideas and code in this official PyTorch tutorial: https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html.