Skip to content

Commit

Permalink
Modify infernce tip for typing.Generic and typing.Annotated with ``__…
Browse files Browse the repository at this point in the history
…class_getitem__`` (#931)

* Modify infernce tip for typing.Generic and typing.Annotated with __class_getitem__
* Fix issue with slots caching
* Clean typing.Generic from mro
  • Loading branch information
cdce8p authored Apr 10, 2021
1 parent 31a731a commit cd0f896
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 9 deletions.
6 changes: 6 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ Release Date: TBA

* Use ``inference_tip`` for ``typing.TypedDict`` brain.

* Fix mro for classes that inherit from typing.Generic

* Add inference tip for typing.Generic and typing.Annotated with ``__class_getitem__``

Closes PyCQA/pylint#2822


What's New in astroid 2.5.2?
============================
Expand Down
42 changes: 33 additions & 9 deletions astroid/brain/brain_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ class {0}(metaclass=Meta):
)
)

CLASS_GETITEM_TEMPLATE = """
@classmethod
def __class_getitem__(cls, item):
return cls
"""


def looks_like_typing_typevar_or_newtype(node):
func = node.func
Expand Down Expand Up @@ -126,7 +132,9 @@ def _looks_like_typing_subscript(node):
return False


def infer_typing_attr(node, context=None):
def infer_typing_attr(
node: nodes.Subscript, ctx: context.InferenceContext = None
) -> typing.Iterator[nodes.ClassDef]:
"""Infer a typing.X[...] subscript"""
try:
value = next(node.value.infer())
Expand All @@ -142,8 +150,31 @@ def infer_typing_attr(node, context=None):
# (PY37+) handle it separately.
raise UseInferenceDefault

if (
PY37
and isinstance(value, nodes.ClassDef)
and value.qname()
in ("typing.Generic", "typing.Annotated", "typing_extensions.Annotated")
):
# With PY37+ typing.Generic and typing.Annotated (PY39) are subscriptable
# through __class_getitem__. Since astroid can't easily
# infer the native methods, replace them for an easy inference tip
func_to_add = astroid.extract_node(CLASS_GETITEM_TEMPLATE)
value.locals["__class_getitem__"] = [func_to_add]
if (
isinstance(node.parent, nodes.ClassDef)
and node in node.parent.bases
and getattr(node.parent, "__cache", None)
):
# node.parent.slots is evaluated and cached before the inference tip
# is first applied. Remove the last result to allow a recalculation of slots
cache = getattr(node.parent, "__cache")
if cache.get(node.parent.slots) is not None:
del cache[node.parent.slots]
return iter([value])

node = extract_node(TYPING_TYPE_TEMPLATE.format(value.qname().split(".")[-1]))
return node.infer(context=context)
return node.infer(context=ctx)


def _looks_like_typedDict( # pylint: disable=invalid-name
Expand All @@ -166,13 +197,6 @@ def infer_typedDict( # pylint: disable=invalid-name
return iter([class_def])


CLASS_GETITEM_TEMPLATE = """
@classmethod
def __class_getitem__(cls, item):
return cls
"""


def _looks_like_typing_alias(node: nodes.Call) -> bool:
"""
Returns True if the node corresponds to a call to _alias function.
Expand Down
37 changes: 37 additions & 0 deletions astroid/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,42 @@ def _c3_merge(sequences, cls, context):
return None


def clean_typing_generic_mro(sequences: List[List["ClassDef"]]) -> None:
"""A class can inherit from typing.Generic directly, as base,
and as base of bases. The merged MRO must however only contain the last entry.
To prepare for _c3_merge, remove some typing.Generic entries from
sequences if multiple are present.
This method will check if Generic is in inferred_bases and also
part of bases_mro. If true, remove it from inferred_bases
as well as its entry the bases_mro.
Format sequences: [[self]] + bases_mro + [inferred_bases]
"""
bases_mro = sequences[1:-1]
inferred_bases = sequences[-1]
# Check if Generic is part of inferred_bases
for i, base in enumerate(inferred_bases):
if base.qname() == "typing.Generic":
position_in_inferred_bases = i
break
else:
return
# Check if also part of bases_mro
# Ignore entry for typing.Generic
for i, seq in enumerate(bases_mro):
if i == position_in_inferred_bases:
continue
if any(base.qname() == "typing.Generic" for base in seq):
break
else:
return
# Found multiple Generics in mro, remove entry from inferred_bases
# and the corresponding one from bases_mro
inferred_bases.pop(position_in_inferred_bases)
bases_mro.pop(position_in_inferred_bases)


def clean_duplicates_mro(sequences, cls, context):
for sequence in sequences:
names = [
Expand Down Expand Up @@ -2924,6 +2960,7 @@ def _compute_mro(self, context=None):

unmerged_mro = [[self]] + bases_mro + [inferred_bases]
unmerged_mro = list(clean_duplicates_mro(unmerged_mro, self, context))
clean_typing_generic_mro(unmerged_mro)
return _c3_merge(unmerged_mro, self, context)

def mro(self, context=None) -> List["ClassDef"]:
Expand Down
50 changes: 50 additions & 0 deletions tests/unittest_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,56 @@ def test_typing_types(self):
inferred = next(node.infer())
self.assertIsInstance(inferred, nodes.ClassDef, node.as_string())

@test_utils.require_version(minver="3.7")
def test_typing_generic_subscriptable(self):
"""Test typing.Generic is subscriptable with __class_getitem__ (added in PY37)"""
node = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
Generic[T]
"""
)
inferred = next(node.infer())
assert isinstance(inferred, nodes.ClassDef)
assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef)

