diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ec87f4448d93..41fbbb591c15a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -181,6 +181,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed infinite loop with CycleIterator and multiple loaders ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889)) +- Fixed lost reference to `_Metadata` object in `ResultMetricCollection` ([#8932](https://github.com/PyTorchLightning/pytorch-lightning/pull/8932)) + + - Fixed bug where data-loading functions where not getting the correct running stage passed ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858)) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 2b2e4613f2298..77079e6397f6f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -289,9 +289,12 @@ class ResultMetricCollection(dict): with the same metadata. """ - def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None: + def __init__(self, *args) -> None: super().__init__(*args) - self.meta = metadata + + @property + def meta(self) -> _Metadata: + return list(self.values())[0].meta def __getstate__(self, drop_value: bool = False) -> dict: def getstate(item: ResultMetric) -> dict: @@ -313,9 +316,6 @@ def setstate(item: dict) -> Union[Dict[str, ResultMetric], ResultMetric, Any]: items = setstate(state["items"]) self.update(items) - any_result_metric = next(iter(items.values())) - self.meta = any_result_metric.meta - @classmethod def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "ResultMetricCollection": rmc = cls() @@ -480,7 +480,7 @@ def fn(v: _METRIC) -> ResultMetric: value = apply_to_collection(value, (torch.Tensor, Metric), fn) if isinstance(value, dict): - value = ResultMetricCollection(value, metadata=meta) + value = ResultMetricCollection(value) self[key] = value def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None: @@ -591,7 +591,6 @@ def extract_batch_size(self, batch: Any) -> None: def to(self, *args, **kwargs) -> "ResultCollection": """Move all data to the given device.""" - self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs)) if self.minimize is not None: diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index b46641849c6b6..8579bc044734a 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -590,3 +590,29 @@ def get_metrics_at_idx(idx): "test_loss", } assert set(results[0]) == {"test_loss"} + + +def test_logging_dict_on_validation_step(tmpdir): + class TestModel(BoringModel): + def validation_step(self, batch, batch_idx): + loss = super().validation_step(batch, batch_idx) + loss = loss["x"] + metrics = { + "loss": loss, + "loss_1": loss, + } + self.log("val_metrics", metrics) + + validation_epoch_end = None + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + progress_bar_refresh_rate=1, + ) + + trainer.fit(model)