This tutorial is going to show you how to use PyTorch's FSDP implementation. We're going to train two models: (i) a shallow model, and (ii) a deep model with local training. Then, we're going to show how FSDP speeds up our training, and decreases memory.

In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms # for datasets
from tqdm import tqdm # for progress bar

In [35]:
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

## Local training

In [36]:
class ShallowNet(nn.Module):
  def __init__(self):
    super(ShallowNet, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 64, 3, 1)
    self.dropout1 = nn.Dropout(0.25)
    self.dropout2 = nn.Dropout(0.5)
    self.fc1 = nn.Linear(9216, 128)
    self.fc2 = nn.Linear(128, 10)

  def forward(self, x):
    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.relu(x)
    x = F.max_pool2d(x, 2)
    x = self.dropout1(x)
    x = torch.flatten(x, 1)
    x = self.fc1(x)
    x = F.relu(x)
    x = self.dropout2(x)
    x = self.fc2(x)
    output = F.log_softmax(x, dim=1)
    return output
  
class DeepNet(nn.Module):
  def __init__(self):
    super(DeepNet, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 64, 3, 1)
    self.dropout1 = nn.Dropout(0.25)
    self.dropout2 = nn.Dropout(0.5)
    self.fc1 = nn.Linear(9216, 9000)
    self.fc1a = nn.Linear(9000, 1000)
    self.fc1b = nn.Linear(1000, 1000)
    self.fc1c = nn.Linear(1000, 1000)
    self.fc1d = nn.Linear(1000, 128)
    self.fc2 = nn.Linear(128, 10)

  def forward(self, x):
    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.relu(x)
    x = F.max_pool2d(x, 2)
    x = self.dropout1(x)
    x = torch.flatten(x, 1)
    x = self.fc1(x)
    x = F.relu(x)
    x = self.dropout2(x)
    x = self.fc1a(x)
    x = self.fc1b(x)
    x = self.fc1c(x)
    x = self.fc1d(x)
    x = self.fc2(x)
    output = F.log_softmax(x, dim=1)
    return output
  
print("# params:")
print(f"  shallow: {count_params(ShallowNet())}")
print(f"  deep: {count_params(DeepNet())}")

# params:
  shallow: 1199882
  deep: 94104234


In [37]:
_transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.1307,), (0.3081,))
  ])

train_data = datasets.MNIST('../data', train=True, download=True, transform=_transform)
test_data = datasets.MNIST('../data', train=False, transform=_transform)

train_data, test_data 

(Dataset MNIST
     Number of datapoints: 60000
     Root location: ../data
     Split: Train
     StandardTransform
 Transform: Compose(
                ToTensor()
                Normalize(mean=(0.1307,), std=(0.3081,))
            ),
 Dataset MNIST
     Number of datapoints: 10000
     Root location: ../data
     Split: Test
     StandardTransform
 Transform: Compose(
                ToTensor()
                Normalize(mean=(0.1307,), std=(0.3081,))
            ))

In [49]:
train_batch_size = 256
test_batch_size = 1000


loader_kwargs = {
    'num_workers': 2,
    'pin_memory': True,
    'shuffle': False,
    'drop_last': True,
    'persistent_workers': True, # warning: on Allen HPC, disabling this massively slows down training
}

train_loader = DataLoader(train_data, batch_size=train_batch_size, **loader_kwargs)
test_loader = DataLoader(test_data, batch_size=test_batch_size, **loader_kwargs)

next(iter(train_loader)), next(iter(test_loader))

