Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update strategy flag in docs #10000

Merged
merged 14 commits into from Oct 20, 2021
46 changes: 23 additions & 23 deletions docs/source/advanced/advanced_gpu.rst
Expand Up @@ -71,9 +71,9 @@ To use Sharded Training, you need to first install FairScale using the command b
.. code-block:: python

# train using Sharded DDP
trainer = Trainer(plugins="ddp_sharded")
trainer = Trainer(strategy="ddp_sharded")

Sharded Training can work across all DDP variants by adding the additional ``--plugins ddp_sharded`` flag.
Sharded Training can work across all DDP variants by adding the additional ``--strategy ddp_sharded`` flag.

Internally we re-initialize your optimizers and shard them across your machines and processes. We handle all communication using PyTorch distributed, so no code changes are required.

Expand Down Expand Up @@ -156,7 +156,7 @@ Below is an example of using both ``wrap`` and ``auto_wrap`` to create your mode


model = MyModel()
trainer = Trainer(gpus=4, plugins="fsdp", precision=16)
trainer = Trainer(gpus=4, strategy="fsdp", precision=16)
trainer.fit(model)

trainer.test()
Expand Down Expand Up @@ -248,7 +248,7 @@ It is recommended to skip Stage 1 and use Stage 2, which comes with larger memor
from pytorch_lightning import Trainer

model = MyModel()
trainer = Trainer(gpus=4, plugins="deepspeed_stage_1", precision=16)
trainer = Trainer(gpus=4, strategy="deepspeed_stage_1", precision=16)
trainer.fit(model)


Expand All @@ -265,7 +265,7 @@ As a result, benefits can also be seen on a single GPU. Do note that the default
from pytorch_lightning import Trainer

model = MyModel()
trainer = Trainer(gpus=4, plugins="deepspeed_stage_2", precision=16)
trainer = Trainer(gpus=4, strategy="deepspeed_stage_2", precision=16)
trainer.fit(model)

