Skip to content

Commit

Permalink
Support PyTorch Lightning's FSDP optimizer states saving and loading (#…
Browse files Browse the repository at this point in the history
…17819)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
4 people committed Jul 7, 2023
1 parent 1b43aac commit 734a325
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 12 deletions.
6 changes: 2 additions & 4 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,9 +659,7 @@ def _get_sharded_state_dict_context(module: "FullyShardedDataParallel") -> _Gene
return state_dict_type_context


def _get_full_state_dict_context(
module: "FullyShardedDataParallel", rank0_only: bool = True
) -> _GeneratorContextManager:
def _get_full_state_dict_context(module: Module, rank0_only: bool = True) -> _GeneratorContextManager:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import FullOptimStateDictConfig, FullStateDictConfig, StateDictType

Expand Down Expand Up @@ -692,7 +690,7 @@ def _no_op() -> None:
@contextmanager
def _apply_optimizers_during_fsdp_backward(
optimizers: Union[Optimizer, Iterable[Optimizer]],
module: torch.nn.Module,
module: Module,
) -> Generator[None, None, None]:
"""Call `Optimizer.step` as gradients become available.
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed incorrect parsing of arguments when augmenting exception messages in DDP ([#17948](https://github.com/Lightning-AI/lightning/pull/17948))


- Fixed the saving and loading of FSDP optimizer states ([#17819](https://github.com/Lightning-AI/lightning/pull/17819))


- Fixed an issue causing the `torch.set_float32_matmul_precision` info message to show multiple times ([#17960](https://github.com/Lightning-AI/lightning/pull/17960))


Expand Down
58 changes: 56 additions & 2 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@
import contextlib
import logging
from datetime import timedelta
from typing import Any, Dict, Generator, List, Optional, Type, Union
from typing import Any, Dict, Generator, List, Mapping, Optional, Type, Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

import lightning.pytorch as pl
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.fsdp import (
_get_full_state_dict_context,
_init_cpu_offload,
_optimizer_has_flat_params,
_setup_activation_checkpointing,
Expand All @@ -43,6 +45,7 @@
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import ProcessGroup, ReduceOp
from lightning.pytorch.core.optimizer import LightningOptimizer
from lightning.pytorch.plugins.precision import PrecisionPlugin
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin
from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
Expand All @@ -61,11 +64,13 @@
FullStateDictConfig,
FullyShardedDataParallel,
MixedPrecision,
OptimStateKeyType,
StateDictType,
)
from torch.distributed.fsdp.wrap import enable_wrap
else:
FullyShardedDataParallel = None # type: ignore[misc,assignment]
OptimStateKeyType = None # type: ignore[misc,assignment]
MixedPrecision = None # type: ignore[misc,assignment]
CPUOffload = None # type: ignore[misc,assignment]

Expand Down Expand Up @@ -223,7 +228,7 @@ def _configure_launcher(self) -> None:
if not self.cluster_environment.creates_processes_externally:
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)

def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
def _setup_model(self, model: Module) -> FullyShardedDataParallel:
"""Wraps the model into a
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
Expand Down Expand Up @@ -396,3 +401,52 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
cpu_offload=True,
)
cls._registered_strategies.append("fsdp_cpu_offload")

def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
optimizer_states = checkpoint.get("optimizer_states")

# If the optimizer states are not present, we don't need to do anything (backward compatibility)
if optimizer_states is None:
return

if len(self.optimizers) != len(optimizer_states):
raise RuntimeError(
f"You have configured {len(self.optimizers)} optimizers but the checkpoint contains"
f" {len(optimizer_states)} optimizers to load. Please resume training with the same number"
" of optimizers or edit the checkpoint manually to remove states."
)

assert isinstance(self.model, FullyShardedDataParallel)

# rank0_only should be false because we need to load the optimizer state on all ranks
with _get_full_state_dict_context(self.model, rank0_only=False):
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
# convert the optimizer state to the format expected by FSDP
opt_state = FullyShardedDataParallel.rekey_optim_state_dict(
opt_state, OptimStateKeyType.PARAM_NAME, self.model
)

opt_state = FullyShardedDataParallel.optim_state_dict_to_load(
optim_state_dict=opt_state,
model=self.model,
optim=optimizer,
)

optimizer.load_state_dict(opt_state)

def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer

assert self.model is not None

with _get_full_state_dict_context(self.model, rank0_only=True):
state_dict = FullyShardedDataParallel.optim_state_dict(self.model, optimizer)

# Store the optimizer state dict in standard format
if self.global_rank == 0:
state_dict = FullyShardedDataParallel.rekey_optim_state_dict(
state_dict, OptimStateKeyType.PARAM_ID, self.model
)

return state_dict
160 changes: 154 additions & 6 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import torch.nn as nn

from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_1_12,
_TORCH_GREATER_EQUAL_2_0,
_TORCH_GREATER_EQUAL_2_1,
)
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
Expand All @@ -35,7 +39,7 @@
class TestFSDPModel(BoringModel):
def __init__(self):
super().__init__()
self.layer: Optional[torch.nn.Module] = None
self.layer: Optional[nn.Module] = None

def _init_model(self) -> None:
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
Expand All @@ -58,7 +62,8 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self._init_model()

def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
# There is some issue with SGD optimizer state in FSDP
return torch.optim.AdamW(self.layer.parameters(), lr=0.1)

