Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cast to float32 or float64 tensor when passing scalar to self.log #18984

Closed
MF-FOOM opened this issue Nov 10, 2023 · 2 comments · Fixed by #19046
Closed

cast to float32 or float64 tensor when passing scalar to self.log #18984

MF-FOOM opened this issue Nov 10, 2023 · 2 comments · Fixed by #19046
Labels
bug Something isn't working logging Related to the `LoggerConnector` and `log()` ver: 2.1.x
Milestone

Comments

@MF-FOOM
Copy link
Contributor

MF-FOOM commented Nov 10, 2023

Bug description

#18686 made it such that all ResultMetric values are stored at float32 or higher precision.

However, values passed in as floats to self.log still get auto-casted to a low precision type (if that's the default precision) via __to_tensor. This means that even though the internal result metric representation is precise, the values can loose precision before even reaching ResultMetric.

I'd suggest we modify __to_tensor to use a function like this (as introduced in #18686) to determine and set the dtype:

def _get_default_dtype() -> torch.dtype:
    """The default dtype for new tensors, but no lower than float32."""
    dtype = torch.get_default_dtype()
    return dtype if dtype in (torch.float32, torch.float64) else torch.float32

What version are you seeing the problem on?

v2.1

How to reproduce the bug

No response

Error messages and logs

N/A

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @carmocca

@MF-FOOM MF-FOOM added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Nov 10, 2023
@MF-FOOM
Copy link
Contributor Author

MF-FOOM commented Nov 17, 2023

bump on this? i'm happy to implement, just want approval to go ahead with a pr

@carmocca
Copy link
Contributor

Please go ahead @MF-FOOM! Sorry for the wait

@carmocca carmocca added logging Related to the `LoggerConnector` and `log()` and removed needs triage Waiting to be triaged by maintainers labels Nov 21, 2023
@carmocca carmocca added this to the 2.1.x milestone Nov 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working logging Related to the `LoggerConnector` and `log()` ver: 2.1.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants