diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index ce24f2da54feb..4705af7b21067 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -659,9 +659,7 @@ def _get_sharded_state_dict_context(module: "FullyShardedDataParallel") -> _Gene return state_dict_type_context -def _get_full_state_dict_context( - module: "FullyShardedDataParallel", rank0_only: bool = True -) -> _GeneratorContextManager: +def _get_full_state_dict_context(module: Module, rank0_only: bool = True) -> _GeneratorContextManager: from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.api import FullOptimStateDictConfig, FullStateDictConfig, StateDictType @@ -692,7 +690,7 @@ def _no_op() -> None: @contextmanager def _apply_optimizers_during_fsdp_backward( optimizers: Union[Optimizer, Iterable[Optimizer]], - module: torch.nn.Module, + module: Module, ) -> Generator[None, None, None]: """Call `Optimizer.step` as gradients become available. diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 6d39f88cc429e..e547294d7c980 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -143,6 +143,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed incorrect parsing of arguments when augmenting exception messages in DDP ([#17948](https://github.com/Lightning-AI/lightning/pull/17948)) +- Fixed the saving and loading of FSDP optimizer states ([#17819](https://github.com/Lightning-AI/lightning/pull/17819)) + + - Fixed an issue causing the `torch.set_float32_matmul_precision` info message to show multiple times ([#17960](https://github.com/Lightning-AI/lightning/pull/17960)) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 7864d857e91cc..9274696854533 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -14,17 +14,19 @@ import contextlib import logging from datetime import timedelta -from typing import Any, Dict, Generator, List, Optional, Type, Union +from typing import Any, Dict, Generator, List, Mapping, Optional, Type, Union import torch from torch import Tensor from torch.nn import Module +from torch.optim import Optimizer import lightning.pytorch as pl from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout from lightning.fabric.strategies import _StrategyRegistry from lightning.fabric.strategies.fsdp import ( + _get_full_state_dict_context, _init_cpu_offload, _optimizer_has_flat_params, _setup_activation_checkpointing, @@ -43,6 +45,7 @@ from lightning.fabric.utilities.optimizer import _optimizers_to_device from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import ProcessGroup, ReduceOp +from lightning.pytorch.core.optimizer import LightningOptimizer from lightning.pytorch.plugins.precision import PrecisionPlugin from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher @@ -61,11 +64,13 @@ FullStateDictConfig, FullyShardedDataParallel, MixedPrecision, + OptimStateKeyType, StateDictType, ) from torch.distributed.fsdp.wrap import enable_wrap else: FullyShardedDataParallel = None # type: ignore[misc,assignment] + OptimStateKeyType = None # type: ignore[misc,assignment] MixedPrecision = None # type: ignore[misc,assignment] CPUOffload = None # type: ignore[misc,assignment] @@ -223,7 +228,7 @@ def _configure_launcher(self) -> None: if not self.cluster_environment.creates_processes_externally: self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) - def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel: + def _setup_model(self, model: Module) -> FullyShardedDataParallel: """Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" # If model is already wrapped, we need to avoid sending the `auto_wrap_policy` @@ -396,3 +401,52 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: cpu_offload=True, ) cls._registered_strategies.append("fsdp_cpu_offload") + + def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + optimizer_states = checkpoint.get("optimizer_states") + + # If the optimizer states are not present, we don't need to do anything (backward compatibility) + if optimizer_states is None: + return + + if len(self.optimizers) != len(optimizer_states): + raise RuntimeError( + f"You have configured {len(self.optimizers)} optimizers but the checkpoint contains" + f" {len(optimizer_states)} optimizers to load. Please resume training with the same number" + " of optimizers or edit the checkpoint manually to remove states." + ) + + assert isinstance(self.model, FullyShardedDataParallel) + + # rank0_only should be false because we need to load the optimizer state on all ranks + with _get_full_state_dict_context(self.model, rank0_only=False): + for optimizer, opt_state in zip(self.optimizers, optimizer_states): + # convert the optimizer state to the format expected by FSDP + opt_state = FullyShardedDataParallel.rekey_optim_state_dict( + opt_state, OptimStateKeyType.PARAM_NAME, self.model + ) + + opt_state = FullyShardedDataParallel.optim_state_dict_to_load( + optim_state_dict=opt_state, + model=self.model, + optim=optimizer, + ) + + optimizer.load_state_dict(opt_state) + + def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: + if isinstance(optimizer, LightningOptimizer): + optimizer = optimizer._optimizer + + assert self.model is not None + + with _get_full_state_dict_context(self.model, rank0_only=True): + state_dict = FullyShardedDataParallel.optim_state_dict(self.model, optimizer) + + # Store the optimizer state dict in standard format + if self.global_rank == 0: + state_dict = FullyShardedDataParallel.rekey_optim_state_dict( + state_dict, OptimStateKeyType.PARAM_ID, self.model + ) + + return state_dict diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index d0abe9c61fc08..839b69b150bcc 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -11,7 +11,11 @@ import torch.nn as nn from lightning.fabric.plugins.environments import LightningEnvironment -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0 +from lightning.fabric.utilities.imports import ( + _TORCH_GREATER_EQUAL_1_12, + _TORCH_GREATER_EQUAL_2_0, + _TORCH_GREATER_EQUAL_2_1, +) from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel @@ -35,7 +39,7 @@ class TestFSDPModel(BoringModel): def __init__(self): super().__init__() - self.layer: Optional[torch.nn.Module] = None + self.layer: Optional[nn.Module] = None def _init_model(self) -> None: self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) @@ -58,7 +62,8 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: self._init_model() def configure_optimizers(self): - return torch.optim.SGD(self.layer.parameters(), lr=0.1) + # There is some issue with SGD optimizer state in FSDP + return torch.optim.AdamW(self.layer.parameters(), lr=0.1) def on_train_batch_end(self, *_) -> None: self._assert_layer_fsdp_instance() @@ -100,17 +105,22 @@ def _assert_layer_fsdp_instance(self) -> None: assert self.layer[layer_num].mixed_precision.buffer_dtype == buffer_dtype -class TestFSDPModelAutoWrapped(BoringModel): +class TestBoringModel(BoringModel): def __init__(self, wrap_min_params: int = 2): super().__init__() + self.save_hyperparameters() self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) self.should_be_wrapped = [(32 * 32 + 32) > wrap_min_params, None, (32 * 2 + 2) > wrap_min_params] def configure_optimizers(self): parameters = self.parameters() if _TORCH_GREATER_EQUAL_2_0 else self.trainer.model.parameters() - return torch.optim.SGD(parameters, lr=0.1) + # SGD's FSDP optimier state is fixed in https://github.com/pytorch/pytorch/pull/99214 + return torch.optim.AdamW(parameters, lr=0.1) + + +class TestFSDPModelAutoWrapped(TestBoringModel): def on_train_batch_end(self, *_) -> None: self._assert_layer_fsdp_instance() @@ -295,7 +305,13 @@ def test_fsdp_strategy_full_state_dict(tmpdir, wrap_min_params): strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params)) trainer = Trainer( - default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy=strategy, precision="16-mixed", max_epochs=1 + default_root_dir=tmpdir, + accelerator="gpu", + devices=2, + strategy=strategy, + precision="16-mixed", + max_epochs=1, + barebones=True, ) trainer.fit(model) @@ -479,3 +495,135 @@ def test_set_timeout(init_process_group_mock): init_process_group_mock.assert_called_with( process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta ) + + +@RunIf(min_torch="1.12") +def test_fsdp_strategy_load_optimizer_states_multiple(): + strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")]) + + # More states than optimizers configured + strategy.optimizers = [Mock()] + checkpoint = {"optimizer_states": [Mock(), Mock()]} + with pytest.raises(RuntimeError, match="1 optimizers but the checkpoint contains 2 optimizers to load"): + strategy.load_optimizer_state_dict(checkpoint) + + # Fewer states than optimizers configured + strategy.optimizers = [Mock(), Mock()] + checkpoint = {"optimizer_states": [Mock()]} + with pytest.raises(RuntimeError, match="2 optimizers but the checkpoint contains 1 optimizers to load"): + strategy.load_optimizer_state_dict(checkpoint) + + +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") +@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000]) +def test_fsdp_strategy_save_optimizer_states(tmpdir, wrap_min_params): + """Test to ensure that the full state dict and optimizer states is saved when using FSDP strategy. + + Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the model can + be restored to DDP, it means that the optimizer states were saved correctly. + """ + model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params) + + strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params)) + trainer = Trainer( + default_root_dir=tmpdir, + accelerator="gpu", + devices=2, + strategy=strategy, + precision="16-mixed", + max_epochs=1, + barebones=True, + ) + + trainer.fit(model) + model_path = os.path.join(tmpdir, "last.ckpt") + model_path = trainer.strategy.broadcast(model_path) + trainer.save_checkpoint(model_path) + + model_state_dict = trainer.strategy.lightning_module_state_dict() + optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers()) + + if trainer.global_rank != 0: + assert len(model_state_dict) == 0 + + if _TORCH_GREATER_EQUAL_2_1: + assert len(optimizer_state_dict) == 0 + + # restore model to ddp + model = TestBoringModel() + trainer = Trainer(default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy="ddp", max_epochs=1) + + # This step will restore the model and optimizer states + trainer.fit(model, ckpt_path=model_path) + + # Get the model and optimizer states from the restored ddp model + restored_model_state_dict = trainer.strategy.lightning_module_state_dict() + restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers()) + + if trainer.global_rank == 0: + # assert everything is the same + assert len(model_state_dict) == len(restored_model_state_dict) + assert len(optimizer_state_dict) == len(restored_optimizer_state_dict) + + torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0) + torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0) + + trainer.strategy.barrier() + + +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") +@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000]) +def test_fsdp_strategy_load_optimizer_states(tmpdir, wrap_min_params): + """Test to ensure that the full state dict and optimizer states can be load when using FSDP strategy. + + Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the DDP model + can be restored to FSDP, it means that the optimizer states were restored correctly. + """ + + # restore model to ddp + model = TestBoringModel() + trainer = Trainer(default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy="ddp", max_epochs=1) + + # This step will restore the model and optimizer states + trainer.fit(model) + model_path = os.path.join(tmpdir, "last.ckpt") + model_path = trainer.strategy.broadcast(model_path) + trainer.save_checkpoint(model_path) + + # Get the model and optimizer states from the restored ddp model + model_state_dict = trainer.strategy.lightning_module_state_dict() + optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers()) + + # Build a new FSDP model + model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params) + + strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params)) + trainer = Trainer( + default_root_dir=tmpdir, + accelerator="gpu", + devices=2, + strategy=strategy, + precision="16-mixed", + max_epochs=1, + barebones=True, + ) + + trainer.fit(model, ckpt_path=model_path) + + restored_model_state_dict = trainer.strategy.lightning_module_state_dict() + restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers()) + + if trainer.global_rank != 0: + assert len(restored_model_state_dict) == 0 + + if _TORCH_GREATER_EQUAL_2_1: + assert len(restored_optimizer_state_dict) == 0 + + if trainer.global_rank == 0: + # assert everything is the same + assert len(model_state_dict) == len(restored_model_state_dict) + assert len(optimizer_state_dict) == len(restored_optimizer_state_dict) + torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0) + torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0) + + trainer.strategy.barrier()