Skip to content

Commit

Permalink
Start version suffixes at 1 (#5008)
Browse files Browse the repository at this point in the history
* Rename original filepath to v0

* Clean-up

* Suggestions from code review

* Revert renaming. Start version number at 1

* Add ModelCheckpoint.STARTING_VERSION

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Add note about class attributes

* Update CHANGELOG

* Fix doc

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
  • Loading branch information
carmocca and rohitgr7 committed Jan 26, 2021
1 parent dee5553 commit 9d165f6
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 55 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -91,6 +91,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the default of `find_unused_parameters` to `False` in DDP ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185))


- Changed `ModelCheckpoint` version suffixes to start at 1 ([5008](https://github.com/PyTorchLightning/pytorch-lightning/pull/5008))


- Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516))


Expand Down
53 changes: 31 additions & 22 deletions pytorch_lightning/callbacks/model_checkpoint.py
Expand Up @@ -80,10 +80,10 @@ class ModelCheckpoint(Callback):
the quantity monitored will be saved.
if ``save_top_k == 0``, no models are saved.
if ``save_top_k == -1``, all models are saved.
Please note that the monitors are checked every `period` epochs.
Please note that the monitors are checked every ``period`` epochs.
if ``save_top_k >= 2`` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.
appended with a version count starting with ``v1``.
mode: one of {auto, min, max}.
If ``save_top_k != 0``, the decision
to overwrite the current save file is made
Expand All @@ -105,6 +105,17 @@ class ModelCheckpoint(Callback):
.. warning::
This argument has been deprecated in v1.1 and will be removed in v1.3
Note:
For extra customization, ModelCheckpoint includes the following attributes:
- ``CHECKPOINT_JOIN_CHAR = "-"``
- ``CHECKPOINT_NAME_LAST = "last"``
- ``FILE_EXTENSION = ".ckpt"``
- ``STARTING_VERSION = 1``
For example, you can change the default last checkpoint name by doing
``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"``
Example::
>>> from pytorch_lightning import Trainer
Expand All @@ -128,11 +139,13 @@ class ModelCheckpoint(Callback):
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path
"""

CHECKPOINT_JOIN_CHAR = "-"
CHECKPOINT_NAME_LAST = "last"
FILE_EXTENSION = ".ckpt"
STARTING_VERSION = 1

def __init__(
self,
Expand Down Expand Up @@ -485,28 +498,24 @@ def _validate_monitor_key(self, trainer):

def _get_metric_interpolated_filepath_name(
self,
ckpt_name_metrics: Dict[str, Any],
monitor_candidates: Dict[str, Any],
epoch: int,
step: int,
del_filepath: Optional[str] = None
) -> str:
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)

version_cnt = 0
filepath = self.format_checkpoint_name(epoch, step, monitor_candidates)
version = self.STARTING_VERSION
while self._fs.exists(filepath) and filepath != del_filepath:
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
version_cnt += 1

filepath = self.format_checkpoint_name(epoch, step, monitor_candidates, ver=version)
version += 1
return filepath

def _monitor_candidates(self, trainer):
ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics)
ckpt_name_metrics.update(trainer.logger_connector.callback_metrics)
ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics)
ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch})
return ckpt_name_metrics
monitor_candidates = deepcopy(trainer.logger_connector.callback_metrics)
monitor_candidates.update(step=trainer.global_step, epoch=trainer.current_epoch)
return monitor_candidates

def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
def _save_last_checkpoint(self, trainer, pl_module, monitor_candidates):
should_save_last = self.monitor is None or self.save_last
if not should_save_last:
return
Expand All @@ -517,13 +526,13 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
self.CHECKPOINT_NAME_LAST,
trainer.current_epoch,
trainer.global_step,
ckpt_name_metrics,
prefix=self.prefix
monitor_candidates,
prefix=self.prefix,
)
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
else:
last_filepath = self._get_metric_interpolated_filepath_name(
ckpt_name_metrics, trainer.current_epoch, trainer.global_step
monitor_candidates, trainer.current_epoch, trainer.global_step
)

accelerator_backend = trainer.accelerator_backend
Expand All @@ -534,10 +543,10 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
else:
self._save_model(last_filepath, trainer, pl_module)
if (
self.last_model_path
and self.last_model_path != last_filepath
and (self.save_top_k != -1 or self.save_last)
and trainer.is_global_zero
self.last_model_path
and self.last_model_path != last_filepath
and (self.save_top_k != -1 or self.save_last)
and trainer.is_global_zero
):
self._del_model(self.last_model_path)
self.last_model_path = last_filepath
Expand Down
91 changes: 58 additions & 33 deletions tests/checkpointing/test_model_checkpoint.py
Expand Up @@ -738,18 +738,20 @@ def test_val_check_interval_checkpoint_files(tmpdir):
save_top_k=-1,
monitor="val_acc",
mode="max",
verbose=True
)
trainer = Trainer(
default_root_dir=tmpdir,
val_check_interval=0.2,
max_epochs=1,
limit_train_batches=10,
callbacks=[model_checkpoint]
callbacks=[model_checkpoint],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(model)
files = sorted([p.name for p in Path(tmpdir).glob("*.ckpt")])
assert files == [f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]]
files = {p.basename for p in tmpdir.listdir()}
assert files == {f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]}


def test_current_score(tmpdir):
Expand Down Expand Up @@ -844,43 +846,66 @@ def __init__(self, hparams):
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type


@pytest.mark.parametrize('max_epochs', [3, 4])
@pytest.mark.parametrize(
'save_top_k, expected',
[
(1, ['curr_epoch.ckpt']),
(2, ['curr_epoch.ckpt', 'curr_epoch-v0.ckpt']),
]
)
def test_model_checkpoint_file_already_exists(tmpdir, max_epochs, save_top_k, expected):
def test_ckpt_version_after_rerun_new_trainer(tmpdir):
"""
Test that version is added to filename if required and it already exists in dirpath.
Check that previous checkpoints are renamed to have the correct
version suffix when new trainer instances are used
"""
model_checkpoint = ModelCheckpoint(
dirpath=tmpdir,
filename='curr_epoch',
save_top_k=save_top_k,
monitor='epoch',
mode='max',
)
epochs = 2
for i in range(epochs):
mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="{epoch}")
trainer = Trainer(
max_epochs=epochs,
limit_train_batches=1,
limit_val_batches=1,
default_root_dir=tmpdir,
callbacks=[mc],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(BoringModel())

# check best_k_models state
expected = {"epoch=0-v1.ckpt", "epoch=1-v1.ckpt"} if i else {"epoch=0.ckpt", "epoch=1.ckpt"}
assert {Path(f).name for f in mc.best_k_models.keys()} == expected

# check created ckpts
assert set(f.basename for f in tmpdir.listdir()) == {
"epoch=0.ckpt",
"epoch=1.ckpt",
"epoch=0-v1.ckpt",
"epoch=1-v1.ckpt",
}


def test_ckpt_version_after_rerun_same_trainer(tmpdir):
"""
Check that previous checkpoints are renamed to have the correct
version suffix when the same trainer instance is used
"""
mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="test")
mc.STARTING_VERSION = 9
trainer = Trainer(
max_epochs=2,
limit_train_batches=1,
limit_val_batches=1,
default_root_dir=tmpdir,
callbacks=[model_checkpoint],
max_epochs=max_epochs,
limit_train_batches=2,
limit_val_batches=2,
logger=None,
callbacks=[mc],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(BoringModel())
trainer.max_epochs = 4
trainer.fit(BoringModel())

model = BoringModel()
trainer.fit(model)
ckpt_files = os.listdir(tmpdir)
assert set(ckpt_files) == set(expected)

epochs_in_ckpt_files = [pl_load(os.path.join(tmpdir, f))['epoch'] - 1 for f in ckpt_files]
assert sorted(epochs_in_ckpt_files) == list(range(max_epochs - save_top_k, max_epochs))
ckpt_range = range(mc.STARTING_VERSION, trainer.max_epochs + mc.STARTING_VERSION)
expected = {'test.ckpt', *[f"test-v{i}.ckpt" for i in ckpt_range]}
# check best_k_models state
assert {Path(f).name for f in mc.best_k_models.keys()} == expected
# check created ckpts
assert set(sorted(os.listdir(tmpdir))) == expected


def test_model_checkpoint_mode_options():
Expand Down

0 comments on commit 9d165f6

Please sign in to comment.