From 75af55ce9c3aadab7168f6d94e0cb44e24ab512d Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 2 Sep 2022 03:45:38 +0530 Subject: [PATCH 1/4] Squeeze tensor while logging --- src/pytorch_lightning/core/module.py | 1 + .../trainer/logging_/test_train_loop_logging.py | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index a479beadc7931..47d75e2c637b5 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -424,6 +424,7 @@ def log( ) value = apply_to_collection(value, numbers.Number, self.__to_tensor) + value = apply_to_collection(value, torch.Tensor, torch.squeeze) apply_to_collection(value, torch.Tensor, self.__check_numel_1, name) if self.trainer._logger_connector.should_reset_tensors(self._current_fx_name): diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py index 85ed3d8e3471d..cd7f83ddc7bfe 100644 --- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py @@ -29,6 +29,7 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar from pytorch_lightning.core.module import LightningModule from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomDictDataset +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.runif import RunIf @@ -836,3 +837,13 @@ def on_train_start(self): assert mock_log_metrics.mock_calls == [call(metrics={"foo": 123.0, "epoch": 0}, step=0)] assert trainer.max_epochs > 1 + + +def test_unsqueezed_tensor_logging(): + model = BoringModel() + trainer = Trainer() + trainer.state.stage = RunningStage.TRAINING + model._current_fx_name = "training_step" + model.trainer = trainer + model.log("foo", torch.Tensor([1.2])) + assert trainer.callback_metrics["foo"].ndim == 0 From 8fd086e6ee1704fa83274feecc74396a7161d4e5 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 2 Sep 2022 03:47:36 +0530 Subject: [PATCH 2/4] chlog --- src/pytorch_lightning/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 9d4323548cb7e..f52c320d66179 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -136,6 +136,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue to keep downscaling the batch size in case there hasn't been even a single successful optimal batch size with `mode="power"` ([#14372](https://github.com/Lightning-AI/lightning/pull/14372)) +- Squeezed tensor values while logging ([#14489](https://github.com/Lightning-AI/lightning/pull/14489)) + + ## [1.7.4] - 2022-08-31 From f3a3518c2c40b3b5af5806d93e05644292e12d30 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Fri, 2 Sep 2022 16:15:20 +0530 Subject: [PATCH 3/4] Update src/pytorch_lightning/CHANGELOG.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- src/pytorch_lightning/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index f52c320d66179..8e5f930e838e7 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -136,7 +136,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue to keep downscaling the batch size in case there hasn't been even a single successful optimal batch size with `mode="power"` ([#14372](https://github.com/Lightning-AI/lightning/pull/14372)) -- Squeezed tensor values while logging ([#14489](https://github.com/Lightning-AI/lightning/pull/14489)) +- Squeezed tensor values when logging with `LightningModule.log` ([#14489](https://github.com/Lightning-AI/lightning/pull/14489)) From abc917650337d6dbae07b5545095e480a04ac6e5 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 5 Sep 2022 16:01:04 +0530 Subject: [PATCH 4/4] update --- src/pytorch_lightning/core/module.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 47d75e2c637b5..a8fea8c210959 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -423,9 +423,7 @@ def log( " but it should not contain information about `dataloader_idx`" ) - value = apply_to_collection(value, numbers.Number, self.__to_tensor) - value = apply_to_collection(value, torch.Tensor, torch.squeeze) - apply_to_collection(value, torch.Tensor, self.__check_numel_1, name) + value = apply_to_collection(value, (torch.Tensor, numbers.Number), self.__to_tensor, name) if self.trainer._logger_connector.should_reset_tensors(self._current_fx_name): # if we started a new epoch (running its first batch) the hook name has changed @@ -557,16 +555,15 @@ def __check_not_nested(value: dict, name: str) -> None: def __check_allowed(v: Any, name: str, value: Any) -> None: raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged") - def __to_tensor(self, value: numbers.Number) -> Tensor: - return torch.tensor(value, device=self.device) - - @staticmethod - def __check_numel_1(value: Tensor, name: str) -> None: + def __to_tensor(self, value: Union[torch.Tensor, numbers.Number], name: str) -> Tensor: + value = torch.tensor(value, device=self.device) if not torch.numel(value) == 1: raise ValueError( f"`self.log({name}, {value})` was called, but the tensor must have a single element." f" You can try doing `self.log({name}, {value}.mean())`" ) + value = value.squeeze() + return value def log_grad_norm(self, grad_norm_dict: Dict[str, float]) -> None: """Override this method to change the default behaviour of ``log_grad_norm``.