diff --git a/CHANGELOG.md b/CHANGELOG.md index 96a24f3dd1ba8..461c0c7800e07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -270,10 +270,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed support for logging within callbacks returned from `LightningModule` ([#10991](https://github.com/PyTorchLightning/pytorch-lightning/pull/10991)) -- +- The TQDM progress bar now correctly shows the `on_epoch` logged values on train epoch end ([#11069](https://github.com/PyTorchLightning/pytorch-lightning/pull/11069)) -- +- Fixed bug where the TQDM updated the training progress bar during `trainer.validate` ([#11069](https://github.com/PyTorchLightning/pytorch-lightning/pull/11069)) ## [1.5.5] - 2021-12-07 diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 9de6770c6c23d..ea667c9ddf968 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -25,6 +25,7 @@ else: from tqdm import tqdm as _tqdm +import pytorch_lightning as pl from pytorch_lightning.callbacks.progress.base import ProgressBarBase from pytorch_lightning.utilities.distributed import rank_zero_debug @@ -207,12 +208,10 @@ def init_test_tqdm(self) -> Tqdm: return bar def on_sanity_check_start(self, trainer, pl_module): - super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() self.main_progress_bar = Tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, trainer, pl_module): - super().on_sanity_check_end(trainer, pl_module) self.main_progress_bar.close() self.val_progress_bar.close() @@ -234,37 +233,44 @@ def on_train_epoch_start(self, trainer, pl_module): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) - total_batches = self.total_train_batches + self.total_val_batches - total_batches = convert_inf(total_batches) - if self._should_update(self.train_batch_idx, total_batches): + if self._should_update(self.train_batch_idx): self._update_bar(self.main_progress_bar) self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.is_enabled: + self._update_bar(self.main_progress_bar) + self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) + + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self.main_progress_bar.close() + def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) if trainer.sanity_checking: reset(self.val_progress_bar, total=sum(trainer.num_sanity_val_batches), current=self.val_batch_idx) else: - self._update_bar(self.main_progress_bar) # fill up remaining + if trainer.state.fn == pl.trainer.states.TrainerFn.FITTING: + self._update_bar(self.main_progress_bar) # fill up remaining self.val_progress_bar = self.init_validation_tqdm() reset(self.val_progress_bar, total=self.total_val_batches, current=self.val_batch_idx) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if self._should_update(self.val_batch_idx, convert_inf(self.total_val_batches)): + if self._should_update(self.val_batch_idx): + self._update_bar(self.val_progress_bar) + if trainer.state.fn == pl.trainer.states.TrainerFn.FITTING: + self._update_bar(self.main_progress_bar) + + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.is_enabled: self._update_bar(self.val_progress_bar) - self._update_bar(self.main_progress_bar) def on_validation_end(self, trainer, pl_module): - super().on_validation_end(trainer, pl_module) - if self.main_progress_bar is not None: + if self.main_progress_bar is not None and trainer.state.fn == pl.trainer.states.TrainerFn.FITTING: self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) self.val_progress_bar.close() - def on_train_end(self, trainer, pl_module): - super().on_train_end(trainer, pl_module) - self.main_progress_bar.close() - def on_test_start(self, trainer, pl_module): super().on_test_start(trainer, pl_module) self.test_progress_bar = self.init_test_tqdm() @@ -272,11 +278,14 @@ def on_test_start(self, trainer, pl_module): def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if self._should_update(self.test_batch_idx, self.total_test_batches): + if self._should_update(self.test_batch_idx): + self._update_bar(self.test_progress_bar) + + def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.is_enabled: self._update_bar(self.test_progress_bar) def on_test_end(self, trainer, pl_module): - super().on_test_end(trainer, pl_module) self.test_progress_bar.close() def on_predict_epoch_start(self, trainer, pl_module): @@ -286,7 +295,7 @@ def on_predict_epoch_start(self, trainer, pl_module): def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if self._should_update(self.predict_batch_idx, self.total_predict_batches): + if self._should_update(self.predict_batch_idx): self._update_bar(self.predict_progress_bar) def on_predict_end(self, trainer, pl_module): @@ -310,8 +319,8 @@ def print( s = sep.join(map(str, args)) active_progress_bar.write(s, end=end, file=file, nolock=nolock) - def _should_update(self, current, total) -> bool: - return self.is_enabled and (current % self.refresh_rate == 0 or current == total) + def _should_update(self, idx: int) -> bool: + return self.is_enabled and (idx % self.refresh_rate == 0) def _update_bar(self, bar: Optional[Tqdm]) -> None: """Updates the bar by the refresh rate without overshooting.""" diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 13e12a11f97aa..38010f7acf0a1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1710,7 +1710,7 @@ def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]: r""" .. deprecated:: v1.5 This method was deprecated in v1.5 in favor of - `pytorch_lightning.callbacks.progress.base.get_standard_metrics` and will be removed in v1.7. + `pytorch_lightning.callbacks.progress.base.get_metrics` and will be removed in v1.7. Implement this to override the default items displayed in the progress bar. By default it includes the average loss value, split index of BPTT (if used) diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index 057d4dc7421bb..e22bb62126188 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -14,6 +14,7 @@ import os import pickle import sys +from collections import defaultdict from typing import Union from unittest import mock from unittest.mock import ANY, call, Mock @@ -595,3 +596,65 @@ def test_tqdm_progress_bar_main_bar_resume(): # restarting mid validation epoch is not currently supported assert bar.val_progress_bar.n == 0 assert bar.val_progress_bar.total == 3 + + +def test_tqdm_progress_bar_correct_value_epoch_end(tmpdir): + class MockedProgressBar(TQDMProgressBar): + calls = defaultdict(list) + + def get_metrics(self, trainer, pl_module): + items = super().get_metrics(trainer, model) + del items["v_num"] + del items["loss"] + # this is equivalent to mocking `set_postfix` as this method gets called every time + self.calls[trainer.state.fn].append( + (trainer.state.stage, trainer.current_epoch, trainer.global_step, items) + ) + return items + + class MyModel(BoringModel): + def training_step(self, batch, batch_idx): + self.log("a", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + self.log("b", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max) + return super().validation_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + self.log("c", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max) + return super().test_step(batch, batch_idx) + + model = MyModel() + pbar = MockedProgressBar() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + max_epochs=2, + enable_model_summary=False, + enable_checkpointing=False, + log_every_n_steps=1, + callbacks=pbar, + ) + + trainer.fit(model) + assert pbar.calls["fit"] == [ + ("sanity_check", 0, 0, {"b": 0}), + ("train", 0, 0, {}), + ("train", 0, 1, {}), + ("validate", 0, 1, {"b": 1}), # validation end + # epoch end over, `on_epoch=True` metrics are computed + ("train", 0, 2, {"a": 1, "b": 1}), # training epoch end + ("train", 1, 2, {"a": 1, "b": 1}), + ("train", 1, 3, {"a": 1, "b": 1}), + ("validate", 1, 3, {"a": 1, "b": 3}), # validation end + ("train", 1, 4, {"a": 3, "b": 3}), # training epoch end + ] + + trainer.validate(model, verbose=False) + assert pbar.calls["validate"] == [] + + trainer.test(model, verbose=False) + assert pbar.calls["test"] == []