Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate on_batch_start/on_batch_end callback hooks #11577

Merged
merged 11 commits into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated function `pytorch_lightning.callbacks.device_stats_monitor.prefix_metric_keys` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254))


- Deprecated `on_batch_start` and `on_batch_end` callback hooks in favor of `on_train_batch_start` and `on_train_batch_end` ([#11577](https://github.com/PyTorchLightning/pytorch-lightning/pull/11577))


- Deprecated `on_configure_sharded_model` callback hook in favor of `setup` ([#11627](https://github.com/PyTorchLightning/pytorch-lightning/pull/11627))


Expand Down
4 changes: 2 additions & 2 deletions docs/source/advanced/profiler.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ PyTorch Lightning supports profiling standard actions in the training loop out o

- on_epoch_start
- on_epoch_end
- on_batch_start
- on_train_batch_start
- model_forward
- model_backward
- on_after_backward
- optimizer_step
- on_batch_end
- on_train_batch_end
- training_step_end
- on_training_end
- etc...
Expand Down
12 changes: 0 additions & 12 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -342,18 +342,6 @@ on_epoch_end
.. automethod:: pytorch_lightning.callbacks.Callback.on_epoch_end
:noindex:

on_batch_start
~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.callbacks.Callback.on_batch_start
:noindex:

on_batch_end
~~~~~~~~~~~~

.. automethod:: pytorch_lightning.callbacks.Callback.on_batch_end
:noindex:

on_validation_batch_start
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
16 changes: 14 additions & 2 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,22 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
"""Called when either of train/val/test epoch ends."""

def on_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the training batch begins."""
r"""
.. deprecated:: v1.6
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
``on_train_batch_start`` instead.

Called when the training batch begins.
"""

def on_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the training batch ends."""
r"""
.. deprecated:: v1.6
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
``on_train_batch_end`` instead.

Called when the training batch ends.
"""

def on_validation_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,7 +1674,7 @@ def tbptt_split_batch(self, batch, split_size):

Note:
Called in the training loop after
:meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start`
:meth:`~pytorch_lightning.callbacks.base.Callback.on_train_batch_start`
if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
Each returned batch split is passed separately to :meth:`training_step`.
"""
Expand Down
9 changes: 3 additions & 6 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,9 @@ def _on_evaluation_batch_start(self, **kwargs: Any) -> None:
self.trainer.logger_connector.on_batch_start(**kwargs)

kwargs.setdefault("dataloader_idx", 0) # TODO: the argument should be keyword for these
if self.trainer.testing:
self.trainer._call_callback_hooks("on_test_batch_start", *kwargs.values())
self.trainer._call_lightning_module_hook("on_test_batch_start", *kwargs.values())
else:
self.trainer._call_callback_hooks("on_validation_batch_start", *kwargs.values())
self.trainer._call_lightning_module_hook("on_validation_batch_start", *kwargs.values())
hook_name = "on_test_batch_start" if self.trainer.testing else "on_validation_batch_start"
self.trainer._call_callback_hooks(hook_name, *kwargs.values())
self.trainer._call_lightning_module_hook(hook_name, *kwargs.values())

def _on_evaluation_batch_end(self, output: Optional[STEP_OUTPUT], **kwargs: Any) -> None:
"""The ``on_{validation/test}_batch_end`` hook.
Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
_check_on_init_start_end(trainer)
# TODO: Delete _check_on_hpc_hooks in v1.8
_check_on_hpc_hooks(model)
# TODO: Delete on_batch_start/on_batch_end hooks in v1.8
_check_on_batch_start_end(trainer, model)
# TODO: Remove this in v1.8
_check_on_configure_sharded_model(trainer)

Expand Down Expand Up @@ -326,6 +328,19 @@ def _check_on_hpc_hooks(model: "pl.LightningModule") -> None:
)


# TODO: Remove on_batch_start/on_batch_end hooks in v1.8
def _check_on_batch_start_end(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
hooks = (["on_batch_start", "on_train_batch_start"], ["on_batch_end", "on_train_batch_end"])

for hook, alternative_hook in hooks:
for callback in trainer.callbacks:
if is_overridden(method_name=hook, instance=callback):
rank_zero_deprecation(
f"The `Callback.{hook}` hook was deprecated in v1.6 and"
f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead."
)


def _check_on_configure_sharded_model(trainer: "pl.Trainer") -> None:
for callback in trainer.callbacks:
if is_overridden(method_name="on_configure_sharded_model", instance=callback):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def __init__(
self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.progress_bar = None

def on_batch_start(self, trainer, pl_module):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
"""Called before each training batch, logs the lr that will be used."""
if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return
Expand Down
31 changes: 31 additions & 0 deletions tests/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,37 @@ def test_v1_8_0_deprecated_lightning_optimizers():
assert trainer.lightning_optimizers == {}


def test_v1_8_0_remove_on_batch_start_end(tmpdir):
class TestCallback(Callback):
def on_batch_start(self, *args, **kwargs):
print("on_batch_start")

model = BoringModel()
trainer = Trainer(
callbacks=[TestCallback()],
fast_dev_run=True,
default_root_dir=tmpdir,
)
with pytest.deprecated_call(
match="The `Callback.on_batch_start` hook was deprecated in v1.6 and will be removed in v1.8"
):
trainer.fit(model)

class TestCallback(Callback):
def on_batch_end(self, *args, **kwargs):
print("on_batch_end")

trainer = Trainer(
callbacks=[TestCallback()],
fast_dev_run=True,
default_root_dir=tmpdir,
)
with pytest.deprecated_call(
match="The `Callback.on_batch_end` hook was deprecated in v1.6 and will be removed in v1.8"
):
trainer.fit(model)


def test_v1_8_0_on_configure_sharded_model(tmpdir):
class TestCallback(Callback):
def on_configure_sharded_model(self, trainer, model):
Expand Down
24 changes: 4 additions & 20 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,16 +288,6 @@ def on_train_epoch_start(self, _, pl_module):
pl_module, "on_train_epoch_start", on_steps=[False], on_epochs=[True], prob_bars=self.choices
)

def on_batch_start(self, _, pl_module, *__):
self.make_logging(
pl_module, "on_batch_start", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
)

def on_batch_end(self, _, pl_module):
self.make_logging(
pl_module, "on_batch_end", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
)

def on_train_batch_start(self, _, pl_module, *__):
self.make_logging(
pl_module, "on_train_batch_start", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
Expand Down Expand Up @@ -347,8 +337,6 @@ def training_step(self, batch, batch_idx):
"on_train_epoch_start": 1,
"on_train_batch_start": 2,
"on_train_batch_end": 2,
"on_batch_start": 2,
"on_batch_end": 2,
"on_train_epoch_end": 1,
"on_epoch_end": 1,
}
Expand Down Expand Up @@ -530,14 +518,11 @@ def on_train_epoch_start(self, trainer, pl_module):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.log("on_train_batch_end", 3)

def on_batch_end(self, trainer, pl_module):
self.log("on_batch_end", 4)

def on_epoch_end(self, trainer, pl_module):
self.log("on_epoch_end", 5)
self.log("on_epoch_end", 4)

def on_train_epoch_end(self, trainer, pl_module):
self.log("on_train_epoch_end", 6)
self.log("on_train_epoch_end", 5)

model = BoringModel()
trainer = Trainer(
Expand All @@ -554,9 +539,8 @@ def on_train_epoch_end(self, trainer, pl_module):
"on_train_start": 1,
"on_train_epoch_start": 2,
"on_train_batch_end": 3,
"on_batch_end": 4,
"on_epoch_end": 5,
"on_train_epoch_end": 6,
"on_epoch_end": 4,
"on_train_epoch_end": 5,
}
assert trainer.callback_metrics == expected

Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params):
model = BoringModel()

class InterruptCallback(Callback):
def on_batch_start(self, trainer, pl_module):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
raise KeyboardInterrupt

trainer = Trainer(callbacks=[InterruptCallback()], default_root_dir=tmpdir, **extra_params)
Expand Down