# Distributed Training with PyTorch Lightning

## Table of Contents

Overview of multi-device backends in Lightning

1. DP: Data-Parallel
2. DDP: Distributed Data-Parallel
3. DDP-spawn: Distributed Data-Parallel Spawn
4. DDP2: DP and DDP Mixed
5. SDP: Sharded Data-Parallel 
6. FSDP: Fully Sharded Data-Parallel

Best practices

## Introduction


This tutorial covers several plugins for multi-GPU and distributed training and will walk you through the differences between these, when to use one over the other, and best practices in writing hardware- and plugin-agnostic code.
We will make use of the following terminology:

- **Accelerator:** The hardware type we are running on. This can be a CPU, GPU, TPU or IPU. This tutorial will focus only on the GPU accelerator because it is compatible with a large selection of different plugins.
- **Plugin:** Also referred to as "training type plugin", "backend" or "distributed backend" sometimes. A plugin determines how model and data are split across multiple devices and defines the communication and synchronization between devices and processes.

**IMPORTANT NOTE:** This notebook is not meant to be executed in full. Some cells will produce an output but many of the backends presented here will NOT run in a Jupyter environment.

Throughout the next sections, we will re-use the following templates for the model and data module:

In [35]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
from torchmetrics import Accuracy
from pytorch_lightning import LightningModule, LightningDataModule, Trainer, seed_everything


