From b13a399152ee5ce6ae8972d40fe43e36e9cba38c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 15 Dec 2021 03:48:13 +0100 Subject: [PATCH 1/4] Add test --- tests/callbacks/test_tqdm_progress_bar.py | 54 +++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index 057d4dc7421bb..8e6d1be6e122f 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -14,6 +14,7 @@ import os import pickle import sys +from collections import defaultdict from typing import Union from unittest import mock from unittest.mock import ANY, call, Mock @@ -595,3 +596,56 @@ def test_tqdm_progress_bar_main_bar_resume(): # restarting mid validation epoch is not currently supported assert bar.val_progress_bar.n == 0 assert bar.val_progress_bar.total == 3 + + +def test_tqdm_progress_bar_correct_value_epoch_end(tmpdir): + class MockedProgressBar(TQDMProgressBar): + calls = defaultdict(list) + + def get_metrics(self, trainer, pl_module): + items = super().get_metrics(trainer, model) + del items["v_num"] + del items["loss"] + # this is equivalent to mocking `set_postfix` as this method gets called every time + self.calls[trainer.state.fn].append( + (trainer.state.stage, trainer.current_epoch, trainer.global_step, items) + ) + return items + + class MyModel(BoringModel): + def training_step(self, batch, batch_idx): + self.log("a", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + self.log("b", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max) + return super().validation_step(batch, batch_idx) + + model = MyModel() + pbar = MockedProgressBar() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + enable_model_summary=False, + log_every_n_steps=1, + callbacks=pbar, + ) + + trainer.fit(model) + assert pbar.calls["fit"] == [ + ("sanity_check", 0, 0, {"b": 0}), + ("train", 0, 0, {}), + ("train", 0, 1, {}), + ("validate", 0, 1, {"b": 1}), # validation end + # epoch end over, `on_epoch=True` metrics are computed + ("train", 0, 2, {"a": 1, "b": 1}), # training epoch end + ("train", 1, 2, {"a": 1, "b": 1}), + ("train", 1, 3, {"a": 1, "b": 1}), + ("validate", 1, 3, {"a": 1, "b": 3}), # validation end + ("train", 1, 4, {"a": 3, "b": 3}), # training epoch end + ] + + trainer.validate(model, verbose=False) + assert pbar.calls["validate"] == [("validate", 1, 4, {"b": 4})] From 80d5d6a59e82d16fe9cb7fe4df096b90dae3159e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 15 Dec 2021 03:50:07 +0100 Subject: [PATCH 2/4] Fix test --- .../callbacks/progress/tqdm_progress.py | 39 +++++++++++-------- pytorch_lightning/core/lightning.py | 2 +- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 9de6770c6c23d..21e462aa1ecb6 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -25,6 +25,7 @@ else: from tqdm import tqdm as _tqdm +import pytorch_lightning as pl from pytorch_lightning.callbacks.progress.base import ProgressBarBase from pytorch_lightning.utilities.distributed import rank_zero_debug @@ -207,12 +208,10 @@ def init_test_tqdm(self) -> Tqdm: return bar def on_sanity_check_start(self, trainer, pl_module): - super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() self.main_progress_bar = Tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, trainer, pl_module): - super().on_sanity_check_end(trainer, pl_module) self.main_progress_bar.close() self.val_progress_bar.close() @@ -234,12 +233,18 @@ def on_train_epoch_start(self, trainer, pl_module): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) - total_batches = self.total_train_batches + self.total_val_batches - total_batches = convert_inf(total_batches) - if self._should_update(self.train_batch_idx, total_batches): + if self._should_update(self.train_batch_idx): self._update_bar(self.main_progress_bar) self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.is_enabled: + self._update_bar(self.main_progress_bar) + self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) + + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self.main_progress_bar.close() + def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) if trainer.sanity_checking: @@ -251,20 +256,19 @@ def on_validation_start(self, trainer, pl_module): def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if self._should_update(self.val_batch_idx, convert_inf(self.total_val_batches)): + if self._should_update(self.val_batch_idx): self._update_bar(self.val_progress_bar) self._update_bar(self.main_progress_bar) + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.is_enabled: + self._update_bar(self.val_progress_bar) + def on_validation_end(self, trainer, pl_module): - super().on_validation_end(trainer, pl_module) if self.main_progress_bar is not None: self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) self.val_progress_bar.close() - def on_train_end(self, trainer, pl_module): - super().on_train_end(trainer, pl_module) - self.main_progress_bar.close() - def on_test_start(self, trainer, pl_module): super().on_test_start(trainer, pl_module) self.test_progress_bar = self.init_test_tqdm() @@ -272,11 +276,14 @@ def on_test_start(self, trainer, pl_module): def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if self._should_update(self.test_batch_idx, self.total_test_batches): + if self._should_update(self.test_batch_idx): + self._update_bar(self.test_progress_bar) + + def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.is_enabled: self._update_bar(self.test_progress_bar) def on_test_end(self, trainer, pl_module): - super().on_test_end(trainer, pl_module) self.test_progress_bar.close() def on_predict_epoch_start(self, trainer, pl_module): @@ -286,7 +293,7 @@ def on_predict_epoch_start(self, trainer, pl_module): def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if self._should_update(self.predict_batch_idx, self.total_predict_batches): + if self._should_update(self.predict_batch_idx): self._update_bar(self.predict_progress_bar) def on_predict_end(self, trainer, pl_module): @@ -310,8 +317,8 @@ def print( s = sep.join(map(str, args)) active_progress_bar.write(s, end=end, file=file, nolock=nolock) - def _should_update(self, current, total) -> bool: - return self.is_enabled and (current % self.refresh_rate == 0 or current == total) + def _should_update(self, idx: int) -> bool: + return self.is_enabled and (idx % self.refresh_rate == 0) def _update_bar(self, bar: Optional[Tqdm]) -> None: """Updates the bar by the refresh rate without overshooting.""" diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 13e12a11f97aa..38010f7acf0a1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1710,7 +1710,7 @@ def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]: r""" .. deprecated:: v1.5 This method was deprecated in v1.5 in favor of - `pytorch_lightning.callbacks.progress.base.get_standard_metrics` and will be removed in v1.7. + `pytorch_lightning.callbacks.progress.base.get_metrics` and will be removed in v1.7. Implement this to override the default items displayed in the progress bar. By default it includes the average loss value, split index of BPTT (if used) From ea82ad1de9272ed6f901935309119450cde78abf Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 15 Dec 2021 04:14:52 +0100 Subject: [PATCH 3/4] Fix bug --- CHANGELOG.md | 4 ++-- pytorch_lightning/callbacks/progress/tqdm_progress.py | 8 +++++--- tests/callbacks/test_tqdm_progress_bar.py | 10 +++++++++- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96a24f3dd1ba8..461c0c7800e07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -270,10 +270,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed support for logging within callbacks returned from `LightningModule` ([#10991](https://github.com/PyTorchLightning/pytorch-lightning/pull/10991)) -- +- The TQDM progress bar now correctly shows the `on_epoch` logged values on train epoch end ([#11069](https://github.com/PyTorchLightning/pytorch-lightning/pull/11069)) -- +- Fixed bug where the TQDM updated the training progress bar during `trainer.validate` ([#11069](https://github.com/PyTorchLightning/pytorch-lightning/pull/11069)) ## [1.5.5] - 2021-12-07 diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 21e462aa1ecb6..ea667c9ddf968 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -250,7 +250,8 @@ def on_validation_start(self, trainer, pl_module): if trainer.sanity_checking: reset(self.val_progress_bar, total=sum(trainer.num_sanity_val_batches), current=self.val_batch_idx) else: - self._update_bar(self.main_progress_bar) # fill up remaining + if trainer.state.fn == pl.trainer.states.TrainerFn.FITTING: + self._update_bar(self.main_progress_bar) # fill up remaining self.val_progress_bar = self.init_validation_tqdm() reset(self.val_progress_bar, total=self.total_val_batches, current=self.val_batch_idx) @@ -258,14 +259,15 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.val_batch_idx): self._update_bar(self.val_progress_bar) - self._update_bar(self.main_progress_bar) + if trainer.state.fn == pl.trainer.states.TrainerFn.FITTING: + self._update_bar(self.main_progress_bar) def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self.is_enabled: self._update_bar(self.val_progress_bar) def on_validation_end(self, trainer, pl_module): - if self.main_progress_bar is not None: + if self.main_progress_bar is not None and trainer.state.fn == pl.trainer.states.TrainerFn.FITTING: self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) self.val_progress_bar.close() diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index 8e6d1be6e122f..fc7bc7fa1cc34 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -621,12 +621,17 @@ def validation_step(self, batch, batch_idx): self.log("b", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max) return super().validation_step(batch, batch_idx) + def test_step(self, batch, batch_idx): + self.log("c", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max) + return super().test_step(batch, batch_idx) + model = MyModel() pbar = MockedProgressBar() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, + limit_test_batches=2, max_epochs=2, enable_model_summary=False, log_every_n_steps=1, @@ -648,4 +653,7 @@ def validation_step(self, batch, batch_idx): ] trainer.validate(model, verbose=False) - assert pbar.calls["validate"] == [("validate", 1, 4, {"b": 4})] + assert pbar.calls["validate"] == [] + + trainer.test(model, verbose=False) + assert pbar.calls["test"] == [] From e3917d5131a5de95e0d41a0a5d54974652d5317d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 15 Dec 2021 16:33:27 +0100 Subject: [PATCH 4/4] Update tests/callbacks/test_tqdm_progress_bar.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- tests/callbacks/test_tqdm_progress_bar.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index fc7bc7fa1cc34..e22bb62126188 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -634,6 +634,7 @@ def test_step(self, batch, batch_idx): limit_test_batches=2, max_epochs=2, enable_model_summary=False, + enable_checkpointing=False, log_every_n_steps=1, callbacks=pbar, )