From 1f16ef844873e258ecdb9366836ba5b55a1feaf8 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 9 Sep 2021 23:17:51 +0530 Subject: [PATCH 1/9] reset metrics --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 5 +++++ pytorch_lightning/loops/dataloader/prediction_loop.py | 5 +++++ pytorch_lightning/loops/fit_loop.py | 5 ++++- .../connectors/logger_connector/logger_connector.py | 10 +++++++--- pytorch_lightning/trainer/trainer.py | 3 +++ 5 files changed, 24 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 69033751ebe4e..a08043cfd4f05 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -94,6 +94,11 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None: self._on_evaluation_start() self._on_evaluation_epoch_start() + def on_advance_start(self) -> None: + """Reset the metrics.""" + # reset metrics + self.trainer.logger_connector.reset_metrics() + def advance(self, *args: Any, **kwargs: Any) -> None: """Performs evaluation on one single dataloader.""" void(*args, **kwargs) diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 212dc5c0e9e96..25ab3739d5e73 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -81,6 +81,11 @@ def on_run_start(self) -> None: """Calls ``on_predict_start`` hook.""" self._on_predict_start() + def on_advance_start(self) -> None: + """Reset the metrics.""" + # reset metrics + self.trainer.logger_connector.reset_metrics() + def advance(self, *args: Any, **kwargs: Any) -> None: """Predicts one entire dataloader.""" void(*args, **kwargs) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 9a4f7c510f303..3a3c8de22c015 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -183,10 +183,13 @@ def on_run_start(self) -> None: self.trainer.call_hook("on_train_start") def on_advance_start(self) -> None: - """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and + """Reset the metrics, prepares the dataloader for training and calls the hooks ``on_epoch_start`` and ``on_train_epoch_start``""" model = self.trainer.lightning_module + # reset metrics + self.trainer.logger_connector.reset_metrics() + # reset train dataloader if self.current_epoch != 0 and self.trainer._should_reload_dl_epoch: self.trainer.reset_train_dataloader(model) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 2e6b6077842f0..eaefbe4323826 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -286,12 +286,16 @@ def should_reset_tensors(self, fx: str) -> bool: is_first_batch = bool(self._batch_idx) + self._split_idx == 0 return is_different_fx and is_first_batch + def reset_metrics(self) -> None: + self._progress_bar_metrics = {} + self._logged_metrics = {} + self._callback_metrics = {} + def reset(self, metrics: Optional[bool] = None) -> None: if self.trainer.sanity_checking: # reset metrics - self._progress_bar_metrics = {} - self._logged_metrics = {} - self._callback_metrics = {} + self.reset_metrics() + assert self.trainer._results is not None self.trainer._results.reset(metrics=metrics) self._batch_idx = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c5516fb4d4f10..2f9759b36eb38 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1200,6 +1200,9 @@ def _run_sanity_check(self, ref_model): stage = self.state.stage self.sanity_checking = True + # reset validation metrics + self.logger_connector.reset() + self.call_hook("on_sanity_check_start") # reload dataloaders From d6ffc49643e0a6f815218c43df37f9ae13be581c Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 11 Sep 2021 21:06:28 +0530 Subject: [PATCH 2/9] update logic --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 10 +++++----- pytorch_lightning/loops/dataloader/prediction_loop.py | 5 ----- pytorch_lightning/loops/fit_loop.py | 8 +++----- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index a08043cfd4f05..23c5a3f072e18 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -19,6 +19,7 @@ from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection +from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -88,17 +89,16 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None: """Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start`` hooks.""" void(*args, **kwargs) + + if self.trainer.state.fn in (TrainerFn.VALIDATING, TrainerFn.TESTING): + self.trainer.logger_connector.reset_metrics() + # hook self._on_evaluation_model_eval() self.trainer.lightning_module.zero_grad() self._on_evaluation_start() self._on_evaluation_epoch_start() - def on_advance_start(self) -> None: - """Reset the metrics.""" - # reset metrics - self.trainer.logger_connector.reset_metrics() - def advance(self, *args: Any, **kwargs: Any) -> None: """Performs evaluation on one single dataloader.""" void(*args, **kwargs) diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 25ab3739d5e73..212dc5c0e9e96 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -81,11 +81,6 @@ def on_run_start(self) -> None: """Calls ``on_predict_start`` hook.""" self._on_predict_start() - def on_advance_start(self) -> None: - """Reset the metrics.""" - # reset metrics - self.trainer.logger_connector.reset_metrics() - def advance(self, *args: Any, **kwargs: Any) -> None: """Predicts one entire dataloader.""" void(*args, **kwargs) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 3a3c8de22c015..038469a7a45f0 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -178,18 +178,16 @@ def reset(self) -> None: self.epoch_progress.current.reset_on_restart() def on_run_start(self) -> None: - """Calls the ``on_train_start`` hook.""" + """Reset the metrics and calls the ``on_train_start`` hook.""" + self.trainer.logger_connector.reset_metrics() self._results.to(device=self.trainer.lightning_module.device) self.trainer.call_hook("on_train_start") def on_advance_start(self) -> None: - """Reset the metrics, prepares the dataloader for training and calls the hooks ``on_epoch_start`` and + """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and ``on_train_epoch_start``""" model = self.trainer.lightning_module - # reset metrics - self.trainer.logger_connector.reset_metrics() - # reset train dataloader if self.current_epoch != 0 and self.trainer._should_reload_dl_epoch: self.trainer.reset_train_dataloader(model) From 40339848843c5d4c68aac7e9e1539e6938898b23 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 11 Sep 2021 21:07:06 +0530 Subject: [PATCH 3/9] add test --- tests/trainer/test_trainer.py | 46 +++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 14d10157c8fa4..99dc0bb4c6a29 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1949,3 +1949,49 @@ def test_error_handling_all_stages(tmpdir, accelerator, num_processes): ) as exception_hook: trainer.predict(model, model.val_dataloader(), return_predictions=False) exception_hook.assert_called() + + +def test_trainer_metrics_reset_before_each_task(tmpdir): + """Test that callback, logged and progress bar metrics are reset before each task starts.""" + + class TestMetricRestartCallback(Callback): + def _make_assertions(self, trainer): + assert trainer.callback_metrics == {} + assert trainer.progress_bar_metrics == {} + assert trainer.logged_metrics == {} + + def on_train_start(self, trainer, *args, **kwargs): + self._make_assertions(trainer) + + def on_validation_start(self, trainer, *args, **kwargs): + if trainer.state.fn == TrainerFn.VALIDATING: + self._make_assertions(trainer) + + def on_test_start(self, trainer, *args, **kwargs): + self._make_assertions(trainer) + + def on_predict_start(self, trainer, *args, **kwargs): + self._make_assertions(trainer) + + class CustomBoringModel(BoringModel): + def __init__(self): + super().__init__() + + def training_step(self, *args, **kwargs): + self.log("train/metric", 7.0) + return super().training_step(*args, **kwargs) + + def validation_step(self, *args, **kwargs): + self.log("val/metric", 14.0) + return super().validation_step(*args, **kwargs) + + def test_step(self, *args, **kwargs): + self.log("test/metric", 21.0) + return super().test_step(*args, **kwargs) + + model = CustomBoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=4, callbacks=[TestMetricRestartCallback()]) + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + trainer.predict(model) From ce1ddc4d798797b2c29c94280ba12629630ef06f Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 11 Sep 2021 21:10:49 +0530 Subject: [PATCH 4/9] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2462ffbdbb773..6bd47ef1a86df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -395,6 +395,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed missing deepspeed distributed call ([#9540](https://github.com/PyTorchLightning/pytorch-lightning/pull/9540)) +- Reset metrics before each task starts ([#9410](https://github.com/PyTorchLightning/pytorch-lightning/pull/9410)) + + ## [1.4.5] - 2021-08-31 - Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142)) From 00c9580d066cc0e333f73472fbd6d5e4f65e92c5 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 11 Sep 2021 22:26:42 +0530 Subject: [PATCH 5/9] update tests --- tests/trainer/logging_/test_eval_loop_logging.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index e8b398bee8872..7b94cfe970cf0 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -536,6 +536,12 @@ def test_step(self, batch, batch_idx): # hp_metric + 2 steps + epoch + 2 steps + epoch expected_num_calls = 1 + 2 + 1 + 2 + 1 + assert set(trainer.callback_metrics) == { + "train_loss", + "valid_loss_0_epoch", + "valid_loss_0", + "valid_loss_1", + } assert len(mock_log_metrics.mock_calls) == expected_num_calls assert mock_log_metrics.mock_calls[0] == call({"hp_metric": -1}, 0) @@ -569,10 +575,6 @@ def get_metrics_at_idx(idx): results = trainer.test(model) assert set(trainer.callback_metrics) == { - "train_loss", - "valid_loss_0_epoch", - "valid_loss_0", - "valid_loss_1", "test_loss", } assert set(results[0]) == {"test_loss"} From 1c75abc7c7d1c0a368c6de85d484db1085c544a3 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 13 Sep 2021 20:11:12 +0530 Subject: [PATCH 6/9] move to run --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 4 ---- pytorch_lightning/loops/dataloader/prediction_loop.py | 2 +- pytorch_lightning/loops/fit_loop.py | 3 +-- .../connectors/logger_connector/logger_connector.py | 10 ++++++---- pytorch_lightning/trainer/trainer.py | 7 +++++-- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 23c5a3f072e18..6338244edb132 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -19,7 +19,6 @@ from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection -from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -90,9 +89,6 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None: hooks.""" void(*args, **kwargs) - if self.trainer.state.fn in (TrainerFn.VALIDATING, TrainerFn.TESTING): - self.trainer.logger_connector.reset_metrics() - # hook self._on_evaluation_model_eval() self.trainer.lightning_module.zero_grad() diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 212dc5c0e9e96..d4a6ab6d29cef 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -78,7 +78,7 @@ def reset(self) -> None: self.epoch_batch_indices = [] def on_run_start(self) -> None: - """Calls ``on_predict_start`` hook.""" + """Calls ``_on_predict_start`` hook.""" self._on_predict_start() def advance(self, *args: Any, **kwargs: Any) -> None: diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 038469a7a45f0..9a4f7c510f303 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -178,8 +178,7 @@ def reset(self) -> None: self.epoch_progress.current.reset_on_restart() def on_run_start(self) -> None: - """Reset the metrics and calls the ``on_train_start`` hook.""" - self.trainer.logger_connector.reset_metrics() + """Calls the ``on_train_start`` hook.""" self._results.to(device=self.trainer.lightning_module.device) self.trainer.call_hook("on_train_start") diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index eaefbe4323826..2b9ae0879c3ca 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -291,13 +291,15 @@ def reset_metrics(self) -> None: self._logged_metrics = {} self._callback_metrics = {} - def reset(self, metrics: Optional[bool] = None) -> None: - if self.trainer.sanity_checking: + def reset(self, metrics: Optional[bool] = None, trainer_metrics: bool = False) -> None: + if trainer_metrics: # reset metrics self.reset_metrics() - assert self.trainer._results is not None - self.trainer._results.reset(metrics=metrics) + if self.trainer.state.fn != TrainerFn.PREDICTING: + assert self.trainer._results is not None + self.trainer._results.reset(metrics=metrics) + self._batch_idx = None self._split_idx = None self._current_fx = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2f9759b36eb38..8f140322cfb45 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1016,6 +1016,9 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # ---------------------------- # TRAIN # ---------------------------- + # reset logger connector + self.logger_connector.reset(trainer_metrics=True) + # hook if self.state.fn == TrainerFn.FITTING: self.call_hook("on_fit_start") @@ -1201,7 +1204,7 @@ def _run_sanity_check(self, ref_model): self.sanity_checking = True # reset validation metrics - self.logger_connector.reset() + self.logger_connector.reset(trainer_metrics=True) self.call_hook("on_sanity_check_start") @@ -1215,7 +1218,7 @@ def _run_sanity_check(self, ref_model): self.call_hook("on_sanity_check_end") # reset validation metrics - self.logger_connector.reset() + self.logger_connector.reset(trainer_metrics=True) # reset the seed to what it was before sanity check # prevents sanity check to affect random sampling in training From d0fc605084f525ed68d03a4d60a742a2b7a26b81 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 16 Sep 2021 00:33:25 +0530 Subject: [PATCH 7/9] review comments --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 2 +- .../connectors/logger_connector/logger_connector.py | 9 ++------- pytorch_lightning/trainer/trainer.py | 9 ++++++--- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 6338244edb132..4f58889b42c4a 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -200,7 +200,7 @@ def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: self.trainer.call_hook("on_validation_end", *args, **kwargs) # reset any `torchmetrics.Metric` and the logger connector state - self.trainer.logger_connector.reset(metrics=True) + self.trainer.logger_connector.reset_results(metrics=True) def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: """Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks.""" diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 2b9ae0879c3ca..ad6a84d64b537 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -291,13 +291,8 @@ def reset_metrics(self) -> None: self._logged_metrics = {} self._callback_metrics = {} - def reset(self, metrics: Optional[bool] = None, trainer_metrics: bool = False) -> None: - if trainer_metrics: - # reset metrics - self.reset_metrics() - - if self.trainer.state.fn != TrainerFn.PREDICTING: - assert self.trainer._results is not None + def reset_results(self, metrics: Optional[bool] = None) -> None: + if self.trainer._results is not None: self.trainer._results.reset(metrics=metrics) self._batch_idx = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8f140322cfb45..eea5f3e279dfc 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1017,7 +1017,8 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # TRAIN # ---------------------------- # reset logger connector - self.logger_connector.reset(trainer_metrics=True) + self.logger_connector.reset_results() + self.logger_connector.reset_metrics() # hook if self.state.fn == TrainerFn.FITTING: @@ -1204,7 +1205,8 @@ def _run_sanity_check(self, ref_model): self.sanity_checking = True # reset validation metrics - self.logger_connector.reset(trainer_metrics=True) + self.logger_connector.reset_results() + self.logger_connector.reset_metrics() self.call_hook("on_sanity_check_start") @@ -1218,7 +1220,8 @@ def _run_sanity_check(self, ref_model): self.call_hook("on_sanity_check_end") # reset validation metrics - self.logger_connector.reset(trainer_metrics=True) + self.logger_connector.reset_results() + self.logger_connector.reset_metrics() # reset the seed to what it was before sanity check # prevents sanity check to affect random sampling in training From 4254f3acf557601935085b79a5fefa131f764d5d Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 16 Sep 2021 00:41:56 +0530 Subject: [PATCH 8/9] move to loop --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 5 +++++ pytorch_lightning/loops/fit_loop.py | 5 ++++- pytorch_lightning/trainer/trainer.py | 3 --- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 4f58889b42c4a..a20fd9ed29527 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -19,6 +19,7 @@ from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection +from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -89,6 +90,10 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None: hooks.""" void(*args, **kwargs) + if self.trainer.state.fn in (TrainerFn.VALIDATING, TrainerFn.TESTING): + self.trainer.logger_connector.reset_results() + self.trainer.logger_connector.reset_metrics() + # hook self._on_evaluation_model_eval() self.trainer.lightning_module.zero_grad() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 9a4f7c510f303..168f947748713 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -178,7 +178,10 @@ def reset(self) -> None: self.epoch_progress.current.reset_on_restart() def on_run_start(self) -> None: - """Calls the ``on_train_start`` hook.""" + """Reset results and metrics, calls the ``on_train_start`` hook.""" + self.trainer.logger_connector.reset_results() + self.trainer.logger_connector.reset_metrics() + self._results.to(device=self.trainer.lightning_module.device) self.trainer.call_hook("on_train_start") diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index eea5f3e279dfc..ad474ddadcf21 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1016,9 +1016,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # ---------------------------- # TRAIN # ---------------------------- - # reset logger connector - self.logger_connector.reset_results() - self.logger_connector.reset_metrics() # hook if self.state.fn == TrainerFn.FITTING: From 5a7d6a3e4b3051d53d7b488126e1a23672df6d89 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 17 Sep 2021 18:58:47 +0530 Subject: [PATCH 9/9] revert back to trainer --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 5 ----- pytorch_lightning/loops/fit_loop.py | 5 +---- pytorch_lightning/trainer/trainer.py | 8 ++++++-- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index a20fd9ed29527..4f58889b42c4a 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -19,7 +19,6 @@ from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection -from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -90,10 +89,6 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None: hooks.""" void(*args, **kwargs) - if self.trainer.state.fn in (TrainerFn.VALIDATING, TrainerFn.TESTING): - self.trainer.logger_connector.reset_results() - self.trainer.logger_connector.reset_metrics() - # hook self._on_evaluation_model_eval() self.trainer.lightning_module.zero_grad() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 168f947748713..9a4f7c510f303 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -178,10 +178,7 @@ def reset(self) -> None: self.epoch_progress.current.reset_on_restart() def on_run_start(self) -> None: - """Reset results and metrics, calls the ``on_train_start`` hook.""" - self.trainer.logger_connector.reset_results() - self.trainer.logger_connector.reset_metrics() - + """Calls the ``on_train_start`` hook.""" self._results.to(device=self.trainer.lightning_module.device) self.trainer.call_hook("on_train_start") diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ad474ddadcf21..6e97e51174dcf 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1017,6 +1017,10 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # TRAIN # ---------------------------- + # reset logger connector + self.logger_connector.reset_results() + self.logger_connector.reset_metrics() + # hook if self.state.fn == TrainerFn.FITTING: self.call_hook("on_fit_start") @@ -1201,7 +1205,7 @@ def _run_sanity_check(self, ref_model): stage = self.state.stage self.sanity_checking = True - # reset validation metrics + # reset logger connector self.logger_connector.reset_results() self.logger_connector.reset_metrics() @@ -1216,7 +1220,7 @@ def _run_sanity_check(self, ref_model): self.call_hook("on_sanity_check_end") - # reset validation metrics + # reset logger connector self.logger_connector.reset_results() self.logger_connector.reset_metrics()