This tutorial is going to show you how to use PyTorch's fully-sharded data parallel (FSDP) implementation. We're going to train two models: (i) a shallow model, and (ii) a deep model with local training. We will observe how FSDP decreases training time, and memory consumption compared to local training.

In [107]:
# standard imports

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms # for datasets
from tqdm import tqdm # for progress bar

## Local training

First, we define our models. We will be using
- a shallow model, ~1.2M params
- a deep model, ~94M params

In [115]:
class ShallowNet(nn.Module):
  def __init__(self):
    super(ShallowNet, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 64, 3, 1)
    self.dropout1 = nn.Dropout(0.25)
    self.dropout2 = nn.Dropout(0.5)
    self.fc1 = nn.Linear(9216, 128)
    self.fc2 = nn.Linear(128, 10)

  def forward(self, x):
    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.relu(x)
    x = F.max_pool2d(x, 2)
    x = self.dropout1(x)
    x = torch.flatten(x, 1)
    x = self.fc1(x)
    x = F.relu(x)
    x = self.dropout2(x)
    x = self.fc2(x)
    output = F.log_softmax(x, dim=1)
    return output
  
class DeepNet(nn.Module):
  def __init__(self):
    super(DeepNet, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 64, 3, 1)
    self.dropout1 = nn.Dropout(0.25)
    self.dropout2 = nn.Dropout(0.5)
    self.fc1 = nn.Linear(9216, 9000)
    self.fc1a = nn.Linear(9000, 1000)
    self.fc1b = nn.Linear(1000, 1000)
    self.fc1c = nn.Linear(1000, 1000)
    self.fc1d = nn.Linear(1000, 128)
    self.fc2 = nn.Linear(128, 10)

  def forward(self, x):
    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.relu(x)
    x = F.max_pool2d(x, 2)
    x = self.dropout1(x)
    x = torch.flatten(x, 1)
    x = self.fc1(x)
    x = F.relu(x)
    x = self.dropout2(x)
    x = self.fc1a(x)
    x = self.fc1b(x)
    x = self.fc1c(x)
    x = self.fc1d(x)
    x = self.fc2(x)
    output = F.log_softmax(x, dim=1)
    return output
  
# helper function
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
  
print("# params:")
print(f"  shallow: {count_params(ShallowNet())}")
print(f"  deep: {count_params(DeepNet())}")

# params:
  shallow: 1199882
  deep: 94104234


We will use the standard MNIST benchmark dataset.

In [110]:
_transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.1307,), (0.3081,))
  ])

train_data = datasets.MNIST('../data', train=True, download=True, transform=_transform)
test_data = datasets.MNIST('../data', train=False, transform=_transform)

train_data, test_data 

(Dataset MNIST
     Number of datapoints: 60000
     Root location: ../data
     Split: Train
     StandardTransform
 Transform: Compose(
                ToTensor()
                Normalize(mean=(0.1307,), std=(0.3081,))
            ),
 Dataset MNIST
     Number of datapoints: 10000
     Root location: ../data
     Split: Test
     StandardTransform
 Transform: Compose(
                ToTensor()
                Normalize(mean=(0.1307,), std=(0.3081,))
            ))

Now we construct our dataloaders. Notice how we set `'persistent_worrkers': True`. If we set this to `False`, our training time massively slows down on Allen HPC.

In [111]:
train_batch_size = 256
test_batch_size = 1000


loader_kwargs = {
    'num_workers': 2,
    'pin_memory': True,
    'shuffle': False,
    'drop_last': True,
    'persistent_workers': True, # warning: on Allen HPC, disabling this massively slows down training
}

train_loader = DataLoader(train_data, batch_size=train_batch_size, **loader_kwargs)
test_loader = DataLoader(test_data, batch_size=test_batch_size, **loader_kwargs)

next(iter(train_loader)), next(iter(test_loader))