([tensor([[[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            ...,
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242]]],
  
  
          [[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            ...,
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242]]],
  
  
          [[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242

In [101]:
def train(model, train_loader, opt, device):
  total_loss = 0
  num_batches = 0

  model.train()
  for (data, target) in tqdm(train_loader, total=len(train_loader)):
    data, target = data.to(device), target.to(device)
    opt.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target, reduction='sum')
    memory = torch.cuda.memory_allocated(device) / 1e6 if torch.cuda.is_available() else 'N/A' # memory in MB
    loss.backward()
    opt.step()
    total_loss += loss.item()
    num_batches += 1

  avg_loss = total_loss / (num_batches * train_loader.batch_size) # average loss per datapoint
  return {
    'avg_loss': avg_loss,
    'memory': memory
  }

def test(model, test_loader, device):
  total_loss = 0
  num_batches = 0
  total_correct = 0
  num_datapoints = 0

  model.eval()
  with torch.no_grad():
    for (data, target) in tqdm(test_loader, total=len(test_loader)):
      data, target = data.to(device), target.to(device)
      output = model(data)
      total_loss += F.nll_loss(output, target, reduction='sum').item()
      pred = output.argmax(dim=1, keepdim=True)
      total_correct += pred.eq(target.view_as(pred)).sum().item()
      num_datapoints += len(data)
      num_batches += 1
  
  avg_loss = total_loss / (num_batches * test_loader.batch_size) # average loss per datapoint
  accuracy = total_correct / num_datapoints
  return {
    'avg_loss': avg_loss, 
    'total_correct': total_correct, 
    'num_datapoints': num_datapoints, 
    'accuracy': accuracy
    }

In [102]:
epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
lr = 0.001

shallow_model = ShallowNet().to(device)
shallow_optimizer = optim.Adam(shallow_model.parameters(), lr=lr)

print("> training shallow model\n")
for epoch in range(epochs):
  train_losses = train(shallow_model, train_loader, shallow_optimizer, device)
  test_losses = test(shallow_model, test_loader, device)
  print(f"epoch: {epoch} | train loss: {train_losses['avg_loss']:.4f} | test loss: {test_losses['avg_loss']:.4f} | test acc: {test_losses['total_correct']}/{test_losses['num_datapoints']} ({test_losses['accuracy']:.4f}) | memory: {train_losses['memory']} MB")
  print('-' * 100)

> training shallow model



100%|██████████| 234/234 [00:02<00:00, 86.32it/s]
100%|██████████| 10/10 [00:00<00:00, 32.07it/s]


epoch: 0 | train loss: 0.2456 | test loss: 0.0529 | test acc: 9833/10000 (0.9833) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 88.83it/s]
100%|██████████| 10/10 [00:00<00:00, 57.27it/s]


epoch: 1 | train loss: 0.0810 | test loss: 0.0384 | test acc: 9866/10000 (0.9866) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 89.39it/s]
100%|██████████| 10/10 [00:00<00:00, 56.68it/s]


epoch: 2 | train loss: 0.0628 | test loss: 0.0316 | test acc: 9896/10000 (0.9896) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 89.03it/s]
100%|██████████| 10/10 [00:00<00:00, 55.99it/s]


epoch: 3 | train loss: 0.0501 | test loss: 0.0293 | test acc: 9895/10000 (0.9895) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 85.63it/s]
100%|██████████| 10/10 [00:00<00:00, 54.99it/s]


epoch: 4 | train loss: 0.0439 | test loss: 0.0309 | test acc: 9898/10000 (0.9898) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 88.13it/s]
100%|██████████| 10/10 [00:00<00:00, 55.85it/s]


epoch: 5 | train loss: 0.0369 | test loss: 0.0314 | test acc: 9900/10000 (0.9900) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 86.92it/s]
100%|██████████| 10/10 [00:00<00:00, 56.14it/s]


epoch: 6 | train loss: 0.0310 | test loss: 0.0297 | test acc: 9896/10000 (0.9896) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 87.90it/s]
100%|██████████| 10/10 [00:00<00:00, 53.66it/s]


epoch: 7 | train loss: 0.0298 | test loss: 0.0401 | test acc: 9878/10000 (0.9878) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 83.84it/s]
100%|██████████| 10/10 [00:00<00:00, 53.43it/s]


epoch: 8 | train loss: 0.0270 | test loss: 0.0320 | test acc: 9907/10000 (0.9907) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 82.25it/s]
100%|██████████| 10/10 [00:00<00:00, 37.57it/s]

epoch: 9 | train loss: 0.0260 | test loss: 0.0325 | test acc: 9895/10000 (0.9895) | memory: N/A MB
----------------------------------------------------------------------------------------------------





In [103]:
deep_model = DeepNet().to(device)
deep_optimizer = optim.Adam(deep_model.parameters(), lr=lr)

print("> training deep model\n")
for epoch in range(epochs):
  train_losses = train(deep_model, train_loader, deep_optimizer, device)
  test_losses = test(deep_model, test_loader, device)
  print(f"epoch: {epoch} | train loss: {train_losses['avg_loss']:.4f} | test loss: {test_losses['avg_loss']:.4f} | test acc: {test_losses['total_correct']}/{test_losses['num_datapoints']} ({test_losses['accuracy']:.4f}) | memory: {train_losses['memory']} MB")
  print('-' * 100)

> training deep model



