From 6efdd5c076ea9601f885af386f6ea7669d79d365 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 27 Sep 2021 08:33:49 -0400 Subject: [PATCH 1/5] add test --- .../loops/dataloader/evaluation_loop.py | 2 +- .../logging_/test_eval_loop_logging.py | 47 +++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index b3a8ef764efc6..d1d1248d0b626 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -214,7 +214,7 @@ def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: self.trainer.call_hook("on_validation_end", *args, **kwargs) # reset any `torchmetrics.Metric` and the logger connector state - self.trainer.logger_connector.reset_results(metrics=True) + self.trainer.logger_connector.reset_results() def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: """Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks.""" diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index c063d500e5555..7b5109b3c75d9 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -602,3 +602,50 @@ def validation_step(self, batch, batch_idx): ) trainer.fit(model) + + +def test_multiple_dataloader_reset(tmpdir): + class TestModel(BoringModel): + def validation_step(self, batch, batch_idx, dataloader_idx): + value = (1 + batch_idx) * (2 if dataloader_idx == 1 else 1) + if self.current_epoch != 0: + value *= 10 + self.log("val_loss", value, on_step=False, on_epoch=True, prog_bar=True, logger=True) + return value + + def validation_epoch_end(self, outputs): + if self.current_epoch == 0: + assert sum(outputs[0]) / 5 == 3 + assert sum(outputs[1]) / 5 == 6 + else: + assert sum(outputs[0]) / 5 == 30 + assert sum(outputs[1]) / 5 == 60 + + tot_loss = torch.tensor(0.0) + for loss in outputs: + tot_loss += sum(loss) / len(loss) + tot_loss = tot_loss / len(outputs) + if self.current_epoch == 0: + assert tot_loss == (3 + 6) / 2 + else: + assert tot_loss == (30 + 60) / 2 + self.log("tot_val_loss", tot_loss, prog_bar=True, logger=True) + assert self.trainer._results["validation_step.val_loss.0"].cumulated_batch_size == 5 + + def configure_optimizers(self): + return torch.optim.SGD(self.layer.parameters(), lr=0.1) + + def val_dataloader(self): + return [super().val_dataloader(), super().val_dataloader()] + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=5, + num_sanity_val_steps=0, + max_epochs=3, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) From f8ce4ab5aa7ab27bca7865accbb45038baee2c4f Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 27 Sep 2021 08:56:27 -0400 Subject: [PATCH 2/5] improve test --- .../logging_/test_eval_loop_logging.py | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 7b5109b3c75d9..dcf7d71b3255a 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -604,8 +604,25 @@ def validation_step(self, batch, batch_idx): trainer.fit(model) -def test_multiple_dataloader_reset(tmpdir): +@pytest.mark.parametrize("val_check_interval", [0.5, 1.0]) +def test_multiple_dataloader_reset(val_check_interval, tmpdir): class TestModel(BoringModel): + def training_step(self, batch, batch_idx): + out = super().training_step(batch, batch_idx) + value = 1 + batch_idx + if self.current_epoch != 0: + value *= 10 + self.log("batch_idx", value, on_step=True, on_epoch=True, prog_bar=True) + return out + + def training_epoch_end(self, outputs): + if val_check_interval == 1.0: + metrics = self.trainer.progress_bar_metrics + if self.current_epoch == 0: + assert metrics["batch_idx_epoch"] == (15 / 5.0) + else: + assert metrics["batch_idx_epoch"] == (150 / 5.0) + def validation_step(self, batch, batch_idx, dataloader_idx): value = (1 + batch_idx) * (2 if dataloader_idx == 1 else 1) if self.current_epoch != 0: @@ -631,6 +648,7 @@ def validation_epoch_end(self, outputs): assert tot_loss == (30 + 60) / 2 self.log("tot_val_loss", tot_loss, prog_bar=True, logger=True) assert self.trainer._results["validation_step.val_loss.0"].cumulated_batch_size == 5 + assert self.trainer._results["validation_step.val_loss.1"].cumulated_batch_size == 5 def configure_optimizers(self): return torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -641,9 +659,10 @@ def val_dataloader(self): model = TestModel() trainer = Trainer( default_root_dir=tmpdir, - limit_train_batches=1, + limit_train_batches=5, limit_val_batches=5, num_sanity_val_steps=0, + val_check_interval=val_check_interval, max_epochs=3, log_every_n_steps=1, weights_summary=None, From 404cd808a42af92f9bd527ad96731eec92705e77 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 27 Sep 2021 08:57:35 -0400 Subject: [PATCH 3/5] add changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 45389685645a9..00f12b33b1aae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -417,6 +417,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704)) +- Fixed `reset` metrics on validation epoch end ([#9717](https://github.com/PyTorchLightning/pytorch-lightning/pull/9717)) + + + ## [1.4.8] - 2021-09-22 - Fixed error reporting in DDP process reconciliation when processes are launched by an external agent ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389)) From 20047a0ceea4d7aed4ee439fcfd7db690aef412c Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 27 Sep 2021 10:51:49 -0400 Subject: [PATCH 4/5] update on comments --- .../loops/dataloader/evaluation_loop.py | 2 +- .../logger_connector/logger_connector.py | 4 ++-- .../logging_/test_eval_loop_logging.py | 19 +++++-------------- 3 files changed, 8 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index d1d1248d0b626..430df2af94046 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -213,7 +213,7 @@ def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: else: self.trainer.call_hook("on_validation_end", *args, **kwargs) - # reset any `torchmetrics.Metric` and the logger connector state + # reset the logger connector state self.trainer.logger_connector.reset_results() def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index ad6a84d64b537..85323e92dc7e5 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -291,9 +291,9 @@ def reset_metrics(self) -> None: self._logged_metrics = {} self._callback_metrics = {} - def reset_results(self, metrics: Optional[bool] = None) -> None: + def reset_results(self) -> None: if self.trainer._results is not None: - self.trainer._results.reset(metrics=metrics) + self.trainer._results.reset() self._batch_idx = None self._split_idx = None diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index dcf7d71b3255a..f082d828e6378 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -605,7 +605,7 @@ def validation_step(self, batch, batch_idx): @pytest.mark.parametrize("val_check_interval", [0.5, 1.0]) -def test_multiple_dataloader_reset(val_check_interval, tmpdir): +def test_multiple_dataloaders_reset(val_check_interval, tmpdir): class TestModel(BoringModel): def training_step(self, batch, batch_idx): out = super().training_step(batch, batch_idx) @@ -618,13 +618,11 @@ def training_step(self, batch, batch_idx): def training_epoch_end(self, outputs): if val_check_interval == 1.0: metrics = self.trainer.progress_bar_metrics - if self.current_epoch == 0: - assert metrics["batch_idx_epoch"] == (15 / 5.0) - else: - assert metrics["batch_idx_epoch"] == (150 / 5.0) + v = 15 if self.current_epoch == 0 else 150 + assert metrics["batch_idx_epoch"] == (v / 5.0) def validation_step(self, batch, batch_idx, dataloader_idx): - value = (1 + batch_idx) * (2 if dataloader_idx == 1 else 1) + value = (1 + batch_idx) * (1 + dataloader_idx) if self.current_epoch != 0: value *= 10 self.log("val_loss", value, on_step=False, on_epoch=True, prog_bar=True, logger=True) @@ -638,21 +636,14 @@ def validation_epoch_end(self, outputs): assert sum(outputs[0]) / 5 == 30 assert sum(outputs[1]) / 5 == 60 - tot_loss = torch.tensor(0.0) - for loss in outputs: - tot_loss += sum(loss) / len(loss) - tot_loss = tot_loss / len(outputs) + tot_loss = torch.mean(torch.tensor(outputs, dtype=torch.float)) if self.current_epoch == 0: assert tot_loss == (3 + 6) / 2 else: assert tot_loss == (30 + 60) / 2 - self.log("tot_val_loss", tot_loss, prog_bar=True, logger=True) assert self.trainer._results["validation_step.val_loss.0"].cumulated_batch_size == 5 assert self.trainer._results["validation_step.val_loss.1"].cumulated_batch_size == 5 - def configure_optimizers(self): - return torch.optim.SGD(self.layer.parameters(), lr=0.1) - def val_dataloader(self): return [super().val_dataloader(), super().val_dataloader()] From d08e2d4375c2dbcbfa586332311f7e057cd1ead9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 27 Sep 2021 17:22:53 +0200 Subject: [PATCH 5/5] Unnecessary if --- tests/trainer/logging_/test_eval_loop_logging.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index f082d828e6378..99145972f3561 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -616,10 +616,9 @@ def training_step(self, batch, batch_idx): return out def training_epoch_end(self, outputs): - if val_check_interval == 1.0: - metrics = self.trainer.progress_bar_metrics - v = 15 if self.current_epoch == 0 else 150 - assert metrics["batch_idx_epoch"] == (v / 5.0) + metrics = self.trainer.progress_bar_metrics + v = 15 if self.current_epoch == 0 else 150 + assert metrics["batch_idx_epoch"] == (v / 5.0) def validation_step(self, batch, batch_idx, dataloader_idx): value = (1 + batch_idx) * (1 + dataloader_idx)