Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed logging of `LightningModule` and `LightningDataModule` hyperparameters to raise an exception only if there are colliding keys with different values ([#9496](https://github.com/PyTorchLightning/pytorch-lightning/pull/9496))


- Reset metrics before each task starts ([#9410](https://github.com/PyTorchLightning/pytorch-lightning/pull/9410))


### Deprecated

- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
hooks."""
void(*args, **kwargs)

# hook
self._on_evaluation_model_eval()
self.trainer.lightning_module.zero_grad()
Expand Down Expand Up @@ -199,7 +200,7 @@ def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
self.trainer.call_hook("on_validation_end", *args, **kwargs)

# reset any `torchmetrics.Metric` and the logger connector state
self.trainer.logger_connector.reset(metrics=True)
self.trainer.logger_connector.reset_results(metrics=True)

def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks."""
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def reset(self) -> None:
self.epoch_batch_indices = []

def on_run_start(self) -> None:
"""Calls ``on_predict_start`` hook."""
"""Calls ``_on_predict_start`` hook."""
self._on_predict_start()

def advance(self, *args: Any, **kwargs: Any) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,14 +286,15 @@ def should_reset_tensors(self, fx: str) -> bool:
is_first_batch = bool(self._batch_idx) + self._split_idx == 0
return is_different_fx and is_first_batch

def reset(self, metrics: Optional[bool] = None) -> None:
if self.trainer.sanity_checking:
# reset metrics
self._progress_bar_metrics = {}
self._logged_metrics = {}
self._callback_metrics = {}
assert self.trainer._results is not None
self.trainer._results.reset(metrics=metrics)
def reset_metrics(self) -> None:
self._progress_bar_metrics = {}
self._logged_metrics = {}
self._callback_metrics = {}

def reset_results(self, metrics: Optional[bool] = None) -> None:
if self.trainer._results is not None:
self.trainer._results.reset(metrics=metrics)

self._batch_idx = None
self._split_idx = None
self._current_fx = None
Expand Down
14 changes: 12 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,11 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
# ----------------------------
# TRAIN
# ----------------------------

# reset logger connector
self.logger_connector.reset_results()
self.logger_connector.reset_metrics()

# hook
if self.state.fn == TrainerFn.FITTING:
self.call_hook("on_fit_start")
Expand Down Expand Up @@ -1206,6 +1211,10 @@ def _run_sanity_check(self, ref_model):
stage = self.state.stage
self.sanity_checking = True

# reset logger connector
self.logger_connector.reset_results()
self.logger_connector.reset_metrics()

self.call_hook("on_sanity_check_start")

# reload dataloaders
Expand All @@ -1217,8 +1226,9 @@ def _run_sanity_check(self, ref_model):

self.call_hook("on_sanity_check_end")

# reset validation metrics
self.logger_connector.reset()
# reset logger connector
self.logger_connector.reset_results()
self.logger_connector.reset_metrics()

# reset the seed to what it was before sanity check
# prevents sanity check to affect random sampling in training
Expand Down
10 changes: 6 additions & 4 deletions tests/trainer/logging_/test_eval_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,12 @@ def test_step(self, batch, batch_idx):
# hp_metric + 2 steps + epoch + 2 steps + epoch
expected_num_calls = 1 + 2 + 1 + 2 + 1

assert set(trainer.callback_metrics) == {
"train_loss",
"valid_loss_0_epoch",
"valid_loss_0",
"valid_loss_1",
}
assert len(mock_log_metrics.mock_calls) == expected_num_calls
assert mock_log_metrics.mock_calls[0] == call({"hp_metric": -1}, 0)

Expand Down Expand Up @@ -569,10 +575,6 @@ def get_metrics_at_idx(idx):

results = trainer.test(model)
assert set(trainer.callback_metrics) == {
"train_loss",
"valid_loss_0_epoch",
"valid_loss_0",
"valid_loss_1",
"test_loss",
}
assert set(results[0]) == {"test_loss"}
Expand Down
46 changes: 46 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1950,3 +1950,49 @@ def test_error_handling_all_stages(tmpdir, accelerator, num_processes):
) as exception_hook:
trainer.predict(model, model.val_dataloader(), return_predictions=False)
exception_hook.assert_called()


def test_trainer_metrics_reset_before_each_task(tmpdir):
"""Test that callback, logged and progress bar metrics are reset before each task starts."""

class TestMetricRestartCallback(Callback):
def _make_assertions(self, trainer):
assert trainer.callback_metrics == {}
assert trainer.progress_bar_metrics == {}
assert trainer.logged_metrics == {}

def on_train_start(self, trainer, *args, **kwargs):
self._make_assertions(trainer)

def on_validation_start(self, trainer, *args, **kwargs):
if trainer.state.fn == TrainerFn.VALIDATING:
self._make_assertions(trainer)

def on_test_start(self, trainer, *args, **kwargs):
self._make_assertions(trainer)

def on_predict_start(self, trainer, *args, **kwargs):
self._make_assertions(trainer)

class CustomBoringModel(BoringModel):
def __init__(self):
super().__init__()

def training_step(self, *args, **kwargs):
self.log("train/metric", 7.0)
return super().training_step(*args, **kwargs)

def validation_step(self, *args, **kwargs):
self.log("val/metric", 14.0)
return super().validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs):
self.log("test/metric", 21.0)
return super().test_step(*args, **kwargs)

model = CustomBoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=4, callbacks=[TestMetricRestartCallback()])
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
trainer.predict(model)