100%|██████████| 234/234 [00:23<00:00, 10.12it/s]
100%|██████████| 10/10 [00:00<00:00, 13.82it/s]


epoch: 0 | train loss: 0.3304 | test loss: 0.0587 | test acc: 9798/10000 (0.9798) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:23<00:00, 10.17it/s]
100%|██████████| 10/10 [00:00<00:00, 13.89it/s]


epoch: 1 | train loss: 0.0649 | test loss: 0.0628 | test acc: 9814/10000 (0.9814) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:23<00:00, 10.04it/s]
100%|██████████| 10/10 [00:00<00:00, 14.63it/s]


epoch: 2 | train loss: 0.0406 | test loss: 0.0481 | test acc: 9864/10000 (0.9864) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:23<00:00, 10.10it/s]
100%|██████████| 10/10 [00:00<00:00, 14.36it/s]


epoch: 3 | train loss: 0.0325 | test loss: 0.0480 | test acc: 9869/10000 (0.9869) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:23<00:00,  9.90it/s]
100%|██████████| 10/10 [00:00<00:00, 12.91it/s]


epoch: 4 | train loss: 0.0303 | test loss: 0.0563 | test acc: 9843/10000 (0.9843) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:23<00:00, 10.09it/s]
100%|██████████| 10/10 [00:00<00:00, 14.38it/s]


epoch: 5 | train loss: 0.0233 | test loss: 0.0478 | test acc: 9895/10000 (0.9895) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:23<00:00, 10.17it/s]
100%|██████████| 10/10 [00:00<00:00, 14.52it/s]


epoch: 6 | train loss: 0.0253 | test loss: 0.0436 | test acc: 9894/10000 (0.9894) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:23<00:00, 10.10it/s]
100%|██████████| 10/10 [00:00<00:00, 14.18it/s]


epoch: 7 | train loss: 0.0237 | test loss: 0.0385 | test acc: 9903/10000 (0.9903) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:22<00:00, 10.19it/s]
100%|██████████| 10/10 [00:00<00:00, 14.20it/s]


epoch: 8 | train loss: 0.0213 | test loss: 0.0410 | test acc: 9901/10000 (0.9901) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:22<00:00, 10.29it/s]
100%|██████████| 10/10 [00:00<00:00, 14.31it/s]

epoch: 9 | train loss: 0.0228 | test loss: 0.0451 | test acc: 9914/10000 (0.9914) | memory: N/A MB
----------------------------------------------------------------------------------------------------





## Distributed training

In [105]:
import socket
import os
from functools import partial

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
import torch.multiprocessing as mp
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
)
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

# assert torch.cuda.is_available(), "CUDA is not available. You must have CUDA enabled to use distributed training."

Unfortunately, distributed training in PyTorch requires a nontrivial amount of ceremony.

In [90]:
def get_free_addr():
  return socket.gethostbyname_ex(socket.gethostname())[2][0]
    
def get_free_port(addr):
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
    s.bind((addr, 0))
    s.listen(1)
    port = s.getsockname()[1]
  return port

def setup_distributed():
  addr = get_free_addr()
  port = get_free_port(addr)
  os.environ['MASTER_ADDR'] = str(addr)
  os.environ['MASTER_PORT'] = str(port)
  if 'a100' in torch.cuda.get_device_name().lower(): # needed for training with A100's on Allen HPC
    os.environ['NCCL_P2P_LVL'] = 'NVL'
  dist.init_process_group(backend='nccl', rank=dist.get_rank(), world_size=dist.get_world_size())

def cleanup_distributed():
  dist.destroy_process_group()

def is_master():
  return dist.get_rank() == 0

In [None]:
def train_dist(model, train_loader, opt):
  rank = dist.get_rank()
  
  total_loss = torch.zeros(1).to(rank)
  num_batches = torch.zeros(1).to(rank)

  model.train()
  bar = tqdm(train_loader, total=len(train_loader)) if is_master() else train_loader
  for (data, target) in bar:
    data, target = data.to(rank), target.to(rank)
    opt.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target, reduction='sum')
    memory = torch.cuda.memory_allocated(rank) / 1e6 if torch.cuda.is_available() else 'N/A'
    loss.backward()
    opt.step()
    total_loss += loss
    num_batches += 1
  
  dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  dist.all_reduce(num_batches, op=dist.ReduceOp.SUM)
  avg_loss = total_loss.item() / (num_batches.item() * train_loader.batch_size)
  return {
    'avg_loss': avg_loss,
    'memory': memory
  }

