Skip to content

Commit

Permalink
Avoid warning when resuming mid-epoch checkpoint and using stateful d…
Browse files Browse the repository at this point in the history
…ataloader (#19475)
  • Loading branch information
awaelchli committed Feb 15, 2024
1 parent 120c87f commit 0e25b1d
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 5 deletions.
14 changes: 11 additions & 3 deletions src/lightning/pytorch/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from typing_extensions import override

import lightning.pytorch as pl
from lightning.fabric.utilities.types import _Stateful
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch import loops # import as loops to avoid circular imports
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher
from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization
Expand Down Expand Up @@ -152,10 +154,16 @@ def reset(self) -> None:
trainer = self.trainer
if trainer.num_training_batches != float("inf"):
expected_steps = math.ceil(trainer.num_training_batches / trainer.accumulate_grad_batches)
if self.global_step % expected_steps != 0:
loader = trainer.fit_loop._combined_loader
assert loader is not None
is_resumable_loader = all(isinstance(loader, _Stateful) for loader in loader.flattened)
if self.global_step % expected_steps != 0 and not is_resumable_loader:
rank_zero_warn(
"You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable"
" results if further training is done. Consider using an end-of-epoch checkpoint"
"You're resuming from a checkpoint that ended before the epoch ended and your dataloader is"
" not resumable. This can cause unreliable results if further training is done."
" Consider using an end-of-epoch checkpoint or make your dataloader resumable by implementing"
" the `state_dict` / `load_state_dict` interface.",
category=PossibleUserWarning,
)
else:
self.batch_progress.reset_on_run()
Expand Down
64 changes: 62 additions & 2 deletions tests/tests_pytorch/loops/test_training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
from unittest.mock import Mock, patch

import pytest
import torch
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.trainer.trainer import Trainer
from lightning_utilities.test.warning import no_warning_call


def test_no_val_on_train_epoch_loop_restart(tmpdir):
def test_no_val_on_train_epoch_loop_restart(tmp_path):
"""Test that training validation loop doesn't get triggered at the beginning of a restart."""
trainer_kwargs = {
"max_epochs": 1,
Expand All @@ -31,7 +35,7 @@ def test_no_val_on_train_epoch_loop_restart(tmpdir):
trainer = Trainer(**trainer_kwargs)
model = BoringModel()
trainer.fit(model)
ckpt_path = str(tmpdir / "last.ckpt")
ckpt_path = str(tmp_path / "last.ckpt")
trainer.save_checkpoint(ckpt_path)

trainer_kwargs["max_epochs"] = 2
Expand Down Expand Up @@ -157,3 +161,59 @@ def optimizer_step(self, epoch, batch_idx, *args, **kwargs):
model = MyModel()
trainer.fit(model)
assert model.last_batch_idx == 3


def test_resume_mid_epoch_warning(tmp_path):
"""Test that resuming from a mid-epoch checkpoint raises a warning unless the dataloader is stateful."""

class NotStatefulIterable:
def __init__(self):
self.index = 0

def __iter__(self):
for i in range(self.index, len(self)):
yield torch.ones(2, 32) * i

def __len__(self):
return 3

class StatefulIterable(NotStatefulIterable):
def state_dict(self):
return {"index": self.index}

def load_state_dict(self, state_dict):
self.index = state_dict["index"]

trainer_kwargs = {
"default_root_dir": tmp_path,
"accelerator": "cpu",
"max_epochs": 1,
"enable_model_summary": False,
"enable_progress_bar": False,
"logger": False,
}

def train_and_resume(dataloader, resume_step, expected_warning):
# Initial training
checkpoint_dir = tmp_path / "checkpoints"
trainer = Trainer(
**trainer_kwargs,
callbacks=ModelCheckpoint(dirpath=checkpoint_dir, every_n_train_steps=1, save_top_k=-1),
)
trainer.fit(BoringModel(), dataloader)

# Resume
trainer = Trainer(**trainer_kwargs, enable_checkpointing=False)
resume_from = checkpoint_dir / f"epoch=0-step={resume_step}.ckpt"
warn_assert = pytest.warns if expected_warning else no_warning_call
with warn_assert(PossibleUserWarning, match="resuming from a checkpoint that ended before"):
trainer.fit(BoringModel(), dataloader, ckpt_path=resume_from)

# Resume mid-epoch, no stateful dataloader -> warning
train_and_resume(dataloader=NotStatefulIterable(), resume_step=1, expected_warning=True)

# Resume end-of-epoch, no stateful dataloader -> no warning
train_and_resume(dataloader=NotStatefulIterable(), resume_step=3, expected_warning=False)

# Resume mid-epoch, stateful dataloader -> no warning
train_and_resume(dataloader=StatefulIterable(), resume_step=1, expected_warning=False)

0 comments on commit 0e25b1d

Please sign in to comment.