Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow user to disable the automatic formatting of checkpoint file names. #6277

Merged
merged 13 commits into from Mar 11, 2021
Merged
6 changes: 6 additions & 0 deletions CHANGELOG.md
Expand Up @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))


- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277))


- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))


Expand All @@ -41,6 +44,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))


carmocca marked this conversation as resolved.
Show resolved Hide resolved
### Deprecated


Expand Down
30 changes: 28 additions & 2 deletions pytorch_lightning/callbacks/model_checkpoint.py
Expand Up @@ -131,6 +131,15 @@ class ModelCheckpoint(Callback):
... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
... )

# save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
# or Neptune, due to the presence of characters like = or /)
maxfrei750 marked this conversation as resolved.
Show resolved Hide resolved
# saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
... monitor='val/loss',
... dirpath='my/path/',
... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}'
... )
maxfrei750 marked this conversation as resolved.
Show resolved Hide resolved

# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
Expand All @@ -156,6 +165,7 @@ def __init__(
save_weights_only: bool = False,
mode: str = "min",
period: int = 1,
auto_insert_metric_name: bool = True
):
super().__init__()
self.monitor = monitor
Expand All @@ -164,6 +174,7 @@ def __init__(
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.auto_insert_metric_name = auto_insert_metric_name
self._last_global_step_saved = -1
self.current_score = None
self.best_k_models = {}
Expand Down Expand Up @@ -356,6 +367,7 @@ def _format_checkpoint_name(
step: int,
metrics: Dict[str, Any],
prefix: str = "",
auto_insert_metric_name: bool = True
) -> str:
if not filename:
# filename is not set, use default name
Expand All @@ -367,7 +379,10 @@ def _format_checkpoint_name(
metrics.update({"epoch": epoch, 'step': step})
for group in groups:
name = group[1:]
filename = filename.replace(group, name + "={" + name)

if auto_insert_metric_name:
filename = filename.replace(group, name + "={" + name)
Copy link
Contributor

@talregev talregev Apr 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filename = filename.replace('/', '_')

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @talregev
Thanks for the comment.
Could you open an issue about it, describing the need for this?


if name not in metrics:
metrics[name] = 0
filename = filename.format(**metrics)
Expand All @@ -392,6 +407,11 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any],
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir,
... filename='epoch={epoch}-validation_loss={val_loss:.2f}',
... auto_insert_metric_name=False)
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
'epoch=2-validation_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={}))
'missing=0.ckpt'
Expand All @@ -400,7 +420,13 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any],
'step=0.ckpt'

"""
filename = self._format_checkpoint_name(self.filename, epoch, step, metrics)
filename = self._format_checkpoint_name(
self.filename,
epoch,
step,
metrics,
auto_insert_metric_name=self.auto_insert_metric_name)

if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))

Expand Down
9 changes: 9 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Expand Up @@ -425,6 +425,15 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
ckpt_name = ckpt.format_checkpoint_name(4, 3, {'val/loss': 0.03})
assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt'

# auto_insert_metric_name=False
ckpt_name = ModelCheckpoint._format_checkpoint_name(
'epoch={epoch:03d}-val_acc={val/acc}',
3,
2,
{'val/acc': 0.03},
auto_insert_metric_name=False)
assert ckpt_name == 'epoch=003-val_acc=0.03'


class ModelCheckpointExtensionTest(ModelCheckpoint):
FILE_EXTENSION = '.tpkc'
Expand Down