Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start version suffixes at 1 #5008

Merged
merged 31 commits into from Jan 26, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
180cd54
Rename original filepath to v0
carmocca Dec 7, 2020
8224a30
Clean-up
carmocca Dec 7, 2020
9486c03
Suggestions from code review
carmocca Dec 10, 2020
96780c6
Revert renaming. Start version number at 1
carmocca Dec 10, 2020
6b1a475
Merge branch 'master' into 5000
carmocca Dec 10, 2020
75586d3
Merge branch 'master' into 5000
carmocca Dec 11, 2020
2a17a02
Merge branch 'master' into 5000
carmocca Dec 11, 2020
e721483
Merge remote-tracking branch 'upstream/release/1.2-dev' into 5000
carmocca Jan 15, 2021
df52229
Add ModelCheckpoint.STARTING_VERSION
carmocca Jan 15, 2021
365c549
Apply suggestions from code review
carmocca Jan 15, 2021
662351c
Add note about class attributes
carmocca Jan 15, 2021
3e6f535
Merge remote-tracking branch 'upstream/release/1.2-dev' into 5000
carmocca Jan 15, 2021
7f70313
Update CHANGELOG
carmocca Jan 15, 2021
23516ca
Fix doc
carmocca Jan 15, 2021
7cfac05
Merge branch 'release/1.2-dev' into 5000
rohitgr7 Jan 21, 2021
44216da
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 25, 2021
62abf1c
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 25, 2021
f6b9a9a
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 25, 2021
18f3a35
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 25, 2021
84eccd1
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 25, 2021
94a7109
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 25, 2021
f03d5f7
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
74473d4
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
6b0cd94
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
be9c55c
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
e5d5530
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
7ac4bc2
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
d5c9ed2
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
2a7c0d7
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
eff2e81
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
b60982a
Merge branch 'release/1.2-dev' into 5000
mergify[bot] Jan 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 18 additions & 21 deletions pytorch_lightning/callbacks/model_checkpoint.py
Expand Up @@ -83,7 +83,7 @@ class ModelCheckpoint(Callback):
Please note that the monitors are checked every `period` epochs.
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if ``save_top_k >= 2`` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.
appended with a version count starting with `v1`.
carmocca marked this conversation as resolved.
Show resolved Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
mode: one of {auto, min, max}.
If ``save_top_k != 0``, the decision
to overwrite the current save file is made
Expand Down Expand Up @@ -133,6 +133,7 @@ class ModelCheckpoint(Callback):
CHECKPOINT_JOIN_CHAR = "-"
CHECKPOINT_NAME_LAST = "last"
FILE_EXTENSION = ".ckpt"
STARTING_VERSION = 1

def __init__(
self,
Expand Down Expand Up @@ -485,28 +486,24 @@ def _validate_monitor_key(self, trainer):

def _get_metric_interpolated_filepath_name(
self,
ckpt_name_metrics: Dict[str, Any],
monitor_candidates: Dict[str, Any],
epoch: int,
step: int,
del_filepath: Optional[str] = None
) -> str:
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)

version_cnt = 0
filepath = self.format_checkpoint_name(epoch, step, monitor_candidates)
version = self.STARTING_VERSION
while self._fs.exists(filepath) and filepath != del_filepath:
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
version_cnt += 1

filepath = self.format_checkpoint_name(epoch, step, monitor_candidates, ver=version)
version += 1
return filepath

def _monitor_candidates(self, trainer):
ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics)
ckpt_name_metrics.update(trainer.logger_connector.callback_metrics)
ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics)
Comment on lines -503 to -505
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we see somewhere that these are already contained?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton told me a long time ago (don't remember where). Can you comment?

ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch})
return ckpt_name_metrics
monitor_candidates = deepcopy(trainer.logger_connector.callback_metrics)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
monitor_candidates.update(step=trainer.global_step, epoch=trainer.current_epoch)
return monitor_candidates

def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
def _save_last_checkpoint(self, trainer, pl_module, monitor_candidates):
should_save_last = self.monitor is None or self.save_last
if not should_save_last:
return
Expand All @@ -517,13 +514,13 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
self.CHECKPOINT_NAME_LAST,
trainer.current_epoch,
trainer.global_step,
ckpt_name_metrics,
prefix=self.prefix
monitor_candidates,
prefix=self.prefix,
)
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
else:
last_filepath = self._get_metric_interpolated_filepath_name(
ckpt_name_metrics, trainer.current_epoch, trainer.global_step
monitor_candidates, trainer.current_epoch, trainer.global_step
)

