Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added Rich progress bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929), [#9559](https://github.com/PyTorchLightning/pytorch-lightning/pull/9559))
* Added Support for iterable datasets ([#9734](https://github.com/PyTorchLightning/pytorch-lightning/pull/9734))
* Added `RichModelSummary` callback ([#9546](https://github.com/PyTorchLightning/pytorch-lightning/pull/9546))
* Added `configure_columns` method to `RichProgressBar` ([#10288](https://github.com/PyTorchLightning/pytorch-lightning/pull/10288))
* Added `leave` argument to `RichProgressBar` ([#10301](https://github.com/PyTorchLightning/pytorch-lightning/pull/10301))
- Added input validation logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080))
- Added support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084))
- Added `on_exception` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183))
Expand Down Expand Up @@ -128,7 +130,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `devices="auto"` ([#10264](https://github.com/PyTorchLightning/pytorch-lightning/pull/10264))
- Added a `filename` argument in `ModelCheckpoint.format_checkpoint_name` ([#9818](https://github.com/PyTorchLightning/pytorch-lightning/pull/9818))
- Added support for empty `gpus` list to run on CPU ([#10246](https://github.com/PyTorchLightning/pytorch-lightning/pull/10246))
- Added `configure_columns` method to `RichProgressBar` ([#10288](https://github.com/PyTorchLightning/pytorch-lightning/pull/10288))
- Added a warning if multiple batch sizes are found from ambiguous batch ([#10247](https://github.com/PyTorchLightning/pytorch-lightning/pull/10247))


Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ class RichProgressBar(ProgressBarBase):

Args:
refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled.
leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
theme: Contains styles used to stylize the progress bar.

Raises:
Expand All @@ -205,6 +206,7 @@ class RichProgressBar(ProgressBarBase):
def __init__(
self,
refresh_rate_per_second: int = 10,
leave: bool = False,
theme: RichProgressBarTheme = RichProgressBarTheme(),
) -> None:
if not _RICH_AVAILABLE:
Expand All @@ -213,6 +215,7 @@ def __init__(
)
super().__init__()
self._refresh_rate_per_second: int = refresh_rate_per_second
self._leave: bool = leave
self._enabled: bool = True
self.progress: Optional[Progress] = None
self.val_sanity_progress_bar_id: Optional[int] = None
Expand Down Expand Up @@ -323,9 +326,15 @@ def on_train_epoch_start(self, trainer, pl_module):
total_batches = total_train_batches + total_val_batches

train_description = self._get_train_description(trainer.current_epoch)
if self.main_progress_bar_id is not None and self._leave:
self._stop_progress()
self._init_progress(trainer, pl_module)
if self.main_progress_bar_id is None:
self.main_progress_bar_id = self._add_task(total_batches, train_description)
self.progress.reset(self.main_progress_bar_id, total=total_batches, description=train_description)
else:
self.progress.reset(
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
)

def on_validation_epoch_start(self, trainer, pl_module):
super().on_validation_epoch_start(trainer, pl_module)
Expand Down
23 changes: 22 additions & 1 deletion tests/callbacks/test_rich_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def on_train_start(self) -> None:


@RunIf(rich=True)
def test_rich_progress_bar_configure_columns(tmpdir):
def test_rich_progress_bar_configure_columns():
from rich.progress import TextColumn

custom_column = TextColumn("[progress.description]Testing Rich!")
Expand All @@ -159,3 +159,24 @@ def configure_columns(self, trainer, pl_module):

assert progress_bar.progress.columns[0] == custom_column
assert len(progress_bar.progress.columns) == 1


@RunIf(rich=True)
@pytest.mark.parametrize(("leave", "reset_call_count"), ([(True, 0), (False, 5)]))
def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count):
# Calling `reset` means continuing on the same progress bar.
model = BoringModel()

with mock.patch(
"pytorch_lightning.callbacks.progress.rich_progress.Progress.reset", autospec=True
) as mock_progress_reset:
progress_bar = RichProgressBar(leave=leave)
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
limit_train_batches=1,
max_epochs=6,
callbacks=progress_bar,
)
trainer.fit(model)
assert mock_progress_reset.call_count == reset_call_count