([tensor([[[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            ...,
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242]]],
  
  
          [[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            ...,
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
            [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242]]],
  
  
          [[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242

Finally, out `train()` and `test()` functions. These functions will compute the per-sample loss for each batch of data. The `train()` function will additionally compute the memory allocated, if training on a NVIDIA GPU. Lastly, we use a `tqdm` progress bar during training to monitor our per-sample training time.

In [112]:
def train(model, train_loader, opt, device):
  total_loss = 0
  num_batches = 0

  model.train()
  for (data, target) in tqdm(train_loader, total=len(train_loader)):
    data, target = data.to(device), target.to(device)
    opt.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target, reduction='sum')
    memory = torch.cuda.memory_allocated(device) / 1e6 if torch.cuda.is_available() else 'N/A' # memory in MB
    loss.backward()
    opt.step()
    total_loss += loss.item()
    num_batches += 1

  avg_loss = total_loss / (num_batches * train_loader.batch_size) # average loss per datapoint
  return {
    'avg_loss': avg_loss,
    'memory': memory
  }

def test(model, test_loader, device):
  total_loss = 0
  num_batches = 0
  total_correct = 0
  num_datapoints = 0

  model.eval()
  with torch.no_grad():
    for (data, target) in tqdm(test_loader, total=len(test_loader)):
      data, target = data.to(device), target.to(device)
      output = model(data)
      total_loss += F.nll_loss(output, target, reduction='sum').item()
      pred = output.argmax(dim=1, keepdim=True)
      total_correct += pred.eq(target.view_as(pred)).sum().item()
      num_datapoints += len(data)
      num_batches += 1
  
  avg_loss = total_loss / (num_batches * test_loader.batch_size) # average loss per datapoint
  accuracy = total_correct / num_datapoints
  return {
    'avg_loss': avg_loss, 
    'total_correct': total_correct, 
    'num_datapoints': num_datapoints, 
    'accuracy': accuracy
    }

First, we train the shallow model.

In [113]:
epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
lr = 0.001

shallow_model = ShallowNet().to(device)
shallow_optimizer = optim.Adam(shallow_model.parameters(), lr=lr)

print("> training shallow model\n")
for epoch in range(epochs):
  train_losses = train(shallow_model, train_loader, shallow_optimizer, device)
  test_losses = test(shallow_model, test_loader, device)
  print(f"epoch: {epoch} | train loss: {train_losses['avg_loss']:.4f} | test loss: {test_losses['avg_loss']:.4f} | test acc: {test_losses['total_correct']}/{test_losses['num_datapoints']} ({test_losses['accuracy']:.4f}) | memory: {train_losses['memory']} MB")
  print('-' * 100)

> training shallow model



100%|██████████| 234/234 [00:02<00:00, 79.76it/s]
100%|██████████| 10/10 [00:00<00:00, 29.93it/s]


epoch: 0 | train loss: 0.2820 | test loss: 0.0620 | test acc: 9807/10000 (0.9807) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 86.85it/s]
100%|██████████| 10/10 [00:00<00:00, 55.93it/s]


epoch: 1 | train loss: 0.0912 | test loss: 0.0404 | test acc: 9863/10000 (0.9863) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 85.15it/s]
100%|██████████| 10/10 [00:00<00:00, 54.31it/s]


epoch: 2 | train loss: 0.0705 | test loss: 0.0377 | test acc: 9870/10000 (0.9870) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 86.19it/s]
100%|██████████| 10/10 [00:00<00:00, 55.00it/s]


epoch: 3 | train loss: 0.0544 | test loss: 0.0347 | test acc: 9888/10000 (0.9888) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 85.64it/s]
100%|██████████| 10/10 [00:00<00:00, 55.42it/s]


epoch: 4 | train loss: 0.0466 | test loss: 0.0368 | test acc: 9873/10000 (0.9873) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 82.93it/s]
100%|██████████| 10/10 [00:00<00:00, 32.07it/s]


epoch: 5 | train loss: 0.0397 | test loss: 0.0369 | test acc: 9887/10000 (0.9887) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 84.63it/s]
100%|██████████| 10/10 [00:00<00:00, 47.92it/s]


epoch: 6 | train loss: 0.0378 | test loss: 0.0313 | test acc: 9902/10000 (0.9902) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 80.89it/s]
100%|██████████| 10/10 [00:00<00:00, 27.33it/s]


epoch: 7 | train loss: 0.0328 | test loss: 0.0317 | test acc: 9907/10000 (0.9907) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 87.61it/s]
100%|██████████| 10/10 [00:00<00:00, 57.06it/s]


epoch: 8 | train loss: 0.0292 | test loss: 0.0368 | test acc: 9889/10000 (0.9889) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:02<00:00, 88.96it/s]
100%|██████████| 10/10 [00:00<00:00, 56.79it/s]

epoch: 9 | train loss: 0.0279 | test loss: 0.0317 | test acc: 9899/10000 (0.9899) | memory: N/A MB
----------------------------------------------------------------------------------------------------





Then, we train the deep model.

In [114]:
deep_model = DeepNet().to(device)
deep_optimizer = optim.Adam(deep_model.parameters(), lr=lr)