.. code-block:: bash
Expand All @@ -286,7 +286,7 @@ Below we show an example of running `ZeRO-Offload <https://www.deepspeed.ai/tuto
from pytorch_lightning.plugins import DeepSpeedPlugin

model = MyModel()
trainer = Trainer(gpus=4, plugins="deepspeed_stage_2_offload", precision=16)
trainer = Trainer(gpus=4, strategy="deepspeed_stage_2_offload", precision=16)
trainer.fit(model)


Expand All @@ -307,7 +307,7 @@ You can also modify the ZeRO-Offload parameters via the plugin as below.
model = MyModel()
trainer = Trainer(
gpus=4,
plugins=DeepSpeedPlugin(offload_optimizer=True, allgather_bucket_size=5e8, reduce_bucket_size=5e8),
strategy=DeepSpeedPlugin(offload_optimizer=True, allgather_bucket_size=5e8, reduce_bucket_size=5e8),
precision=16,
)
trainer.fit(model)
Expand Down Expand Up @@ -340,7 +340,7 @@ For even more speed benefit, DeepSpeed offers an optimized CPU version of ADAM c


model = MyModel()
trainer = Trainer(gpus=4, plugins="deepspeed_stage_2_offload", precision=16)
trainer = Trainer(gpus=4, strategy="deepspeed_stage_2_offload", precision=16)
trainer.fit(model)


Expand Down Expand Up @@ -383,7 +383,7 @@ Also please have a look at our :ref:`deepspeed-zero-stage-3-tips` which contains


model = MyModel()
trainer = Trainer(gpus=4, plugins="deepspeed_stage_3", precision=16)
trainer = Trainer(gpus=4, strategy="deepspeed_stage_3", precision=16)
trainer.fit(model)

trainer.test()
Expand All @@ -403,7 +403,7 @@ You can also use the Lightning Trainer to run predict or evaluate with DeepSpeed


model = MyModel()
trainer = Trainer(gpus=4, plugins="deepspeed_stage_3", precision=16)
trainer = Trainer(gpus=4, strategy="deepspeed_stage_3", precision=16)
trainer.test(ckpt_path="my_saved_deepspeed_checkpoint.ckpt")


Expand Down Expand Up @@ -438,7 +438,7 @@ This reduces the time taken to initialize very large models, as well as ensure w


model = MyModel()
trainer = Trainer(gpus=4, plugins="deepspeed_stage_3", precision=16)
trainer = Trainer(gpus=4, strategy="deepspeed_stage_3", precision=16)
trainer.fit(model)

trainer.test()
Expand All @@ -463,14 +463,14 @@ DeepSpeed ZeRO Stage 3 Offloads optimizer state, gradients to the host CPU to re

# Enable CPU Offloading
model = MyModel()
trainer = Trainer(gpus=4, plugins="deepspeed_stage_3_offload", precision=16)
trainer = Trainer(gpus=4, strategy="deepspeed_stage_3_offload", precision=16)
trainer.fit(model)

# Enable CPU Offloading, and offload parameters to CPU
model = MyModel()
trainer = Trainer(
gpus=4,
plugins=DeepSpeedPlugin(
strategy=DeepSpeedPlugin(
stage=3,
offload_optimizer=True,
offload_parameters=True,
Expand All @@ -492,14 +492,14 @@ Additionally, DeepSpeed supports offloading to NVMe drives for even larger model

# Enable CPU Offloading
model = MyModel()
trainer = Trainer(gpus=4, plugins="deepspeed_stage_3_offload", precision=16)
trainer = Trainer(gpus=4, strategy="deepspeed_stage_3_offload", precision=16)
trainer.fit(model)

# Enable CPU Offloading, and offload parameters to CPU
model = MyModel()
trainer = Trainer(
gpus=4,
plugins=DeepSpeedPlugin(
strategy=DeepSpeedPlugin(
stage=3,
offload_optimizer=True,
offload_parameters=True,
Expand Down Expand Up @@ -576,12 +576,12 @@ This saves memory when training larger models, however requires using a checkpoi
model = MyModel()


trainer = Trainer(gpus=4, plugins="deepspeed_stage_3_offload", precision=16)
trainer = Trainer(gpus=4, strategy="deepspeed_stage_3_offload", precision=16)

# Enable CPU Activation Checkpointing
trainer = Trainer(
gpus=4,
plugins=DeepSpeedPlugin(
strategy=DeepSpeedPlugin(
stage=3,
offload_optimizer=True, # Enable CPU Offloading
cpu_checkpointing=True, # (Optional) offload activations to CPU
Expand Down Expand Up @@ -670,7 +670,7 @@ In some cases you may want to define your own DeepSpeed Config, to access all pa
}

model = MyModel()
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(deepspeed_config), precision=16)
trainer = Trainer(gpus=4, strategy=DeepSpeedPlugin(deepspeed_config), precision=16)
trainer.fit(model)


Expand All @@ -682,7 +682,7 @@ We support taking the config as a json formatted file:
from pytorch_lightning.plugins import DeepSpeedPlugin

model = MyModel()
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin("/path/to/deepspeed_config.json"), precision=16)
trainer = Trainer(gpus=4, strategy=DeepSpeedPlugin("/path/to/deepspeed_config.json"), precision=16)
trainer.fit(model)


Expand Down Expand Up @@ -717,7 +717,7 @@ This can reduce peak memory usage and throughput as saved memory will be equal t
from pytorch_lightning.plugins import DDPPlugin

model = MyModel()
trainer = Trainer(gpus=4, plugins=DDPPlugin(gradient_as_bucket_view=True))
trainer = Trainer(gpus=4, strategy=DDPPlugin(gradient_as_bucket_view=True))
trainer.fit(model)

DDP Communication Hooks
Expand All @@ -740,7 +740,7 @@ Enable `FP16 Compress Hook for multi-node throughput improvement <https://pytorc
)

