Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Trainer.init_module and LightningModule.configure_model #18004

Merged
merged 14 commits into from
Jul 14, 2023
64 changes: 55 additions & 9 deletions docs/source-pytorch/advanced/model_parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,54 @@ For training on unreliable mixed GPUs across the internet check out the :doc:`Hi
----


************************
Efficient initialization
carmocca marked this conversation as resolved.
Show resolved Hide resolved
************************

Instantiating a ``nn.Module`` in PyTorch creates all parameters on CPU in float32 precision by default.
To speed up initialization, you can force PyTorch to create the model directly on the target device and with the desired precision without changing your model code.

.. code-block:: python

fabric = Trainer(accelerator="cuda", precision="16-true")

with trainer.init_module():
# models created here will be on GPU and in float16
model = MyModel()

trainer.fit(model)

This eliminates the waiting time to transfer the model parameters from the CPU to the device.

When loading a model from a checkpoint, for example when fine-tuning, set `empty_init=True` to avoid expensive
and redundant memory initialization:

.. code-block:: python

with trainer.init_module(empty_init=True):
# creation of the model is very fast
model = MyModel.load_from_checkpoint("my/checkpoint/path.ckpt")

trainer.fit(model)

For strategies that handle large sharded models (FSDP, DeepSpeed), the :meth:`~lightning.pytorch.trainer.trainer.Trainer.init_module`
should not be used, instead override the :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` hook:

.. code-block:: python

class MyModel(LightningModule):
def __init__(self):
super().__init__()
# don't instantiate layers here
# move the creation of layers to `configure_model`

def configure_model(self):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
# create all your layers here
self.layers = nn.Sequential(...)

This makes it possible to work with models that are larger than the memory of a single device.


.. _fully-sharded-training:

**********************
Expand All @@ -79,7 +127,6 @@ Lightning supports.

.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.


Auto Wrapping
=============

Expand All @@ -92,7 +139,6 @@ have to ``wrap`` layers manually as in the case of manual wrapping.
PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your
``lightning_module.parameters()`` will return a generator with no params.


.. code-block:: python

model = BoringModel()
Expand All @@ -107,9 +153,9 @@ Manual Wrapping
===============

Manual wrapping can be useful to explore complex sharding strategies by applying ``wrap`` selectively to some parts of the model. To activate
parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly.
parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_model`` hook to make sure the ``wrap`` parameters are passed correctly.

When not using Fully Sharded these wrap functions are a no-op. This means once the changes have been made, there is no need to remove the changes for other strategies.
When not using Fully Sharded, these ``wrap`` calls are a no-op. This means once the changes have been made, there is no need to remove the changes for other strategies.

``wrap`` simply wraps the module with a Fully Sharded Parallel class with the correct parameters from the Lightning context manager.

Expand All @@ -130,7 +176,7 @@ Here's an example using that uses ``wrap`` to create your model:
self.linear_layer = nn.Linear(32, 32)
self.block = nn.Sequential(nn.Linear(32, 32), nn.Linear(32, 32))

def configure_sharded_model(self):
def configure_model(self):
# modules are sharded across processes
# as soon as they are wrapped with `wrap`.
# During the forward/backward passes, weights get synced across processes
Expand Down Expand Up @@ -188,7 +234,7 @@ Enable checkpointing on large layers (like Transformers) by providing the layer
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)


You could also configure activation checkpointing manually inside the ``configure_sharded_model`` hook:
You could also configure activation checkpointing manually inside the ``configure_model`` hook:

.. code-block:: python

Expand All @@ -198,7 +244,7 @@ You could also configure activation checkpointing manually inside the ``configur
class MyModel(pl.LightningModule):
...

def configure_sharded_model(self):
def configure_model(self):
# Same code as in the "Manual wrapping" snippet above
...
apply_activation_checkpointing(self.model)
Expand Down Expand Up @@ -440,7 +486,7 @@ This reduces the time taken to initialize very large models, as well as ensure w
class MyModel(pl.LightningModule):
...

def configure_sharded_model(self):
def configure_model(self):
# Created within sharded model context, modules are instantly sharded across processes
# as soon as they are made.
self.block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
Expand Down Expand Up @@ -576,7 +622,7 @@ This saves memory when training larger models, however requires using a checkpoi
class MyModel(pl.LightningModule):
...

def configure_sharded_model(self):
def configure_model(self):
self.block_1 = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
self.block_2 = torch.nn.Linear(32, 2)

Expand Down
4 changes: 2 additions & 2 deletions docs/source-pytorch/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1327,10 +1327,10 @@ on_validation_epoch_end
.. automethod:: lightning.pytorch.core.module.LightningModule.on_validation_epoch_end
:noindex:

configure_sharded_model
configure_model
~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: lightning.pytorch.core.module.LightningModule.configure_sharded_model
.. automethod:: lightning.pytorch.core.module.LightningModule.configure_model
:noindex:

on_validation_model_eval
Expand Down
8 changes: 4 additions & 4 deletions docs/source-pytorch/integrations/strategies/colossalai.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ See a full example of a benchmark with the a `GPT-2 model <https://github.com/hp
HybridAdam`` now. You can set ``adamw_mode`` to False to use normal Adam. Noticing that ``HybridAdam`` is highly optimized, it uses fused CUDA kernel and parallel CPU kernel.
It is recomended to use ``HybridAdam``, since it updates parameters in GPU and CPU both.

* Your model must be created using the :meth:`~lightning.pytorch.core.module.LightningModule.configure_sharded_model` method.
* Your model must be created using the :meth:`~lightning.pytorch.core.module.LightningModule.configure_model` method.

* ``ColossalaiStrategy`` doesn't support gradient accumulation as of now.

Expand All @@ -65,7 +65,7 @@ See a full example of a benchmark with the a `GPT-2 model <https://github.com/hp
Model Definition
================

ColossalAI requires the layers of your model to be created in the special :meth:`~lightning.pytorch.core.module.LightningModule.configure_sharded_model` hook.
ColossalAI requires the layers of your model to be created in the special :meth:`~lightning.pytorch.core.module.LightningModule.configure_model` hook.
This allows the strategy to efficiently shard your model before materializing the weight tensors.

.. code-block:: python
Expand All @@ -74,9 +74,9 @@ This allows the strategy to efficiently shard your model before materializing th
def __init__(self):
super().__init__()
# don't instantiate layers here
# move the creation of layers to `configure_sharded_model`
# move the creation of layers to `configure_model`

def configure_sharded_model(self):
def configure_model(self):
# create all your layers here
self.layers = nn.Sequential(...)

Expand Down
5 changes: 4 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Handles initialization for FSDP models before wrapping and the Zero stage 3 initialization for DeepSpeed before sharding


