From 2ddf9e82b49c790596e12e58973463c038c73e1c Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 16 Aug 2021 11:44:27 +0200 Subject: [PATCH 1/2] resolve bug --- .../connectors/logger_connector/result.py | 14 +++++----- .../logging_/test_eval_loop_logging.py | 26 +++++++++++++++++++ 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 2b2e4613f2298..3db7c25f5e560 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -289,9 +289,12 @@ class ResultMetricCollection(dict): with the same metadata. """ - def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None: + def __init__(self, *args) -> None: super().__init__(*args) - self.meta = metadata + + @property + def meta(self) -> _Metadata: + return list(self.values())[0].meta def __getstate__(self, drop_value: bool = False) -> dict: def getstate(item: ResultMetric) -> dict: @@ -313,9 +316,6 @@ def setstate(item: dict) -> Union[Dict[str, ResultMetric], ResultMetric, Any]: items = setstate(state["items"]) self.update(items) - any_result_metric = next(iter(items.values())) - self.meta = any_result_metric.meta - @classmethod def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "ResultMetricCollection": rmc = cls() @@ -480,7 +480,7 @@ def fn(v: _METRIC) -> ResultMetric: value = apply_to_collection(value, (torch.Tensor, Metric), fn) if isinstance(value, dict): - value = ResultMetricCollection(value, metadata=meta) + value = ResultMetricCollection(value) self[key] = value def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None: @@ -591,7 +591,7 @@ def extract_batch_size(self, batch: Any) -> None: def to(self, *args, **kwargs) -> "ResultCollection": """Move all data to the given device.""" - + # the meta reference is lost there for ``ResultMetricCollection``. self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs)) if self.minimize is not None: diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index b46641849c6b6..8579bc044734a 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -590,3 +590,29 @@ def get_metrics_at_idx(idx): "test_loss", } assert set(results[0]) == {"test_loss"} + + +def test_logging_dict_on_validation_step(tmpdir): + class TestModel(BoringModel): + def validation_step(self, batch, batch_idx): + loss = super().validation_step(batch, batch_idx) + loss = loss["x"] + metrics = { + "loss": loss, + "loss_1": loss, + } + self.log("val_metrics", metrics) + + validation_epoch_end = None + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + progress_bar_refresh_rate=1, + ) + + trainer.fit(model) From 833083c896fb11c56fc7fd5ff0f6643e19d91c54 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 16 Aug 2021 11:46:56 +0200 Subject: [PATCH 2/2] update changelog --- CHANGELOG.md | 4 ++++ .../trainer/connectors/logger_connector/result.py | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b51f0c2e67002..1053894d4dd9e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -176,6 +176,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed infinite loop with CycleIterator and multiple loaders ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889)) +- Fixed lost reference to `_Metadata` object in `ResultMetricCollection` ([#8932](https://github.com/PyTorchLightning/pytorch-lightning/pull/8932)) + + + ## [1.4.0] - 2021-07-27 ### Added diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 3db7c25f5e560..77079e6397f6f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -591,7 +591,6 @@ def extract_batch_size(self, batch: Any) -> None: def to(self, *args, **kwargs) -> "ResultCollection": """Move all data to the given device.""" - # the meta reference is lost there for ``ResultMetricCollection``. self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs)) if self.minimize is not None: