diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index ad98c653007c6..9e36ee65176c8 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -18,6 +18,8 @@ from typing_extensions import override import lightning.pytorch as pl +from lightning.fabric.utilities.types import _Stateful +from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import loops # import as loops to avoid circular imports from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization @@ -152,10 +154,16 @@ def reset(self) -> None: trainer = self.trainer if trainer.num_training_batches != float("inf"): expected_steps = math.ceil(trainer.num_training_batches / trainer.accumulate_grad_batches) - if self.global_step % expected_steps != 0: + loader = trainer.fit_loop._combined_loader + assert loader is not None + is_resumable_loader = all(isinstance(loader, _Stateful) for loader in loader.flattened) + if self.global_step % expected_steps != 0 and not is_resumable_loader: rank_zero_warn( - "You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable" - " results if further training is done. Consider using an end-of-epoch checkpoint" + "You're resuming from a checkpoint that ended before the epoch ended and your dataloader is" + " not resumable. This can cause unreliable results if further training is done." + " Consider using an end-of-epoch checkpoint or make your dataloader resumable by implementing" + " the `state_dict` / `load_state_dict` interface.", + category=PossibleUserWarning, ) else: self.batch_progress.reset_on_run() diff --git a/tests/tests_pytorch/loops/test_training_epoch_loop.py b/tests/tests_pytorch/loops/test_training_epoch_loop.py index 06f27ab322530..16ed3842e3a96 100644 --- a/tests/tests_pytorch/loops/test_training_epoch_loop.py +++ b/tests/tests_pytorch/loops/test_training_epoch_loop.py @@ -15,11 +15,15 @@ from unittest.mock import Mock, patch import pytest +import torch +from lightning.fabric.utilities.warnings import PossibleUserWarning +from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.trainer import Trainer +from lightning_utilities.test.warning import no_warning_call -def test_no_val_on_train_epoch_loop_restart(tmpdir): +def test_no_val_on_train_epoch_loop_restart(tmp_path): """Test that training validation loop doesn't get triggered at the beginning of a restart.""" trainer_kwargs = { "max_epochs": 1, @@ -31,7 +35,7 @@ def test_no_val_on_train_epoch_loop_restart(tmpdir): trainer = Trainer(**trainer_kwargs) model = BoringModel() trainer.fit(model) - ckpt_path = str(tmpdir / "last.ckpt") + ckpt_path = str(tmp_path / "last.ckpt") trainer.save_checkpoint(ckpt_path) trainer_kwargs["max_epochs"] = 2 @@ -157,3 +161,59 @@ def optimizer_step(self, epoch, batch_idx, *args, **kwargs): model = MyModel() trainer.fit(model) assert model.last_batch_idx == 3 + + +def test_resume_mid_epoch_warning(tmp_path): + """Test that resuming from a mid-epoch checkpoint raises a warning unless the dataloader is stateful.""" + + class NotStatefulIterable: + def __init__(self): + self.index = 0 + + def __iter__(self): + for i in range(self.index, len(self)): + yield torch.ones(2, 32) * i + + def __len__(self): + return 3 + + class StatefulIterable(NotStatefulIterable): + def state_dict(self): + return {"index": self.index} + + def load_state_dict(self, state_dict): + self.index = state_dict["index"] + + trainer_kwargs = { + "default_root_dir": tmp_path, + "accelerator": "cpu", + "max_epochs": 1, + "enable_model_summary": False, + "enable_progress_bar": False, + "logger": False, + } + + def train_and_resume(dataloader, resume_step, expected_warning): + # Initial training + checkpoint_dir = tmp_path / "checkpoints" + trainer = Trainer( + **trainer_kwargs, + callbacks=ModelCheckpoint(dirpath=checkpoint_dir, every_n_train_steps=1, save_top_k=-1), + ) + trainer.fit(BoringModel(), dataloader) + + # Resume + trainer = Trainer(**trainer_kwargs, enable_checkpointing=False) + resume_from = checkpoint_dir / f"epoch=0-step={resume_step}.ckpt" + warn_assert = pytest.warns if expected_warning else no_warning_call + with warn_assert(PossibleUserWarning, match="resuming from a checkpoint that ended before"): + trainer.fit(BoringModel(), dataloader, ckpt_path=resume_from) + + # Resume mid-epoch, no stateful dataloader -> warning + train_and_resume(dataloader=NotStatefulIterable(), resume_step=1, expected_warning=True) + + # Resume end-of-epoch, no stateful dataloader -> no warning + train_and_resume(dataloader=NotStatefulIterable(), resume_step=3, expected_warning=False) + + # Resume mid-epoch, stateful dataloader -> no warning + train_and_resume(dataloader=StatefulIterable(), resume_step=1, expected_warning=False)