Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow user to disable the automatic formatting of checkpoint file names. #6277

Merged
merged 13 commits into from Mar 11, 2021
Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))


- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277))


- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))


Expand Down
31 changes: 29 additions & 2 deletions pytorch_lightning/callbacks/model_checkpoint.py
Expand Up @@ -131,6 +131,16 @@ class ModelCheckpoint(Callback):
... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
... )

# save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
# or Neptune, due to the presence of characters like '=' or '/')
# saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
... monitor='val/loss',
... dirpath='my/path/',
... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
... auto_insert_metric_name=False
... )

# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
Expand All @@ -156,6 +166,7 @@ def __init__(
save_weights_only: bool = False,
mode: str = "min",
period: int = 1,
auto_insert_metric_name: bool = True
):
super().__init__()
self.monitor = monitor
Expand All @@ -164,6 +175,7 @@ def __init__(
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.auto_insert_metric_name = auto_insert_metric_name
self._last_global_step_saved = -1
self.current_score = None
self.best_k_models = {}
Expand Down Expand Up @@ -356,6 +368,7 @@ def _format_checkpoint_name(
step: int,
metrics: Dict[str, Any],
prefix: str = "",
auto_insert_metric_name: bool = True
) -> str:
if not filename:
# filename is not set, use default name
Expand All @@ -367,7 +380,10 @@ def _format_checkpoint_name(
metrics.update({"epoch": epoch, 'step': step})
for group in groups:
name = group[1:]
filename = filename.replace(group, name + "={" + name)

if auto_insert_metric_name:
filename = filename.replace(group, name + "={" + name)
Copy link
Contributor

@talregev talregev Apr 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filename = filename.replace('/', '_')

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @talregev
Thanks for the comment.
Could you open an issue about it, describing the need for this?


if name not in metrics:
metrics[name] = 0
filename = filename.format(**metrics)
Expand All @@ -392,6 +408,11 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any],
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir,
... filename='epoch={epoch}-validation_loss={val_loss:.2f}',
... auto_insert_metric_name=False)
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
'epoch=2-validation_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={}))
'missing=0.ckpt'
Expand All @@ -400,7 +421,13 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any],
'step=0.ckpt'

"""
filename = self._format_checkpoint_name(self.filename, epoch, step, metrics)
filename = self._format_checkpoint_name(
self.filename,
epoch,
step,
metrics,
auto_insert_metric_name=self.auto_insert_metric_name)

if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))

Expand Down
9 changes: 9 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Expand Up @@ -425,6 +425,15 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
ckpt_name = ckpt.format_checkpoint_name(4, 3, {'val/loss': 0.03})
assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt'

# auto_insert_metric_name=False
ckpt_name = ModelCheckpoint._format_checkpoint_name(
'epoch={epoch:03d}-val_acc={val/acc}',
3,
2,
{'val/acc': 0.03},
auto_insert_metric_name=False)
assert ckpt_name == 'epoch=003-val_acc=0.03'


class ModelCheckpointExtensionTest(ModelCheckpoint):
FILE_EXTENSION = '.tpkc'
Expand Down