@test_utils.require_version(minver="3.9")
def test_typing_annotated_subscriptable(self):
"""Test typing.Annotated is subscriptable with __class_getitem__"""
node = builder.extract_node(
"""
import typing
typing.Annotated[str, "data"]
"""
)
inferred = next(node.infer())
assert isinstance(inferred, nodes.ClassDef)
assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef)

@test_utils.require_version(minver="3.7")
def test_typing_generic_slots(self):
"""Test cache reset for slots if Generic subscript is inferred."""
node = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A(Generic[T]):
__slots__ = ['value']
def __init__(self, value):
self.value = value
"""
)
inferred = next(node.infer())
assert len(inferred.slots()) == 0
# Only after the subscript base is inferred and the inference tip applied,
# will slots contain the correct value
next(node.bases[0].infer())
slots = inferred.slots()
assert len(slots) == 1
assert isinstance(slots[0], nodes.Const)
assert slots[0].value == "value"

def test_has_dunder_args(self):
ast_node = builder.extract_node(
"""
Expand Down
139 changes: 139 additions & 0 deletions tests/unittest_scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,9 @@ class NodeBase(object):
def assertEqualMro(self, klass, expected_mro):
self.assertEqual([member.name for member in klass.mro()], expected_mro)

def assertEqualMroQName(self, klass, expected_mro):
self.assertEqual([member.qname() for member in klass.mro()], expected_mro)

@unittest.skipUnless(HAS_SIX, "These tests require the six library")
def test_with_metaclass_mro(self):
astroid = builder.parse(
Expand Down Expand Up @@ -1438,6 +1441,142 @@ class C(scope.A, scope.B):
)
self.assertEqualMro(cls, ["C", "A", "B", "object"])

@test_utils.require_version(minver="3.7")
def test_mro_generic_1(self):
cls = builder.extract_node(
"""
import typing
T = typing.TypeVar('T')
class A(typing.Generic[T]): ...
class B: ...
class C(A[T], B): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", "typing.Generic", ".B", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_2(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A: ...
class B(Generic[T]): ...
class C(Generic[T], A, B[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_3(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A: ...
class B(A, Generic[T]): ...
class C(Generic[T]): ...
class D(B[T], C[T], Generic[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".D", ".B", ".A", ".C", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_4(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A: ...
class B(Generic[T]): ...
class C(A, Generic[T], B[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_5(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T1 = TypeVar('T1')
T2 = TypeVar('T2')
class A(Generic[T1]): ...
class B(Generic[T2]): ...
class C(A[T1], B[T2]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_6(self):
cls = builder.extract_node(
"""
from typing import Generic as TGeneric, TypeVar
T = TypeVar('T')
class Generic: ...
class A(Generic): ...
class B(TGeneric[T]): ...
class C(A, B[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".Generic", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_7(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A(): ...
class B(Generic[T]): ...
class C(A, B[T]): ...
class D: ...
class E(C[str], D): ...
"""
)
self.assertEqualMroQName(
cls, [".E", ".C", ".A", ".B", "typing.Generic", ".D", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_error_1(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T1 = TypeVar('T1')
T2 = TypeVar('T2')
class A(Generic[T1], Generic[T2]): ...
"""
)
with self.assertRaises(DuplicateBasesError) as ex:
cls.mro()

@test_utils.require_version(minver="3.7")
def test_mro_generic_error_2(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A(Generic[T]): ...
class B(A[T], A[T]): ...
"""
)
with self.assertRaises(DuplicateBasesError) as ex:
cls.mro()

def test_generator_from_infer_call_result_parent(self):
func = builder.extract_node(
"""
Expand Down

0 comments on commit cd0f896

Please sign in to comment.