Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue to keep downscaling the batch size in case there hasn't been even a single successful optimal batch size with `mode="power"` ([#14372](https://github.com/Lightning-AI/lightning/pull/14372))


- Squeezed tensor values when logging with `LightningModule.log` ([#14489](https://github.com/Lightning-AI/lightning/pull/14489))


- Fixed `WandbLogger` `save_dir` is not set after creation ([#14326](https://github.com/Lightning-AI/lightning/pull/14326))


Expand Down
12 changes: 5 additions & 7 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,7 @@ def log(
" but it should not contain information about `dataloader_idx`"
)

value = apply_to_collection(value, numbers.Number, self.__to_tensor)
apply_to_collection(value, torch.Tensor, self.__check_numel_1, name)
value = apply_to_collection(value, (torch.Tensor, numbers.Number), self.__to_tensor, name)

if self.trainer._logger_connector.should_reset_tensors(self._current_fx_name):
# if we started a new epoch (running its first batch) the hook name has changed
Expand Down Expand Up @@ -556,16 +555,15 @@ def __check_not_nested(value: dict, name: str) -> None:
def __check_allowed(v: Any, name: str, value: Any) -> None:
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")

def __to_tensor(self, value: numbers.Number) -> Tensor:
return torch.tensor(value, device=self.device)

@staticmethod
def __check_numel_1(value: Tensor, name: str) -> None:
def __to_tensor(self, value: Union[torch.Tensor, numbers.Number], name: str) -> Tensor:
value = torch.tensor(value, device=self.device)
if not torch.numel(value) == 1:
raise ValueError(
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
f" You can try doing `self.log({name}, {value}.mean())`"
)
value = value.squeeze()
return value

def log_grad_norm(self, grad_norm_dict: Dict[str, float]) -> None:
"""Override this method to change the default behaviour of ``log_grad_norm``.
Expand Down
11 changes: 11 additions & 0 deletions tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.core.module import LightningModule
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomDictDataset
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.runif import RunIf

Expand Down Expand Up @@ -836,3 +837,13 @@ def on_train_start(self):

assert mock_log_metrics.mock_calls == [call(metrics={"foo": 123.0, "epoch": 0}, step=0)]
assert trainer.max_epochs > 1


def test_unsqueezed_tensor_logging():
model = BoringModel()
trainer = Trainer()
trainer.state.stage = RunningStage.TRAINING
model._current_fx_name = "training_step"
model.trainer = trainer
model.log("foo", torch.Tensor([1.2]))
assert trainer.callback_metrics["foo"].ndim == 0