Skip to content

Commit

Permalink
Fix trainer.save_checkpoint after trainer.test with FSDP (#18992)
Browse files Browse the repository at this point in the history
(cherry picked from commit 3acea8d)
  • Loading branch information
awaelchli authored and lantiga committed Nov 15, 2023
1 parent deddb0a commit 97ddcb1
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 4 deletions.
1 change: 0 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

-


## [2.1.1] - 2023-11-06

### Fixed
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,11 @@ def setup(self, trainer: "pl.Trainer") -> None:
self.setup_precision_plugin()

def setup_optimizers(self, trainer: "pl.Trainer") -> None:
# If we're setting up for evaluation after fitting, we need to discard the optimizers
# since we're rewrapping the model, otherwise optimizer param references are no longer valid
# and subsequent checkpoint saving can fail
self._reset_optimizers_and_schedulers()

if self.kwargs.get("use_orig_params"):
return super().setup_optimizers(trainer)

Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,11 @@ def on_exception(self, exception: BaseException) -> None:
"""Called when the trainer execution is interrupted by an exception."""
pass

def _reset_optimizers_and_schedulers(self) -> None:
self._optimizers = []
self._lightning_optimizers = []
self.lr_scheduler_configs = []

def __getstate__(self) -> Dict:
# `LightningOptimizer` overrides `self.__class__` so they cannot be pickled
state = dict(vars(self)) # copy
Expand Down
10 changes: 7 additions & 3 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,13 @@ def _assert_layer_fsdp_instance(self) -> None:

def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
trainer.fit(model)
trainer.test(model)

model_path = trainer.strategy.broadcast(model_path)
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path
model_path = Path(model_path if model_path else trainer.checkpoint_callback.last_model_path)

# Save another checkpoint after testing, without optimizer states
trainer.save_checkpoint(model_path.with_name("after-test"))
trainer.save_checkpoint(model_path, weights_only=True)

_assert_save_equality(trainer, model_path, cls=model.__class__)
Expand Down Expand Up @@ -270,13 +274,13 @@ def training_step(self, batch, batch_idx):
trainer.fit(model)


@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True)
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
def test_fsdp_strategy_checkpoint(tmpdir, precision):
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
model = TestFSDPModel()
trainer = Trainer(
default_root_dir=tmpdir, accelerator="gpu", devices=1, strategy="fsdp", precision=precision, max_epochs=1
default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy="fsdp", precision=precision, max_epochs=1
)
_run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))

Expand Down

0 comments on commit 97ddcb1

Please sign in to comment.