accelerator_backend = trainer.accelerator_backend
Expand All @@ -534,10 +531,10 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
else:
self._save_model(last_filepath, trainer, pl_module)
if (
self.last_model_path
and self.last_model_path != last_filepath
and (self.save_top_k != -1 or self.save_last)
and trainer.is_global_zero
self.last_model_path
and self.last_model_path != last_filepath
and (self.save_top_k != -1 or self.save_last)
and trainer.is_global_zero
):
self._del_model(self.last_model_path)
self.last_model_path = last_filepath
Expand Down
91 changes: 58 additions & 33 deletions tests/checkpointing/test_model_checkpoint.py
Expand Up @@ -804,18 +804,20 @@ def test_val_check_interval_checkpoint_files(tmpdir):
save_top_k=-1,
monitor="val_acc",
mode="max",
verbose=True
)
trainer = Trainer(
default_root_dir=tmpdir,
val_check_interval=0.2,
max_epochs=1,
limit_train_batches=10,
callbacks=[model_checkpoint]
callbacks=[model_checkpoint],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(model)
files = sorted([p.name for p in Path(tmpdir).glob("*.ckpt")])
assert files == [f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]]
files = {p.basename for p in tmpdir.listdir()}
assert files == {f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]}


def test_current_score(tmpdir):
Expand Down Expand Up @@ -910,43 +912,66 @@ def __init__(self, hparams):
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type


@pytest.mark.parametrize('max_epochs', [3, 4])
@pytest.mark.parametrize(
'save_top_k, expected',
[
(1, ['curr_epoch.ckpt']),
(2, ['curr_epoch.ckpt', 'curr_epoch-v0.ckpt']),
]
)
def test_model_checkpoint_file_already_exists(tmpdir, max_epochs, save_top_k, expected):
def test_ckpt_version_after_rerun_new_trainer(tmpdir):
"""
Test that version is added to filename if required and it already exists in dirpath.
Check that previous checkpoints are renamed to have the correct
version suffix when new trainer instances are used
"""
model_checkpoint = ModelCheckpoint(
dirpath=tmpdir,
filename='curr_epoch',
save_top_k=save_top_k,
monitor='epoch',
mode='max',
)
epochs = 2
for i in range(epochs):
mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="{epoch}")
trainer = Trainer(
max_epochs=epochs,
limit_train_batches=1,
limit_val_batches=1,
default_root_dir=tmpdir,
callbacks=[mc],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(BoringModel())

# check best_k_models state
expected = {"epoch=0-v1.ckpt", "epoch=1-v1.ckpt"} if i else {"epoch=0.ckpt", "epoch=1.ckpt"}
assert {Path(f).name for f in mc.best_k_models.keys()} == expected

# check created ckpts
assert set(f.basename for f in tmpdir.listdir()) == {
"epoch=0.ckpt",
"epoch=1.ckpt",
"epoch=0-v1.ckpt",
"epoch=1-v1.ckpt",
}


def test_ckpt_version_after_rerun_same_trainer(tmpdir):
"""
Check that previous checkpoints are renamed to have the correct
version suffix when the same trainer instance is used
"""
mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="test")
mc.STARTING_VERSION = 9
trainer = Trainer(
max_epochs=2,
limit_train_batches=1,
limit_val_batches=1,
default_root_dir=tmpdir,
callbacks=[model_checkpoint],
max_epochs=max_epochs,
limit_train_batches=2,
limit_val_batches=2,
logger=None,
callbacks=[mc],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(BoringModel())
trainer.max_epochs = 4
trainer.fit(BoringModel())

model = BoringModel()
trainer.fit(model)
ckpt_files = os.listdir(tmpdir)
assert set(ckpt_files) == set(expected)

epochs_in_ckpt_files = [pl_load(os.path.join(tmpdir, f))['epoch'] - 1 for f in ckpt_files]
assert sorted(epochs_in_ckpt_files) == list(range(max_epochs - save_top_k, max_epochs))
ckpt_range = range(mc.STARTING_VERSION, trainer.max_epochs + mc.STARTING_VERSION)
expected = {'test.ckpt', *[f"test-v{i}.ckpt" for i in ckpt_range]}
# check best_k_models state
assert {Path(f).name for f in mc.best_k_models.keys()} == expected
# check created ckpts
assert set(sorted(os.listdir(tmpdir))) == expected


def test_model_checkpoint_mode_options():
Expand Down