model = MyModel()
trainer = Trainer(gpus=4, plugins=DDPPlugin(ddp_comm_hook=default.fp16_compress_hook))
trainer = Trainer(gpus=4, strategy=DDPPlugin(ddp_comm_hook=default.fp16_compress_hook))
trainer.fit(model)

Enable `PowerSGD for multi-node throughput improvement <https://pytorch.org/docs/stable/ddp_comm_hooks.html#powersgd-communication-hook>`__:
Expand All @@ -758,7 +758,7 @@ Enable `PowerSGD for multi-node throughput improvement <https://pytorch.org/docs
model = MyModel()
trainer = Trainer(
gpus=4,
plugins=DDPPlugin(
strategy=DDPPlugin(
ddp_comm_state=powerSGD.PowerSGDState(
process_group=None,
matrix_approximation_rank=1,
Expand Down Expand Up @@ -787,7 +787,7 @@ Combine hooks for accumulated benefit:
model = MyModel()
trainer = Trainer(
gpus=4,
plugins=DDPPlugin(
strategy=DDPPlugin(
ddp_comm_state=powerSGD.PowerSGDState(
process_group=None,
matrix_approximation_rank=1,
Expand Down
10 changes: 5 additions & 5 deletions docs/source/advanced/ipu.rst
Expand Up @@ -83,7 +83,7 @@ IPUs provide further optimizations to speed up training. By using the ``IPUPlugi
from pytorch_lightning.plugins import IPUPlugin

model = MyLightningModule()
trainer = pl.Trainer(ipus=8, plugins=IPUPlugin(device_iterations=32))
trainer = pl.Trainer(ipus=8, strategy=IPUPlugin(device_iterations=32))
carmocca marked this conversation as resolved.
Show resolved Hide resolved
trainer.fit(model)

Note that by default we return the last device iteration loss. You can override this by passing in your own ``poptorch.Options`` and setting the AnchorMode as described in the `PopTorch documentation <https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/reference.html#poptorch.Options.anchorMode>`__.
Expand All @@ -102,7 +102,7 @@ Note that by default we return the last device iteration loss. You can override
training_opts.anchorMode(poptorch.AnchorMode.All)
training_opts.deviceIterations(32)

trainer = Trainer(ipus=8, plugins=IPUPlugin(inference_opts=inference_opts, training_opts=training_opts))
trainer = Trainer(ipus=8, strategy=IPUPlugin(inference_opts=inference_opts, training_opts=training_opts))
trainer.fit(model)

You can also override all options by passing the ``poptorch.Options`` to the plugin. See `PopTorch options documentation <https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/batching.html>`__ for more information.
Expand All @@ -124,7 +124,7 @@ Lightning supports dumping all reports to a directory to open using the tool.
from pytorch_lightning.plugins import IPUPlugin

model = MyLightningModule()
trainer = pl.Trainer(ipus=8, plugins=IPUPlugin(autoreport_dir="report_dir/"))
trainer = pl.Trainer(ipus=8, strategy=IPUPlugin(autoreport_dir="report_dir/"))
trainer.fit(model)

This will dump all reports to ``report_dir/`` which can then be opened using the Graph Analyser Tool, see `Opening Reports <https://docs.graphcore.ai/projects/graphcore-popvision-user-guide/en/latest/graph/graph.html#opening-reports>`__.
Expand Down Expand Up @@ -174,7 +174,7 @@ Below is an example using the block annotation in a LightningModule.


model = MyLightningModule()
trainer = pl.Trainer(ipus=8, plugins=IPUPlugin(device_iterations=20))
trainer = pl.Trainer(ipus=8, strategy=IPUPlugin(device_iterations=20))
trainer.fit(model)


Expand Down Expand Up @@ -217,7 +217,7 @@ You can also use the block context manager within the forward function, or any o


model = MyLightningModule()
trainer = pl.Trainer(ipus=8, plugins=IPUPlugin(device_iterations=20))
trainer = pl.Trainer(ipus=8, strategy=IPUPlugin(device_iterations=20))
trainer.fit(model)


