Skip to content

Commit

Permalink
Update the DDP optimizations page (Lightning-AI#18344)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Aug 19, 2023
1 parent 2ca1571 commit 2eb6214
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 46 deletions.
102 changes: 58 additions & 44 deletions docs/source-pytorch/advanced/ddp_optimizations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,70 @@

.. _ddp-optimizations:

*****************
#################
DDP Optimizations
*****************
#################

Tune settings specific to DDP training for increased speed and memory efficiency.


----

DDP Static Graph
================

`DDP static graph <https://pytorch.org/blog/pytorch-1.11-released/#stable-ddp-static-graph>`__ assumes that your model
employs the same set of used/unused parameters in every iteration, so that it can deterministically know the flow of
training and apply special optimizations during runtime.
***********************
Gradient as Bucket View
***********************

Enabling ``gradient_as_bucket_view=True`` in the ``DDPStrategy`` will make gradients views point to different offsets of the ``allreduce`` communication buckets.
See :class:`~torch.nn.parallel.DistributedDataParallel` for more information.
This can reduce peak memory usage and throughput as saved memory will be equal to the total gradient memory + removes the need to copy gradients to the ``allreduce`` communication buckets.

.. code-block:: python
import lightning as L
from lightning.pytorch.strategies import DDPStrategy
model = MyModel()
trainer = L.Trainer(devices=4, strategy=DDPStrategy(gradient_as_bucket_view=True))
trainer.fit(model)
.. note::
DDP static graph support requires PyTorch>=1.11.0
When ``gradient_as_bucket_view=True`` you cannot call ``detach_()`` on gradients.


----


****************
DDP Static Graph
****************

`DDP static graph <https://pytorch.org/blog/pytorch-1.11-released/#stable-ddp-static-graph>`__ assumes that your model employs the same set of used/unused parameters in every iteration, so that it can deterministically know the flow of training and apply special optimizations during runtime.

.. code-block:: python
from lightning.pytorch import Trainer
import lightning as L
from lightning.pytorch.strategies import DDPStrategy
trainer = Trainer(devices=4, strategy=DDPStrategy(static_graph=True))
trainer = L.Trainer(devices=4, strategy=DDPStrategy(static_graph=True))
----

When Using DDP on a Multi-node Cluster, Set NCCL Parameters
===========================================================

`NCCL <https://developer.nvidia.com/nccl>`__ is the NVIDIA Collective Communications Library that is used by PyTorch to handle communication across nodes and GPUs. There are reported benefits in terms of speedups when adjusting NCCL parameters as seen in this `issue <https://github.com/Lightning-AI/lightning/issues/7179>`__. In the issue, we see a 30% speed improvement when training the Transformer XLM-RoBERTa and a 15% improvement in training with Detectron2.
********************************************
On a Multi-Node Cluster, Set NCCL Parameters
********************************************

`NCCL <https://developer.nvidia.com/nccl>`__ is the NVIDIA Collective Communications Library that is used by PyTorch to handle communication across nodes and GPUs.
There are reported benefits in terms of speedups when adjusting NCCL parameters as seen in this `issue <https://github.com/Lightning-AI/lightning/issues/7179>`__.
In the issue, we see a 30% speed improvement when training the Transformer XLM-RoBERTa and a 15% improvement in training with Detectron2.
NCCL parameters can be adjusted via environment variables.

.. note::

AWS and GCP already set default values for these on their clusters. This is typically useful for custom cluster setups.
AWS and GCP already set default values for these on their clusters.
This is typically useful for custom cluster setups.

* `NCCL_NSOCKS_PERTHREAD <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nsocks-perthread>`__
* `NCCL_SOCKET_NTHREADS <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-socket-nthreads>`__
Expand All @@ -46,42 +77,25 @@ NCCL parameters can be adjusted via environment variables.
export NCCL_SOCKET_NTHREADS=2
Gradients as Bucket View
========================

Enabling ``gradient_as_bucket_view=True`` in the ``DDPStrategy`` will make gradients views point to different offsets of the ``allreduce`` communication buckets. See :class:`~torch.nn.parallel.DistributedDataParallel` for more information.

This can reduce peak memory usage and throughput as saved memory will be equal to the total gradient memory + removes the need to copy gradients to the ``allreduce`` communication buckets.

.. note::

When ``gradient_as_bucket_view=True`` you cannot call ``detach_()`` on gradients. If hitting such errors, please fix it by referring to the :meth:`~torch.optim.Optimizer.zero_grad` function in ``torch/optim/optimizer.py`` as a solution (`source <https://pytorch.org/docs/master/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel>`__).

.. code-block:: python
from lightning.pytorch import Trainer
from lightning.pytorch.strategies import DDPStrategy
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy=DDPStrategy(gradient_as_bucket_view=True))
trainer.fit(model)
----