def test_dist(model, test_loader):
  rank = dist.get_rank()

  total_loss = torch.zeros(1).to(rank)
  num_batches = torch.zeros(1).to(rank)
  total_correct = torch.zeros(1).to(rank)
  num_datapoints = torch.zeros(1).to(rank)

  model.eval()
  with torch.no_grad():
    for (data, target) in test_loader:
      data, target = data.to(rank), target.to(rank)
      output = model(data)
      total_loss += F.nll_loss(output, target, reduction='sum')
      pred = output.argmax(dim=1, keepdim=True)
      total_correct += pred.eq(target.view_as(pred)).sum()
      num_datapoints += len(data)
      num_batches += 1

  dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  dist.all_reduce(num_batches, op=dist.ReduceOp.SUM)
  dist.all_reduce(total_correct, op=dist.ReduceOp.SUM)
  dist.all_reduce(num_datapoints, op=dist.ReduceOp.SUM)
  avg_loss = total_loss.item() / (num_batches.item() * test_loader.batch_size)
  accuracy = total_correct.item() / num_datapoints.item()
  return {
    'avg_loss': avg_loss,
    'total_correct': total_correct.item(),
    'num_datapoints': num_datapoints.item(),
    'accuracy': accuracy
  }


In [106]:
def main(rank, world_size):
  train_batch_size = 256
  test_batch_size = 1000
  lr = 0.001
  epochs = 10
  min_num_params = 20000
  auto_wrap_policy = partial(size_based_auto_wrap_policy, min_num_params=min_num_params)
  
  
  setup_distributed()

  _transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.1307,), (0.3081,))
  ])
  train_data = datasets.MNIST('../data', train=True, download=True, transform=_transform)
  test_data = datasets.MNIST('../data', train=False, transform=_transform)

  train_sampler = DistributedSampler(train_data, num_replicas=world_size, rank=rank)
  test_sampler = DistributedSampler(test_data, num_replicas=world_size, rank=rank)

  loader_kwargs = {
      'num_workers': 2,
      'pin_memory': True,
      'shuffle': False,
      'drop_last': True,
      'persistent_workers': True # warning: on Allen HPC, disabling this massively slows down training
  }

  train_loader_dist = DataLoader(train_data, batch_size=train_batch_size, sampler=train_sampler, **loader_kwargs)
  test_loader_dist = DataLoader(test_data, batch_size=test_batch_size, sampler=test_sampler, **loader_kwargs)

  shallow_model = ShallowNet().to(rank)
  shallow_model = FSDP(shallow_model, auto_wrap_policy=auto_wrap_policy)
  shallow_optimizer = optim.Adam(shallow_model.parameters(), lr=lr)
  if is_master():
    print("> training shallow model\n")
  for epoch in range(epochs):
    train_sampler.set_epoch(epoch)
    train_losses = train_dist(shallow_model, train_loader_dist, shallow_optimizer)

    test_sampler.set_epoch(epoch)
    test_losses = test_dist(shallow_model, test_loader_dist)
    if is_master():
      print(f"epoch: {epoch} | train loss: {train_losses['avg_loss']:.4f} | test loss: {test_losses['avg_loss']:.4f} | test acc: {test_losses['total_correct']}/{test_losses['num_datapoints']} ({test_losses['accuracy']:.4f}) | memory: {train_losses['memory']} MB")
      print('-' * 100)
    
  deep_model = DeepNet().to(rank)
  deep_model = FSDP(deep_model, auto_wrap_policy=auto_wrap_policy)
  deep_optimizer = optim.Adam(deep_model.parameters(), lr=lr)
  if is_master():
    print("> training deep model\n")
  for epoch in range(epochs):
    train_sampler.set_epoch(epoch)
    train_losses = train_dist(deep_model, train_loader_dist, deep_optimizer)

    test_sampler.set_epoch(epoch)
    test_losses = test_dist(deep_model, test_loader_dist)
    if is_master():
      print(f"epoch: {epoch} | train loss: {train_losses['avg_loss']:.4f} | test loss: {test_losses['avg_loss']:.4f} | test acc: {test_losses['total_correct']}/{test_losses['num_datapoints']} ({test_losses['accuracy']:.4f}) | memory: {train_losses['memory']} MB")
      print('-' * 100)

  cleanup_distributed()
  

In [None]:
mp.spawn(main, args=(torch.cuda.device_count(),), nprocs=torch.cuda.device_count(), join=True)