diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 0285e3ea89d7b..9274696854533 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -409,6 +409,13 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: if optimizer_states is None: return + if len(self.optimizers) != len(optimizer_states): + raise RuntimeError( + f"You have configured {len(self.optimizers)} optimizers but the checkpoint contains" + f" {len(optimizer_states)} optimizers to load. Please resume training with the same number" + " of optimizers or edit the checkpoint manually to remove states." + ) + assert isinstance(self.model, FullyShardedDataParallel) # rank0_only should be false because we need to load the optimizer state on all ranks diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 8e8e049427a81..839b69b150bcc 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -497,6 +497,23 @@ def test_set_timeout(init_process_group_mock): ) +@RunIf(min_torch="1.12") +def test_fsdp_strategy_load_optimizer_states_multiple(): + strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")]) + + # More states than optimizers configured + strategy.optimizers = [Mock()] + checkpoint = {"optimizer_states": [Mock(), Mock()]} + with pytest.raises(RuntimeError, match="1 optimizers but the checkpoint contains 2 optimizers to load"): + strategy.load_optimizer_state_dict(checkpoint) + + # Fewer states than optimizers configured + strategy.optimizers = [Mock(), Mock()] + checkpoint = {"optimizer_states": [Mock()]} + with pytest.raises(RuntimeError, match="2 optimizers but the checkpoint contains 1 optimizers to load"): + strategy.load_optimizer_state_dict(checkpoint) + + @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") @pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000]) def test_fsdp_strategy_save_optimizer_states(tmpdir, wrap_min_params):