Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/lightning/pytorch/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from typing_extensions import override

import lightning.pytorch as pl
from lightning.fabric.utilities.types import _Stateful
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch import loops # import as loops to avoid circular imports
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher
from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization
Expand Down Expand Up @@ -152,10 +154,16 @@ def reset(self) -> None:
trainer = self.trainer
if trainer.num_training_batches != float("inf"):
expected_steps = math.ceil(trainer.num_training_batches / trainer.accumulate_grad_batches)
if self.global_step % expected_steps != 0:
loader = trainer.fit_loop._combined_loader
assert loader is not None
is_resumable_loader = all(isinstance(loader, _Stateful) for loader in loader.flattened)
if self.global_step % expected_steps != 0 and not is_resumable_loader:
rank_zero_warn(
"You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable"
" results if further training is done. Consider using an end-of-epoch checkpoint"
"You're resuming from a checkpoint that ended before the epoch ended and your dataloader is"
" not resumable. This can cause unreliable results if further training is done."
" Consider using an end-of-epoch checkpoint or make your dataloader resumable by implementing"
" the `state_dict` / `load_state_dict` interface.",
category=PossibleUserWarning,
)
else:
self.batch_progress.reset_on_run()
Expand Down
64 changes: 62 additions & 2 deletions tests/tests_pytorch/loops/test_training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
from unittest.mock import Mock, patch

import pytest
import torch
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.trainer.trainer import Trainer
from lightning_utilities.test.warning import no_warning_call


def test_no_val_on_train_epoch_loop_restart(tmpdir):
def test_no_val_on_train_epoch_loop_restart(tmp_path):
"""Test that training validation loop doesn't get triggered at the beginning of a restart."""
trainer_kwargs = {
"max_epochs": 1,
Expand All @@ -31,7 +35,7 @@ def test_no_val_on_train_epoch_loop_restart(tmpdir):
trainer = Trainer(**trainer_kwargs)
model = BoringModel()
trainer.fit(model)
ckpt_path = str(tmpdir / "last.ckpt")
ckpt_path = str(tmp_path / "last.ckpt")
trainer.save_checkpoint(ckpt_path)

trainer_kwargs["max_epochs"] = 2
Expand Down Expand Up @@ -157,3 +161,59 @@ def optimizer_step(self, epoch, batch_idx, *args, **kwargs):
model = MyModel()
trainer.fit(model)
assert model.last_batch_idx == 3


def test_resume_mid_epoch_warning(tmp_path):
"""Test that resuming from a mid-epoch checkpoint raises a warning unless the dataloader is stateful."""

class NotStatefulIterable:
def __init__(self):
self.index = 0

def __iter__(self):
for i in range(self.index, len(self)):
yield torch.ones(2, 32) * i

def __len__(self):
return 3

class StatefulIterable(NotStatefulIterable):
def state_dict(self):
return {"index": self.index}

def load_state_dict(self, state_dict):
self.index = state_dict["index"]

trainer_kwargs = {
"default_root_dir": tmp_path,
"accelerator": "cpu",
"max_epochs": 1,
"enable_model_summary": False,
"enable_progress_bar": False,
"logger": False,
}

def train_and_resume(dataloader, resume_step, expected_warning):
# Initial training
checkpoint_dir = tmp_path / "checkpoints"
trainer = Trainer(
**trainer_kwargs,
callbacks=ModelCheckpoint(dirpath=checkpoint_dir, every_n_train_steps=1, save_top_k=-1),
)
trainer.fit(BoringModel(), dataloader)

# Resume
trainer = Trainer(**trainer_kwargs, enable_checkpointing=False)
resume_from = checkpoint_dir / f"epoch=0-step={resume_step}.ckpt"
warn_assert = pytest.warns if expected_warning else no_warning_call
with warn_assert(PossibleUserWarning, match="resuming from a checkpoint that ended before"):
trainer.fit(BoringModel(), dataloader, ckpt_path=resume_from)

# Resume mid-epoch, no stateful dataloader -> warning
train_and_resume(dataloader=NotStatefulIterable(), resume_step=1, expected_warning=True)

# Resume end-of-epoch, no stateful dataloader -> no warning
train_and_resume(dataloader=NotStatefulIterable(), resume_step=3, expected_warning=False)

# Resume mid-epoch, stateful dataloader -> no warning
train_and_resume(dataloader=StatefulIterable(), resume_step=1, expected_warning=False)