Skip to content

Commit

Permalink
Add trainer.init_module and trainer.init_tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jul 6, 2023
1 parent f4240ca commit f0939ea
Show file tree
Hide file tree
Showing 12 changed files with 239 additions and 111 deletions.
103 changes: 84 additions & 19 deletions docs/source-pytorch/advanced/model_parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,39 @@ For training on unreliable mixed GPUs across the internet check out the :doc:`Hi
----


************************
Efficient initialization
************************

Instantiating a ``nn.Module`` in PyTorch creates all parameters on CPU in float32 precision by default.
To speed up initialization, you can force PyTorch to create the model directly on the target device and with the desired precision without changing your model code.

.. code-block:: python
fabric = Trainer(accelerator="cuda", precision="16-true")
with trainer.init_module():
# models created here will be on GPU and in float16
model = MyModel()
trainer.fit(model)
This eliminates the waiting time to transfer the model parameters from the CPU to the device.
For strategies that handle large sharded models (FSDP, DeepSpeed), the :meth:`~lightning.pytorch.trainer.trainer.Trainer.init_module` method will allocate the model parameters on the meta device first before sharding.
This makes it possible to work with models that are larger than the memory of a single device.

When loading a model from a checkpoint, for example when fine-tuning, set `empty_init=True` to avoid expensive
and redundant memory initialization:

.. code-block:: python
with trainer.init_module(empty_init=True):
# creation of the model is very fast
model = MyModel.load_from_checkpoint("my/checkpoint/path.ckpt")
trainer.fit(model)
.. _fully-sharded-training:

**********************
Expand All @@ -79,7 +112,6 @@ Lightning supports.

.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.


Auto Wrapping
=============

Expand All @@ -92,13 +124,30 @@ have to ``wrap`` layers manually as in the case of manual wrapping.
PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your
``lightning_module.parameters()`` will return a generator with no params.


.. code-block:: python
model = BoringModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp", precision=16)
trainer.fit(model)
You can customize the strategy configuration by adjusting the arguments of :class:`~lightning.pytorch.strategies.FSDPStrategy` and pass that to the ``strategy`` argument inside the ``Trainer``.

.. code-block:: python
from lightning.pytorch import Trainer
from lightning.pytorch.strategies import FSDPStrategy
# equivalent to passing `"fsdp_cpu_offload"`
fsdp = FSDPStrategy(cpu_offload=True)
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)
# configure the wrapping condition
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
import functools
my_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda module: isinstance(module, torch.nn.Linear))
fsdp = FSDPStrategy(auto_wrap_policy=my_policy)
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)
Read more `here <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/#auto-wrapping>`__.

Expand All @@ -107,9 +156,9 @@ Manual Wrapping
===============

Manual wrapping can be useful to explore complex sharding strategies by applying ``wrap`` selectively to some parts of the model. To activate
parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly.
parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` hook to make sure the ``wrap`` parameters are passed correctly.

When not using Fully Sharded these wrap functions are a no-op. This means once the changes have been made, there is no need to remove the changes for other strategies.
When not using Fully Sharded, these ``wrap`` calls are a no-op. This means once the changes have been made, there is no need to remove the changes for other strategies.

``wrap`` simply wraps the module with a Fully Sharded Parallel class with the correct parameters from the Lightning context manager.

Expand Down Expand Up @@ -152,19 +201,10 @@ Here's an example using that uses ``wrap`` to create your model:
trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp", precision=16)
trainer.fit(model)
In this case, Lightning will not re-wrap your model, so you don't need to set ``FSDPStrategy(auto_wrap_policy=...)``.

You can customize the strategy configuration by adjusting the arguments of :class:`~lightning.pytorch.strategies.FSDPStrategy` and pass that to the ``strategy`` argument inside the ``Trainer``.

.. code-block:: python
from lightning.pytorch import Trainer
from lightning.pytorch.strategies import FSDPStrategy
fsdp = FSDPStrategy(cpu_offload=True)
# equivalent to passing `"fsdp_cpu_offload"`
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)
You could achieve the same goal by using the :meth:`~lightning.pytorch.trainer.trainer.Trainer.init_module` context
manager to initialize your model, instead of overriding :meth:`~lightning.pytorch.core.module.configure_sharded_model`

Check out `this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ to learn more about it.

Expand All @@ -184,12 +224,37 @@ Enable checkpointing on large layers (like Transformers) by providing the layer
from lightning.pytorch.strategies import FSDPStrategy
fsdp = FSDPStrategy(
activation_checkpointing=MyTransformerBlock, # or pass a list with multiple types
)
fsdp = FSDPStrategy(activation_checkpointing=MyTransformerBlock) # or pass a list with multiple types
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)
You could also configure activation checkpointing manually inside the ``configure_sharded_model`` hook:

