diff --git a/CHANGELOG.md b/CHANGELOG.md index 209a8a4671028..94057eac37d6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -83,6 +83,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added Rich progress bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929), [#9559](https://github.com/PyTorchLightning/pytorch-lightning/pull/9559)) * Added Support for iterable datasets ([#9734](https://github.com/PyTorchLightning/pytorch-lightning/pull/9734)) * Added `RichModelSummary` callback ([#9546](https://github.com/PyTorchLightning/pytorch-lightning/pull/9546)) + * Added `configure_columns` method to `RichProgressBar` ([#10288](https://github.com/PyTorchLightning/pytorch-lightning/pull/10288)) + * Added `leave` argument to `RichProgressBar` ([#10301](https://github.com/PyTorchLightning/pytorch-lightning/pull/10301)) - Added input validation logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080)) - Added support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084)) - Added `on_exception` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183)) @@ -128,7 +130,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `devices="auto"` ([#10264](https://github.com/PyTorchLightning/pytorch-lightning/pull/10264)) - Added a `filename` argument in `ModelCheckpoint.format_checkpoint_name` ([#9818](https://github.com/PyTorchLightning/pytorch-lightning/pull/9818)) - Added support for empty `gpus` list to run on CPU ([#10246](https://github.com/PyTorchLightning/pytorch-lightning/pull/10246)) -- Added `configure_columns` method to `RichProgressBar` ([#10288](https://github.com/PyTorchLightning/pytorch-lightning/pull/10288)) - Added a warning if multiple batch sizes are found from ambiguous batch ([#10247](https://github.com/PyTorchLightning/pytorch-lightning/pull/10247)) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index d684f8a7e38ed..f6f862704f599 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -195,6 +195,7 @@ class RichProgressBar(ProgressBarBase): Args: refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled. + leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False theme: Contains styles used to stylize the progress bar. Raises: @@ -205,6 +206,7 @@ class RichProgressBar(ProgressBarBase): def __init__( self, refresh_rate_per_second: int = 10, + leave: bool = False, theme: RichProgressBarTheme = RichProgressBarTheme(), ) -> None: if not _RICH_AVAILABLE: @@ -213,6 +215,7 @@ def __init__( ) super().__init__() self._refresh_rate_per_second: int = refresh_rate_per_second + self._leave: bool = leave self._enabled: bool = True self.progress: Optional[Progress] = None self.val_sanity_progress_bar_id: Optional[int] = None @@ -323,9 +326,15 @@ def on_train_epoch_start(self, trainer, pl_module): total_batches = total_train_batches + total_val_batches train_description = self._get_train_description(trainer.current_epoch) + if self.main_progress_bar_id is not None and self._leave: + self._stop_progress() + self._init_progress(trainer, pl_module) if self.main_progress_bar_id is None: self.main_progress_bar_id = self._add_task(total_batches, train_description) - self.progress.reset(self.main_progress_bar_id, total=total_batches, description=train_description) + else: + self.progress.reset( + self.main_progress_bar_id, total=total_batches, description=train_description, visible=True + ) def on_validation_epoch_start(self, trainer, pl_module): super().on_validation_epoch_start(trainer, pl_module) diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index ab0852c4729af..6c0a201c794c3 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -144,7 +144,7 @@ def on_train_start(self) -> None: @RunIf(rich=True) -def test_rich_progress_bar_configure_columns(tmpdir): +def test_rich_progress_bar_configure_columns(): from rich.progress import TextColumn custom_column = TextColumn("[progress.description]Testing Rich!") @@ -159,3 +159,24 @@ def configure_columns(self, trainer, pl_module): assert progress_bar.progress.columns[0] == custom_column assert len(progress_bar.progress.columns) == 1 + + +@RunIf(rich=True) +@pytest.mark.parametrize(("leave", "reset_call_count"), ([(True, 0), (False, 5)])) +def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count): + # Calling `reset` means continuing on the same progress bar. + model = BoringModel() + + with mock.patch( + "pytorch_lightning.callbacks.progress.rich_progress.Progress.reset", autospec=True + ) as mock_progress_reset: + progress_bar = RichProgressBar(leave=leave) + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + limit_train_batches=1, + max_epochs=6, + callbacks=progress_bar, + ) + trainer.fit(model) + assert mock_progress_reset.call_count == reset_call_count