diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 3838a6258b052..1f09770a3d533 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `on_exception` hook to `LightningDataModule` ([#19601](https://github.com/Lightning-AI/pytorch-lightning/pull/19601)) +- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the current device ([#19813](https://github.com/Lightning-AI/pytorch-lightning/issues/19813)) + - ### Changed diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index 7e0ef433031bd..c887ef1befc71 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -403,26 +403,19 @@ def log( # register logged value if it doesn't exist if key not in self: - self.register_key(key, meta, value) + metric = _ResultMetric(meta, isinstance(value, Tensor)) + self[key] = metric # check the stored metadata and the current one match elif meta != self[key].meta: raise MisconfigurationException( f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed" ) + self[key].to(value.device) batch_size = self._extract_batch_size(self[key], batch_size, meta) self.update_metrics(key, value, batch_size) - def register_key(self, key: str, meta: _Metadata, value: _VALUE) -> None: - """Create one _ResultMetric object per value. - - Value can be provided as a nested collection - - """ - metric = _ResultMetric(meta, isinstance(value, Tensor)).to(value.device) - self[key] = metric - def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None: result_metric = self[key] # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl` diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index 32923d444d4bb..53b2e7dff2b4c 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -639,3 +639,23 @@ def test_result_collection_no_batch_size_extraction(): assert results["training_step.epoch_log_val"].value == log_val * batch_size assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size assert results["training_step.epoch_sum_log_val"].value == log_val + + +def test_result_collection_changes_device(): # mock_torch): + results = _ResultCollection(training=True) + fx_name = "training_step" + log_val = torch.tensor(7.0) + + # same device as the original tensor + results.log(fx_name, "step_log_val", log_val, on_step=True, on_epoch=False, reduce_fx="mean") + assert results["training_step.step_log_val"].cumulated_batch_size.device == log_val.device + + # moved to cpu + cumulated_batch_size = results["training_step.step_log_val"].cumulated_batch_size = Mock(spec=torch.Tensor) + cumulated_batch_size.to.return_value = Mock(spec=torch.Tensor) + results.cpu() + cumulated_batch_size.to.assert_called_once_with(log_val.device) + + # same device as the new tensor + results.log(fx_name, "step_log_val", log_val, on_step=True, on_epoch=False, reduce_fx="mean") + cumulated_batch_size.to.return_value.to.assert_called_once_with(log_val.device)