Skip to content

Commit

Permalink
Fixed the __init__ of dataclassess with multiple inheritance (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielNoord authored and Pierre-Sassoulas committed Sep 7, 2022
1 parent d154666 commit 449a95b
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 20 deletions.
4 changes: 4 additions & 0 deletions ChangeLog
Expand Up @@ -12,6 +12,10 @@ What's New in astroid 2.12.9?
=============================
Release date: TBA

* Fixed creation of the ``__init__`` of ``dataclassess`` with multiple inheritance.

Closes PyCQA/pylint#7427

* Fixed a crash on ``namedtuples`` that use ``typename`` to specify their name.

Closes PyCQA/pylint#7429
Expand Down
62 changes: 42 additions & 20 deletions astroid/brain/brain_dataclasses.py
Expand Up @@ -177,6 +177,45 @@ def _check_generate_dataclass_init(node: nodes.ClassDef) -> bool:
)


def _find_arguments_from_base_classes(
node: nodes.ClassDef, skippable_names: set[str]
) -> tuple[str, str]:
"""Iterate through all bases and add them to the list of arguments to add to the init."""
prev_pos_only = ""
prev_kw_only = ""
for base in node.mro():
if not base.is_dataclass:
continue
try:
base_init: nodes.FunctionDef = base.locals["__init__"][0]
except KeyError:
continue

# Skip the self argument and check for duplicate arguments
arguments = base_init.args.format_args(skippable_names=skippable_names)
try:
new_prev_pos_only, new_prev_kw_only = arguments.split("*, ")
except ValueError:
new_prev_pos_only, new_prev_kw_only = arguments, ""

if new_prev_pos_only:
# The split on '*, ' can crete a pos_only string that consists only of a comma
if new_prev_pos_only == ", ":
new_prev_pos_only = ""
elif not new_prev_pos_only.endswith(", "):
new_prev_pos_only += ", "

# Dataclasses put last seen arguments at the front of the init
prev_pos_only = new_prev_pos_only + prev_pos_only
prev_kw_only = new_prev_kw_only + prev_kw_only

# Add arguments to skippable arguments
skippable_names.update(arg.name for arg in base_init.args.args)
skippable_names.update(arg.name for arg in base_init.args.kwonlyargs)

return prev_pos_only, prev_kw_only


