Skip to content

Commit

Permalink
Add trainer.init_module and trainer.init_tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jul 6, 2023
1 parent f4240ca commit 7ffddf2
Show file tree
Hide file tree
Showing 15 changed files with 248 additions and 94 deletions.
4 changes: 2 additions & 2 deletions docs/source-pytorch/accelerators/gpu_intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,9 @@ DDP can also be used with 1 GPU, but there's no reason to do so other than debug

Implement Your Own Distributed (DDP) training
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
If you need your own way to init PyTorch DDP you can override :meth:`lightning.pytorch.strategies.ddp.DDPStrategy.setup_distributed`.
If you need your own way to init PyTorch DDP you can override :meth:`~lightning.pytorch.strategies.ddp.DDPStrategy.setup_distributed`.

If you also need to use your own DDP implementation, override :meth:`lightning.pytorch.strategies.ddp.DDPStrategy.configure_ddp`.
If you also need to use your own DDP implementation, override :meth:`~lightning.pytorch.strategies.ddp.DDPStrategy.configure_ddp`.

----------

Expand Down
103 changes: 84 additions & 19 deletions docs/source-pytorch/advanced/model_parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,39 @@ For training on unreliable mixed GPUs across the internet check out the :doc:`Hi
----


************************
Efficient initialization
************************

Instantiating a ``nn.Module`` in PyTorch creates all parameters on CPU in float32 precision by default.
To speed up initialization, you can force PyTorch to create the model directly on the target device and with the desired precision without changing your model code.

.. code-block:: python
fabric = Trainer(accelerator="cuda", precision="16-true")
with trainer.init_module():
# models created here will be on GPU and in float16
model = MyModel()
trainer.fit(model)
This eliminates the waiting time to transfer the model parameters from the CPU to the device.
For strategies that handle large sharded models (FSDP, DeepSpeed), the :meth:`~lightning.pytorch.trainer.trainer.Trainer.init_module` method will allocate the model parameters on the meta device first before sharding.
This makes it possible to work with models that are larger than the memory of a single device.

When loading a model from a checkpoint, for example when fine-tuning, set `empty_init=True` to avoid expensive
and redundant memory initialization:

.. code-block:: python
with trainer.init_module(empty_init=True):
# creation of the model is very fast
model = MyModel.load_from_checkpoint("my/checkpoint/path.ckpt")
trainer.fit(model)
.. _fully-sharded-training:

**********************
Expand All @@ -79,7 +112,6 @@ Lightning supports.

.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.


Auto Wrapping
=============

Expand All @@ -92,13 +124,30 @@ have to ``wrap`` layers manually as in the case of manual wrapping.
PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your
``lightning_module.parameters()`` will return a generator with no params.


.. code-block:: python
model = BoringModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp", precision=16)
trainer.fit(model)
You can customize the strategy configuration by adjusting the arguments of :class:`~lightning.pytorch.strategies.FSDPStrategy` and pass that to the ``strategy`` argument inside the ``Trainer``.

.. code-block:: python
from lightning.pytorch import Trainer
from lightning.pytorch.strategies import FSDPStrategy
# equivalent to passing `"fsdp_cpu_offload"`
fsdp = FSDPStrategy(cpu_offload=True)
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)
# configure the wrapping condition
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
import functools
my_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda module: isinstance(module, torch.nn.Linear))
fsdp = FSDPStrategy(auto_wrap_policy=my_policy)
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)
Read more `here <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/#auto-wrapping>`__.

Expand All @@ -107,9 +156,9 @@ Manual Wrapping
===============

Manual wrapping can be useful to explore complex sharding strategies by applying ``wrap`` selectively to some parts of the model. To activate
parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly.
parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` hook to make sure the ``wrap`` parameters are passed correctly.

When not using Fully Sharded these wrap functions are a no-op. This means once the changes have been made, there is no need to remove the changes for other strategies.
When not using Fully Sharded, these ``wrap`` calls are a no-op. This means once the changes have been made, there is no need to remove the changes for other strategies.

``wrap`` simply wraps the module with a Fully Sharded Parallel class with the correct parameters from the Lightning context manager.

Expand Down Expand Up @@ -152,19 +201,10 @@ Here's an example using that uses ``wrap`` to create your model:
trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp", precision=16)
trainer.fit(model)
In this case, Lightning will not re-wrap your model, so you don't need to set ``FSDPStrategy(auto_wrap_policy=...)``.

You can customize the strategy configuration by adjusting the arguments of :class:`~lightning.pytorch.strategies.FSDPStrategy` and pass that to the ``strategy`` argument inside the ``Trainer``.

.. code-block:: python
from lightning.pytorch import Trainer
from lightning.pytorch.strategies import FSDPStrategy
fsdp = FSDPStrategy(cpu_offload=True)
# equivalent to passing `"fsdp_cpu_offload"`
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)
You could achieve the same goal by using the :meth:`~lightning.pytorch.trainer.trainer.Trainer.init_module` context
manager to initialize your model, instead of overriding :meth:`~lightning.pytorch.core.module.configure_sharded_model`

Check out `this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ to learn more about it.

