Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix support for dataclasses with ClassVar/InitVar in apply_to_collection #9702

Merged
merged 15 commits into from Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -368,6 +368,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed an issue where class or init-only variables of dataclasses were passed to the dataclass constructor in `utilities.apply_to_collection` ([#9208](https://github.com/PyTorchLightning/pytorch-lightning/issues/9208))
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8685](https://github.com/PyTorchLightning/pytorch-lightning/pull/8685))

Expand Down
11 changes: 6 additions & 5 deletions pytorch_lightning/utilities/apply_func.py
Expand Up @@ -117,10 +117,10 @@ def apply_to_collection(
return elem_type(*out) if is_namedtuple else elem_type(out)

if _is_dataclass_instance(data):
out_dict = {}
for field in data.__dataclass_fields__:
result = copy(data)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
for field in dataclasses.fields(data):
v = apply_to_collection(
getattr(data, field),
getattr(data, field.name),
dtype,
function,
*args,
Expand All @@ -129,8 +129,9 @@ def apply_to_collection(
**kwargs,
)
if include_none or v is not None:
out_dict[field] = v
return elem_type(**out_dict)
setattr(result, field.name, v)
# else retain old field value
a-gardner1 marked this conversation as resolved.
Show resolved Hide resolved
return result

# data is neither of dtype, nor a collection
return data
Expand Down
121 changes: 93 additions & 28 deletions tests/utilities/test_apply_func.py
Expand Up @@ -14,7 +14,8 @@
import dataclasses
import numbers
from collections import namedtuple, OrderedDict
from typing import List
from dataclasses import InitVar
from typing import Any, ClassVar, List, Optional

import numpy as np
import pytest
Expand All @@ -37,6 +38,36 @@ class ModelExample:
feature: Feature
label: torch.Tensor

@dataclasses.dataclass
class WithClassVar:
class_var: ClassVar[int] = 0
dummy: Any

@dataclasses.dataclass
class WithInitVar:
dummy: Any
override: InitVar[Optional[Any]] = None

def __post_init__(self, override: Optional[Any]):
if override is not None:
self.dummy = override

@dataclasses.dataclass
class WithClassAndInitVar:
class_var: ClassVar[torch.Tensor] = torch.tensor(0)
dummy: Any
override: InitVar[Optional[Any]] = torch.tensor(1)

def __post_init__(self, override: Optional[Any]):
if override is not None:
self.dummy = override

model_example = ModelExample(
example_ids=["i-1", "i-2", "i-3"],
feature=Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])),
label=torch.tensor([7.0, 8.0, 9.0]),
)

to_reduce = {
"a": torch.tensor([1.0]), # Tensor
"b": [torch.tensor([2.0])], # list
Expand All @@ -46,13 +77,18 @@ class ModelExample:
"f": "this_is_a_dummy_str", # string
"g": 12.0, # number
"h": Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])), # dataclass
"i": ModelExample(
example_ids=["i-1", "i-2", "i-3"],
feature=Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])),
label=torch.tensor([7.0, 8.0, 9.0]),
), # nested dataclass
"i": model_example, # nested dataclass
"j": WithClassVar(torch.arange(3)), # dataclass with class variable
"k": WithInitVar("this_gets_overridden", torch.tensor([2.0])), # dataclass with init-only variable
"l": WithClassAndInitVar(model_example, None), # nested dataclass with class and init-only variables
}

model_example_result = ModelExample(
example_ids=["i-1", "i-2", "i-3"],
feature=Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])),
label=torch.tensor([14.0, 16.0, 18.0]),
)

expected_result = {
"a": torch.tensor([2.0]),
"b": [torch.tensor([4.0])],
Expand All @@ -62,16 +98,15 @@ class ModelExample:
"f": "this_is_a_dummy_str",
"g": 24.0,
"h": Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])),
"i": ModelExample(
example_ids=["i-1", "i-2", "i-3"],
feature=Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])),
label=torch.tensor([14.0, 16.0, 18.0]),
),
"i": model_example_result,
"j": WithClassVar(torch.arange(0, 6, 2)),
"k": WithInitVar(torch.tensor([4.0])),
"l": WithClassAndInitVar(model_example_result, None),
}

reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2)

assert isinstance(reduced, dict), " Type Consistency of dict not preserved"
assert isinstance(reduced, dict), "Type Consistency of dict not preserved"
assert all(x in reduced for x in to_reduce), "Not all entries of the dict were preserved"
assert all(
isinstance(reduced[k], type(expected_result[k])) for k in to_reduce
Expand Down Expand Up @@ -115,24 +150,54 @@ class ModelExample:
reduced["h"].segment_ids, expected_result["h"].segment_ids
), "Reduction of a dataclass did not yield the desired result"

assert dataclasses.is_dataclass(reduced["i"]) and not isinstance(
reduced["i"], type
), "Reduction of a dataclass should result in a dataclass"
assert dataclasses.is_dataclass(reduced["i"].feature) and not isinstance(
reduced["i"].feature, type
), "Reduction of a nested dataclass should result in a nested dataclass"
assert (
reduced["i"].example_ids == expected_result["i"].example_ids
), "Reduction of a nested dataclass did not yield the desired result"
def _assert_nested_dataclass_reduction(actual: ModelExample, expected: ModelExample):

assert dataclasses.is_dataclass(actual) and not isinstance(
actual, type
), "Reduction of a dataclass should result in a dataclass"
assert dataclasses.is_dataclass(actual.feature) and not isinstance(
actual.feature, type
), "Reduction of a nested dataclass should result in a nested dataclass"
assert (
actual.example_ids == expected.example_ids
), "Reduction of a nested dataclass did not yield the desired result"
assert torch.allclose(
carmocca marked this conversation as resolved.
Show resolved Hide resolved
actual.label, expected.label
), "Reduction of a nested dataclass did not yield the desired result"
assert torch.allclose(
actual.feature.input_ids, expected.feature.input_ids
), "Reduction of a nested dataclass did not yield the desired result"
assert np.allclose(
actual.feature.segment_ids, expected.feature.segment_ids
), "Reduction of a nested dataclass did not yield the desired result"
carmocca marked this conversation as resolved.
Show resolved Hide resolved

_assert_nested_dataclass_reduction(reduced["i"], expected_result["i"])

assert dataclasses.is_dataclass(reduced["j"]) and not isinstance(
reduced["j"], type
), "Reduction of a dataclass with a class var should result in a dataclass"
assert WithClassVar.class_var == 0, "Reduction of a dataclass with a class var should not change the class var"
assert torch.allclose(
reduced["i"].label, expected_result["i"].label
), "Reduction of a nested dataclass did not yield the desired result"
reduced["j"].dummy, expected_result["j"].dummy
), "Reduction of a dataclass with a class var did not yield the desired result"

assert dataclasses.is_dataclass(reduced["k"]) and not isinstance(
reduced["k"], type
), "Reduction of a dataclass with an init-only var should result in a dataclass"
assert torch.allclose(
reduced["i"].feature.input_ids, expected_result["i"].feature.input_ids
), "Reduction of a nested dataclass did not yield the desired result"
assert np.allclose(
reduced["i"].feature.segment_ids, expected_result["i"].feature.segment_ids
), "Reduction of a nested dataclass did not yield the desired result"
reduced["k"].dummy, expected_result["k"].dummy
), "Reduction of a dataclass with an init-only var did not yield the desired result"

assert dataclasses.is_dataclass(reduced["l"]) and not isinstance(
reduced["l"], type
), "Reduction of a dataclass with class and init-only vars should result in a dataclass"
assert torch.equal(
WithClassAndInitVar.class_var, torch.tensor(0)
), "Reduction of a dataclass with class and init-only vars should not change the class var"
try:
_assert_nested_dataclass_reduction(reduced["l"].dummy, expected_result["l"].dummy)
except AssertionError:
raise AssertionError("Reduction of a dataclass with class and init-only vars did not yield the desired result")
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

# mapping support
reduced = apply_to_collection({"a": 1, "b": 2}, int, lambda x: str(x))
Expand Down