Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed support for `ModelCheckpoint` monitors with dots ([#12783](https://github.com/PyTorchLightning/pytorch-lightning/pull/12783))


- Use only a single instance of `rich.console.Console` throughout codebase ([#12886](https://github.com/PyTorchLightning/pytorch-lightning/pull/12886))


## [1.6.1] - 2022-04-13

### Changed
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

Task, Style = None, None
if _RICH_AVAILABLE:
from rich.console import Console, RenderableType
from rich import get_console, reconfigure
from rich.console import RenderableType
from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID, TextColumn
from rich.progress_bar import ProgressBar
from rich.style import Style
Expand Down Expand Up @@ -278,7 +279,8 @@ def enable(self) -> None:
def _init_progress(self, trainer):
if self.is_enabled and (self.progress is None or self._progress_stopped):
self._reset_progress_bar_ids()
self._console = Console(**self._console_kwargs)
reconfigure(**self._console_kwargs)
self._console = get_console()
self._console.clear_live()
self._metric_component = MetricsTextColumn(trainer, self.theme.metrics)
self.progress = CustomProgress(
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pytorch_lightning.utilities.model_summary import get_human_readable_count

if _RICH_AVAILABLE:
from rich.console import Console
from rich import get_console
from rich.table import Table


Expand Down Expand Up @@ -73,7 +73,7 @@ def summarize(
model_size: float,
) -> None:

console = Console()
console = get_console()

table = Table(header_style="bold magenta")
table.add_column(" ", style="dim")
Expand Down
20 changes: 8 additions & 12 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sys
from collections import ChainMap, OrderedDict
from functools import partial
from typing import Any, IO, Iterable, List, Optional, Sequence, Type, Union
from typing import Any, Iterable, List, Optional, Sequence, Type, Union

import torch
from deprecate.utils import void
Expand All @@ -42,7 +42,7 @@
from pytorch_lightning.utilities.types import EPOCH_OUTPUT

if _RICH_AVAILABLE:
from rich.console import Console
from rich import get_console
from rich.table import Column, Table


Expand Down Expand Up @@ -319,11 +319,7 @@ def _find_value(data: dict, target: str) -> Iterable[Any]:
yield from EvaluationLoop._find_value(v, target)

@staticmethod
def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] = None) -> None:
# print to stdout by default
if file is None:
file = sys.stdout

def _print_results(results: List[_OUT_DICT], stage: str) -> None:
# remove the dl idx suffix
results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results]
metrics = sorted({k for keys in apply_to_collection(results, dict, EvaluationLoop._get_keys) for k in keys})
Expand Down Expand Up @@ -358,24 +354,24 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]]
table_headers.insert(0, f"{stage} Metric".capitalize())

if _RICH_AVAILABLE:
console = Console(file=file)

columns = [Column(h, justify="center", style="magenta", width=max_length) for h in table_headers]
columns[0].style = "cyan"

table = Table(*columns)
for metric, row in zip(metrics, table_rows):
row.insert(0, metric)
table.add_row(*row)

console = get_console()
console.print(table)
else:
row_format = f"{{:^{max_length}}}" * len(table_headers)
half_term_size = int(term_size / 2)

try:
# some terminals do not support this character
if hasattr(file, "encoding") and file.encoding is not None:
"─".encode(file.encoding)
if sys.stdout.encoding is not None:
"─".encode(sys.stdout.encoding)
except UnicodeEncodeError:
bar_character = "-"
else:
Expand All @@ -394,7 +390,7 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]]
else:
lines.append(row_format.format(metric, *row).rstrip())
lines.append(bar)
print(os.linesep.join(lines), file=file)
print(os.linesep.join(lines))


def _select_data_fetcher_type(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]:
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def test_rich_progress_bar_import_error(monkeypatch):


@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.rich_model_summary.Console.print", autospec=True)
@mock.patch("pytorch_lightning.callbacks.rich_model_summary.Table.add_row", autospec=True)
@mock.patch("rich.console.Console.print", autospec=True)
@mock.patch("rich.table.Table.add_row", autospec=True)
def test_rich_summary_tuples(mock_table_add_row, mock_console):
"""Ensure that tuples are converted into string, and print is called correctly."""
model_summary = RichModelSummary()
Expand Down
21 changes: 14 additions & 7 deletions tests/trainer/logging_/test_eval_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import collections
import itertools
import os
from contextlib import redirect_stdout
from io import StringIO
from unittest import mock
from unittest.mock import call
Expand All @@ -28,10 +29,13 @@
from pytorch_lightning.loops.dataloader import EvaluationLoop
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _RICH_AVAILABLE
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf

if _RICH_AVAILABLE:
from rich import get_console


def test__validation_step__log(tmpdir):
"""Tests that validation_step can log."""
Expand Down Expand Up @@ -864,8 +868,9 @@ def test_native_print_results(monkeypatch, inputs, expected):
import pytorch_lightning.loops.dataloader.evaluation_loop as imports

monkeypatch.setattr(imports, "_RICH_AVAILABLE", False)
out = StringIO()
EvaluationLoop._print_results(*inputs, file=out)

with redirect_stdout(StringIO()) as out:
EvaluationLoop._print_results(*inputs)
expected = expected[1:] # remove the initial line break from the """ string
assert out.getvalue().replace(os.linesep, "\n") == expected.lstrip()

Expand All @@ -878,7 +883,8 @@ def test_native_print_results_encodings(monkeypatch, encoding):

out = mock.Mock()
out.encoding = encoding
EvaluationLoop._print_results(*inputs0, file=out)
with redirect_stdout(out) as out:
EvaluationLoop._print_results(*inputs0)

# Attempt to encode everything the file is told to write with the given encoding
for call_ in out.method_calls:
Expand Down Expand Up @@ -950,7 +956,8 @@ def test_native_print_results_encodings(monkeypatch, encoding):
)
@RunIf(skip_windows=True, rich=True)
def test_rich_print_results(inputs, expected):
out = StringIO()
EvaluationLoop._print_results(*inputs, file=out)
console = get_console()
with console.capture() as capture:
EvaluationLoop._print_results(*inputs)
expected = expected[1:] # remove the initial line break from the """ string
assert out.getvalue() == expected.lstrip()
assert capture.get() == expected.lstrip()