Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a friendly error message when DDP attempts to spawn new distributed processes with rank > 0 ([#9005](https://github.com/PyTorchLightning/pytorch-lightning/pull/9005))


- Added Rich Progress Bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929))
- Added Rich Progress Bar:
* Rich Progress Bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929))
* Improvements for rich progress bar ([#9559](https://github.com/PyTorchLightning/pytorch-lightning/pull/9559))


- Added validate logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080))
Expand Down
89 changes: 59 additions & 30 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,52 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from datetime import timedelta
from typing import Dict, Optional
from typing import Optional, Union

from pytorch_lightning.callbacks.progress.base import ProgressBarBase
from pytorch_lightning.utilities import _RICH_AVAILABLE

Style = None
if _RICH_AVAILABLE:
from rich.console import Console, RenderableType
from rich.progress import BarColumn, Progress, ProgressColumn, SpinnerColumn, TextColumn
from rich.progress import BarColumn, Progress, ProgressColumn, TextColumn
from rich.style import Style
from rich.text import Text

class CustomTimeColumn(ProgressColumn):

# Only refresh twice a second to prevent jitter
max_refresh = 0.5

def __init__(self, style: Union[str, Style]) -> None:
self.style = style
super().__init__()

def render(self, task) -> Text:
elapsed = task.finished_time if task.finished else task.elapsed
remaining = task.time_remaining
elapsed_delta = "-:--:--" if elapsed is None else str(timedelta(seconds=int(elapsed)))
remaining_delta = "-:--:--" if remaining is None else str(timedelta(seconds=int(remaining)))
return Text.from_markup(f"[progress.elapsed]{elapsed_delta} < [progress.remaining]{remaining_delta}")
return Text(f"{elapsed_delta} {remaining_delta}", style=self.style)

class BatchesProcessedColumn(ProgressColumn):
def __init__(self, style: Union[str, Style]):
self.style = style
super().__init__()

def render(self, task) -> RenderableType:
return Text.from_markup(f"[magenta] {int(task.completed)}/{task.total}")
return Text(f"{int(task.completed)}/{task.total}", style=self.style)

class ProcessingSpeedColumn(ProgressColumn):
def __init__(self, style: Union[str, Style]):
self.style = style
super().__init__()

def render(self, task) -> RenderableType:
task_speed = f"{task.speed:>.2f}" if task.speed is not None else "0.00"
return Text.from_markup(f"[progress.data.speed] {task_speed}it/s")
return Text(f"{task_speed}it/s", style=self.style)

class MetricsTextColumn(ProgressColumn):
"""A column containing text."""
Expand Down Expand Up @@ -71,19 +86,26 @@ def render(self, task) -> Text:
metrics = self._trainer.progress_bar_callback.get_metrics(self._trainer, self._pl_module)
else:
metrics = self._trainer.progress_bar_metrics

for k, v in metrics.items():
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
text = Text.from_markup(_text, style=None, justify="left")
return text


STYLES: Dict[str, str] = {
"train": "red",
"sanity_check": "yellow",
"validate": "yellow",
"test": "yellow",
"predict": "yellow",
}
@dataclass
class RichProgressBarTheme:
"""Styles to associate to different base components.

https://rich.readthedocs.io/en/stable/style.html
"""

text_color: str = "white"
progress_bar_complete: Union[str, Style] = "#6206E0"
progress_bar_finished: Union[str, Style] = "#6206E0"
batch_process: str = "white"
time: str = "grey54"
processing_speed: str = "grey70"
Comment on lines -80 to +108
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SeanNaren @kaushikb11

noob question: Why did we introduce this new class RichProgressBarTheme and pass an instance to RichProgressBar? I'm asking because if we didn't have this class and passed colours (or a dictionary of colours) directly to RichProgressBar, colour customisation would be easier for users in my opinion. (One less thing to remember :])

from pytorch_lightning.callbacks import RichProgressBar
-from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme

progress_bar = RichProgressBar(
-    theme=RichProgressBarTheme(
        description="green_yellow",
        progress_bar="green1",
        progress_bar_finished="green1",
        batch_progress="green_yellow",
        time="grey82",
        processing_speed="grey82",
        metrics="grey82",
-    )
)

FYI, I'm looking at the example in the blog post: https://devblog.pytorchlightning.ai/super-charged-progress-bars-with-rich-lightning-669653d6ab97



class RichProgressBar(ProgressBarBase):
Expand All @@ -104,13 +126,18 @@ class RichProgressBar(ProgressBarBase):

Args:
refresh_rate: the number of updates per second, must be strictly positive
theme: Contains styles used to stylize the progress bar.