Expand All @@ -184,12 +224,37 @@ Enable checkpointing on large layers (like Transformers) by providing the layer
from lightning.pytorch.strategies import FSDPStrategy
fsdp = FSDPStrategy(
activation_checkpointing=MyTransformerBlock, # or pass a list with multiple types
)
fsdp = FSDPStrategy(activation_checkpointing=MyTransformerBlock) # or pass a list with multiple types
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)
You could also configure activation checkpointing manually inside the ``configure_sharded_model`` hook:

.. code-block:: python
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)
class MyModel(pl.LightningModule):
...
def configure_sharded_model(self):
# Same code as in the "Manual wrapping" snippet above
...
checkpoint_policy = lambda submodule: isinstance(submodule, torch.nn.Linear)
apply_activation_checkpointing(
self.model,
checkpoint_wrapper_fn=checkpoint_wrapper,
check_fn=checkpoint_policy,
)
In this case, Lightning will not re-configure activation checkpointing, so you don't need to set ``FSDPStrategy(activation_checkpointing=...)``.

----


Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ So you can run it like so:
If you want to stop a training run early, you can press "Ctrl + C" on your keyboard.
The trainer will catch the ``KeyboardInterrupt`` and attempt a graceful shutdown. The trainer object will also set
an attribute ``interrupted`` to ``True`` in such cases. If you have a callback which shuts down compute
resources, for example, you can conditionally run the shutdown logic for only uninterrupted runs by overriding :meth:`lightning.pytorch.Callback.on_exception`.
resources, for example, you can conditionally run the shutdown logic for only uninterrupted runs by overriding :meth:`~lightning.pytorch.Callback.on_exception`.

------------

Expand Down
8 changes: 4 additions & 4 deletions examples/fabric/build_your_own_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def train_loop(
limit_batches: Limits the batches during this training epoch.
If greater then the number of batches in the ``train_loader``, this has no effect.
scheduler_cfg: The learning rate scheduler configuration.
Have a look at :meth:`lightning.pytorch.LightninModule.configure_optimizers` for supported values.
Have a look at :meth:`~lightning.pytorch.LightninModule.configure_optimizers` for supported values.
"""
self.fabric.call("on_train_epoch_start")
iterable = self.progbar_wrapper(
Expand Down Expand Up @@ -344,7 +344,7 @@ def step_scheduler(
Args:
model: The LightningModule to train
scheduler_cfg: The learning rate scheduler configuration.
Have a look at :meth:`lightning.pytorch.LightninModule.configure_optimizers` for supported values.
Have a look at :meth:`~lightning.pytorch.LightningModule.configure_optimizers` for supported values.
level: whether we are trying to step on epoch- or step-level
current_value: Holds the current_epoch if ``level==epoch``, else holds the ``global_step``
"""
Expand Down Expand Up @@ -453,11 +453,11 @@ def _parse_optimizers_schedulers(
Optional[L.fabric.utilities.types.Optimizable],
Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]],
]:
"""Recursively parses the output of :meth:`lightning.pytorch.LightningModule.configure_optimizers`.
"""Recursively parses the output of :meth:`~lightning.pytorch.LightningModule.configure_optimizers`.
Args:
configure_optim_output: The output of ``configure_optimizers``.
For supported values, please refer to :meth:`lightning.pytorch.LightningModule.configure_optimizers`.
For supported values, please refer to :meth:`~lightning.pytorch.LightningModule.configure_optimizers`.
"""
_lr_sched_defaults = {"interval": "epoch", "frequency": 1, "monitor": "val_loss"}

Expand Down
48 changes: 32 additions & 16 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,24 +222,27 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel":
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel

if "auto_wrap_policy" in self._fsdp_kwargs and any(
isinstance(mod, FullyShardedDataParallel) for mod in module.modules()
):
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
del self._fsdp_kwargs["auto_wrap_policy"]
wrapped_module = FullyShardedDataParallel(
module=module,
cpu_offload=self.cpu_offload,
mixed_precision=self.mixed_precision_config,
device_id=self.root_device.index,
**self._fsdp_kwargs,
)
if any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()):
# the user wrapped at least one layer in `configure_sharded_model` already
if "auto_wrap_policy" in self._fsdp_kwargs:
rank_zero_warn(
"A FSDP `auto_wrap_policy` is set, but the model is already wrapped. The policy will be ignored."
)
del self._fsdp_kwargs["auto_wrap_policy"]
else:
module = FullyShardedDataParallel(
module=module,
cpu_offload=self.cpu_offload,
mixed_precision=self.mixed_precision_config,
device_id=self.root_device.index,
**self._fsdp_kwargs,
)

# activation checkpointing needs to be set up after wrapping the model
if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing:
_setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing)
if _TORCH_GREATER_EQUAL_1_13:
_setup_activation_checkpointing(module=module, layers=self._activation_checkpointing)

return wrapped_module
return module

def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Set up an optimizer for a model wrapped with FSDP.
Expand Down Expand Up @@ -594,13 +597,26 @@ def _set_world_ranks(self) -> None:
rank_zero_only.rank = self.global_rank


def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: List[Type[Module]]) -> None:
def _setup_activation_checkpointing(module: Module, layers: List[Type[Module]]) -> None:
if not layers:
return

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
CheckpointWrapper,
)

if any(isinstance(mod, CheckpointWrapper) for mod in module.modules()):
if layers:
rank_zero_warn(
f"FSDP checkpointing for the layers {layers} is configured, but the model already contains checkpointed"
" layers. Checkpointing will be ignored."
)
# the module is already wrapped with activation checkpointing, avoid wrapping twice
return

check_fn = lambda submodule: isinstance(submodule, tuple(layers))
wrapper = functools.partial(
checkpoint_wrapper,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
forward_module: The module to wrap the ``forward`` method on.
precision: Reference to the precision plugin for handling precision context
original_module: The original, unmodified module as passed into the
:meth:`lightning.fabric.fabric.Fabric.setup` method. This is needed when attribute lookup
:meth:`~lightning.fabric.fabric.Fabric.setup` method. This is needed when attribute lookup
on this wrapper should pass through to the original module.
"""
super().__init__()
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@
class ModelCheckpoint(Checkpoint):
r"""
Save the model periodically by monitoring a quantity. Every metric logged with
:meth:`~lightning.pytorch.core.module.log` or :meth:`~lightning.pytorch.core.module.log_dict` in
LightningModule is a candidate for the monitor key. For more information, see
:ref:`checkpointing`.
:meth:`~lightning.pytorch.core.module.LightningModule.log` or
:meth:`~lightning.pytorch.core.module.LightningModule.log_dict` is a candidate for the monitor key.
For more information, see :ref:`checkpointing`.
After training finishes, use :attr:`best_model_path` to retrieve the path to the
best checkpoint file and :attr:`best_model_score` to retrieve its score.
Expand Down
11 changes: 8 additions & 3 deletions src/lightning/pytorch/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,17 @@ def on_before_optimizer_step(self, optimizer):
"""