***********************
DDP Communication Hooks
=======================

DDP Communication hooks is an interface to control how gradients are communicated across workers, overriding the standard allreduce in DistributedDataParallel. This allows you to enable performance improving communication hooks when using multiple nodes.
***********************

DDP Communication hooks is an interface to control how gradients are communicated across workers, overriding the standard allreduce in :class:`~torch.nn.parallel.DistributedDataParallel`.
This allows you to enable performance improving communication hooks when using multiple nodes.
Enable `FP16 Compress Hook for multi-node throughput improvement <https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook>`__:

.. code-block:: python
from lightning.pytorch import Trainer
import lightning as L
from lightning.pytorch.strategies import DDPStrategy
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy=DDPStrategy(ddp_comm_hook=default.fp16_compress_hook))
trainer = L.Trainer(accelerator="gpu", devices=4, strategy=DDPStrategy(ddp_comm_hook=default.fp16_compress_hook))
trainer.fit(model)
Enable `PowerSGD for multi-node throughput improvement <https://pytorch.org/docs/stable/ddp_comm_hooks.html#powersgd-communication-hook>`__:
Expand All @@ -92,12 +106,12 @@ Enable `PowerSGD for multi-node throughput improvement <https://pytorch.org/docs

.. code-block:: python
from lightning.pytorch import Trainer
import lightning as L
from lightning.pytorch.strategies import DDPStrategy
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD
model = MyModel()
trainer = Trainer(
trainer = L.Trainer(
accelerator="gpu",
devices=4,
strategy=DDPStrategy(
Expand All @@ -116,15 +130,15 @@ Combine hooks for accumulated benefit:

.. code-block:: python
from lightning.pytorch import Trainer
import lightning as L
from lightning.pytorch.strategies import DDPStrategy
from torch.distributed.algorithms.ddp_comm_hooks import (
default_hooks as default,
powerSGD_hook as powerSGD,
)
model = MyModel()
trainer = Trainer(
trainer = L.Trainer(
accelerator="gpu",
devices=4,
strategy=DDPStrategy(
Expand All @@ -144,12 +158,12 @@ When using Post-localSGD, you must also pass ``model_averaging_period`` to allow

.. code-block:: python
from lightning.pytorch import Trainer
import lightning as L
from lightning.pytorch.strategies import DDPStrategy
from torch.distributed.algorithms.ddp_comm_hooks import post_localSGD_hook as post_localSGD
model = MyModel()
trainer = Trainer(
trainer = L.Trainer(
accelerator="gpu",
devices=4,
strategy=DDPStrategy(
Expand Down
4 changes: 2 additions & 2 deletions docs/source-pytorch/advanced/model_init.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

.. _model_init:

************************
########################
Efficient initialization
************************
########################

Instantiating a ``nn.Module`` in PyTorch creates all parameters on CPU in float32 precision by default.
To speed up initialization, you can force PyTorch to create the model directly on the target device and with the desired precision without changing your model code.
Expand Down

0 comments on commit 2eb6214

Please sign in to comment.