Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast to >=float32 tensor when passing scalar to self.log #19046

Merged
merged 7 commits into from
Nov 24, 2023
17 changes: 16 additions & 1 deletion src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def log(
" but it should not contain information about `dataloader_idx`"
)

value = apply_to_collection(value, (Tensor, numbers.Number), self.__to_tensor, name)
MF-FOOM marked this conversation as resolved.
Show resolved Hide resolved
value = apply_to_collection(value, (Tensor, numbers.Number), self.__to_tensor_high_precision, name)

if trainer._logger_connector.should_reset_tensors(self._current_fx_name):
# if we started a new epoch (running its first batch) the hook name has changed
Expand Down Expand Up @@ -621,6 +621,21 @@ def __check_not_nested(value: dict, name: str) -> None:
def __check_allowed(v: Any, name: str, value: Any) -> None:
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")

def _get_default_high_precision_dtype() -> torch.dtype:
"""The default dtype for new tensors, but no lower than float32."""
dtype = torch.get_default_dtype()
return dtype if dtype in (torch.float32, torch.float64) else torch.float32

def __to_tensor_high_precision(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor:
value = value.clone().detach() if isinstance(value, Tensor) else torch.tensor(value, device=self.device, dtype=_get_default_high_precision_dtype())
if not torch.numel(value) == 1:
raise ValueError(
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
f" You can try doing `self.log({name}, {value}.mean())`"
)
value = value.squeeze()
return value

def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor:
value = value.clone().detach() if isinstance(value, Tensor) else torch.tensor(value, device=self.device)
if not torch.numel(value) == 1:
Expand Down