Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- When using custom DataLoaders in LightningDataModule, multiple inheritance is resolved properly ([#12716](https://github.com/PyTorchLightning/pytorch-lightning/pull/12716))


- Fixed encoding issues on terminals that do not support unicode characters ([#12828](https://github.com/PyTorchLightning/pytorch-lightning/pull/12828))


- Fixed support for `ModelCheckpoint` monitors with dots ([#12783](https://github.com/PyTorchLightning/pytorch-lightning/pull/12783))


Expand Down
16 changes: 15 additions & 1 deletion pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import os
import shutil
import sys
from collections import ChainMap, OrderedDict
from functools import partial
from typing import Any, IO, Iterable, List, Optional, Sequence, Type, Union
Expand Down Expand Up @@ -336,6 +337,10 @@ def _find_value(data: dict, target: str) -> Iterable[Any]:

@staticmethod
def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] = None) -> None:
# print to stdout by default
if file is None:
file = sys.stdout

# remove the dl idx suffix
results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results]
metrics = sorted({k for keys in apply_to_collection(results, dict, EvaluationLoop._get_keys) for k in keys})
Expand Down Expand Up @@ -384,7 +389,16 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]]
row_format = f"{{:^{max_length}}}" * len(table_headers)
half_term_size = int(term_size / 2)

bar = "─" * term_size
try:
# some terminals do not support this character
if hasattr(file, "encoding") and file.encoding is not None:
"─".encode(file.encoding)
except UnicodeEncodeError:
bar_character = "-"
else:
bar_character = "─"
bar = bar_character * term_size

lines = [bar, row_format.format(*table_headers).rstrip(), bar]
for metric, row in zip(metrics, table_rows):
# deal with column overflow
Expand Down
17 changes: 17 additions & 0 deletions tests/trainer/logging_/test_eval_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,23 @@ def test_native_print_results(monkeypatch, inputs, expected):
assert out.getvalue().replace(os.linesep, "\n") == expected.lstrip()


@pytest.mark.parametrize("encoding", ["latin-1", "utf-8"])
def test_native_print_results_encodings(monkeypatch, encoding):
import pytorch_lightning.loops.dataloader.evaluation_loop as imports

monkeypatch.setattr(imports, "_RICH_AVAILABLE", False)

out = mock.Mock()
out.encoding = encoding
EvaluationLoop._print_results(*inputs0, file=out)

# Attempt to encode everything the file is told to write with the given encoding
for call_ in out.method_calls:
name, args, kwargs = call_
if name == "write":
args[0].encode(encoding)


expected0 = """
┏━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ DataLoader 0 ┃ DataLoader 1 ┃
Expand Down