Skip to content

Commit

Permalink
carmocca review
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jul 7, 2023
1 parent a1b80c8 commit 9ad3e57
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,13 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
if optimizer_states is None:
return

if len(self.optimizers) != len(optimizer_states):
raise RuntimeError(
f"You have configured {len(self.optimizers)} optimizers but the checkpoint contains"
f" {len(optimizer_states)} optimizers to load. Please resume training with the same number"
" of optimizers or edit the checkpoint manually to remove states."
)

assert isinstance(self.model, FullyShardedDataParallel)

# rank0_only should be false because we need to load the optimizer state on all ranks
Expand Down
17 changes: 17 additions & 0 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,23 @@ def test_set_timeout(init_process_group_mock):
)


@RunIf(min_torch="1.12")
def test_fsdp_strategy_load_optimizer_states_multiple():
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")])

# More states than optimizers configured
strategy.optimizers = [Mock()]
checkpoint = {"optimizer_states": [Mock(), Mock()]}
with pytest.raises(RuntimeError, match="1 optimizers but the checkpoint contains 2 optimizers to load"):
strategy.load_optimizer_state_dict(checkpoint)

# Fewer states than optimizers configured
strategy.optimizers = [Mock(), Mock()]
checkpoint = {"optimizer_states": [Mock()]}
with pytest.raises(RuntimeError, match="2 optimizers but the checkpoint contains 1 optimizers to load"):
strategy.load_optimizer_state_dict(checkpoint)


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
def test_fsdp_strategy_save_optimizer_states(tmpdir, wrap_min_params):
Expand Down

0 comments on commit 9ad3e57

Please sign in to comment.