Expand Down
56 changes: 28 additions & 28 deletions docs/source/advanced/multi_gpu.rst
Expand Up @@ -253,11 +253,11 @@ Distributed modes
-----------------
Lightning allows multiple ways of training

- Data Parallel (``accelerator='dp'``) (multiple-gpus, 1 machine)
- DistributedDataParallel (``accelerator='ddp'``) (multiple-gpus across many machines (python script based)).
- DistributedDataParallel (``accelerator='ddp_spawn'``) (multiple-gpus across many machines (spawn based)).
- DistributedDataParallel 2 (``accelerator='ddp2'``) (DP in a machine, DDP across machines).
- Horovod (``accelerator='horovod'``) (multi-machine, multi-gpu, configured at runtime)
- Data Parallel (``strategy='dp'``) (multiple-gpus, 1 machine)
- DistributedDataParallel (``strategy='ddp'``) (multiple-gpus across many machines (python script based)).
- DistributedDataParallel (``strategy='ddp_spawn'``) (multiple-gpus across many machines (spawn based)).
- DistributedDataParallel 2 (``strategy='ddp2'``) (DP in a machine, DDP across machines).
- Horovod (``strategy='horovod'``) (multi-machine, multi-gpu, configured at runtime)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
- TPUs (``tpu_cores=8|x``) (tpu or TPU pod)

.. note::
Expand Down Expand Up @@ -287,7 +287,7 @@ after which the root node will aggregate the results.
:skipif: torch.cuda.device_count() < 2

# train on 2 GPUs (using DP mode)
trainer = Trainer(gpus=2, accelerator="dp")
trainer = Trainer(gpus=2, strategy="dp")

Distributed Data Parallel
^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -308,10 +308,10 @@ Distributed Data Parallel
.. code-block:: python

# train on 8 GPUs (same machine (ie: node))
trainer = Trainer(gpus=8, accelerator="ddp")
trainer = Trainer(gpus=8, strategy="ddp")

# train on 32 GPUs (4 nodes)
trainer = Trainer(gpus=8, accelerator="ddp", num_nodes=4)
trainer = Trainer(gpus=8, strategy="ddp", num_nodes=4)

This Lightning implementation of DDP calls your script under the hood multiple times with the correct environment
variables:
Expand Down Expand Up @@ -356,7 +356,7 @@ In this case, we can use DDP2 which behaves like DP in a machine and DDP across
.. code-block:: python

# train on 32 GPUs (4 nodes)
trainer = Trainer(gpus=8, accelerator="ddp2", num_nodes=4)
trainer = Trainer(gpus=8, strategy="ddp2", num_nodes=4)

Distributed Data Parallel Spawn
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -374,7 +374,7 @@ project module) you can use the following method:
.. code-block:: python

# train on 8 GPUs (same machine (ie: node))
trainer = Trainer(gpus=8, accelerator="ddp_spawn")
trainer = Trainer(gpus=8, strategy="ddp_spawn")

We STRONGLY discourage this use because it has limitations (due to Python and PyTorch):

Expand Down Expand Up @@ -446,10 +446,10 @@ Horovod can be configured in the training script to run with any number of GPUs
.. code-block:: python

# train Horovod on GPU (number of GPUs / machines provided on command-line)
trainer = Trainer(accelerator="horovod", gpus=1)
trainer = Trainer(strategy="horovod", gpus=1)

# train Horovod on CPU (number of processes / machines provided on command-line)
trainer = Trainer(accelerator="horovod")
trainer = Trainer(strategy="horovod")

