Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast to >=float32 tensor when passing scalar to self.log #19046

Merged
merged 7 commits into from Nov 24, 2023

Conversation

MF-FOOM
Copy link
Contributor

@MF-FOOM MF-FOOM commented Nov 21, 2023

What does this PR do?

Ensures that values get cast to float32+ tensors when passed to self.log.

Fixes #18984

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

馃摎 Documentation preview 馃摎: https://pytorch-lightning--19046.org.readthedocs.build/en/19046/

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Nov 21, 2023
@awaelchli awaelchli added the community This PR is from the community label Nov 22, 2023
@awaelchli awaelchli added this to the 2.1.x milestone Nov 22, 2023
@awaelchli awaelchli added bug Something isn't working logging Related to the `LoggerConnector` and `log()` labels Nov 22, 2023
Copy link

codecov bot commented Nov 22, 2023

Codecov Report

Merging #19046 (3ee07db) into master (1c86011) will decrease coverage by 27%.
Report is 55 commits behind head on master.
The diff coverage is 100%.

Additional details and impacted files
@@            Coverage Diff            @@
##           master   #19046     +/-   ##
=========================================
- Coverage      76%      49%    -27%     
=========================================
  Files         447      435     -12     
  Lines       36375    36419     +44     
=========================================
- Hits        27520    17816   -9704     
- Misses       8855    18603   +9748     

Copy link
Member

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried writing a test for this but it's complicated since the metric will be moved to float32 after, so we would need to dynamically patch the call to result.log. It's not worth it.

def test_log_tensor_wrap_dtype():
    class MyModel(LightningModule):
        def test_step(self, *_):
            self.log("foo", 1.23, batch_size=1)

    model = MyModel()
    trainer = Trainer(logger=False, enable_checkpointing=False, enable_model_summary=False, enable_progress_bar=False)
    # using the context manager class for cleanup, but users would use `set_default_dtype` directly
    with _DtypeContextManager(torch.float16):
        trainer.test(model, [0])
    assert trainer.callback_metrics["foo"].dtype == torch.float32

This attempt passes anyways without your PR checked out

LGTM

src/lightning/pytorch/core/module.py Outdated Show resolved Hide resolved
@mergify mergify bot added the ready PRs ready to be merged label Nov 23, 2023
@awaelchli awaelchli changed the title fix: cast to >=float32 tensor when passing scalar to self.log Cast to >=float32 tensor when passing scalar to self.log Nov 24, 2023
@awaelchli awaelchli merged commit 1fcb4ae into Lightning-AI:master Nov 24, 2023
93 checks passed
Borda pushed a commit that referenced this pull request Dec 19, 2023
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
(cherry picked from commit 1fcb4ae)
lantiga pushed a commit that referenced this pull request Dec 20, 2023
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
(cherry picked from commit 1fcb4ae)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working community This PR is from the community logging Related to the `LoggerConnector` and `log()` pl Generic label for PyTorch Lightning package ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

cast to float32 or float64 tensor when passing scalar to self.log
3 participants