diff --git a/CHANGELOG.md b/CHANGELOG.md index c9991b028e181..c2bf09068ff61 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719)) +- Show a better error message when frozen dataclass is used as a batch ([#10927](https://github.com/PyTorchLightning/pytorch-lightning/issues/10927)) + + - Save the `Loop`'s state by default in the checkpoint ([#10784](https://github.com/PyTorchLightning/pytorch-lightning/issues/10784)) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index e4756f2632970..ae5896704e424 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -23,6 +23,7 @@ import numpy as np import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_LEGACY from pytorch_lightning.utilities.warnings import rank_zero_deprecation @@ -147,7 +148,13 @@ def apply_to_collection( ) if not field_init or (not include_none and v is None): # retain old value v = getattr(data, field_name) - setattr(result, field_name, v) + try: + setattr(result, field_name, v) + except dataclasses.FrozenInstanceError as e: + raise MisconfigurationException( + "A frozen dataclass was passed to `apply_to_collection` but this is not allowed." + " HINT: is your batch a frozen dataclass?" + ) from e return result # data is neither of dtype, nor a collection diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index 9b0fcbd643744..869a3bb619ad5 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -22,6 +22,7 @@ import torch from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device +from pytorch_lightning.utilities.exceptions import MisconfigurationException def test_recursive_application_to_collection(): @@ -302,6 +303,17 @@ def fn(a, b): assert reduced is None +def test_apply_to_collection_frozen_dataclass(): + @dataclasses.dataclass(frozen=True) + class Foo: + input: torch.Tensor + + foo = Foo(torch.tensor(0)) + + with pytest.raises(MisconfigurationException, match="frozen dataclass was passed"): + apply_to_collection(foo, torch.Tensor, lambda t: t.to(torch.int)) + + @pytest.mark.parametrize("should_return", [False, True]) def test_wrongly_implemented_transferable_data_type(should_return): class TensorObject: