Skip to content

Commit

Permalink
Simplify enabling CPU offload in FSDP (#15832)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
awaelchli and Borda committed Dec 7, 2022
1 parent 852089e commit 2debd1c
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 25 deletions.
3 changes: 1 addition & 2 deletions docs/source-pytorch/advanced/model_parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -424,10 +424,9 @@ You can customize the strategy configuration by adjusting the arguments of :clas
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=CPUOffload(offload_params=True))
native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=True)
trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", devices=4)
Expand Down
20 changes: 12 additions & 8 deletions src/lightning_lite/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
`this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ for more information.
Arguments:
cpu_offload: CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It
can be enabled via passing in ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently
cpu_offload: Enable offloading parameters and gradients to CPU to save GPU memory at the cost of speed.
You can also pass a config: ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently
implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device
to work with the optimizer. This API is subject to change. Default is ``None`` in which case there
will be no offloading.
to work with the optimizer. This API is subject to change. Default: no offoading
backward_prefetch: This is an experimental feature that is subject to change in the near future. It allows
users to enable two different backward prefetching algorithms to help backward communication and
computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
Expand All @@ -96,7 +95,7 @@ def __init__(
precision: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
cpu_offload: Optional["CPUOffload"] = None,
cpu_offload: Union[bool, "CPUOffload", None] = None,
backward_prefetch: Optional["BackwardPrefetch"] = None,
mixed_precision: Optional["MixedPrecision"] = None,
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
Expand Down Expand Up @@ -125,7 +124,7 @@ def __init__(
[activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing
)

self.cpu_offload = cpu_offload
self.cpu_offload = _init_cpu_offload(cpu_offload)
self.backward_prefetch = backward_prefetch
self.mixed_precision = mixed_precision

Expand Down Expand Up @@ -276,7 +275,6 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
def register_strategies(cls, strategy_registry: Dict) -> None:
if not _TORCH_GREATER_EQUAL_1_12 or not torch.distributed.is_available():
return
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload

strategy_registry.register(
"fsdp",
Expand All @@ -287,7 +285,7 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
"fsdp_full_shard_offload",
cls,
description="Native FSDP with Full Sharding and CPU Offloading",
cpu_offload=CPUOffload(offload_params=True),
cpu_offload=True,
)

def _setup_distributed(self) -> None:
Expand Down Expand Up @@ -341,6 +339,12 @@ def no_backward_sync(self, module: Module) -> Generator:
yield


def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUOffload":
from torch.distributed.fsdp import CPUOffload

return cpu_offload if isinstance(cpu_offload, CPUOffload) else CPUOffload(offload_params=bool(cpu_offload))


def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
from torch.distributed.fsdp import FlatParameter

Expand Down
4 changes: 4 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))



- Added the option to set `DDPFullyShardedNativeStrategy(cpu_offload=True|False)` via bool instead of needing to pass a configufation object ([#15832](https://github.com/Lightning-AI/lightning/pull/15832))


### Changed

- Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347))
Expand Down
24 changes: 12 additions & 12 deletions src/pytorch_lightning/strategies/fully_sharded_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@

import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.strategies.fsdp import _optimizer_has_flat_params, _setup_activation_checkpointing
from lightning_lite.strategies.fsdp import (
_init_cpu_offload,
_optimizer_has_flat_params,
_setup_activation_checkpointing,
)
from lightning_lite.utilities.distributed import (
_get_default_process_group_backend_for_device,
_init_dist_connection,
Expand Down Expand Up @@ -84,14 +88,10 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
`this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ for more information.
Arguments:
cpu_offload:
CPU offloading config. Currently, only parameter and gradient CPU
offload is supported. It can be enabled via passing in
``cpu_offload=CPUOffload(offload_params=True)``. Note that this
currently implicitly enables gradient offloading to CPU in order for
params and grads to be on same device to work with optimizer. This
API is subject to change. Default is ``None`` in which case there
will be no offloading.
cpu_offload: Enable offloading parameters and gradients to CPU to save GPU memory at the cost of speed.
You can also pass a config: ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently
implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device
to work with the optimizer. This API is subject to change. Default: no offoading
backward_prefetch:
This is an experimental feature that is subject to change in the
the near future. It allows users to enable two different backward_prefetch
Expand Down Expand Up @@ -120,7 +120,7 @@ def __init__(
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
process_group_backend: Optional[str] = None,
cpu_offload: Optional[CPUOffload] = None,
cpu_offload: Union[bool, "CPUOffload", None] = None,
backward_prefetch: Optional[BackwardPrefetch] = None,
mixed_precision: Optional[MixedPrecision] = None,
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
Expand All @@ -141,7 +141,7 @@ def __init__(
self._process_group = None
self.num_nodes = 1
self._process_group_backend = process_group_backend
self.cpu_offload = cpu_offload
self.cpu_offload = _init_cpu_offload(cpu_offload)
self.backward_prefetch = backward_prefetch
self.mixed_precision = mixed_precision
self._rank_0_will_call_children_scripts: bool = False
Expand Down Expand Up @@ -403,6 +403,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
"fsdp_native_full_shard_offload",
cls,
description="Native FSDP with Full Sharding and CPU Offloading",
cpu_offload=CPUOffload(offload_params=True),
cpu_offload=True,
)
cls._registered_strategies.append("fsdp_native_full_shard_offload")
17 changes: 15 additions & 2 deletions tests/tests_lite/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12

if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision


@mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_12", False)
Expand All @@ -36,13 +36,26 @@ def test_fsdp_support(*_):


@RunIf(min_torch="1.12")
def test_fsdp_custom_mixed_precision(*_):
def test_fsdp_custom_mixed_precision():
"""Test that passing a custom mixed precision config works."""
config = MixedPrecision()
strategy = FSDPStrategy(mixed_precision=config)
assert strategy.mixed_precision_config == config


@RunIf(min_torch="1.12")
def test_fsdp_cpu_offload():
"""Test the different ways cpu offloading can be enabled."""
# bool
strategy = FSDPStrategy(cpu_offload=True)
assert strategy.cpu_offload == CPUOffload(offload_params=True)

# dataclass
config = CPUOffload()
strategy = FSDPStrategy(cpu_offload=config)
assert strategy.cpu_offload == config


@RunIf(min_torch="1.12")
def test_fsdp_setup_optimizer_validation():
"""Test that `setup_optimizer()` validates the param groups and reference to FSDP parameters."""
Expand Down
15 changes: 14 additions & 1 deletion tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from tests_pytorch.helpers.runif import RunIf

if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import wrap


Expand Down Expand Up @@ -306,3 +306,16 @@ def __init__(self):
) as ckpt_mock:
strategy._setup_model(model)
ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY)


@RunIf(min_torch="1.12")
def test_fully_sharded_native_strategy_cpu_offload():
"""Test the different ways cpu offloading can be enabled."""
# bool
strategy = DDPFullyShardedNativeStrategy(cpu_offload=True)
assert strategy.cpu_offload == CPUOffload(offload_params=True)

# dataclass
config = CPUOffload()
strategy = DDPFullyShardedNativeStrategy(cpu_offload=config)
assert strategy.cpu_offload == config

0 comments on commit 2debd1c

Please sign in to comment.