From b8a9ef7d403bd09ac8febe200d6b15ac2f523bb6 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 26 Apr 2022 11:48:49 +0200 Subject: [PATCH 1/7] use only one rich console throughout codebase --- pytorch_lightning/callbacks/progress/rich_progress.py | 6 ++++-- pytorch_lightning/callbacks/rich_model_summary.py | 4 ++-- pytorch_lightning/loops/dataloader/evaluation_loop.py | 10 ++++++---- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 741d4b85d9214..131cec031be2a 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -22,7 +22,8 @@ Task, Style = None, None if _RICH_AVAILABLE: - from rich.console import Console, RenderableType + from rich import get_console, reconfigure + from rich.console import RenderableType from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID, TextColumn from rich.progress_bar import ProgressBar from rich.style import Style @@ -278,7 +279,8 @@ def enable(self) -> None: def _init_progress(self, trainer): if self.is_enabled and (self.progress is None or self._progress_stopped): self._reset_progress_bar_ids() - self._console = Console(**self._console_kwargs) + reconfigure(**self._console_kwargs) + self._console = get_console() self._console.clear_live() self._metric_component = MetricsTextColumn(trainer, self.theme.metrics) self.progress = CustomProgress( diff --git a/pytorch_lightning/callbacks/rich_model_summary.py b/pytorch_lightning/callbacks/rich_model_summary.py index 830e865632a3e..c290ee764d8b4 100644 --- a/pytorch_lightning/callbacks/rich_model_summary.py +++ b/pytorch_lightning/callbacks/rich_model_summary.py @@ -18,7 +18,7 @@ from pytorch_lightning.utilities.model_summary import get_human_readable_count if _RICH_AVAILABLE: - from rich.console import Console + from rich import get_console from rich.table import Table @@ -73,7 +73,7 @@ def summarize( model_size: float, ) -> None: - console = Console() + console = get_console() table = Table(header_style="bold magenta") table.add_column(" ", style="dim") diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index d8b511e29a5c8..3ac62367ec934 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -42,7 +42,7 @@ from pytorch_lightning.utilities.types import EPOCH_OUTPUT if _RICH_AVAILABLE: - from rich.console import Console + from rich import get_console from rich.table import Column, Table @@ -358,8 +358,6 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] table_headers.insert(0, f"{stage} Metric".capitalize()) if _RICH_AVAILABLE: - console = Console(file=file) - columns = [Column(h, justify="center", style="magenta", width=max_length) for h in table_headers] columns[0].style = "cyan" @@ -367,7 +365,11 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] for metric, row in zip(metrics, table_rows): row.insert(0, metric) table.add_row(*row) - console.print(table) + + console = get_console() + with console.capture() as capture: + console.print(table) + print(capture.get(), file=file) else: row_format = f"{{:^{max_length}}}" * len(table_headers) half_term_size = int(term_size / 2) From ae33078377a77da020d6167f0b513196b35d5ef0 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 26 Apr 2022 13:39:28 +0200 Subject: [PATCH 2/7] add tests --- tests/callbacks/test_rich_model_summary.py | 4 +-- tests/callbacks/test_rich_progress_bar.py | 36 ++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/tests/callbacks/test_rich_model_summary.py b/tests/callbacks/test_rich_model_summary.py index c596557eed0dc..d9e4ec55902ca 100644 --- a/tests/callbacks/test_rich_model_summary.py +++ b/tests/callbacks/test_rich_model_summary.py @@ -41,8 +41,8 @@ def test_rich_progress_bar_import_error(monkeypatch): @RunIf(rich=True) -@mock.patch("pytorch_lightning.callbacks.rich_model_summary.Console.print", autospec=True) -@mock.patch("pytorch_lightning.callbacks.rich_model_summary.Table.add_row", autospec=True) +@mock.patch("rich.console.Console.print", autospec=True) +@mock.patch("rich.table.Table.add_row", autospec=True) def test_rich_summary_tuples(mock_table_add_row, mock_console): """Ensure that tuples are converted into string, and print is called correctly.""" model_summary = RichModelSummary() diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index 8fdcb6c99e331..ad903fa17568e 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict +from contextlib import redirect_stdout +from io import StringIO from unittest import mock from unittest.mock import DEFAULT, Mock @@ -400,3 +402,37 @@ def test_step(self, batch, batch_idx): trainer.test(model, verbose=False) assert pbar.calls["test"] == [] + + +@RunIf(rich=True) +def test_rich_print_results_with_progress_bar(tmpdir): + """Test whether Rich table is rendered on its own line. + + Test to counter the issue #12824 + """ + + expected = "\n┏━━━━━━━" + + class MyModel(BoringModel): + def test_step(self, batch, batch_idx): + self.log("c", self.global_step) + return super().test_step(batch, batch_idx) + + with redirect_stdout(StringIO()) as out: + model = MyModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + max_epochs=2, + enable_model_summary=False, + enable_checkpointing=False, + log_every_n_steps=1, + callbacks=RichProgressBar(), + ) + + trainer.fit(model) + trainer.test(model) + + assert expected in out.getvalue() From 038a1fc16a280d48fe31fc1f30b223e78a159d21 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 26 Apr 2022 13:50:57 +0200 Subject: [PATCH 3/7] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e0d8b51f38bc..15b74a6cf4717 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -150,6 +150,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed support for `ModelCheckpoint` monitors with dots ([#12783](https://github.com/PyTorchLightning/pytorch-lightning/pull/12783)) +- Use only a single instance of `rich.console.Console` throughout codebase ([#12886](https://github.com/PyTorchLightning/pytorch-lightning/pull/12886)) + + ## [1.6.1] - 2022-04-13 ### Changed From ae414d526939c51d4cf3eb8653db6e2491945144 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 26 Apr 2022 14:33:03 +0200 Subject: [PATCH 4/7] fixing failing tests --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 3ac62367ec934..b3a44089b5604 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -369,7 +369,7 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] console = get_console() with console.capture() as capture: console.print(table) - print(capture.get(), file=file) + print(capture.get(), end="", file=file) else: row_format = f"{{:^{max_length}}}" * len(table_headers) half_term_size = int(term_size / 2) From a7e10378610a5c7876019161be7301414d769743 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 26 Apr 2022 15:00:00 +0200 Subject: [PATCH 5/7] windows test only + moved to correct location --- tests/callbacks/test_rich_progress_bar.py | 36 ------------------- .../logging_/test_eval_loop_logging.py | 36 +++++++++++++++++++ 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index ad903fa17568e..8fdcb6c99e331 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict -from contextlib import redirect_stdout -from io import StringIO from unittest import mock from unittest.mock import DEFAULT, Mock @@ -402,37 +400,3 @@ def test_step(self, batch, batch_idx): trainer.test(model, verbose=False) assert pbar.calls["test"] == [] - - -@RunIf(rich=True) -def test_rich_print_results_with_progress_bar(tmpdir): - """Test whether Rich table is rendered on its own line. - - Test to counter the issue #12824 - """ - - expected = "\n┏━━━━━━━" - - class MyModel(BoringModel): - def test_step(self, batch, batch_idx): - self.log("c", self.global_step) - return super().test_step(batch, batch_idx) - - with redirect_stdout(StringIO()) as out: - model = MyModel() - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - max_epochs=2, - enable_model_summary=False, - enable_checkpointing=False, - log_every_n_steps=1, - callbacks=RichProgressBar(), - ) - - trainer.fit(model) - trainer.test(model) - - assert expected in out.getvalue() diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 0ca9bf3107b9c..2baccd5a4375b 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -15,6 +15,7 @@ import collections import itertools import os +from contextlib import redirect_stdout from io import StringIO from unittest import mock from unittest.mock import call @@ -24,6 +25,7 @@ import torch from pytorch_lightning import callbacks, Trainer +from pytorch_lightning.callbacks import RichProgressBar from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.loops.dataloader import EvaluationLoop from pytorch_lightning.trainer.states import RunningStage @@ -954,3 +956,37 @@ def test_rich_print_results(inputs, expected): EvaluationLoop._print_results(*inputs, file=out) expected = expected[1:] # remove the initial line break from the """ string assert out.getvalue() == expected.lstrip() + + +@RunIf(rich=True, skip_windows=True) +def test_rich_print_results_with_progress_bar(tmpdir): + """Test whether Rich table is rendered on its own line. + + Test to counter the issue #12824 + """ + + expected = "\n┏━━━━━━━" + + class MyModel(BoringModel): + def test_step(self, batch, batch_idx): + self.log("c", self.global_step) + return super().test_step(batch, batch_idx) + + with redirect_stdout(StringIO()) as out: + model = MyModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + max_epochs=2, + enable_model_summary=False, + enable_checkpointing=False, + log_every_n_steps=1, + callbacks=RichProgressBar(), + ) + + trainer.fit(model) + trainer.test(model) + + assert expected in out.getvalue() From 8ed64e6a3654d326b9c3117459b2435250cb3dfb Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 26 Apr 2022 18:23:14 +0200 Subject: [PATCH 6/7] remove file argument from _print_results --- .../loops/dataloader/evaluation_loop.py | 18 ++++++----------- .../logging_/test_eval_loop_logging.py | 20 ++++++++++++------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index b3a44089b5604..26e5098e99942 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -16,7 +16,7 @@ import sys from collections import ChainMap, OrderedDict from functools import partial -from typing import Any, IO, Iterable, List, Optional, Sequence, Type, Union +from typing import Any, Iterable, List, Optional, Sequence, Type, Union import torch from deprecate.utils import void @@ -319,11 +319,7 @@ def _find_value(data: dict, target: str) -> Iterable[Any]: yield from EvaluationLoop._find_value(v, target) @staticmethod - def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] = None) -> None: - # print to stdout by default - if file is None: - file = sys.stdout - + def _print_results(results: List[_OUT_DICT], stage: str) -> None: # remove the dl idx suffix results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results] metrics = sorted({k for keys in apply_to_collection(results, dict, EvaluationLoop._get_keys) for k in keys}) @@ -367,17 +363,15 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] table.add_row(*row) console = get_console() - with console.capture() as capture: - console.print(table) - print(capture.get(), end="", file=file) + console.print(table) else: row_format = f"{{:^{max_length}}}" * len(table_headers) half_term_size = int(term_size / 2) try: # some terminals do not support this character - if hasattr(file, "encoding") and file.encoding is not None: - "─".encode(file.encoding) + if sys.stdout.encoding is not None: + "─".encode(sys.stdout.encoding) except UnicodeEncodeError: bar_character = "-" else: @@ -396,7 +390,7 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] else: lines.append(row_format.format(metric, *row).rstrip()) lines.append(bar) - print(os.linesep.join(lines), file=file) + print(os.linesep.join(lines)) def _select_data_fetcher_type(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]: diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 2baccd5a4375b..65db25a7189eb 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -30,10 +30,13 @@ from pytorch_lightning.loops.dataloader import EvaluationLoop from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0 +from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _RICH_AVAILABLE from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf +if _RICH_AVAILABLE: + from rich import get_console + def test__validation_step__log(tmpdir): """Tests that validation_step can log.""" @@ -866,8 +869,9 @@ def test_native_print_results(monkeypatch, inputs, expected): import pytorch_lightning.loops.dataloader.evaluation_loop as imports monkeypatch.setattr(imports, "_RICH_AVAILABLE", False) - out = StringIO() - EvaluationLoop._print_results(*inputs, file=out) + + with redirect_stdout(StringIO()) as out: + EvaluationLoop._print_results(*inputs) expected = expected[1:] # remove the initial line break from the """ string assert out.getvalue().replace(os.linesep, "\n") == expected.lstrip() @@ -880,7 +884,8 @@ def test_native_print_results_encodings(monkeypatch, encoding): out = mock.Mock() out.encoding = encoding - EvaluationLoop._print_results(*inputs0, file=out) + with redirect_stdout(out) as out: + EvaluationLoop._print_results(*inputs0) # Attempt to encode everything the file is told to write with the given encoding for call_ in out.method_calls: @@ -952,10 +957,11 @@ def test_native_print_results_encodings(monkeypatch, encoding): ) @RunIf(skip_windows=True, rich=True) def test_rich_print_results(inputs, expected): - out = StringIO() - EvaluationLoop._print_results(*inputs, file=out) + console = get_console() + with console.capture() as capture: + EvaluationLoop._print_results(*inputs) expected = expected[1:] # remove the initial line break from the """ string - assert out.getvalue() == expected.lstrip() + assert capture.get() == expected.lstrip() @RunIf(rich=True, skip_windows=True) From be9cc9a0906d5952ce5ec92c4067bdd98b6f0b12 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 26 Apr 2022 19:17:14 +0200 Subject: [PATCH 7/7] removing too broad test --- .../logging_/test_eval_loop_logging.py | 35 ------------------- 1 file changed, 35 deletions(-) diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 65db25a7189eb..c9f5632e4aee4 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -25,7 +25,6 @@ import torch from pytorch_lightning import callbacks, Trainer -from pytorch_lightning.callbacks import RichProgressBar from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.loops.dataloader import EvaluationLoop from pytorch_lightning.trainer.states import RunningStage @@ -962,37 +961,3 @@ def test_rich_print_results(inputs, expected): EvaluationLoop._print_results(*inputs) expected = expected[1:] # remove the initial line break from the """ string assert capture.get() == expected.lstrip() - - -@RunIf(rich=True, skip_windows=True) -def test_rich_print_results_with_progress_bar(tmpdir): - """Test whether Rich table is rendered on its own line. - - Test to counter the issue #12824 - """ - - expected = "\n┏━━━━━━━" - - class MyModel(BoringModel): - def test_step(self, batch, batch_idx): - self.log("c", self.global_step) - return super().test_step(batch, batch_idx) - - with redirect_stdout(StringIO()) as out: - model = MyModel() - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - max_epochs=2, - enable_model_summary=False, - enable_checkpointing=False, - log_every_n_steps=1, - callbacks=RichProgressBar(), - ) - - trainer.fit(model) - trainer.test(model) - - assert expected in out.getvalue()