print("> training deep model\n")
for epoch in range(epochs):
  train_losses = train(deep_model, train_loader, deep_optimizer, device)
  test_losses = test(deep_model, test_loader, device)
  print(f"epoch: {epoch} | train loss: {train_losses['avg_loss']:.4f} | test loss: {test_losses['avg_loss']:.4f} | test acc: {test_losses['total_correct']}/{test_losses['num_datapoints']} ({test_losses['accuracy']:.4f}) | memory: {train_losses['memory']} MB")
  print('-' * 100)

> training deep model



100%|██████████| 234/234 [00:23<00:00, 10.10it/s]
100%|██████████| 10/10 [00:00<00:00, 13.87it/s]


epoch: 0 | train loss: 0.2725 | test loss: 0.0509 | test acc: 9850/10000 (0.9850) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:23<00:00, 10.16it/s]
100%|██████████| 10/10 [00:00<00:00, 11.49it/s]


epoch: 1 | train loss: 0.0580 | test loss: 0.0563 | test acc: 9843/10000 (0.9843) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:23<00:00, 10.11it/s]
100%|██████████| 10/10 [00:00<00:00, 14.57it/s]


epoch: 2 | train loss: 0.0448 | test loss: 0.0505 | test acc: 9867/10000 (0.9867) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:23<00:00, 10.17it/s]
100%|██████████| 10/10 [00:00<00:00, 14.26it/s]


epoch: 3 | train loss: 0.0327 | test loss: 0.0420 | test acc: 9891/10000 (0.9891) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:23<00:00, 10.12it/s]
100%|██████████| 10/10 [00:00<00:00, 13.98it/s]


epoch: 4 | train loss: 0.0331 | test loss: 0.0507 | test acc: 9872/10000 (0.9872) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:22<00:00, 10.20it/s]
100%|██████████| 10/10 [00:00<00:00, 13.82it/s]


epoch: 5 | train loss: 0.0309 | test loss: 0.0476 | test acc: 9883/10000 (0.9883) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:22<00:00, 10.23it/s]
100%|██████████| 10/10 [00:00<00:00, 14.12it/s]


epoch: 6 | train loss: 0.0261 | test loss: 0.0429 | test acc: 9892/10000 (0.9892) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:22<00:00, 10.20it/s]
100%|██████████| 10/10 [00:00<00:00, 14.32it/s]


epoch: 7 | train loss: 0.0250 | test loss: 0.0446 | test acc: 9893/10000 (0.9893) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:22<00:00, 10.20it/s]
100%|██████████| 10/10 [00:00<00:00, 14.36it/s]


epoch: 8 | train loss: 0.0220 | test loss: 0.0707 | test acc: 9860/10000 (0.9860) | memory: N/A MB
----------------------------------------------------------------------------------------------------


100%|██████████| 234/234 [00:22<00:00, 10.22it/s]
100%|██████████| 10/10 [00:00<00:00, 14.20it/s]

epoch: 9 | train loss: 0.0233 | test loss: 0.0696 | test acc: 9866/10000 (0.9866) | memory: N/A MB
----------------------------------------------------------------------------------------------------





## Distributed training

We now move onto distributed training with FSDP. Note, this code currently is setup for training for training on multiple NVIDIA GPUs. Please make sure you have multiple NVIDIA GPUs available on your device (try `nvidia-smi` in the shell) before proceeding.

In [119]:
!nvidia-smi 

zsh:1: command not found: nvidia-smi


First, we have our additional imports for FSDP.

In [105]:
import socket
import os
from functools import partial

import torch.distributed as dist # for distributed communication
from torch.utils.data.distributed import DistributedSampler # for distributed data across GPUs
import torch.multiprocessing as mp # for spawning processes on each GPU
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP, # FSDP constructor for sharding model parameters
)
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy # for FSDP configuration

# assert torch.cuda.is_available(), "CUDA is not available. You must have CUDA enabled to use distributed training."

Some basic vocabulary:

`world_size` is the number of processes you will spawn for distributed training, which should usually be the GPUs you're training on. So if you have 4 GPUs on your node/machine, you want to set `world_size` to 4.

Then, `rank` is the ID of the current process. In order to use distributed rather than local training, we will spawn a separate Python process on each GPU using PyTorch's multiprocessing API. Then, we number processes by the GPU they are spawned on. For example, if you have 4 GPUs, one enumerates their GPUs by [GPU0, GPU1, GPU2, GPU3]. Then, on each of GPU0, GPU1, GPU2, GPU3, one spawns a separate process. These processesare ID'd by their `rank`. So the process on GPU0 is called "rank 0", the process on GPU1 is called "rank 1", etc.