.. code-block:: python
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)
class MyModel(pl.LightningModule):
...
def configure_sharded_model(self):
# Same code as in the "Manual wrapping" snippet above
...
checkpoint_policy = lambda submodule: isinstance(submodule, torch.nn.Linear)
apply_activation_checkpointing(
self.model,
checkpoint_wrapper_fn=checkpoint_wrapper,
check_fn=checkpoint_policy,
)
In this case, Lightning will not re-configure activation checkpointing, so you don't need to set ``FSDPStrategy(activation_checkpointing=...)``.

----


Expand Down
2 changes: 1 addition & 1 deletion examples/fabric/build_your_own_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def step_scheduler(
Args:
model: The LightningModule to train
scheduler_cfg: The learning rate scheduler configuration.
Have a look at :meth:`lightning.pytorch.LightninModule.configure_optimizers` for supported values.
Have a look at :meth:`lightning.pytorch.LightningModule.configure_optimizers` for supported values.
level: whether we are trying to step on epoch- or step-level
current_value: Holds the current_epoch if ``level==epoch``, else holds the ``global_step``
"""
Expand Down
43 changes: 28 additions & 15 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,24 +222,27 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel":
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel

if "auto_wrap_policy" in self._fsdp_kwargs and any(
isinstance(mod, FullyShardedDataParallel) for mod in module.modules()
):
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
del self._fsdp_kwargs["auto_wrap_policy"]
wrapped_module = FullyShardedDataParallel(
module=module,
cpu_offload=self.cpu_offload,
mixed_precision=self.mixed_precision_config,
device_id=self.root_device.index,
**self._fsdp_kwargs,
)
if any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()):
# the user wrapped at least one layer in `configure_sharded_model` already
if "auto_wrap_policy" in self._fsdp_kwargs:
rank_zero_warn(
"A FSDP `auto_wrap_policy` is set, but the model is already wrapped. The policy will be ignored."
)
del self._fsdp_kwargs["auto_wrap_policy"]
else:
module = FullyShardedDataParallel(
module=module,
cpu_offload=self.cpu_offload,
mixed_precision=self.mixed_precision_config,
device_id=self.root_device.index,
**self._fsdp_kwargs,
)

# activation checkpointing needs to be set up after wrapping the model
if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing:
_setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing)
_setup_activation_checkpointing(module=module, layers=self._activation_checkpointing)

return wrapped_module
return module

def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Set up an optimizer for a model wrapped with FSDP.
Expand Down Expand Up @@ -594,13 +597,23 @@ def _set_world_ranks(self) -> None:
rank_zero_only.rank = self.global_rank


def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: List[Type[Module]]) -> None:
def _setup_activation_checkpointing(module: Module, layers: List[Type[Module]]) -> None:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
CheckpointWrapper,
)

if any(isinstance(mod, CheckpointWrapper) for mod in module.modules()):
if layers:
rank_zero_warn(
f"FSDP checkpointing for the layers {layers} is configured, but the model already contains checkpointed"
" layers. Checkpointing will be ignored."
)
# the module is already wrapped with activation checkpointing, avoid wrapping twice
return

check_fn = lambda submodule: isinstance(submodule, tuple(layers))
wrapper = functools.partial(
checkpoint_wrapper,
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@
class ModelCheckpoint(Checkpoint):
r"""
Save the model periodically by monitoring a quantity. Every metric logged with
:meth:`~lightning.pytorch.core.module.log` or :meth:`~lightning.pytorch.core.module.log_dict` in
LightningModule is a candidate for the monitor key. For more information, see
:ref:`checkpointing`.
:meth:`~lightning.pytorch.core.module.LightningModule.log` or
:meth:`~lightning.pytorch.core.module.LightningModule.log_dict` is a candidate for the monitor key.
For more information, see :ref:`checkpointing`.
After training finishes, use :attr:`best_model_path` to retrieve the path to the
best checkpoint file and :attr:`best_model_score` to retrieve its score.
Expand Down
11 changes: 8 additions & 3 deletions src/lightning/pytorch/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,17 @@ def on_before_optimizer_step(self, optimizer):
"""

def configure_sharded_model(self) -> None:
"""Hook to create modules in a distributed aware context. This is useful for when using sharded plugins,
where we'd like to shard the model instantly, which is useful for extremely large models which can save
memory and initialization time.
"""Hook to create modules in a strategy and precision aware context. This is particularly useful for when
using sharded plugins, where we'd like to shard the model instantly to save memory and initialization time.
This is recommended also for non-sharded models.
This hook is called during each of fit/val/test/predict stages in the same process, so ensure that
implementation of this hook is idempotent.
.. note ::
This uses the :meth:`~lightning.pytorch.trainer.trainer.Trainer.init_module` context manager. So you don't
need to use this hook if you already instantiated your module under it.
"""


Expand Down
7 changes: 3 additions & 4 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,9 @@ def __init__(

def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect
# for FSDP module. To overcome this, needs to call sharded_module.clip_grad_norm(clip_val)
# however we rely on LightningModule's configure_sharded_model to wrap FSDP, it would be hard to
# trace back the root FSDP. Now we only support clip by value.
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect with FSDP.
# To overcome this we need to call root_sharded_module.clip_grad_norm(clip_val), but we don't have a reference
# to the root module
raise MisconfigurationException(
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
)
Expand Down
19 changes: 17 additions & 2 deletions src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
import json
import logging
import os
import platform
from collections import OrderedDict
from contextlib import contextmanager, nullcontext
from pathlib import Path
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union

Expand All @@ -31,6 +31,8 @@
from lightning.fabric.plugins import ClusterEnvironment
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE, _validate_device_index_selection
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13
from lightning.fabric.utilities.init import _EmptyInit
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau
Expand Down Expand Up @@ -498,7 +500,20 @@ def _initialize_deepspeed_train(self, model: Module) -> None:
self.lr_scheduler_configs = [lr_scheduler]
self.model = model

@contextlib.contextmanager
@contextmanager
def module_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]:
if self.zero_stage_3 and empty_init is False:
raise NotImplementedError(
f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled."
)
empty_init = empty_init and not self.zero_stage_3
empty_init_context = (
_EmptyInit(enabled=empty_init) if _TORCH_GREATER_EQUAL_1_13 and not self.zero_stage_3 else nullcontext()
)
with empty_init_context, self.tensor_init_context(), self.model_sharded_context():
yield

@contextmanager
def model_sharded_context(self) -> Generator[None, None, None]:
import deepspeed

Expand Down
49 changes: 21 additions & 28 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@
from lightning.pytorch.strategies.strategy import TBroadcast
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn

_distributed_available = torch.distributed.is_available()
_fsdp_available = _TORCH_GREATER_EQUAL_1_12 and _distributed_available
Expand Down Expand Up @@ -226,29 +225,29 @@ def _configure_launcher(self) -> None:
def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
"""Wraps the model into a
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
assert self.lightning_module is not None
if "auto_wrap_policy" in self.kwargs and any(
isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules()
):
del self.kwargs["auto_wrap_policy"]

log.debug(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")

wrapped_module = FullyShardedDataParallel(
module=model,
process_group=self.process_group,
cpu_offload=self.cpu_offload,
mixed_precision=self.mixed_precision_config,
device_id=self.root_device.index,
**self.kwargs,
)
if any(isinstance(mod, FullyShardedDataParallel) for mod in model.modules()):
# the user wrapped at least one layer in `configure_sharded_model` already
if "auto_wrap_policy" in self.kwargs:
rank_zero_warn(
"A FSDP `auto_wrap_policy` is set, but the model is already wrapped. The policy will be ignored."
)
del self.kwargs["auto_wrap_policy"]
else:
log.debug(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")
model = FullyShardedDataParallel(
module=model,
process_group=self.process_group,
cpu_offload=self.cpu_offload,
mixed_precision=self.mixed_precision_config,
device_id=self.root_device.index,
**self.kwargs,
)

# activation checkpointing needs to be set up after wrapping the model
if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing:
_setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing)
_setup_activation_checkpointing(module=model, layers=self._activation_checkpointing)

return wrapped_module
return model

def setup(self, trainer: "pl.Trainer") -> None:
assert self.accelerator is not None
Expand All @@ -262,13 +261,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
assert self.lightning_module is not None
self.lightning_module._device = self.root_device

if is_overridden("configure_sharded_model", self.lightning_module):
rank_zero_info(
"You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers"
" are already wrapped for sharding and won't wrap the entire model using `FullyShardedDataParallel`."
)
else:
self.model = self._setup_model(self.model)
self.model = self._setup_model(self.model)
self.barrier()

self.setup_optimizers(trainer)
Expand Down

0 comments on commit f0939ea

Please sign in to comment.