def _generate_dataclass_init(
node: nodes.ClassDef, assigns: list[nodes.AnnAssign], kw_only_decorated: bool
) -> str:
Expand Down Expand Up @@ -228,26 +267,9 @@ def _generate_dataclass_init(
if not init_var:
assignments.append(assignment_str)

try:
base = next(next(iter(node.bases)).infer())
if not isinstance(base, nodes.ClassDef):
raise InferenceError
base_init: nodes.FunctionDef | None = base.locals["__init__"][0]
except (StopIteration, InferenceError, KeyError):
base_init = None

prev_pos_only = ""
prev_kw_only = ""
if base_init and base.is_dataclass:
# Skip the self argument and check for duplicate arguments
arguments = base_init.args.format_args(skippable_names=assign_names)[6:]
try:
prev_pos_only, prev_kw_only = arguments.split("*, ")
except ValueError:
prev_pos_only, prev_kw_only = arguments, ""

if prev_pos_only and not prev_pos_only.endswith(", "):
prev_pos_only += ", "
prev_pos_only, prev_kw_only = _find_arguments_from_base_classes(
node, set(assign_names + ["self"])
)

# Construct the new init method paramter string
params_string = "self, "
Expand Down
132 changes: 132 additions & 0 deletions tests/unittest_brain_dataclasses.py
Expand Up @@ -912,3 +912,135 @@ class GoodExampleClass(GoodExampleParentClass):
good_init: bases.UnboundMethod = next(good_node.infer())
assert bad_init.args.defaults
assert [a.name for a in good_init.args.args] == ["self", "xyz"]


def test_dataclass_with_multiple_inheritance() -> None:
"""Regression test for dataclasses with multiple inheritance.
Reported in https://github.com/PyCQA/pylint/issues/7427
"""
first, second, overwritten, overwriting, mixed = astroid.extract_node(
"""
from dataclasses import dataclass
@dataclass
class BaseParent:
_abc: int = 1
@dataclass
class AnotherParent:
ef: int = 2
@dataclass
class FirstChild(BaseParent, AnotherParent):
ghi: int = 3
@dataclass
class ConvolutedParent(AnotherParent):
'''Convoluted Parent'''
@dataclass
class SecondChild(BaseParent, ConvolutedParent):
jkl: int = 4
@dataclass
class OverwritingParent:
ef: str = "2"
@dataclass
class OverwrittenChild(OverwritingParent, AnotherParent):
'''Overwritten Child'''
@dataclass
class OverwritingChild(BaseParent, AnotherParent):
_abc: float = 1.0
ef: float = 2.0
class NotADataclassParent:
ef: int = 2
@dataclass
class ChildWithMixedParents(BaseParent, NotADataclassParent):
ghi: int = 3
FirstChild.__init__ #@
SecondChild.__init__ #@
OverwrittenChild.__init__ #@
OverwritingChild.__init__ #@
ChildWithMixedParents.__init__ #@
"""
)

first_init: bases.UnboundMethod = next(first.infer())
assert [a.name for a in first_init.args.args] == ["self", "ef", "_abc", "ghi"]
assert [a.value for a in first_init.args.defaults] == [2, 1, 3]

second_init: bases.UnboundMethod = next(second.infer())
assert [a.name for a in second_init.args.args] == ["self", "ef", "_abc", "jkl"]
assert [a.value for a in second_init.args.defaults] == [2, 1, 4]

overwritten_init: bases.UnboundMethod = next(overwritten.infer())
assert [a.name for a in overwritten_init.args.args] == ["self", "ef"]
assert [a.value for a in overwritten_init.args.defaults] == ["2"]

overwriting_init: bases.UnboundMethod = next(overwriting.infer())
assert [a.name for a in overwriting_init.args.args] == ["self", "_abc", "ef"]
assert [a.value for a in overwriting_init.args.defaults] == [1.0, 2.0]

mixed_init: bases.UnboundMethod = next(mixed.infer())
assert [a.name for a in mixed_init.args.args] == ["self", "_abc", "ghi"]
assert [a.value for a in mixed_init.args.defaults] == [1, 3]


def test_dataclass_inits_of_non_dataclasses() -> None:
"""Regression test for __init__ mangling for non dataclasses.
Regression test against changes tested in test_dataclass_with_multiple_inheritance
"""
first, second, third = astroid.extract_node(
"""
from dataclasses import dataclass
@dataclass
class DataclassParent:
_abc: int = 1
class NotADataclassParent:
ef: int = 2
class FirstChild(DataclassParent, NotADataclassParent):
ghi: int = 3
class SecondChild(DataclassParent, NotADataclassParent):
ghi: int = 3
def __init__(self, ef: int = 3):
self.ef = ef
class ThirdChild(NotADataclassParent, DataclassParent):
ghi: int = 3
def __init__(self, ef: int = 3):
self.ef = ef
FirstChild.__init__ #@
SecondChild.__init__ #@
ThirdChild.__init__ #@
"""
)

first_init: bases.UnboundMethod = next(first.infer())
assert [a.name for a in first_init.args.args] == ["self", "_abc"]
assert [a.value for a in first_init.args.defaults] == [1]

second_init: bases.UnboundMethod = next(second.infer())
assert [a.name for a in second_init.args.args] == ["self", "ef"]
assert [a.value for a in second_init.args.defaults] == [3]

third_init: bases.UnboundMethod = next(third.infer())
assert [a.name for a in third_init.args.args] == ["self", "ef"]
assert [a.value for a in third_init.args.defaults] == [3]

0 comments on commit 449a95b

Please sign in to comment.