- Added supports for empty weight initialization with `Fabric.init_module(empty_init=True)` for efficient sharding and checkpoint loading ([#17627](https://github.com/Lightning-AI/lightning/pull/17627))
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
- Added supports for empty weight initialization with `Fabric.init_module(empty_init=True)` for checkpoint loading ([#17627](https://github.com/Lightning-AI/lightning/pull/17627))


- Added `lightning.fabric.plugins.Precision.init_context()` and `lightning.fabric.strategies.Strategy.module_init_context()` context managers to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))


- `lightning.fabric.strategies.Strategy.tensor_init_context()` context manager to instantiate tensors efficiently directly on device and dtype ([#17607](https://github.com/Lightning-AI/lightning/pull/17607))


- Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334))


Expand Down
14 changes: 14 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `Trainer.print()` to print on local rank zero only ([#17980](https://github.com/Lightning-AI/lightning/pull/17980))


- Added `Trainer.init_module()` context manager to instantiate large models efficiently directly on device, dtype ([#18004](https://github.com/Lightning-AI/lightning/pull/18004))
* Creates the model parameters in the desired dtype (`torch.float32`, `torch.float64`) depending on the 'true' precision choice in `Trainer(precision='32-true'|'64-true')`


- Added the `LightningModule.configure_model()` hook to instantiate large models efficiently directly on device, dtype, and with sharding support ([#18004](https://github.com/Lightning-AI/lightning/pull/18004))
* Handles initialization for FSDP models before wrapping and the Zero stage 3 initialization for DeepSpeed before sharding


- Added `lightning.pytorch.plugins.PrecisionPlugin.init_context()` and `lightning.pytorch.strategies.Strategy.tensor_init_context()` context managers to control model and tensor instantiation ([#18004](https://github.com/Lightning-AI/lightning/pull/18004))


- Automatically call `xla_model.mark_step()` before saving checkpoints with XLA ([#17882](https://github.com/Lightning-AI/lightning/pull/17882))


Expand Down Expand Up @@ -128,6 +139,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the `Strategy.post_training_step` method ([#17531](https://github.com/Lightning-AI/lightning/pull/17531))


- Deprecated the `LightningModule.configure_sharded_model` hook in favor of `LightningModule.configure_model` ([#18004](https://github.com/Lightning-AI/lightning/pull/18004))


### Removed

- Removed the `XLAStrategy.is_distributed` property. It is always True ([#17381](https://github.com/Lightning-AI/lightning/pull/17381))
Expand Down
15 changes: 12 additions & 3 deletions src/lightning/pytorch/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,18 @@ def on_before_optimizer_step(self, optimizer):
"""

def configure_sharded_model(self) -> None:
"""Hook to create modules in a distributed aware context. This is useful for when using sharded plugins,
where we'd like to shard the model instantly, which is useful for extremely large models which can save
memory and initialization time.
"""Deprecated.

Use :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` instead.
"""

def configure_model(self) -> None:
"""Hook to create modules in a strategy and precision aware context.

This is particularly useful for when using sharded strategies (FSDP and DeepSpeed), where we'd like to shard
the model instantly to save memory and initialization time.
For non-sharded strategies, you can choose to override this hook or to initialize your model under the
:meth:`~lightning.pytorch.trainer.trainer.Trainer.init_module` context manager.

This hook is called during each of fit/val/test/predict stages in the same process, so ensure that
implementation of this hook is idempotent.
Expand Down
11 changes: 11 additions & 0 deletions src/lightning/pytorch/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,17 @@ def connect(

return super().connect(model, optimizers, lr_schedulers)

@contextmanager
def init_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type when initializing module parameters or tensors.

See: :meth:`torch.set_default_dtype`
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float64)
yield
torch.set_default_dtype(default_dtype)

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type.
Expand Down
18 changes: 14 additions & 4 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,9 @@ def __init__(

def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect
# for FSDP module. To overcome this, needs to call sharded_module.clip_grad_norm(clip_val)
# however we rely on LightningModule's configure_sharded_model to wrap FSDP, it would be hard to
# trace back the root FSDP. Now we only support clip by value.
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect with FSDP.
# To overcome this we need to call root_sharded_module.clip_grad_norm(clip_val), but we don't have a reference
# to the root module
raise MisconfigurationException(
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
)
Expand Down Expand Up @@ -75,6 +74,17 @@ def mixed_precision_config(self) -> "TorchMixedPrecision":
buffer_dtype=buffer_dtype,
)

@contextmanager
def init_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type when initializing module parameters or tensors.

See: :meth:`torch.set_default_dtype`
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self.mixed_precision_config.param_dtype)
yield
torch.set_default_dtype(default_dtype)

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""For FSDP, this context manager is a no-op since conversion is already handled internally.
Expand Down
16 changes: 14 additions & 2 deletions src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
import json
import logging
import os
import platform
from collections import OrderedDict
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union

Expand Down Expand Up @@ -498,7 +498,19 @@ def _initialize_deepspeed_train(self, model: Module) -> None:
self.lr_scheduler_configs = [lr_scheduler]
self.model = model

@contextlib.contextmanager
@contextmanager
def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]:
if self.zero_stage_3:
if empty_init is False:
raise NotImplementedError(
f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled."
)
yield
return
with super().tensor_init_context(empty_init=empty_init):
yield

@contextmanager
def model_sharded_context(self) -> Generator[None, None, None]:
import deepspeed

Expand Down
16 changes: 13 additions & 3 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
from contextlib import contextmanager, nullcontext
from datetime import timedelta
from typing import Any, Dict, Generator, List, Mapping, Optional, Type, Union

Expand Down Expand Up @@ -42,6 +42,7 @@
_TORCH_GREATER_EQUAL_1_13,
_TORCH_GREATER_EQUAL_2_0,
)
from lightning.fabric.utilities.init import _EmptyInit
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import ProcessGroup, ReduceOp
Expand Down Expand Up @@ -268,6 +269,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
self.lightning_module._device = self.root_device

if is_overridden("configure_sharded_model", self.lightning_module):
# legacy: we don't skip setup with the `configure_model` alternative
rank_zero_info(
"You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers"
" are already wrapped for sharding and won't wrap the entire model using `FullyShardedDataParallel`."
Expand Down Expand Up @@ -305,8 +307,16 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
def model_to_device(self) -> None:
pass

@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
@contextmanager
def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]:
# TODO: Use the meta device and reset parameters after https://github.com/pytorch/pytorch/issues/90465
# is resolved. For now, the module will get moved to the device in `setup_module`.
empty_init_context = _EmptyInit(enabled=bool(empty_init)) if _TORCH_GREATER_EQUAL_1_13 else nullcontext()
with empty_init_context, self.precision_plugin.init_context():
yield

@contextmanager
def model_sharded_context(self) -> Generator[None, None, None]:
log.debug(f"{self.__class__.__name__}: entered model_sharded_context.")
with enable_wrap(
wrapper_cls=FullyShardedDataParallel,
Expand Down