diff --git a/CHANGELOG.md b/CHANGELOG.md index daee5ae803144..e5d266c77dfd8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -154,6 +154,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed support for `ModelCheckpoint` monitors with dots ([#12783](https://github.com/PyTorchLightning/pytorch-lightning/pull/12783)) +- Use only a single instance of `rich.console.Console` throughout codebase ([#12886](https://github.com/PyTorchLightning/pytorch-lightning/pull/12886)) + + ## [1.6.1] - 2022-04-13 ### Changed diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 741d4b85d9214..131cec031be2a 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -22,7 +22,8 @@ Task, Style = None, None if _RICH_AVAILABLE: - from rich.console import Console, RenderableType + from rich import get_console, reconfigure + from rich.console import RenderableType from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID, TextColumn from rich.progress_bar import ProgressBar from rich.style import Style @@ -278,7 +279,8 @@ def enable(self) -> None: def _init_progress(self, trainer): if self.is_enabled and (self.progress is None or self._progress_stopped): self._reset_progress_bar_ids() - self._console = Console(**self._console_kwargs) + reconfigure(**self._console_kwargs) + self._console = get_console() self._console.clear_live() self._metric_component = MetricsTextColumn(trainer, self.theme.metrics) self.progress = CustomProgress( diff --git a/pytorch_lightning/callbacks/rich_model_summary.py b/pytorch_lightning/callbacks/rich_model_summary.py index 830e865632a3e..c290ee764d8b4 100644 --- a/pytorch_lightning/callbacks/rich_model_summary.py +++ b/pytorch_lightning/callbacks/rich_model_summary.py @@ -18,7 +18,7 @@ from pytorch_lightning.utilities.model_summary import get_human_readable_count if _RICH_AVAILABLE: - from rich.console import Console + from rich import get_console from rich.table import Table @@ -73,7 +73,7 @@ def summarize( model_size: float, ) -> None: - console = Console() + console = get_console() table = Table(header_style="bold magenta") table.add_column(" ", style="dim") diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index d8b511e29a5c8..26e5098e99942 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -16,7 +16,7 @@ import sys from collections import ChainMap, OrderedDict from functools import partial -from typing import Any, IO, Iterable, List, Optional, Sequence, Type, Union +from typing import Any, Iterable, List, Optional, Sequence, Type, Union import torch from deprecate.utils import void @@ -42,7 +42,7 @@ from pytorch_lightning.utilities.types import EPOCH_OUTPUT if _RICH_AVAILABLE: - from rich.console import Console + from rich import get_console from rich.table import Column, Table @@ -319,11 +319,7 @@ def _find_value(data: dict, target: str) -> Iterable[Any]: yield from EvaluationLoop._find_value(v, target) @staticmethod - def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] = None) -> None: - # print to stdout by default - if file is None: - file = sys.stdout - + def _print_results(results: List[_OUT_DICT], stage: str) -> None: # remove the dl idx suffix results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results] metrics = sorted({k for keys in apply_to_collection(results, dict, EvaluationLoop._get_keys) for k in keys}) @@ -358,8 +354,6 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] table_headers.insert(0, f"{stage} Metric".capitalize()) if _RICH_AVAILABLE: - console = Console(file=file) - columns = [Column(h, justify="center", style="magenta", width=max_length) for h in table_headers] columns[0].style = "cyan" @@ -367,6 +361,8 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] for metric, row in zip(metrics, table_rows): row.insert(0, metric) table.add_row(*row) + + console = get_console() console.print(table) else: row_format = f"{{:^{max_length}}}" * len(table_headers) @@ -374,8 +370,8 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] try: # some terminals do not support this character - if hasattr(file, "encoding") and file.encoding is not None: - "─".encode(file.encoding) + if sys.stdout.encoding is not None: + "─".encode(sys.stdout.encoding) except UnicodeEncodeError: bar_character = "-" else: @@ -394,7 +390,7 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] else: lines.append(row_format.format(metric, *row).rstrip()) lines.append(bar) - print(os.linesep.join(lines), file=file) + print(os.linesep.join(lines)) def _select_data_fetcher_type(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]: diff --git a/tests/callbacks/test_rich_model_summary.py b/tests/callbacks/test_rich_model_summary.py index c596557eed0dc..d9e4ec55902ca 100644 --- a/tests/callbacks/test_rich_model_summary.py +++ b/tests/callbacks/test_rich_model_summary.py @@ -41,8 +41,8 @@ def test_rich_progress_bar_import_error(monkeypatch): @RunIf(rich=True) -@mock.patch("pytorch_lightning.callbacks.rich_model_summary.Console.print", autospec=True) -@mock.patch("pytorch_lightning.callbacks.rich_model_summary.Table.add_row", autospec=True) +@mock.patch("rich.console.Console.print", autospec=True) +@mock.patch("rich.table.Table.add_row", autospec=True) def test_rich_summary_tuples(mock_table_add_row, mock_console): """Ensure that tuples are converted into string, and print is called correctly.""" model_summary = RichModelSummary() diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 0ca9bf3107b9c..c9f5632e4aee4 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -15,6 +15,7 @@ import collections import itertools import os +from contextlib import redirect_stdout from io import StringIO from unittest import mock from unittest.mock import call @@ -28,10 +29,13 @@ from pytorch_lightning.loops.dataloader import EvaluationLoop from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0 +from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _RICH_AVAILABLE from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf +if _RICH_AVAILABLE: + from rich import get_console + def test__validation_step__log(tmpdir): """Tests that validation_step can log.""" @@ -864,8 +868,9 @@ def test_native_print_results(monkeypatch, inputs, expected): import pytorch_lightning.loops.dataloader.evaluation_loop as imports monkeypatch.setattr(imports, "_RICH_AVAILABLE", False) - out = StringIO() - EvaluationLoop._print_results(*inputs, file=out) + + with redirect_stdout(StringIO()) as out: + EvaluationLoop._print_results(*inputs) expected = expected[1:] # remove the initial line break from the """ string assert out.getvalue().replace(os.linesep, "\n") == expected.lstrip() @@ -878,7 +883,8 @@ def test_native_print_results_encodings(monkeypatch, encoding): out = mock.Mock() out.encoding = encoding - EvaluationLoop._print_results(*inputs0, file=out) + with redirect_stdout(out) as out: + EvaluationLoop._print_results(*inputs0) # Attempt to encode everything the file is told to write with the given encoding for call_ in out.method_calls: @@ -950,7 +956,8 @@ def test_native_print_results_encodings(monkeypatch, encoding): ) @RunIf(skip_windows=True, rich=True) def test_rich_print_results(inputs, expected): - out = StringIO() - EvaluationLoop._print_results(*inputs, file=out) + console = get_console() + with console.capture() as capture: + EvaluationLoop._print_results(*inputs) expected = expected[1:] # remove the initial line break from the """ string - assert out.getvalue() == expected.lstrip() + assert capture.get() == expected.lstrip()