From bf021fe5a7e1cb68573888942327a65ecbb008d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 19 Feb 2024 20:00:33 +0100 Subject: [PATCH] WIP --- src/lightning/fabric/strategies/fsdp.py | 67 ++++++++++------------ tests/tests_fabric/strategies/test_fsdp.py | 2 +- 2 files changed, 32 insertions(+), 37 deletions(-) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 02888c58b49be..d0f3e96d24163 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -439,6 +439,7 @@ def save_checkpoint( ) if filter is not None and self._state_dict_type == "sharded": # https://github.com/pytorch/pytorch/issues/105379 + # FIXME: revisit support with new APIs raise NotImplementedError( "FSDP doesn't support loading sharded filtered checkpoints, so saving them is disabled." ) @@ -468,9 +469,8 @@ def save_checkpoint( path.unlink() path.mkdir(parents=True, exist_ok=True) - converted_state, metadata = _save_state_dict( - state, module, path, filter, self._state_dict_type, self.world_size - ) + converted_state, metadata = _get_state_dict(state, module, filter, self._state_dict_type, self.world_size) + _distributed_checkpoint_save(converted_state, path) if self.global_rank == 0: torch.save(metadata, path / _METADATA_FILENAME) @@ -478,9 +478,7 @@ def save_checkpoint( if _is_sharded_checkpoint(path): shutil.rmtree(path) - converted_state, metadata = _save_state_dict( - state, module, path, filter, self._state_dict_type, self.world_size - ) + converted_state, metadata = _get_state_dict(state, module, filter, self._state_dict_type, self.world_size) converted_state.update(metadata) if self.global_rank == 0: torch.save(converted_state, path) @@ -540,7 +538,7 @@ def load_checkpoint( module_key, module = list(modules.items())[0] if _is_sharded_checkpoint(path): - _load_state_dict(module, module_key, optimizers, path, "sharded", strict, self.world_size) + _set_state_dict(module, module_key, optimizers, path, "sharded", strict, self.world_size) # Load metadata (anything not a module or optimizer) metadata = torch.load(path / _METADATA_FILENAME) @@ -554,7 +552,7 @@ def load_checkpoint( return metadata if _is_full_checkpoint(path): - checkpoint = _load_state_dict(module, module_key, optimizers, path, "full", strict, self.world_size) + checkpoint = _set_state_dict(module, module_key, optimizers, path, "full", strict, self.world_size) assert checkpoint is not None requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() @@ -862,10 +860,9 @@ def _distributed_checkpoint_save(state_dict: Dict[str, Any], path: Path) -> None save(state_dict, writer) -def _save_state_dict( +def _get_state_dict( state: Dict[str, Any], module: Module, - path: Path, filter: Optional[Dict[str, Callable[[str, Any], bool]]], state_dict_type: Literal["sharded", "full"], world_size: int, @@ -910,12 +907,28 @@ def _save_state_dict( target_dict = metadata _apply_filter(key, filter or {}, converted, target_dict) - _distributed_checkpoint_save(converted_state, path) - return converted_state, metadata -def _load_state_dict( +def _distributed_checkpoint_load(module_state: Dict[str, Any], path: Path) -> None: + if _TORCH_GREATER_EQUAL_2_3: + from torch.distributed.checkpoint import load + + # let torch automatically infer the reader to use. This might also support fsspec paths in the future + # https://github.com/pytorch/pytorch/issues/118036 + load(module_state, checkpoint_id=path) # type: ignore[call-arg] + else: # deprecated + from torch.distributed.checkpoint import FileSystemReader + + if _TORCH_GREATER_EQUAL_2_2: + from torch.distributed.checkpoint import load + else: + from torch.distributed.checkpoint import load_state_dict as load + reader = FileSystemReader(path=path) + load(module_state, reader) + + +def _set_state_dict( module: Module, module_key: str, optimizers: Dict[str, torch.optim.Optimizer], @@ -931,14 +944,14 @@ def _load_state_dict( set_optimizer_state_dict, ) - options = StateDictOptions(full_state_dict=state_dict_type == "full", cpu_offload=False) + options = StateDictOptions(full_state_dict=state_dict_type == "full", cpu_offload=False, strict=strict) module_state = {module_key: module.state_dict()} _distributed_checkpoint_load(module_state, path) - set_model_state_dict(module, module_state, options=options) # type: ignore[arg-type] - for key, optimizer in optimizers.items(): - optimizer_state = {key: optimizer.state_dict()} + set_model_state_dict(module, module_state[module_key], options=options) # type: ignore[arg-type] + for optim_key, optim in optimizers.values(): + optimizer_state = {optim_key: optim_key.state_dict()} _distributed_checkpoint_load(optimizer_state, path) - set_optimizer_state_dict(module, optimizer, optim_state_dict=optimizer_state, options=options) + set_optimizer_state_dict(module, optim, optim_state_dict=optimizer_state[optim_key], options=options) else: if state_dict_type == "sharded": state_dict_ctx = _get_sharded_state_dict_context(module) @@ -998,21 +1011,3 @@ def _load_state_dict( optim.load_state_dict(optim_state_dict) return checkpoint - - -def _distributed_checkpoint_load(module_state: Dict[str, Any], path: Path) -> None: - if _TORCH_GREATER_EQUAL_2_3: - from torch.distributed.checkpoint import load - - # let torch automatically infer the reader to use. This might also support fsspec paths in the future - # https://github.com/pytorch/pytorch/issues/118036 - load(module_state, checkpoint_id=path) # type: ignore[call-arg] - else: # deprecated - from torch.distributed.checkpoint import FileSystemReader - - if _TORCH_GREATER_EQUAL_2_2: - from torch.distributed.checkpoint import load - else: - from torch.distributed.checkpoint import load_state_dict as load - reader = FileSystemReader(path=path) - load(module_state, reader) diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index c37ce6c33d63e..d5ac651c60645 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -244,7 +244,7 @@ def test_fsdp_save_checkpoint_storage_options(tmp_path): @mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x) @mock.patch("lightning.fabric.strategies.fsdp._get_full_state_dict_context") @mock.patch("lightning.fabric.strategies.fsdp._get_sharded_state_dict_context") -@mock.patch("lightning.fabric.strategies.fsdp._save_state_dict", return_value=({}, {})) +@mock.patch("lightning.fabric.strategies.fsdp._get_state_dict", return_value=({}, {})) @mock.patch("lightning.fabric.strategies.fsdp.torch.save") @mock.patch("lightning.fabric.strategies.fsdp.shutil") def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, ____, tmp_path):