From 76676bc7facd3797173596027cfddb1ceb4e9eb1 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 2 Nov 2021 12:46:27 +0530 Subject: [PATCH 1/7] Add display_every_n_epochs argument to RichProgressBar --- .../callbacks/progress/rich_progress.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index d684f8a7e38ed..3964c70b218d6 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -18,6 +18,7 @@ from pytorch_lightning.callbacks.progress.base import ProgressBarBase from pytorch_lightning.utilities import _RICH_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException Task, Style = None, None if _RICH_AVAILABLE: @@ -195,6 +196,7 @@ class RichProgressBar(ProgressBarBase): Args: refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled. + display_every_n_epochs: Set to a non-negative integer to display progress bar every n epochs. theme: Contains styles used to stylize the progress bar. Raises: @@ -205,6 +207,7 @@ class RichProgressBar(ProgressBarBase): def __init__( self, refresh_rate_per_second: int = 10, + display_every_n_epochs: Optional[int] = None, theme: RichProgressBarTheme = RichProgressBarTheme(), ) -> None: if not _RICH_AVAILABLE: @@ -213,6 +216,11 @@ def __init__( ) super().__init__() self._refresh_rate_per_second: int = refresh_rate_per_second + if display_every_n_epochs and (not isinstance(display_every_n_epochs, int) or (display_every_n_epochs < 0)): + raise MisconfigurationException( + f"`display_every_n_epochs` should be an int >= 0, got {display_every_n_epochs}." + ) + self._display_every_n_epochs: Optional[int] = display_every_n_epochs self._enabled: bool = True self.progress: Optional[Progress] = None self.val_sanity_progress_bar_id: Optional[int] = None @@ -323,9 +331,19 @@ def on_train_epoch_start(self, trainer, pl_module): total_batches = total_train_batches + total_val_batches train_description = self._get_train_description(trainer.current_epoch) + if ( + self.main_progress_bar_id is not None + and self._display_every_n_epochs + and (not trainer.current_epoch % self._display_every_n_epochs) + ): + self._stop_progress() + self._init_progress(trainer, pl_module) if self.main_progress_bar_id is None: self.main_progress_bar_id = self._add_task(total_batches, train_description) - self.progress.reset(self.main_progress_bar_id, total=total_batches, description=train_description) + else: + self.progress.reset( + self.main_progress_bar_id, total=total_batches, description=train_description, visible=True + ) def on_validation_epoch_start(self, trainer, pl_module): super().on_validation_epoch_start(trainer, pl_module) From 97c27ac692e321e7b11895c6a72eae948dabbc49 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 2 Nov 2021 13:21:56 +0530 Subject: [PATCH 2/7] Add tests --- CHANGELOG.md | 1 + tests/callbacks/test_rich_progress_bar.py | 29 ++++++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 209a8a4671028..593b2b607d7c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -130,6 +130,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for empty `gpus` list to run on CPU ([#10246](https://github.com/PyTorchLightning/pytorch-lightning/pull/10246)) - Added `configure_columns` method to `RichProgressBar` ([#10288](https://github.com/PyTorchLightning/pytorch-lightning/pull/10288)) - Added a warning if multiple batch sizes are found from ambiguous batch ([#10247](https://github.com/PyTorchLightning/pytorch-lightning/pull/10247)) +- Added `display_every_n_epochs` argument to `RichProgressBar` ([#10301](https://github.com/PyTorchLightning/pytorch-lightning/pull/10301)) ### Changed diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index ab0852c4729af..db3304cb933b2 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -20,6 +20,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _RICH_AVAILABLE from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset from tests.helpers.runif import RunIf @@ -144,7 +145,7 @@ def on_train_start(self) -> None: @RunIf(rich=True) -def test_rich_progress_bar_configure_columns(tmpdir): +def test_rich_progress_bar_configure_columns(): from rich.progress import TextColumn custom_column = TextColumn("[progress.description]Testing Rich!") @@ -159,3 +160,29 @@ def configure_columns(self, trainer, pl_module): assert progress_bar.progress.columns[0] == custom_column assert len(progress_bar.progress.columns) == 1 + + +@RunIf(rich=True) +def test_rich_progress_bar_display_every_n_epochs(tmpdir): + + model = BoringModel() + + with mock.patch( + "pytorch_lightning.callbacks.progress.rich_progress.Progress.stop", autospec=True + ) as mock_progress_stop: + progress_bar = RichProgressBar(display_every_n_epochs=2) + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + limit_train_batches=1, + max_epochs=6, + callbacks=progress_bar, + ) + trainer.fit(model) + assert mock_progress_stop.call_count == 3 + + +@RunIf(rich=True) +def test_rich_progress_bar_display_every_n_steps_misconfiguration(): + with pytest.raises(MisconfigurationException, match=r"`display_every_n_epochs` should be an int >= 0"): + Trainer(callbacks=RichProgressBar(display_every_n_epochs="test")) From 3a9a5d1fbb0ab394fd5f9698ffe02c1cdf9a83cc Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 2 Nov 2021 14:45:39 +0530 Subject: [PATCH 3/7] Update test --- tests/callbacks/test_rich_progress_bar.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index db3304cb933b2..a3e5663a090de 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -163,14 +163,15 @@ def configure_columns(self, trainer, pl_module): @RunIf(rich=True) -def test_rich_progress_bar_display_every_n_epochs(tmpdir): +@pytest.mark.parametrize(("display_every_n_epochs", "reset_call_count"), ([(None, 5), (1, 0), (2, 3), (3, 4)])) +def test_rich_progress_bar_display_every_n_epochs(tmpdir, display_every_n_epochs, reset_call_count): model = BoringModel() with mock.patch( - "pytorch_lightning.callbacks.progress.rich_progress.Progress.stop", autospec=True - ) as mock_progress_stop: - progress_bar = RichProgressBar(display_every_n_epochs=2) + "pytorch_lightning.callbacks.progress.rich_progress.Progress.reset", autospec=True + ) as mock_progress_reset: + progress_bar = RichProgressBar(display_every_n_epochs=display_every_n_epochs) trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, @@ -179,7 +180,7 @@ def test_rich_progress_bar_display_every_n_epochs(tmpdir): callbacks=progress_bar, ) trainer.fit(model) - assert mock_progress_stop.call_count == 3 + assert mock_progress_reset.call_count == reset_call_count @RunIf(rich=True) From b454fe2ed45b94477330af780634e8fc7263f43d Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 2 Nov 2021 14:47:58 +0530 Subject: [PATCH 4/7] Update test --- tests/callbacks/test_rich_progress_bar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index a3e5663a090de..9c167a71835bf 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -165,7 +165,7 @@ def configure_columns(self, trainer, pl_module): @RunIf(rich=True) @pytest.mark.parametrize(("display_every_n_epochs", "reset_call_count"), ([(None, 5), (1, 0), (2, 3), (3, 4)])) def test_rich_progress_bar_display_every_n_epochs(tmpdir, display_every_n_epochs, reset_call_count): - + # Calling `reset` means continuing on the same progress bar. model = BoringModel() with mock.patch( From 428b5e04b503bae4f89b30561567b36870db600a Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 2 Nov 2021 18:41:51 +0530 Subject: [PATCH 5/7] Update changelog --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 593b2b607d7c0..5c0ca7b4f5f5f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -83,6 +83,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added Rich progress bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929), [#9559](https://github.com/PyTorchLightning/pytorch-lightning/pull/9559)) * Added Support for iterable datasets ([#9734](https://github.com/PyTorchLightning/pytorch-lightning/pull/9734)) * Added `RichModelSummary` callback ([#9546](https://github.com/PyTorchLightning/pytorch-lightning/pull/9546)) + * Added `configure_columns` method to `RichProgressBar` ([#10288](https://github.com/PyTorchLightning/pytorch-lightning/pull/10288)) + * Added `display_every_n_epochs` argument to `RichProgressBar` ([#10301](https://github.com/PyTorchLightning/pytorch-lightning/pull/10301)) - Added input validation logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080)) - Added support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084)) - Added `on_exception` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183)) @@ -128,9 +130,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `devices="auto"` ([#10264](https://github.com/PyTorchLightning/pytorch-lightning/pull/10264)) - Added a `filename` argument in `ModelCheckpoint.format_checkpoint_name` ([#9818](https://github.com/PyTorchLightning/pytorch-lightning/pull/9818)) - Added support for empty `gpus` list to run on CPU ([#10246](https://github.com/PyTorchLightning/pytorch-lightning/pull/10246)) -- Added `configure_columns` method to `RichProgressBar` ([#10288](https://github.com/PyTorchLightning/pytorch-lightning/pull/10288)) - Added a warning if multiple batch sizes are found from ambiguous batch ([#10247](https://github.com/PyTorchLightning/pytorch-lightning/pull/10247)) -- Added `display_every_n_epochs` argument to `RichProgressBar` ([#10301](https://github.com/PyTorchLightning/pytorch-lightning/pull/10301)) ### Changed From 437e33ab8573b25d78eaa616d45df83f78e29191 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 2 Nov 2021 21:38:21 +0530 Subject: [PATCH 6/7] use leave argument instead --- CHANGELOG.md | 2 +- .../callbacks/progress/rich_progress.py | 17 ++++------------- tests/callbacks/test_rich_progress_bar.py | 13 +++---------- 3 files changed, 8 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c0ca7b4f5f5f..94057eac37d6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -84,7 +84,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added Support for iterable datasets ([#9734](https://github.com/PyTorchLightning/pytorch-lightning/pull/9734)) * Added `RichModelSummary` callback ([#9546](https://github.com/PyTorchLightning/pytorch-lightning/pull/9546)) * Added `configure_columns` method to `RichProgressBar` ([#10288](https://github.com/PyTorchLightning/pytorch-lightning/pull/10288)) - * Added `display_every_n_epochs` argument to `RichProgressBar` ([#10301](https://github.com/PyTorchLightning/pytorch-lightning/pull/10301)) + * Added `leave` argument to `RichProgressBar` ([#10301](https://github.com/PyTorchLightning/pytorch-lightning/pull/10301)) - Added input validation logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080)) - Added support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084)) - Added `on_exception` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183)) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 3964c70b218d6..66c4161ba68f4 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -18,7 +18,6 @@ from pytorch_lightning.callbacks.progress.base import ProgressBarBase from pytorch_lightning.utilities import _RICH_AVAILABLE -from pytorch_lightning.utilities.exceptions import MisconfigurationException Task, Style = None, None if _RICH_AVAILABLE: @@ -196,7 +195,7 @@ class RichProgressBar(ProgressBarBase): Args: refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled. - display_every_n_epochs: Set to a non-negative integer to display progress bar every n epochs. + leave: Displays progress bar per epoch. Default: False theme: Contains styles used to stylize the progress bar. Raises: @@ -207,7 +206,7 @@ class RichProgressBar(ProgressBarBase): def __init__( self, refresh_rate_per_second: int = 10, - display_every_n_epochs: Optional[int] = None, + leave: bool = False, theme: RichProgressBarTheme = RichProgressBarTheme(), ) -> None: if not _RICH_AVAILABLE: @@ -216,11 +215,7 @@ def __init__( ) super().__init__() self._refresh_rate_per_second: int = refresh_rate_per_second - if display_every_n_epochs and (not isinstance(display_every_n_epochs, int) or (display_every_n_epochs < 0)): - raise MisconfigurationException( - f"`display_every_n_epochs` should be an int >= 0, got {display_every_n_epochs}." - ) - self._display_every_n_epochs: Optional[int] = display_every_n_epochs + self._leave: bool = leave self._enabled: bool = True self.progress: Optional[Progress] = None self.val_sanity_progress_bar_id: Optional[int] = None @@ -331,11 +326,7 @@ def on_train_epoch_start(self, trainer, pl_module): total_batches = total_train_batches + total_val_batches train_description = self._get_train_description(trainer.current_epoch) - if ( - self.main_progress_bar_id is not None - and self._display_every_n_epochs - and (not trainer.current_epoch % self._display_every_n_epochs) - ): + if self.main_progress_bar_id is not None and self._leave: self._stop_progress() self._init_progress(trainer, pl_module) if self.main_progress_bar_id is None: diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index 9c167a71835bf..6c0a201c794c3 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -20,7 +20,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme -from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _RICH_AVAILABLE from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset from tests.helpers.runif import RunIf @@ -163,15 +162,15 @@ def configure_columns(self, trainer, pl_module): @RunIf(rich=True) -@pytest.mark.parametrize(("display_every_n_epochs", "reset_call_count"), ([(None, 5), (1, 0), (2, 3), (3, 4)])) -def test_rich_progress_bar_display_every_n_epochs(tmpdir, display_every_n_epochs, reset_call_count): +@pytest.mark.parametrize(("leave", "reset_call_count"), ([(True, 0), (False, 5)])) +def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count): # Calling `reset` means continuing on the same progress bar. model = BoringModel() with mock.patch( "pytorch_lightning.callbacks.progress.rich_progress.Progress.reset", autospec=True ) as mock_progress_reset: - progress_bar = RichProgressBar(display_every_n_epochs=display_every_n_epochs) + progress_bar = RichProgressBar(leave=leave) trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, @@ -181,9 +180,3 @@ def test_rich_progress_bar_display_every_n_epochs(tmpdir, display_every_n_epochs ) trainer.fit(model) assert mock_progress_reset.call_count == reset_call_count - - -@RunIf(rich=True) -def test_rich_progress_bar_display_every_n_steps_misconfiguration(): - with pytest.raises(MisconfigurationException, match=r"`display_every_n_epochs` should be an int >= 0"): - Trainer(callbacks=RichProgressBar(display_every_n_epochs="test")) From 7ef1b951b849bf8da35fb65aa6d78950b66466b6 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 2 Nov 2021 17:02:34 +0000 Subject: [PATCH 7/7] Update pytorch_lightning/callbacks/progress/rich_progress.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/callbacks/progress/rich_progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 66c4161ba68f4..f6f862704f599 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -195,7 +195,7 @@ class RichProgressBar(ProgressBarBase): Args: refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled. - leave: Displays progress bar per epoch. Default: False + leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False theme: Contains styles used to stylize the progress bar. Raises: