Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify enabling CPU offload in FSDP #15832

Merged
merged 11 commits into from
Dec 7, 2022
3 changes: 1 addition & 2 deletions docs/source-pytorch/advanced/model_parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -424,10 +424,9 @@ You can customize the strategy configuration by adjusting the arguments of :clas

from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload


native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=CPUOffload(offload_params=True))
native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=True)
trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", device=4)


Expand Down
20 changes: 12 additions & 8 deletions src/lightning_lite/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
`this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ for more information.

Arguments:
cpu_offload: CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It
can be enabled via passing in ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently
cpu_offload: Enable offloading parameters and gradients to CPU to save GPU memory at the cost of speed.
You can also pass a config: ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently
implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device
to work with the optimizer. This API is subject to change. Default is ``None`` in which case there
will be no offloading.
to work with the optimizer. This API is subject to change. Default: no offoading
backward_prefetch: This is an experimental feature that is subject to change in the near future. It allows
users to enable two different backward prefetching algorithms to help backward communication and
computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
Expand All @@ -91,7 +90,7 @@ def __init__(
precision: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
cpu_offload: Optional["CPUOffload"] = None,
cpu_offload: Union[bool, "CPUOffload", None] = None,
backward_prefetch: Optional["BackwardPrefetch"] = None,
mixed_precision: Optional["MixedPrecision"] = None,
**kwargs: Any,
Expand All @@ -112,7 +111,7 @@ def __init__(
self._backward_sync_control = _FSDPBackwardSyncControl()
self._ddp_kwargs = kwargs

self.cpu_offload = cpu_offload
self.cpu_offload = _init_cpu_offload(cpu_offload)
self.backward_prefetch = backward_prefetch
self.mixed_precision = mixed_precision

Expand Down Expand Up @@ -258,7 +257,6 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
def register_strategies(cls, strategy_registry: Dict) -> None:
if not _TORCH_GREATER_EQUAL_1_12 or not torch.distributed.is_available():
return
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

strategy_registry.register(
"fsdp",
Expand All @@ -269,7 +267,7 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
"fsdp_full_shard_offload",
cls,
description="Native FSDP with Full Sharding and CPU Offloading",
cpu_offload=CPUOffload(offload_params=True),
cpu_offload=True,
)

def _setup_distributed(self) -> None:
Expand Down Expand Up @@ -308,6 +306,12 @@ def no_backward_sync(self, module: Module) -> Generator:
yield


def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUOffload":
from torch.distributed.fsdp import CPUOffload

return cpu_offload if isinstance(cpu_offload, CPUOffload) else CPUOffload(offload_params=bool(cpu_offload))


def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
from torch.distributed.fsdp import FlatParameter

Expand Down
4 changes: 4 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814))


- Added the option to set `DDPFullyShardedNativeStrategy(cpu_offload=True|False)` via bool instead of needing to pass a configufation object ([#15832](https://github.com/Lightning-AI/lightning/pull/15832))


### Changed

- Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347))
Expand Down
20 changes: 8 additions & 12 deletions src/pytorch_lightning/strategies/fully_sharded_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.strategies.fsdp import _optimizer_has_flat_params
from lightning_lite.strategies.fsdp import _init_cpu_offload, _optimizer_has_flat_params
from lightning_lite.utilities.distributed import (
_get_default_process_group_backend_for_device,
_init_dist_connection,
Expand Down Expand Up @@ -83,14 +83,10 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
`this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ for more information.

Arguments:
cpu_offload:
CPU offloading config. Currently, only parameter and gradient CPU
offload is supported. It can be enabled via passing in
``cpu_offload=CPUOffload(offload_params=True)``. Note that this
currently implicitly enables gradient offloading to CPU in order for
params and grads to be on same device to work with optimizer. This
API is subject to change. Default is ``None`` in which case there
will be no offloading.
cpu_offload: Enable offloading parameters and gradients to CPU to save GPU memory at the cost of speed.
You can also pass a config: ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently
implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device
to work with the optimizer. This API is subject to change. Default: no offoading
backward_prefetch:
This is an experimental feature that is subject to change in the
the near future. It allows users to enable two different backward_prefetch
Expand All @@ -115,7 +111,7 @@ def __init__(
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
process_group_backend: Optional[str] = None,
cpu_offload: Optional[CPUOffload] = None,
cpu_offload: Union[bool, "CPUOffload", None] = None,
backward_prefetch: Optional[BackwardPrefetch] = None,
mixed_precision: Optional[MixedPrecision] = None,
**kwargs: Any,
Expand All @@ -135,7 +131,7 @@ def __init__(
self._process_group = None
self.num_nodes = 1
self._process_group_backend = process_group_backend
self.cpu_offload = cpu_offload
self.cpu_offload = _init_cpu_offload(cpu_offload)
self.backward_prefetch = backward_prefetch
self.mixed_precision = mixed_precision
self._rank_0_will_call_children_scripts: bool = False
Expand Down Expand Up @@ -386,6 +382,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
"fsdp_native_full_shard_offload",
cls,
description="Native FSDP with Full Sharding and CPU Offloading",
cpu_offload=CPUOffload(offload_params=True),
cpu_offload=True,
)
cls._registered_strategies.append("fsdp_native_full_shard_offload")
17 changes: 15 additions & 2 deletions tests/tests_lite/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12

if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision


@mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_12", False)
Expand All @@ -36,13 +36,26 @@ def test_fsdp_support(*_):


@RunIf(min_torch="1.12")
def test_fsdp_custom_mixed_precision(*_):
def test_fsdp_custom_mixed_precision():
"""Test that passing a custom mixed precision config works."""
config = MixedPrecision()
strategy = FSDPStrategy(mixed_precision=config)
assert strategy.mixed_precision_config == config


@RunIf(min_torch="1.12")
def test_fsdp_cpu_offload():
"""Test the different ways cpu offloading can be enabled."""
# bool
strategy = FSDPStrategy(cpu_offload=True)
assert strategy.cpu_offload == CPUOffload(offload_params=True)

# dataclass
config = CPUOffload()
strategy = FSDPStrategy(cpu_offload=config)
assert strategy.cpu_offload == config


@RunIf(min_torch="1.12")
def test_fsdp_setup_optimizer_validation():
"""Test that `setup_optimizer()` validates the param groups and reference to FSDP parameters."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tests_pytorch.helpers.runif import RunIf

if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import wrap


Expand Down Expand Up @@ -259,3 +259,16 @@ def configure_optimizers(self):
model = NoFlatParametersModel()
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
trainer.fit(model)


@RunIf(min_torch="1.12")
def test_fully_sharded_native_strategy_cpu_offload():
"""Test the different ways cpu offloading can be enabled."""
# bool
strategy = DDPFullyShardedNativeStrategy(cpu_offload=True)
assert strategy.cpu_offload == CPUOffload(offload_params=True)

# dataclass
config = CPUOffload()
strategy = DDPFullyShardedNativeStrategy(cpu_offload=config)
assert strategy.cpu_offload == config