Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for reloading the last checkpoint saved by passing ckpt_path="last" #12816

Merged
merged 28 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2dc6d11
last checkpoint WIP
Apr 19, 2022
69c64b5
Merge branch 'master' into feature/last_checkpoint
Apr 20, 2022
ac24724
check timestamps of all available callbacks
Apr 20, 2022
6846fb6
fix unbound error
Apr 20, 2022
e3ac8e5
Merge branch 'master' into feature/last_checkpoint
Apr 20, 2022
d40ea7c
change the logic and pass old tests
Apr 20, 2022
b4eade3
simple test to check last model is loaded
Apr 21, 2022
3c9475c
Merge branch 'master' into feature/last_checkpoint
Apr 21, 2022
dc59d60
Merge branch 'master' into feature/last_checkpoint
Apr 21, 2022
ec62fcc
changelog + docs
Apr 21, 2022
0923aa2
merge
Apr 21, 2022
37cce26
add last checkpoint as a parameter to an existing test
Apr 21, 2022
5e66d4b
Remove unused getattr
carmocca Apr 21, 2022
e4ab941
new test
Apr 22, 2022
bb10f0a
Merge branch 'master' into feature/last_checkpoint
Apr 22, 2022
46c6c2d
Merge branch 'master' into feature/last_checkpoint
Apr 25, 2022
f62982f
Apply Adrian's suggestions and split a test
Apr 25, 2022
f111aba
merge
Apr 27, 2022
645cae2
Minor change
carmocca Apr 28, 2022
f78ef81
Delay ft checkpoint exists check
carmocca Apr 28, 2022
40b78bc
Simplify and speed up test
carmocca Apr 28, 2022
7be3b9e
Simplify test
carmocca Apr 28, 2022
70b38ab
Update CHANGELOG.md
otaj Apr 28, 2022
901e990
Merge branch 'master' into feature/last_checkpoint
Apr 28, 2022
018058d
fix fault tolerant test
Apr 28, 2022
d4f5693
Merge branch 'master' into feature/last_checkpoint
May 2, 2022
05df471
Merge branch 'master' into feature/last_checkpoint
carmocca May 4, 2022
43b505e
Merge branch 'master' into feature/last_checkpoint
carmocca May 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Allow to load last checkpoint using `ckpt_path="last"` ([#12816](https://github.com/PyTorchLightning/pytorch-lightning/pull/12816))
otaj marked this conversation as resolved.
Show resolved Hide resolved


- Added a friendly error message when attempting to call `Trainer.save_checkpoint()` without a model attached ([#12772](https://github.com/PyTorchLightning/pytorch-lightning/pull/12772))

Expand Down
7 changes: 5 additions & 2 deletions docs/source/common/evaluation_intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ To run the test set after training completes, use this method.
# (1) load the best checkpoint automatically (lightning tracks this for you)
trainer.test(ckpt_path="best")

# (2) test using a specific checkpoint
# (2) load the last available checkpoint
trainer.test(ckpt_path="last")

# (3) test using a specific checkpoint
trainer.test(ckpt_path="/path/to/my_checkpoint.ckpt")

# (3) test with an explicit model (will use this model and not load a checkpoint)
# (4) test with an explicit model (will use this model and not load a checkpoint)
trainer.test(model)

.. warning::
Expand Down
68 changes: 55 additions & 13 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
import inspect
import logging
import math
import operator
import os
import traceback
import warnings
from argparse import ArgumentParser, Namespace
from contextlib import contextmanager
from copy import deepcopy
from datetime import timedelta
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type, Union
from weakref import proxy
Expand Down Expand Up @@ -1386,28 +1388,53 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint

ft_checkpoints = [cb for cb in self.callbacks if isinstance(cb, _FaultToleranceCheckpoint)]
ft_ckpt_path = None
if ft_checkpoints:
ft_ckpt_path = ft_checkpoints[0].ckpt_path
fs = get_filesystem(ft_ckpt_path)
if fs.exists(ft_ckpt_path):
return ft_ckpt_path

if model_provided and ckpt_path is None:
# use passed model to function without loading weights
return
tmp_ft_ckpt_path = ft_checkpoints[0].ckpt_path
fs = get_filesystem(tmp_ft_ckpt_path)
if fs.exists(tmp_ft_ckpt_path):
ft_ckpt_path = tmp_ft_ckpt_path

fn = self.state.fn.value

if model_connected and ckpt_path is None:
if ckpt_path is None and ft_ckpt_path is not None and self.state.fn == TrainerFn.FITTING:
ckpt_path = "last"
rank_zero_warn(
f"`.{fn}(ckpt_path=None)` was called without a model."
" The best model of the previous `fit` call will be used."
f" You can pass `{fn}(ckpt_path='best')` to use and best model"
" checkpoint and avoid this warning or"
" `ckpt_path=trainer.checkpoint_callback.last_model_path` to use the last model."
" Because fault tolerance is enabled, the last model of the previous `fit` call will be used."
f" You can pass `{fn}(ckpt_path='best')` to use the best model or"
f" `{fn}(ckpt_path='last')` to use the last model."
" If you pass a value, this warning will be silenced."
)

if model_provided and ckpt_path is None:
# use passed model to function without loading weights
return

if model_connected and ckpt_path is None:
if ft_ckpt_path:
full_msg = (
f"`.{fn}(ckpt_path=None)` was called without a model."
" The best model of the previous `fit` call will be used."
" There is also a fault-tolerant checkpoint available,"
" however it is default only when fitting."
f" You can pass `{fn}(ckpt_path='best')` to use the best model or"
f" `{fn}(ckpt_path='last')` to use the last model."
" If you pass a value, this warning will be silenced."
)
else:
full_msg = (
f"`.{fn}(ckpt_path=None)` was called without a model."
" The best model of the previous `fit` call will be used."
f" You can pass `{fn}(ckpt_path='best')` to use the best model or"
f" `{fn}(ckpt_path='last')` to use the last model."
" If you pass a value, this warning will be silenced."
)

ckpt_path = "best"

rank_zero_warn(full_msg)

if ckpt_path == "best":
if len(self.checkpoint_callbacks) > 1:
rank_zero_warn(
Expand All @@ -1432,6 +1459,21 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
# load best weights
ckpt_path = self.checkpoint_callback.best_model_path

if ckpt_path == "last":
candidates = [ft.ckpt_path for ft in ft_checkpoints] + [
cb.last_model_path for cb in self.checkpoint_callbacks
]
candidates_fs = {path: get_filesystem(path) for path in candidates if path}
candidates_ts = {path: fs.modified(path) for path, fs in candidates_fs.items() if fs.exists(path)}
otaj marked this conversation as resolved.
Show resolved Hide resolved
if not candidates_ts:
rank_zero_warn(
f'.{fn}(ckpt_path="last") is set, but there is no fault tolerant'
" or last checkpoint available. No checkpoint will be loaded."
)
return

ckpt_path = max(candidates_ts.keys(), key=partial(operator.getitem, candidates_ts))

if not ckpt_path:
raise MisconfigurationException(
f"`.{fn}()` found no path for the best weights: {ckpt_path!r}. Please"
Expand Down
146 changes: 136 additions & 10 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pickle
import sys
from argparse import Namespace
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
from unittest import mock
Expand Down Expand Up @@ -664,6 +665,131 @@ def test_benchmark_option(benchmark_, deterministic, expected):
torch.backends.cudnn.benchmark = original_val


@pytest.mark.parametrize("ckpt_path", (None, "last"))
@pytest.mark.parametrize("fn", ("fit", "validate"))
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def test_checkpoint_path_input_last_fault_tolerant(tmpdir, ckpt_path, fn):
should_signal = True

class ExitGracefullyException(Exception):
pass

class TestModel(BoringModel):
def validation_step(self, batch, batch_idx):
self.log("foo", -batch_idx)
if should_signal and batch_idx == 1:
raise ExitGracefullyException
return super().validation_step(batch, batch_idx)

def training_step(self, batch, batch_idx):
if should_signal and batch_idx == 1:
raise ExitGracefullyException
return super().training_step(batch, batch_idx)

model = TestModel()
model.test_epoch_end = None
mc = ModelCheckpoint(monitor="foo")
trainer = Trainer(
max_epochs=2,
limit_val_batches=3,
enable_progress_bar=False,
default_root_dir=tmpdir,
callbacks=[mc],
)
assert trainer.ckpt_path is None
trainer_fn = getattr(trainer, fn)

from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint

ft_checkpoints = [cb for cb in trainer.callbacks if isinstance(cb, _FaultToleranceCheckpoint)]
ft_ckpt_path = ft_checkpoints[0].ckpt_path

if fn == "validate":
should_signal = False
trainer.fit(model)
should_signal = True

with pytest.raises(ExitGracefullyException):
trainer_fn(model)

should_signal = False

if ckpt_path == "last":
ctxt = nullcontext()
final_path = ft_ckpt_path

elif fn == "fit": # and ckpt_path == best
ctxt = pytest.warns(UserWarning, match="Because fault tolerance is enabled")
final_path = ft_ckpt_path
else: # ckpt_path == best and fn == validate
ctxt = pytest.warns(UserWarning, match="There is also a fault-tolerant checkpoint available")
final_path = mc.best_model_path

with ctxt:
if fn == "fit":
trainer_fn(model, ckpt_path=ckpt_path)
else:
trainer_fn(ckpt_path=ckpt_path)
assert trainer.ckpt_path == final_path


@pytest.mark.parametrize("ckpt_path", (None, "last"))
@pytest.mark.parametrize("save_last", (True, False))
@pytest.mark.parametrize("fn", ("fit", "validate"))
def test_checkpoint_path_input_last(tmpdir, ckpt_path, save_last, fn):
class TestModel(BoringModel):
def validation_step(self, batch, batch_idx):
self.log("foo", -batch_idx)
return super().validation_step(batch, batch_idx)

def training_step(self, batch, batch_idx):
return super().training_step(batch, batch_idx)

model = TestModel()
model.test_epoch_end = None
mc = ModelCheckpoint(monitor="foo", save_last=save_last)
trainer = Trainer(
max_epochs=2,
limit_val_batches=3,
enable_progress_bar=False,
default_root_dir=tmpdir,
callbacks=[mc],
)
assert trainer.ckpt_path is None
trainer_fn = getattr(trainer, fn)

if fn == "fit":
if ckpt_path is None:
ctxt = nullcontext()
else:
ctxt = pytest.warns(UserWarning, match="No checkpoint will be loaded")

with ctxt:
trainer_fn(model, ckpt_path=ckpt_path)

assert trainer.ckpt_path is None
else:
trainer.fit(model)
if ckpt_path is None:
ctxt = pytest.warns(
UserWarning,
match=r"(?!.*however it is default only when fitting)^"
r".*The best model of the previous `fit` call will be used",
)
final_path = mc.best_model_path
else:
if save_last:
ctxt = nullcontext()
final_path = mc.last_model_path
else:
ctxt = pytest.warns(UserWarning, match="No checkpoint will be loaded")
final_path = None

with ctxt:
trainer_fn(ckpt_path=ckpt_path)
assert trainer.ckpt_path == final_path


@pytest.mark.parametrize("ckpt_path", (None, "best", "specific"))
@pytest.mark.parametrize("save_top_k", (-1, 0, 1, 2))
@pytest.mark.parametrize("fn", ("validate", "test", "predict"))
Expand Down Expand Up @@ -693,7 +819,7 @@ def predict_step(self, batch, *_):
trainer.fit(model)

trainer_fn = getattr(trainer, fn)
assert getattr(trainer, "ckpt_path") is None
assert trainer.ckpt_path is None

if ckpt_path == "best":
# ckpt_path is 'best', meaning we load the best weights
Expand All @@ -704,20 +830,20 @@ def predict_step(self, batch, *_):
trainer_fn(model, ckpt_path=ckpt_path)
else:
trainer_fn(ckpt_path=ckpt_path)
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
assert trainer.ckpt_path == trainer.checkpoint_callback.best_model_path

trainer_fn(model, ckpt_path=ckpt_path)
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
assert trainer.ckpt_path == trainer.checkpoint_callback.best_model_path
elif ckpt_path is None:
# ckpt_path is None, meaning we don't load any checkpoints and use the provided model
trainer_fn(model, ckpt_path=ckpt_path)
assert getattr(trainer, "ckpt_path") is None
assert trainer.ckpt_path is None

if save_top_k > 0:
# ckpt_path is None with no model provided means load the best weights
with pytest.warns(UserWarning, match="The best model of the previous `fit` call will be used"):
trainer_fn(ckpt_path=ckpt_path)
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
assert trainer.ckpt_path == trainer.checkpoint_callback.best_model_path
else:
# specific checkpoint, pick one from saved ones
if save_top_k == 0:
Expand All @@ -730,10 +856,10 @@ def predict_step(self, batch, *_):
].absolute()
)
trainer_fn(ckpt_path=ckpt_path)
assert getattr(trainer, "ckpt_path") == ckpt_path
assert trainer.ckpt_path == ckpt_path

trainer_fn(model, ckpt_path=ckpt_path)
assert getattr(trainer, "ckpt_path") == ckpt_path
assert trainer.ckpt_path == ckpt_path


@pytest.mark.parametrize("enable_checkpointing", (False, True))
Expand Down Expand Up @@ -764,14 +890,14 @@ def predict_step(self, batch, *_):
trainer.fit(model)

trainer_fn = getattr(trainer, fn)
assert getattr(trainer, "ckpt_path") is None
assert trainer.ckpt_path is None

if enable_checkpointing:
trainer_fn(ckpt_path="best")
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
assert trainer.ckpt_path == trainer.checkpoint_callback.best_model_path

trainer_fn(model, ckpt_path="best")
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
assert trainer.ckpt_path == trainer.checkpoint_callback.best_model_path
else:
with pytest.raises(MisconfigurationException, match="`ModelCheckpoint` is not configured."):
trainer_fn(ckpt_path="best")
Expand Down