Skip to content

Commit

Permalink
Allowed __new__() to type check against its own class
Browse files Browse the repository at this point in the history
This is needed for enums where `__new__()` is overriden since the class doesn't exist yet in the module namespace so checkers have to check against `cls` instead of using the actual class name.
  • Loading branch information
agronholm committed Sep 10, 2023
1 parent 4d1768c commit 888a8c5
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 3 deletions.
3 changes: 3 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ This library adheres to
argument, variable assignment or an ``if TYPE_CHECKING`` import)
(`#394 <https://github.com/agronholm/typeguard/issues/394>`_,
`#395 <https://github.com/agronholm/typeguard/issues/395>`_)
- Fixed type checking of class instances created in ``__new__()`` in cases such as enums
where this method is already invoked before the class has finished initializing
(`#398 <https://github.com/agronholm/typeguard/issues/398>`_)

**4.1.3** (2023-08-27)

Expand Down
22 changes: 20 additions & 2 deletions src/typeguard/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,9 @@ def _use_memo(
self, node: ClassDef | FunctionDef | AsyncFunctionDef
) -> Generator[None, Any, None]:
new_memo = TransformMemo(node, self._memo, self._memo.path + (node.name,))
old_memo = self._memo
self._memo = new_memo

if isinstance(node, (FunctionDef, AsyncFunctionDef)):
new_memo.should_instrument = (
self._target_path is None or new_memo.path == self._target_path
Expand Down Expand Up @@ -580,8 +583,6 @@ def _use_memo(
if isinstance(node, AsyncFunctionDef):
new_memo.is_async = True

old_memo = self._memo
self._memo = new_memo
yield
self._memo = old_memo

Expand Down Expand Up @@ -921,6 +922,23 @@ def visit_FunctionDef(

self._memo.insert_imports(node)

# Special case the __new__() method to create a local alias from the
# class name to the first argument (usually "cls")
if (
isinstance(node, FunctionDef)
and node.args
and self._memo.parent is not None
and isinstance(self._memo.parent.node, ClassDef)
and node.name == "__new__"
and self._memo.parent.node.name in self._memo.local_names
):
first_args_expr = Name(node.args.args[0].arg, ctx=Load())
cls_name = Name(self._memo.parent.node.name, ctx=Store())
node.body.insert(
self._memo.code_inject_index,
Assign([cls_name], first_args_expr),
)

# Rmove any placeholder "pass" at the end
if isinstance(node.body[-1], Pass):
del node.body[-1]
Expand Down
37 changes: 36 additions & 1 deletion tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def foo(x: int) -> int:
)


def test_new() -> None:
def test_new_with_self() -> None:
node = parse(
dedent(
"""
Expand Down Expand Up @@ -653,6 +653,41 @@ def __new__(cls) -> Self:
)


def test_new_with_explicit_class_name() -> None:
# Regression test for #398
node = parse(
dedent(
"""
class A:
def __new__(cls) -> 'A':
obj: A = object.__new__(cls)
return obj
"""
)
)
TypeguardTransformer().visit(node)
assert (
unparse(node)
== dedent(
"""
from typeguard import TypeCheckMemo
from typeguard._functions import check_return_type, \
check_variable_assignment
class A:
def __new__(cls) -> 'A':
A = cls
memo = TypeCheckMemo(globals(), locals(), self_type=cls)
obj: A = check_variable_assignment(object.__new__(cls), 'obj', A, \
memo)
return check_return_type('A.__new__', obj, A, memo)
"""
).strip()
)


def test_local_function() -> None:
node = parse(
dedent(
Expand Down

0 comments on commit 888a8c5

Please sign in to comment.