Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

every_n_val_epochs -> every_n_epochs #8383

Merged
merged 3 commits into from Jul 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -325,6 +325,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the use of `CheckpointConnector.hpc_load()` in favor of `CheckpointConnector.restore()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652))


- Deprecated `ModelCheckpoint(every_n_val_epochs)` in favor of `ModelCheckpoint(every_n_epochs)` ([#8383](https://github.com/PyTorchLightning/pytorch-lightning/pull/8383))


- Deprecated `DDPPlugin.task_idx` in favor of `DDPPlugin.local_rank` ([#8203](https://github.com/PyTorchLightning/pytorch-lightning/pull/8203))


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Expand Up @@ -61,7 +61,7 @@ class EarlyStopping(Callback):
stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold.
divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold.
check_on_train_epoch_end: whether to run early stopping at the end of the training epoch.
If this is ``False``, then the check runs at the end of the validation epoch.
If this is ``False``, then the check runs at the end of the validation.

Raises:
MisconfigurationException:
Expand Down
88 changes: 51 additions & 37 deletions pytorch_lightning/callbacks/model_checkpoint.py
Expand Up @@ -108,26 +108,33 @@ class ModelCheckpoint(Callback):
every_n_train_steps: Number of training steps between checkpoints.
If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training.
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative.
This must be mutually exclusive with ``train_time_interval`` and ``every_n_val_epochs``.
This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
train_time_interval: Checkpoints are monitored at the specified time interval.
For all practical purposes, this cannot be smaller than the amount
of time it takes to process a single training batch. This is not
guaranteed to execute at the exact time specified, but should be close.
This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_val_epochs``.
every_n_val_epochs: Number of validation epochs between checkpoints.
If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end.
To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative.
This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
every_n_epochs: Number of epochs between checkpoints.
carmocca marked this conversation as resolved.
Show resolved Hide resolved
If ``every_n_epochs == None or every_n_epochs == 0``, we skip saving when the epoch ends.
To disable, set ``every_n_epochs = 0``. This value must be ``None`` or non-negative.
This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``.
Setting both ``ModelCheckpoint(..., every_n_val_epochs=V)`` and
Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and
``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
will only save checkpoints at epochs 0 < E <= N
where both values for ``every_n_val_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
period: Interval (number of epochs) between checkpoints.

.. warning::
This argument has been deprecated in v1.3 and will be removed in v1.5.

Use ``every_n_val_epochs`` instead.
Use ``every_n_epochs`` instead.
every_n_val_epochs: Number of epochs between checkpoints.

.. warning::
This argument has been deprecated in v1.4 and will be removed in v1.6.

Use ``every_n_epochs`` instead.


Note:
For extra customization, ModelCheckpoint includes the following attributes:
Expand Down Expand Up @@ -205,8 +212,9 @@ def __init__(
auto_insert_metric_name: bool = True,
every_n_train_steps: Optional[int] = None,
train_time_interval: Optional[timedelta] = None,
every_n_val_epochs: Optional[int] = None,
every_n_epochs: Optional[int] = None,
period: Optional[int] = None,
every_n_val_epochs: Optional[int] = None,
):
super().__init__()
self.monitor = monitor
Expand All @@ -224,9 +232,16 @@ def __init__(
self.best_model_path = ""
self.last_model_path = ""

if every_n_val_epochs is not None:
rank_zero_deprecation(
'`ModelCheckpoint(every_n_val_epochs)` is deprecated in v1.4 and will be removed in v1.6.'
' Please use `every_n_epochs` instead.'
)
every_n_epochs = every_n_val_epochs

self.__init_monitor_mode(mode)
self.__init_ckpt_dir(dirpath, filename)
self.__init_triggers(every_n_train_steps, every_n_val_epochs, train_time_interval, period)
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval, period)
self.__validate_init_configuration()
self._save_function = None

Expand Down Expand Up @@ -274,11 +289,10 @@ def on_train_batch_end(

def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
""" Save a checkpoint at the end of the validation stage. """
skip = (
self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1
or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0
)
if skip:
if (
self._should_skip_saving_checkpoint(trainer) or self._every_n_epochs < 1
or (trainer.current_epoch + 1) % self._every_n_epochs != 0
):
return
self.save_checkpoint(trainer)

Expand Down Expand Up @@ -354,18 +368,16 @@ def __validate_init_configuration(self) -> None:
raise MisconfigurationException(
f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0'
)
if self._every_n_val_epochs < 0:
raise MisconfigurationException(
f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0'
)
if self._every_n_epochs < 0:
raise MisconfigurationException(f'Invalid value for every_n_epochs={self._every_n_epochs}. Must be >= 0')

every_n_train_steps_triggered = self._every_n_train_steps >= 1
every_n_val_epochs_triggered = self._every_n_val_epochs >= 1
every_n_epochs_triggered = self._every_n_epochs >= 1
train_time_interval_triggered = self._train_time_interval is not None
if every_n_train_steps_triggered + every_n_val_epochs_triggered + train_time_interval_triggered > 1:
if every_n_train_steps_triggered + every_n_epochs_triggered + train_time_interval_triggered > 1:
raise MisconfigurationException(
f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, "
f"every_n_val_epochs={self._every_n_val_epochs} and train_time_interval={self._train_time_interval} "
f"every_n_epochs={self._every_n_epochs} and train_time_interval={self._train_time_interval} "
"should be mutually exclusive."
)

Expand Down Expand Up @@ -412,47 +424,49 @@ def __init_monitor_mode(self, mode: str) -> None:
self.kth_value, self.mode = mode_dict[mode]

def __init_triggers(
self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int],
train_time_interval: Optional[timedelta], period: Optional[int]
self,
every_n_train_steps: Optional[int],
every_n_epochs: Optional[int],
train_time_interval: Optional[timedelta],
period: Optional[int],
) -> None:

# Default to running once after each validation epoch if neither
# every_n_train_steps nor every_n_val_epochs is set
if every_n_train_steps is None and every_n_val_epochs is None and train_time_interval is None:
every_n_val_epochs = 1
# every_n_train_steps nor every_n_epochs is set
if every_n_train_steps is None and every_n_epochs is None and train_time_interval is None:
every_n_epochs = 1
every_n_train_steps = 0
log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1")
log.debug("Both every_n_train_steps and every_n_epochs are not set. Setting every_n_epochs=1")
else:
every_n_val_epochs = every_n_val_epochs or 0
every_n_epochs = every_n_epochs or 0
every_n_train_steps = every_n_train_steps or 0

self._train_time_interval: Optional[timedelta] = train_time_interval
self._every_n_val_epochs: int = every_n_val_epochs
self._every_n_epochs: int = every_n_epochs
self._every_n_train_steps: int = every_n_train_steps

# period takes precedence over every_n_val_epochs for backwards compatibility
# period takes precedence over every_n_epochs for backwards compatibility
if period is not None:
rank_zero_deprecation(
'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.'
' Please use `every_n_epochs` instead.'
)
self._every_n_val_epochs = period

self._period = self._every_n_val_epochs
self._every_n_epochs = period
self._period = self._every_n_epochs

@property
def period(self) -> Optional[int]:
rank_zero_deprecation(
'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.'
' Please use `every_n_epochs` instead.'
)
return self._period

@period.setter
def period(self, value: Optional[int]) -> None:
rank_zero_deprecation(
'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.'
' Please use `every_n_epochs` instead.'
)
self._period = value

Expand Down
Expand Up @@ -60,6 +60,7 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_inter
max_epochs=epochs,
weights_summary=None,
val_check_interval=val_check_interval,
limit_val_batches=1,
progress_bar_refresh_rate=0,
)
trainer.fit(model)
Expand Down
62 changes: 29 additions & 33 deletions tests/checkpointing/test_model_checkpoint.py
Expand Up @@ -558,50 +558,48 @@ def test_none_monitor_save_last(tmpdir):
ModelCheckpoint(dirpath=tmpdir, save_last=False)


def test_invalid_every_n_val_epochs(tmpdir):
""" Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """
def test_invalid_every_n_epochs(tmpdir):
""" Make sure that a MisconfigurationException is raised for a negative every_n_epochs argument. """
with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'):
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=-3)
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-3)
# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=0)
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=1)
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0)
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=1)
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=2)


def test_invalid_every_n_train_steps(tmpdir):
""" Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """
""" Make sure that a MisconfigurationException is raised for a negative every_n_epochs argument. """
with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'):
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=-3)
# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1)
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=2)


def test_invalid_trigger_combination(tmpdir):
"""
Test that a MisconfigurationException is raised if more than one of
every_n_val_epochs, every_n_train_steps, and train_time_interval are enabled together.
every_n_epochs, every_n_train_steps, and train_time_interval are enabled together.
"""
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_val_epochs=2)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_epochs=2)
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_val_epochs=2)
ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_epochs=2)
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_train_steps=2)

# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=3)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_val_epochs=0)
ModelCheckpoint(
dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=0, train_time_interval=timedelta(minutes=1)
)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_epochs=3)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_epochs=0)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_epochs=0, train_time_interval=timedelta(minutes=1))


def test_none_every_n_train_steps_val_epochs(tmpdir):
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir)
assert checkpoint_callback.period == 1
assert checkpoint_callback._every_n_val_epochs == 1
assert checkpoint_callback._every_n_epochs == 1
assert checkpoint_callback._every_n_train_steps == 0


Expand Down Expand Up @@ -659,12 +657,12 @@ def test_model_checkpoint_period(tmpdir, period: int):
assert set(os.listdir(tmpdir)) == set(expected)


@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
@pytest.mark.parametrize("every_n_epochs", list(range(4)))
def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs):
model = LogInTwoMethods()
epochs = 5
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs
dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs
)
trainer = Trainer(
default_root_dir=tmpdir,
Expand All @@ -677,22 +675,17 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
trainer.fit(model)

# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs)
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else []
assert set(os.listdir(tmpdir)) == set(expected)


@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs):
""" Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """
@pytest.mark.parametrize("every_n_epochs", list(range(4)))
def test_model_checkpoint_every_n_epochs_and_period(tmpdir, every_n_epochs):
""" Tests that if period is set, it takes precedence over every_n_epochs for backwards compatibility. """
model = LogInTwoMethods()
epochs = 5
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir,
filename='{epoch}',
save_top_k=-1,
every_n_val_epochs=(2 * every_n_val_epochs),
period=every_n_val_epochs
dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=(2 * every_n_epochs), period=every_n_epochs
)
trainer = Trainer(
default_root_dir=tmpdir,
Expand All @@ -705,8 +698,7 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc
trainer.fit(model)

# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs)
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else []
assert set(os.listdir(tmpdir)) == set(expected)


Expand All @@ -719,7 +711,7 @@ def test_ckpt_every_n_train_steps(tmpdir):
epoch_length = 64
checkpoint_callback = ModelCheckpoint(
filename="{step}",
every_n_val_epochs=0,
every_n_epochs=0,
every_n_train_steps=every_n_train_steps,
dirpath=tmpdir,
save_top_k=-1,
Expand Down Expand Up @@ -892,6 +884,8 @@ def test_model_checkpoint_save_last_warning(
default_root_dir=tmpdir,
callbacks=[ckpt],
max_epochs=max_epochs,
limit_train_batches=1,
limit_val_batches=1,
)
with caplog.at_level(logging.INFO):
trainer.fit(model)
Expand All @@ -910,6 +904,8 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
default_root_dir=tmpdir,
callbacks=[model_checkpoint],
max_epochs=num_epochs,
limit_train_batches=2,
limit_val_batches=2,
)
trainer.fit(model)

Expand Down
6 changes: 6 additions & 0 deletions tests/deprecated_api/test_remove_1-6.py
Expand Up @@ -15,6 +15,7 @@
import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
Expand Down Expand Up @@ -303,3 +304,8 @@ def test_v1_6_0_deprecated_disable_validation():
trainer = Trainer()
with pytest.deprecated_call(match="disable_validation` is deprecated in v1.4"):
_ = trainer.disable_validation


def test_v1_6_0_every_n_val_epochs():
with pytest.deprecated_call(match="use `every_n_epochs` instead"):
_ = ModelCheckpoint(every_n_val_epochs=1)
1 change: 0 additions & 1 deletion tests/loggers/test_wandb.py
Expand Up @@ -213,7 +213,6 @@ def test_wandb_log_model(wandb, tmpdir):
'save_top_k': 1,
'save_weights_only': False,
'_every_n_train_steps': 0,
'_every_n_val_epochs': 1
}
}
)
Expand Down