class MNISTDataModule(LightningDataModule):

    def __init__(self, data_dir: str = "./", batch_size: int = 16):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))
        ])
        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # only downloads the data once
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)
    
    
class TutorialModule(LightningModule):

    def __init__(
        self,
        hidden_dim: int = 128,
        learning_rate: float = 0.0001,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
        self.val_accuracy = Accuracy(num_classes=10)
        self.test_accuracy = Accuracy(num_classes=10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = self.l2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return {"loss": loss, "y_hat": y_hat.detach()}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        prob = F.softmax(self(x), dim=1)
        pred = torch.argmax(prob, dim=1)
        self.log("val_acc", self.val_accuracy(pred, y), prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        prob = F.softmax(self(x), dim=1)
        pred = torch.argmax(prob, dim=1)
        self.log("test_acc", self.test_accuracy(pred, y))

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

We perform a quick test run to check that the template code works. It should achieve a test accuracy of ~92% after one epoch of training.

In [34]:
seed_everything(1)

model = TutorialModule()
datamodule = MNISTDataModule()
trainer = Trainer(max_epochs=1)

trainer.fit(model, datamodule=datamodule)
trainer.test(model, datamodule=datamodule)

Global seed set to 1
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn(
  rank_zero_deprecation(

  | Name          | Type     | Params
-------------------------------------------
0 | l1            | Linear   | 100 K 
1 | l2            | Linear   | 1.3 K 
2 | val_accuracy  | Accuracy | 0     
3 | test_accuracy | Accuracy | 0     
-------------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 1


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.926800012588501}
--------------------------------------------------------------------------------


[{'test_acc': 0.926800012588501}]

## Overview of Multi-device Plugins in Lightning

Lightning supports a variety of different plugins for multi-GPU/distributed training (we sometimes call these training type plugins, or distributed backend). A plugin determines how model and data are split across multiple devices and defines the communication and synchronization between devices and processes. Below we list all plugins, each with recommendations when to use and when not to use, an example code, and important details to consider for writing device-agnostic and performant code.

### DP: DataParallel

**Use when:** 
- you want to port an existing PyTorch model written with DataParallel and want to maintain 100% parity;
- your optimization needs the full aggregated batch of outputs/losses from all GPUs;
- none of the other backends presented here are suitable due to their hardware and runtime requirements.

**Do not use when**
- you are looking for the most performant multi-GPU code.
- you have custom batch structures that can not be converted to primitive containers like tuples, lists, dicts etc.
- you rely heavily on torchmetrics

**How to activate:**

In [3]:
# data-parallel with 2 GPUs
trainer = Trainer(gpus=2, accelerator="dp")

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Data-Parallel initially moves all model parameters, buffers and data tensors to the root GPU. In Lighting, this is GPU 0. The following steps take place in _every_ training step:
1. The model gets replicated to every device, i.e., parameters and buffers get copied from the root device to all other devices. 
2. The data batch that initially resides on GPU 0 gets split into N sub-batches along dimension 0 (batch dimension). Each GPU receives one of these batch splits and they are passed to the ``training_step`` hook.
3. The output of ``training_step`` in each device will be transferred back to the root device and averaged.

The fact that the module is replicated every forward and backward pass makes this the least efficient plugin for multi-GPU training. An additional caveat is that state changes on the module during ``training_step`` are lost, and this is a common source of bugs. It is also the reason why torchmetrics is not recommended together with this plugin.

**Example 1:**

In [36]:
class DPModule(TutorialModule):
    
    # *_step() happens on the replica of the model (each GPU runs this)
    def validation_step(self, batch, batch_idx):
        x, y = batch
        # total batch size = 16, 2 GPUs -> each GPU sees batch of size 8
        # the last batch may still be smaller, the dataset may not be evenly divisible by the batch size 
        assert x.shape[0] <= 8 
        prob = F.softmax(self(x), dim=1)
        pred = torch.argmax(prob, dim=1)
        return pred, y

    def validation_step_end(self, outputs):
        # torchmetrics do not support sync on the replica
        # all torchmetric computations need to be performed in *_step_end()
        # which happens on the root device
        pred, y = outputs
        self.log("val_acc", self.val_accuracy(pred, y), prog_bar=True)

In [38]:
seed_everything(1)

model = DPModule()

datamodule = MNISTDataModule()
trainer = Trainer(gpus=2, accelerator="dp", max_epochs=1)

trainer.fit(model, datamodule=datamodule)

Global seed set to 1
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name          | Type     | Params
-------------------------------------------
0 | l1            | Linear   | 100 K 
1 | l2            | Linear   | 1.3 K 
2 | val_accuracy  | Accuracy | 0     
3 | test_accuracy | Accuracy | 0     
-------------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
Global seed set to 1
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

**Example 2: Custom reduction**

In [66]:
class DPModule(TutorialModule):
    
    def training_step_end(self, outputs):
        # outputs is a dict
        # it is the result of merging all dicts returned by training_step() on each device
        
        # the loss from each GPU, 2 GPUs are used here
        losses = outputs["loss"]
        assert losses.shape[0] == 2
        
        # each GPU returned 8 predictions
        y_hats = outputs["y_hat"]
        assert y_hats.shape[0] == 2 * 8
        
        probs = F.softmax(y_hats, dim=1)
        preds = torch.argmax(probs, dim=1)
        loss = torch.mean(losses)
        return {"loss": loss, "pred": preds}
    
    def training_epoch_end(self, outputs):
        # we can receive all outputs from all training steps and concatenate them
        all_predictions = torch.cat([out["pred"] for out in outputs])
        print(all_predictions)


In [67]:
model = DPModule()
datamodule = MNISTDataModule()

trainer = Trainer(
    gpus=2, 
    accelerator="dp", 
    max_steps=4,
    weights_summary=None,
    limit_val_batches=0,
)
trainer.fit(model, datamodule=datamodule)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

tensor([3, 2, 2, 2, 3, 5, 2, 7, 7, 5, 9, 2, 2, 1, 7, 2, 3, 2, 2, 2, 3, 5, 2, 7,
        7, 5, 9, 2, 2, 1, 7, 2, 3, 2, 2, 2, 3, 5, 2, 7, 7, 5, 9, 2, 2, 1, 7, 2,
        3, 2, 2, 2, 3, 5, 2, 7, 7, 5, 9, 2, 2, 1, 7, 2], device='cuda:0')
