Skip to content

Commit

Permalink
Fix initialized weights resetting in Fabric.setup() when using FSDP (
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Apr 11, 2024
1 parent 316cc71 commit dcb91d5
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed issue where some model methods couldn't be monkeypatched after being Fabric wrapped ([#19705](https://github.com/Lightning-AI/pytorch-lightning/pull/19705))

-
- Fixed an issue causing weights to be reset in `Fabric.setup()` when using FSDP ([#19755](https://github.com/Lightning-AI/pytorch-lightning/pull/19755))



## [2.2.1] - 2024-03-04
Expand Down
6 changes: 2 additions & 4 deletions src/lightning/fabric/utilities/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,12 @@ def half(self) -> Self:
def _update_properties(
root: torch.nn.Module, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
) -> None:
def apply_fn(module: Union[_DeviceDtypeModuleMixin, Module]) -> None:
for module in root.modules():
if not isinstance(module, _DeviceDtypeModuleMixin):
return
continue
# cannot use `module.to()` because we don't actually want to move the model in case there are multiple
# devices types (such as partial meta parameters)
if device is not None:
module._device = device
if dtype is not None:
module._dtype = dtype

root.apply(apply_fn)
19 changes: 19 additions & 0 deletions tests/tests_fabric/strategies/test_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,3 +668,22 @@ def test_save_sharded_and_consolidate_and_load(tmp_path):
model, optimizer = fabric.setup(model, optimizer)
state = {"model": model, "optimizer": optimizer, "steps": 1}
fabric.load(checkpoint_path_full, state)


@RunIf(min_cuda_gpus=2, standalone=True)
def test_no_call_to_apply(monkeypatch):
"""Regression test to ensure we're not calling `FSDP.apply()` indirectly (see #19755)."""
monkeypatch.setattr(torch.distributed.fsdp.FullyShardedDataParallel, "apply", Mock())

fabric = Fabric(
accelerator="cuda",
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
devices=2,
)
fabric.launch()

for setup_method in ("setup", "setup_module"):
model = BoringModel()
setup = getattr(fabric, setup_method)
model = setup(model)
model._forward_module.apply.assert_not_called()

0 comments on commit dcb91d5

Please sign in to comment.