Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PyTorch Lightning's FSDP optimizer states saving and loading #17819

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,9 +661,7 @@ def _get_sharded_state_dict_context(module: "FullyShardedDataParallel") -> _Gene
return state_dict_type_context


def _get_full_state_dict_context(
module: "FullyShardedDataParallel", rank0_only: bool = True
) -> _GeneratorContextManager:
def _get_full_state_dict_context(module: Module, rank0_only: bool = True) -> _GeneratorContextManager:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import FullOptimStateDictConfig, FullStateDictConfig, StateDictType

Expand Down Expand Up @@ -694,7 +692,7 @@ def _no_op() -> None:
@contextmanager
def _apply_optimizers_during_fsdp_backward(
optimizers: Union[Optimizer, Iterable[Optimizer]],
module: torch.nn.Module,
module: Module,
) -> Generator[None, None, None]:
"""Call `Optimizer.step` as gradients become available.

Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed delayed creation of experiment metadata and checkpoint/log dir name when using `WandbLogger` ([#17818](https://github.com/Lightning-AI/lightning/pull/17818))


- Fixed the saving and loading of FSDP optimizer states ([#17819](https://github.com/Lightning-AI/lightning/pull/17819))


## [2.0.3] - 2023-06-07

### Changed
Expand Down
43 changes: 41 additions & 2 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@
import contextlib
import logging
from datetime import timedelta
from typing import Any, Dict, Generator, List, Optional, Type, Union
from typing import Any, Dict, Generator, List, Mapping, Optional, Type, Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

import lightning.pytorch as pl
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.fsdp import (
_get_full_state_dict_context,
_init_cpu_offload,
_optimizer_has_flat_params,
_setup_activation_checkpointing,
Expand All @@ -43,6 +45,7 @@
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import ProcessGroup, ReduceOp
from lightning.pytorch.core.optimizer import LightningOptimizer
from lightning.pytorch.plugins.precision import PrecisionPlugin
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin
from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
Expand All @@ -61,11 +64,13 @@
FullStateDictConfig,
FullyShardedDataParallel,
MixedPrecision,
OptimStateKeyType,
StateDictType,
)
from torch.distributed.fsdp.wrap import enable_wrap
else:
FullyShardedDataParallel = None # type: ignore[misc,assignment]
OptimStateKeyType = None # type: ignore[misc,assignment]
MixedPrecision = None # type: ignore[misc,assignment]
CPUOffload = None # type: ignore[misc,assignment]

Expand Down Expand Up @@ -223,7 +228,7 @@ def _configure_launcher(self) -> None:
if not self.cluster_environment.creates_processes_externally:
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)

def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
def _setup_model(self, model: Module) -> FullyShardedDataParallel:
"""Wraps the model into a
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
Expand Down Expand Up @@ -396,3 +401,37 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
cpu_offload=True,
)
cls._registered_strategies.append("fsdp_cpu_offload")

def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
optimizer_states = checkpoint["optimizer_states"]
leng-yue marked this conversation as resolved.
Show resolved Hide resolved

assert self.model is not None
assert isinstance(self.model, FullyShardedDataParallel)
leng-yue marked this conversation as resolved.
Show resolved Hide resolved

# rank0_only should be false because we need to load the optimizer state on all ranks
with _get_full_state_dict_context(self.model, rank0_only=False):
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
# convert the optimizer state to the format expected by FSDP
opt_state = FullyShardedDataParallel.rekey_optim_state_dict(
opt_state, OptimStateKeyType.PARAM_NAME, self.model
)

opt_state = FullyShardedDataParallel.optim_state_dict_to_load(
optim_state_dict=opt_state,
model=self.model,
optim=optimizer,
)

optimizer.load_state_dict(opt_state)

def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer

assert self.model is not None

with _get_full_state_dict_context(self.model, rank0_only=True):
state_dict = FullyShardedDataParallel.optim_state_dict(self.model, optimizer)

# Store the optimizer state dict in standard format
return FullyShardedDataParallel.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, self.model)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
118 changes: 113 additions & 5 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
class TestFSDPModel(BoringModel):
def __init__(self):
super().__init__()
self.layer: Optional[torch.nn.Module] = None
self.layer: Optional[nn.Module] = None

def _init_model(self) -> None:
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
Expand All @@ -58,7 +58,8 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self._init_model()

def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
# There is some issue with SGD optimizer state in FSDP
return torch.optim.AdamW(self.layer.parameters(), lr=0.1)
leng-yue marked this conversation as resolved.
Show resolved Hide resolved

def on_train_batch_end(self, *_) -> None:
self._assert_layer_fsdp_instance()
Expand Down Expand Up @@ -100,17 +101,23 @@ def _assert_layer_fsdp_instance(self) -> None:
assert self.layer[layer_num].mixed_precision.buffer_dtype == buffer_dtype


class TestFSDPModelAutoWrapped(BoringModel):
def __init__(self, wrap_min_params: int = 2):
class TestBoringModel(BoringModel):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, wrap_min_params: int = 2, automatic_optimization: bool = True):
super().__init__()

self.save_hyperparameters()
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
self.should_be_wrapped = [(32 * 32 + 32) > wrap_min_params, None, (32 * 2 + 2) > wrap_min_params]
self.automatic_optimization = automatic_optimization

def configure_optimizers(self):
parameters = self.parameters() if _TORCH_GREATER_EQUAL_2_0 else self.trainer.model.parameters()
return torch.optim.SGD(parameters, lr=0.1)

# SGD's FSDP optimier state is fixed in https://github.com/pytorch/pytorch/pull/99214
return torch.optim.AdamW(parameters, lr=0.1)


class TestFSDPModelAutoWrapped(TestBoringModel):
def on_train_batch_end(self, *_) -> None:
self._assert_layer_fsdp_instance()

Expand Down Expand Up @@ -479,3 +486,104 @@ def test_set_timeout(init_process_group_mock):
init_process_group_mock.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
)


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
def test_fsdp_strategy_save_optimizer_states(tmpdir, wrap_min_params):
"""Test to ensure that the full state dict and optimizer states is saved when using FSDP strategy.

Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the model can
be restored to DDP, it means that the optimizer states were saved correctly.
"""

model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params)

strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params))
trainer = Trainer(
default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy=strategy, precision="16-mixed", max_epochs=1
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
)

trainer.fit(model)
model_path = os.path.join(tmpdir, "last.ckpt")
model_path = trainer.strategy.broadcast(model_path)
trainer.save_checkpoint(model_path)

model_state_dict = trainer.strategy.lightning_module_state_dict()
optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())

if trainer.global_rank != 0:
assert len(model_state_dict) == 0

# restore model to ddp, disable automatic_optimization to avoid optimizer state / model state mismatch
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
model = TestBoringModel(automatic_optimization=False)
trainer = Trainer(
default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy="ddp", precision="16-mixed", max_epochs=1
)

# This step will restore the model and optimizer states
trainer.fit(model, ckpt_path=model_path)

# Get the model and optimizer states from the restored ddp model
restored_model_state_dict = trainer.strategy.lightning_module_state_dict()
restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())

if trainer.global_rank != 0:
return

# assert everything is the same
assert len(model_state_dict) == len(restored_model_state_dict)
assert len(optimizer_state_dict) == len(restored_optimizer_state_dict)

torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0)
torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0)


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
def test_fsdp_strategy_load_optimizer_states(tmpdir, wrap_min_params):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Test to ensure that the full state dict and optimizer states can be load when using FSDP strategy.

Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the DDP model
can be restored to FSDP, it means that the optimizer states were restored correctly.
"""

# restore model to ddp, disable automatic_optimization to avoid optimizer state / model state mismatch
model = TestBoringModel()
trainer = Trainer(
default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy="ddp", precision="16-mixed", max_epochs=1
)

# This step will restore the model and optimizer states
trainer.fit(model)
model_path = os.path.join(tmpdir, "last.ckpt")
model_path = trainer.strategy.broadcast(model_path)
trainer.save_checkpoint(model_path)

# Get the model and optimizer states from the restored ddp model
model_state_dict = trainer.strategy.lightning_module_state_dict()
optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())

# Build a new FSDP model, without automatic_optimization
model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params, automatic_optimization=False)

strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params))
trainer = Trainer(
default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy=strategy, precision="16-mixed", max_epochs=1
)

trainer.fit(model, ckpt_path=model_path)

restored_model_state_dict = trainer.strategy.lightning_module_state_dict()
restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())

if trainer.global_rank != 0:
assert len(restored_model_state_dict) == 0
return

# assert everything is the same
assert len(model_state_dict) == len(restored_model_state_dict)
assert len(optimizer_state_dict) == len(restored_optimizer_state_dict)

torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0)
torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0)
Loading