After initializing PyTorch's distributed backend using `torch.distributed.init_process_group()`, one can always compute the rank of the current process via `torch.distributed.get_rank()`. Similarly, one can compute the world size (that is, the number of processes spawned, which should be the number of GPUs available on one's device) via `torch.distributed.get_world_size()`. The world size is the same across ranks. 

Finally, before the program exits, one should call `torch.distributed.destroy_process_group()` in order to cleanup any remaining resources. This is simply the ceremony that PyTorch requires us to go through in order to use distributed training.

Unfortunately, distributed training in PyTorch requires a nontrivial amount of ceremony. These are some functions to initialize PyTorch's distributed backend. 

Notice, we have some additional code to handle if our GPU is an NVIDIA A100. When testing FSDP on Allen HPC, we found we needed to set the environment flag `NCCL_P2P_LVL` to value `NVL` in order for PyTorch's distributed backend to not timeout when using NVIDIA A100's. When using A100's on other machines and cloud GPU providers (such as Lambda Labs), setting the `NCCL_P2P_LVL` environmental variable was not necessary. Additionally, setting this environmental variable was not necessary when using other NVIDIA GPUs (V100, TitanXP, etc.) on Allen HPC.

In [117]:
# helper functions to initialize distributed training

def get_free_addr():
  return socket.gethostbyname_ex(socket.gethostname())[2][0]
    
def get_free_port(addr):
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
    s.bind((addr, 0))
    s.listen(1)
    port = s.getsockname()[1]
  return port

# initialize PyTorch's distributed backend
def setup_distributed():
  addr = get_free_addr()
  port = get_free_port(addr)
  os.environ['MASTER_ADDR'] = str(addr) # address that ranks will use to communicate
  os.environ['MASTER_PORT'] = str(port) # port that ranks will use to communicate
  if 'a100' in torch.cuda.get_device_name().lower(): # needed for training with A100's on Allen HPC
    os.environ['NCCL_P2P_LVL'] = 'NVL'
  dist.init_process_group(backend='nccl', rank=dist.get_rank(), world_size=dist.get_world_size())

# cleanup PyTorch's distributed backend
def cleanup_distributed():
  dist.destroy_process_group()

# we only want to print on the master rank (rank 0) to avoid duplicate logging. This is a convention commonly followed
# when doing distributed training to reduce clutter in the logs.
def is_master():
  return dist.get_rank() == 0

We 

In [None]:
def train_dist(model, train_loader, opt):
  rank = dist.get_rank() # new
  
  total_loss = torch.zeros(1).to(rank) # updated
  num_batches = torch.zeros(1).to(rank) # updated

  model.train()
  bar = tqdm(train_loader, total=len(train_loader)) if is_master() else train_loader # updated: we only want a progress bar on the master rank
  for (data, target) in bar: # updated
    data, target = data.to(rank), target.to(rank) # updated
    opt.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target, reduction='sum')
    memory = torch.cuda.memory_allocated(rank) / 1e6 if torch.cuda.is_available() else 'N/A'
    loss.backward()
    opt.step()
    total_loss += loss
    num_batches += 1
  
  dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) # new
  dist.all_reduce(num_batches, op=dist.ReduceOp.SUM) # new
  avg_loss = total_loss.item() / (num_batches.item() * train_loader.batch_size) # updated
  return {
    'avg_loss': avg_loss,
    'memory': memory
  }

def test_dist(model, test_loader):
  rank = dist.get_rank() # new

  total_loss = torch.zeros(1).to(rank) # updated
  num_batches = torch.zeros(1).to(rank) # updated
  total_correct = torch.zeros(1).to(rank) # updated
  num_datapoints = torch.zeros(1).to(rank) # updated

  model.eval()
  with torch.no_grad():
    bar = tqdm(test_loader, total=len(test_loader)) if is_master() else test_loader
    for (data, target) in bar: # updated: we only want a progress bar on the master rank
      data, target = data.to(rank), target.to(rank)
      output = model(data)
      total_loss += F.nll_loss(output, target, reduction='sum')
      pred = output.argmax(dim=1, keepdim=True)
      total_correct += pred.eq(target.view_as(pred)).sum()
      num_datapoints += len(data)
      num_batches += 1

  dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  dist.all_reduce(num_batches, op=dist.ReduceOp.SUM)
  dist.all_reduce(total_correct, op=dist.ReduceOp.SUM)
  dist.all_reduce(num_datapoints, op=dist.ReduceOp.SUM)
  avg_loss = total_loss.item() / (num_batches.item() * test_loader.batch_size)
  accuracy = total_correct.item() / num_datapoints.item()
  return {
    'avg_loss': avg_loss,
    'total_correct': total_correct.item(),
    'num_datapoints': num_datapoints.item(),
    'accuracy': accuracy
  }