def configure_sharded_model(self) -> None:
"""Hook to create modules in a distributed aware context. This is useful for when using sharded plugins,
where we'd like to shard the model instantly, which is useful for extremely large models which can save
memory and initialization time.
"""Hook to create modules in a strategy and precision aware context. This is particularly useful for when
using sharded plugins, where we'd like to shard the model instantly to save memory and initialization time.
This is recommended also for non-sharded models.
This hook is called during each of fit/val/test/predict stages in the same process, so ensure that
implementation of this hook is idempotent.
.. note ::
This uses the :meth:`~lightning.pytorch.trainer.trainer.Trainer.init_module` context manager. So you don't
need to use this hook if you already instantiated your module under it.
"""


Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/loops/optimization/automatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class Closure(AbstractClosure[ClosureResult]):
do something with the output.
Args:
step_fn: This is typically the :meth:`lightning.pytorch.core.module.LightningModule.training_step
step_fn: This is typically the :meth:`~lightning.pytorch.core.module.LightningModule.training_step
wrapped with processing for its outputs
backward_fn: A function that takes a loss value as input, performs back-propagation and returns the loss value.
Can be set to ``None`` to skip the backward operation.
Expand Down
19 changes: 17 additions & 2 deletions src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
import json
import logging
import os
import platform
from collections import OrderedDict
from contextlib import contextmanager, nullcontext
from pathlib import Path
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union

Expand All @@ -31,6 +31,8 @@
from lightning.fabric.plugins import ClusterEnvironment
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE, _validate_device_index_selection
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13
from lightning.fabric.utilities.init import _EmptyInit
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau
Expand Down Expand Up @@ -498,7 +500,20 @@ def _initialize_deepspeed_train(self, model: Module) -> None:
self.lr_scheduler_configs = [lr_scheduler]
self.model = model

@contextlib.contextmanager
@contextmanager
def module_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]:
if self.zero_stage_3 and empty_init is False:
raise NotImplementedError(
f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled."
)
empty_init = empty_init and not self.zero_stage_3
empty_init_context = (
_EmptyInit(enabled=empty_init) if _TORCH_GREATER_EQUAL_1_13 and not self.zero_stage_3 else nullcontext()
)
with empty_init_context, self.tensor_init_context(), self.model_sharded_context():
yield

@contextmanager
def model_sharded_context(self) -> Generator[None, None, None]:
import deepspeed

Expand Down

0 comments on commit 7ffddf2

Please sign in to comment.