Raises:
ImportError:
If required `rich` package is not installed on the device.
"""

def __init__(self, refresh_rate: float = 1.0):
def __init__(
self,
refresh_rate: float = 1.0,
theme: RichProgressBarTheme = RichProgressBarTheme(),
) -> None:
if not _RICH_AVAILABLE:
raise ImportError(
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`."
Expand All @@ -126,6 +153,7 @@ def __init__(self, refresh_rate: float = 1.0):
self.test_progress_bar_id: Optional[int] = None
self.predict_progress_bar_id: Optional[int] = None
self.console = Console(record=True)
self.theme = theme

@property
def refresh_rate(self) -> int:
Expand All @@ -147,39 +175,36 @@ def enable(self) -> None:

@property
def sanity_check_description(self) -> str:
return "[Validation Sanity Check]"
return "Validation Sanity Check"

@property
def validation_description(self) -> str:
return "[Validation]"
return "Validation"

@property
def test_description(self) -> str:
return "[Testing]"
return "Testing"

@property
def predict_description(self) -> str:
return "[Predicting]"
return "Predicting"

def setup(self, trainer, pl_module, stage):
self.progress = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
BatchesProcessedColumn(),
"[",
CustomTimeColumn(),
ProcessingSpeedColumn(),
BarColumn(complete_style=self.theme.progress_bar_complete, finished_style=self.theme.progress_bar_finished),
BatchesProcessedColumn(style=self.theme.batch_process),
CustomTimeColumn(style=self.theme.time),
ProcessingSpeedColumn(style=self.theme.processing_speed),
MetricsTextColumn(trainer, pl_module, stage),
"]",
console=self.console,
refresh_per_second=self.refresh_rate,
).__enter__()

def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_sanity_progress_bar_id = self.progress.add_task(
f"[{STYLES['sanity_check']}]{self.sanity_check_description}",
f"[{self.theme.text_color}]{self.sanity_check_description}",
total=trainer.num_sanity_val_steps,
)

Expand All @@ -201,15 +226,15 @@ def on_train_epoch_start(self, trainer, pl_module):
train_description = self._get_train_description(trainer.current_epoch)

self.main_progress_bar_id = self.progress.add_task(
f"[{STYLES['train']}]{train_description}",
f"[{self.theme.text_color}]{train_description}",
total=total_batches,
)

def on_validation_epoch_start(self, trainer, pl_module):
super().on_validation_epoch_start(trainer, pl_module)
if self._total_val_batches > 0:
self.val_progress_bar_id = self.progress.add_task(
f"[{STYLES['validate']}]{self.validation_description}",
f"[{self.theme.text_color}]{self.validation_description}",
total=self._total_val_batches,
)

Expand All @@ -221,14 +246,14 @@ def on_validation_epoch_end(self, trainer, pl_module):
def on_test_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
self.test_progress_bar_id = self.progress.add_task(
f"[{STYLES['test']}]{self.test_description}",
f"[{self.theme.text_color}]{self.test_description}",
total=self.total_test_batches,
)

def on_predict_epoch_start(self, trainer, pl_module):
super().on_predict_epoch_start(trainer, pl_module)
self.predict_progress_bar_id = self.progress.add_task(
f"[{STYLES['predict']}]{self.predict_description}",
f"[{self.theme.text_color}]{self.predict_description}",
total=self.total_predict_batches,
)

Expand Down Expand Up @@ -261,7 +286,7 @@ def _should_update(self, current, total) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

def _get_train_description(self, current_epoch: int) -> str:
train_description = f"[Epoch {current_epoch}]"
train_description = f"Epoch {current_epoch}"
if len(self.validation_description) > len(train_description):
# Padding is required to avoid flickering due of uneven lengths of "Epoch X"
# and "Validation" Bar description
Expand All @@ -273,3 +298,7 @@ def _get_train_description(self, current_epoch: int) -> str:

def teardown(self, trainer, pl_module, stage):
self.progress.__exit__(None, None, None)

def on_exception(self, trainer, pl_module, exception: BaseException) -> None:
if isinstance(exception, KeyboardInterrupt):
self.progress.stop()
60 changes: 57 additions & 3 deletions tests/callbacks/test_rich_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import DEFAULT

import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf


@RunIf(rich=True)
def test_rich_progress_bar_callback():

trainer = Trainer(callbacks=RichProgressBar())

progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)]
Expand All @@ -36,7 +37,6 @@ def test_rich_progress_bar_callback():
@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
def test_rich_progress_bar(progress_update, tmpdir):

model = BoringModel()

trainer = Trainer(
Expand All @@ -58,7 +58,61 @@ def test_rich_progress_bar(progress_update, tmpdir):


def test_rich_progress_bar_import_error():

if not _RICH_AVAILABLE:
with pytest.raises(ImportError, match="`RichProgressBar` requires `rich` to be installed."):
Trainer(callbacks=RichProgressBar())


@RunIf(rich=True)
def test_rich_progress_bar_custom_theme(tmpdir):
"""Test to ensure that custom theme styles are used."""
with mock.patch.multiple(
"pytorch_lightning.callbacks.progress.rich_progress",
BarColumn=DEFAULT,
BatchesProcessedColumn=DEFAULT,
CustomTimeColumn=DEFAULT,
ProcessingSpeedColumn=DEFAULT,
) as mocks:

theme = RichProgressBarTheme()

progress_bar = RichProgressBar(theme=theme)
progress_bar.setup(Trainer(tmpdir), BoringModel(), stage=None)

assert progress_bar.theme == theme
args, kwargs = mocks["BarColumn"].call_args
assert kwargs["complete_style"] == theme.progress_bar_complete
assert kwargs["finished_style"] == theme.progress_bar_finished

args, kwargs = mocks["BatchesProcessedColumn"].call_args
assert kwargs["style"] == theme.batch_process

args, kwargs = mocks["CustomTimeColumn"].call_args
assert kwargs["style"] == theme.time

args, kwargs = mocks["ProcessingSpeedColumn"].call_args
assert kwargs["style"] == theme.processing_speed


@RunIf(rich=True)
def test_rich_progress_bar_keyboard_interrupt(tmpdir):
"""Test to ensure that when the user keyboard interrupts, we close the progress bar."""

class TestModel(BoringModel):
def on_train_start(self) -> None:
raise KeyboardInterrupt

model = TestModel()

with mock.patch(
"pytorch_lightning.callbacks.progress.rich_progress.Progress.stop", autospec=True
) as mock_progress_stop:
progress_bar = RichProgressBar()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
callbacks=progress_bar,
)

trainer.fit(model)
mock_progress_stop.assert_called_once()