In [106]:
def main(rank, world_size):
  # training config, same as before
  train_batch_size = 256
  test_batch_size = 1000
  lr = 0.001
  epochs = 10
  
  # new: config for FSDP
  min_num_params = 20000
  auto_wrap_policy = partial(size_based_auto_wrap_policy, min_num_params=min_num_params)
  
  
  setup_distributed() # new: initialize PyTorch's distributed backend

  torch.cuda.set_device(rank) # new: set the device to the current rank

  _transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.1307,), (0.3081,))
  ])
  train_data = datasets.MNIST('../data', train=True, download=True, transform=_transform)
  test_data = datasets.MNIST('../data', train=False, transform=_transform)

  train_sampler = DistributedSampler(train_data, num_replicas=world_size, rank=rank) # new: use DistributedSampler to distribute data
                                                                                     #      across all ranks
  test_sampler = DistributedSampler(test_data, num_replicas=world_size, rank=rank)

  loader_kwargs = {
      'num_workers': 2,
      'pin_memory': True,
      'shuffle': False,
      'drop_last': True,
      'persistent_workers': True # warning: on Allen HPC, disabling this massively slows down training
  }

  train_loader_dist = DataLoader(train_data, batch_size=train_batch_size, sampler=train_sampler, **loader_kwargs) # updated: make sure to pass in the sampler
  test_loader_dist = DataLoader(test_data, batch_size=test_batch_size, sampler=test_sampler, **loader_kwargs)

  # First, we train the shallow model

  shallow_model = ShallowNet().to(rank)
  shallow_model = FSDP(shallow_model, auto_wrap_policy=auto_wrap_policy) # new: wrap the model with FSDP. This will shard the model's parameters across ranks during training
  shallow_optimizer = optim.Adam(shallow_model.parameters(), lr=lr) # make sure to construct the optimizer AFTER wrapping the model with FSDP, as we want to update the 
                                                                    # parameters of the sharded model, not  original model
  if is_master():
    print("> training shallow model\n") # updated: only print on the master rank
  for epoch in range(epochs):
    train_sampler.set_epoch(epoch) # new: we must ensure each rank works on a different partition of the same batch of data
    train_losses = train_dist(shallow_model, train_loader_dist, shallow_optimizer)

    test_sampler.set_epoch(epoch) # new: we must ensure each rank works on a different partition of the same batch of data
    test_losses = test_dist(shallow_model, test_loader_dist)
    if is_master(): # updated: only print on the master rank
      print(f"epoch: {epoch} | train loss: {train_losses['avg_loss']:.4f} | test loss: {test_losses['avg_loss']:.4f} | test acc: {test_losses['total_correct']}/{test_losses['num_datapoints']} ({test_losses['accuracy']:.4f}) | memory: {train_losses['memory']} MB")
      print('-' * 100)
    
  # Next, we train the deep model

  deep_model = DeepNet().to(rank)
  deep_model = FSDP(deep_model, auto_wrap_policy=auto_wrap_policy)
  deep_optimizer = optim.Adam(deep_model.parameters(), lr=lr)
  if is_master():
    print("> training deep model\n")
  for epoch in range(epochs):
    train_sampler.set_epoch(epoch)
    train_losses = train_dist(deep_model, train_loader_dist, deep_optimizer)

    test_sampler.set_epoch(epoch)
    test_losses = test_dist(deep_model, test_loader_dist)
    if is_master():
      print(f"epoch: {epoch} | train loss: {train_losses['avg_loss']:.4f} | test loss: {test_losses['avg_loss']:.4f} | test acc: {test_losses['total_correct']}/{test_losses['num_datapoints']} ({test_losses['accuracy']:.4f}) | memory: {train_losses['memory']} MB")
      print('-' * 100)

  cleanup_distributed() # new: cleanup PyTorch's distributed backend to release resources
  

In [None]:
mp.spawn(main, args=(torch.cuda.device_count(),), nprocs=torch.cuda.device_count(), join=True)