When starting the training job, the driver application will then be used to specify the total
number of worker processes:
Expand Down Expand Up @@ -583,11 +583,11 @@ Below are the possible configurations we support.
+-------+---------+----+-----+--------+------------------------------------------------------------+
| Y | | | | Y | `Trainer(gpus=1, precision=16)` |
+-------+---------+----+-----+--------+------------------------------------------------------------+
| | Y | Y | | | `Trainer(gpus=k, accelerator='dp')` |
| | Y | Y | | | `Trainer(gpus=k, strategy='dp')` |
+-------+---------+----+-----+--------+------------------------------------------------------------+
| | Y | | Y | | `Trainer(gpus=k, accelerator='ddp')` |
| | Y | | Y | | `Trainer(gpus=k, strategy='ddp')` |
+-------+---------+----+-----+--------+------------------------------------------------------------+
| | Y | | Y | Y | `Trainer(gpus=k, accelerator='ddp', precision=16)` |
| | Y | | Y | Y | `Trainer(gpus=k, strategy='ddp', precision=16)` |
+-------+---------+----+-----+--------+------------------------------------------------------------+


Expand Down Expand Up @@ -616,29 +616,29 @@ In DDP, DDP_SPAWN, Deepspeed, DDP_SHARDED, or Horovod your effective batch size
.. code-block:: python

# effective batch size = 7 * 8
Trainer(gpus=8, accelerator="ddp")
Trainer(gpus=8, accelerator="ddp_spawn")
Trainer(gpus=8, accelerator="ddp_sharded")
Trainer(gpus=8, accelerator="horovod")
Trainer(gpus=8, strategy="ddp")
Trainer(gpus=8, strategy="ddp_spawn")
Trainer(gpus=8, strategy="ddp_sharded")
Trainer(gpus=8, strategy="horovod")

# effective batch size = 7 * 8 * 10
Trainer(gpus=8, num_nodes=10, accelerator="ddp")
Trainer(gpus=8, num_nodes=10, accelerator="ddp_spawn")
Trainer(gpus=8, num_nodes=10, accelerator="ddp_sharded")
Trainer(gpus=8, num_nodes=10, accelerator="horovod")
Trainer(gpus=8, num_nodes=10, strategy="ddp")
Trainer(gpus=8, num_nodes=10, strategy="ddp_spawn")
Trainer(gpus=8, num_nodes=10, strategy="ddp_sharded")
Trainer(gpus=8, num_nodes=10, strategy="horovod")

In DDP2 or DP, your effective batch size will be 7 * num_nodes.
The reason is that the full batch is visible to all GPUs on the node when using DDP2.

.. code-block:: python

# effective batch size = 7
Trainer(gpus=8, accelerator="ddp2")
Trainer(gpus=8, accelerator="dp")
Trainer(gpus=8, strategy="ddp2")
Trainer(gpus=8, strategy="dp")

# effective batch size = 7 * 10
Trainer(gpus=8, num_nodes=10, accelerator="ddp2")
Trainer(gpus=8, accelerator="dp")
Trainer(gpus=8, num_nodes=10, strategy="ddp2")
Trainer(gpus=8, strategy="dp")


.. note:: Huge batch sizes are actually really bad for convergence. Check out:
Expand All @@ -652,7 +652,7 @@ Lightning supports the use of Torch Distributed Elastic to enable fault-tolerant

.. code-block:: python

Trainer(gpus=8, accelerator="ddp")
Trainer(gpus=8, strategy="ddp")

To launch a fault-tolerant job, run the following on all nodes.

Expand Down
4 changes: 2 additions & 2 deletions docs/source/advanced/tpu.rst
Expand Up @@ -349,14 +349,14 @@ Don't use ``xm.xla_device()`` while working on Lightning + TPUs!

PyTorch XLA only supports Tensor objects for CPU to TPU data transfer. Might cause issues if the User is trying to send some non-tensor objects through the DataLoader or during saving states.

- **Using `tpu_spawn_debug` Plugin**
- **Using `tpu_spawn_debug` Plugin alias**

.. code-block:: python

import pytorch_lightning as pl

my_model = MyLightningModule()
trainer = pl.Trainer(tpu_cores=8, plugins="tpu_spawn_debug")
trainer = pl.Trainer(tpu_cores=8, strategy="tpu_spawn_debug")
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
trainer.fit(my_model)

Example Metrics report:
Expand Down