Skip to content

Commit

Permalink
Allow user to disable the automatic formatting of checkpoint file nam…
Browse files Browse the repository at this point in the history
…es. (#6277)

* cleaning SWA (#6259)

* rename

* if

* test

* chlog

* Remove opt from manual_backward in docs (#6267)

* switch agents pool (#6270)

* Allow user to disable the automatic formatting of checkpoint file names.

* Added changelog entry.

* Made flake8 happy.

* Applied review suggestion: quotes for special characters in docstring

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Fixed example in docstring.

* Fixed syntax error in docstring.

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
5 people committed Mar 11, 2021
1 parent f4cc745 commit 2ecda5d
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -27,6 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))


- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277))


- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))


Expand Down
31 changes: 29 additions & 2 deletions pytorch_lightning/callbacks/model_checkpoint.py
Expand Up @@ -131,6 +131,16 @@ class ModelCheckpoint(Callback):
... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
... )
# save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
# or Neptune, due to the presence of characters like '=' or '/')
# saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
... monitor='val/loss',
... dirpath='my/path/',
... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
... auto_insert_metric_name=False
... )
# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
Expand All @@ -156,6 +166,7 @@ def __init__(
save_weights_only: bool = False,
mode: str = "min",
period: int = 1,
auto_insert_metric_name: bool = True
):
super().__init__()
self.monitor = monitor
Expand All @@ -164,6 +175,7 @@ def __init__(
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.auto_insert_metric_name = auto_insert_metric_name
self._last_global_step_saved = -1
self.current_score = None
self.best_k_models = {}
Expand Down Expand Up @@ -356,6 +368,7 @@ def _format_checkpoint_name(
step: int,
metrics: Dict[str, Any],
prefix: str = "",
auto_insert_metric_name: bool = True
) -> str:
if not filename:
# filename is not set, use default name
Expand All @@ -367,7 +380,10 @@ def _format_checkpoint_name(
metrics.update({"epoch": epoch, 'step': step})
for group in groups:
name = group[1:]
filename = filename.replace(group, name + "={" + name)

if auto_insert_metric_name:
filename = filename.replace(group, name + "={" + name)

if name not in metrics:
metrics[name] = 0
filename = filename.format(**metrics)
Expand All @@ -392,6 +408,11 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any],
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir,
... filename='epoch={epoch}-validation_loss={val_loss:.2f}',
... auto_insert_metric_name=False)
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
'epoch=2-validation_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={}))
'missing=0.ckpt'
Expand All @@ -400,7 +421,13 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any],
'step=0.ckpt'
"""
filename = self._format_checkpoint_name(self.filename, epoch, step, metrics)
filename = self._format_checkpoint_name(
self.filename,
epoch,
step,
metrics,
auto_insert_metric_name=self.auto_insert_metric_name)

if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))

Expand Down
9 changes: 9 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Expand Up @@ -432,6 +432,15 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
ckpt_name = ckpt.format_checkpoint_name(4, 3, {'val/loss': 0.03})
assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt'

# auto_insert_metric_name=False
ckpt_name = ModelCheckpoint._format_checkpoint_name(
'epoch={epoch:03d}-val_acc={val/acc}',
3,
2,
{'val/acc': 0.03},
auto_insert_metric_name=False)
assert ckpt_name == 'epoch=003-val_acc=0.03'


class ModelCheckpointExtensionTest(ModelCheckpoint):
FILE_EXTENSION = '.tpkc'
Expand Down

0 comments on commit 2ecda5d

Please sign in to comment.