def on_train_batch_end(self, *_) -> None:
self._assert_layer_fsdp_instance()
Expand Down Expand Up @@ -100,17 +105,22 @@ def _assert_layer_fsdp_instance(self) -> None:
assert self.layer[layer_num].mixed_precision.buffer_dtype == buffer_dtype


class TestFSDPModelAutoWrapped(BoringModel):
class TestBoringModel(BoringModel):
def __init__(self, wrap_min_params: int = 2):
super().__init__()

self.save_hyperparameters()
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
self.should_be_wrapped = [(32 * 32 + 32) > wrap_min_params, None, (32 * 2 + 2) > wrap_min_params]

def configure_optimizers(self):
parameters = self.parameters() if _TORCH_GREATER_EQUAL_2_0 else self.trainer.model.parameters()
return torch.optim.SGD(parameters, lr=0.1)

# SGD's FSDP optimier state is fixed in https://github.com/pytorch/pytorch/pull/99214
return torch.optim.AdamW(parameters, lr=0.1)


class TestFSDPModelAutoWrapped(TestBoringModel):
def on_train_batch_end(self, *_) -> None:
self._assert_layer_fsdp_instance()

Expand Down Expand Up @@ -295,7 +305,13 @@ def test_fsdp_strategy_full_state_dict(tmpdir, wrap_min_params):

strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params))
trainer = Trainer(
default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy=strategy, precision="16-mixed", max_epochs=1
default_root_dir=tmpdir,
accelerator="gpu",
devices=2,
strategy=strategy,
precision="16-mixed",
max_epochs=1,
barebones=True,
)
trainer.fit(model)

Expand Down Expand Up @@ -479,3 +495,135 @@ def test_set_timeout(init_process_group_mock):
init_process_group_mock.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
)


@RunIf(min_torch="1.12")
def test_fsdp_strategy_load_optimizer_states_multiple():
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")])

# More states than optimizers configured
strategy.optimizers = [Mock()]
checkpoint = {"optimizer_states": [Mock(), Mock()]}
with pytest.raises(RuntimeError, match="1 optimizers but the checkpoint contains 2 optimizers to load"):
strategy.load_optimizer_state_dict(checkpoint)

# Fewer states than optimizers configured
strategy.optimizers = [Mock(), Mock()]
checkpoint = {"optimizer_states": [Mock()]}
with pytest.raises(RuntimeError, match="2 optimizers but the checkpoint contains 1 optimizers to load"):
strategy.load_optimizer_state_dict(checkpoint)


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
def test_fsdp_strategy_save_optimizer_states(tmpdir, wrap_min_params):
"""Test to ensure that the full state dict and optimizer states is saved when using FSDP strategy.
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the model can
be restored to DDP, it means that the optimizer states were saved correctly.
"""
model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params)

strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params))
trainer = Trainer(
default_root_dir=tmpdir,
accelerator="gpu",
devices=2,
strategy=strategy,
precision="16-mixed",
max_epochs=1,
barebones=True,
)

trainer.fit(model)
model_path = os.path.join(tmpdir, "last.ckpt")
model_path = trainer.strategy.broadcast(model_path)
trainer.save_checkpoint(model_path)

model_state_dict = trainer.strategy.lightning_module_state_dict()
optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())

if trainer.global_rank != 0:
assert len(model_state_dict) == 0

if _TORCH_GREATER_EQUAL_2_1:
assert len(optimizer_state_dict) == 0

# restore model to ddp
model = TestBoringModel()
trainer = Trainer(default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy="ddp", max_epochs=1)

# This step will restore the model and optimizer states
trainer.fit(model, ckpt_path=model_path)

# Get the model and optimizer states from the restored ddp model
restored_model_state_dict = trainer.strategy.lightning_module_state_dict()
restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())

if trainer.global_rank == 0:
# assert everything is the same
assert len(model_state_dict) == len(restored_model_state_dict)
assert len(optimizer_state_dict) == len(restored_optimizer_state_dict)

torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0)
torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0)

trainer.strategy.barrier()


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
def test_fsdp_strategy_load_optimizer_states(tmpdir, wrap_min_params):
"""Test to ensure that the full state dict and optimizer states can be load when using FSDP strategy.
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the DDP model
can be restored to FSDP, it means that the optimizer states were restored correctly.
"""

# restore model to ddp
model = TestBoringModel()
trainer = Trainer(default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy="ddp", max_epochs=1)

# This step will restore the model and optimizer states
trainer.fit(model)
model_path = os.path.join(tmpdir, "last.ckpt")
model_path = trainer.strategy.broadcast(model_path)
trainer.save_checkpoint(model_path)

# Get the model and optimizer states from the restored ddp model
model_state_dict = trainer.strategy.lightning_module_state_dict()
optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())

# Build a new FSDP model
model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params)

strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params))
trainer = Trainer(
default_root_dir=tmpdir,
accelerator="gpu",
devices=2,
strategy=strategy,
precision="16-mixed",
max_epochs=1,
barebones=True,
)

trainer.fit(model, ckpt_path=model_path)

restored_model_state_dict = trainer.strategy.lightning_module_state_dict()
restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())

if trainer.global_rank != 0:
assert len(restored_model_state_dict) == 0

if _TORCH_GREATER_EQUAL_2_1:
assert len(restored_optimizer_state_dict) == 0

if trainer.global_rank == 0:
# assert everything is the same
assert len(model_state_dict) == len(restored_model_state_dict)
assert len(optimizer_state_dict) == len(restored_optimizer_state_dict)
torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0)
torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0)

trainer.strategy.barrier()

0 comments on commit 734a325

Please sign in to comment.