Skip to content

Commit

Permalink
Fixed compilation error when a type annotation contains a type guarde…
Browse files Browse the repository at this point in the history
…d by `if TYPE_CHECKING:`

Fixes #331.
  • Loading branch information
agronholm committed Apr 8, 2023
1 parent f5cf181 commit 903b3f5
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ This library adheres to `Semantic Versioning 2.0 <https://semver.org/#semantic-v
generated by older versions
- Fixed typed variable positional and keyword arguments causing compilation errors on
Python 3.7 and 3.8
- Fixed compilation error when a type annotation contains a type guarded by
``if TYPE_CHECKING:``

**4.0.0rc1** (2023-04-02)

Expand Down
18 changes: 16 additions & 2 deletions src/typeguard/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ast
import builtins
import sys
from _ast import AST, Expression
from ast import (
Add,
AnnAssign,
Expand Down Expand Up @@ -330,6 +331,13 @@ def __init__(self, transformer: TypeguardTransformer):
self._memo = transformer._memo
self.level = 0

def visit(self, node: AST) -> Any:
new_node = super().visit(node)
if isinstance(new_node, Expression) and not hasattr(new_node, "body"):
return None

return new_node

def visit_BinOp(self, node: BinOp) -> Any:
self.level += 1
self.generic_visit(node)
Expand Down Expand Up @@ -385,15 +393,21 @@ def visit_Constant(self, node: Constant) -> Any:
if isinstance(node.value, str):
expression = ast.parse(node.value, mode="eval")
new_node = self.visit(expression)
return copy_location(new_node.body, node)
if new_node:
return copy_location(new_node.body, node)
else:
return None

return node

def visit_Str(self, node: Str) -> Any:
# Only used on Python 3.7
expression = ast.parse(node.s, mode="eval")
new_node = self.visit(expression)
return copy_location(new_node.body, node)
if new_node:
return copy_location(new_node.body, node)
else:
return None


class TypeguardTransformer(NodeTransformer):
Expand Down
10 changes: 10 additions & 0 deletions tests/dummymodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Expand Down Expand Up @@ -33,6 +34,9 @@
typeguard_ignore,
)

if TYPE_CHECKING:
from nonexistent import Imaginary

P = ParamSpec("P")


Expand Down Expand Up @@ -306,3 +310,9 @@ def typed_variable_args(
*args: str, **kwargs: int
) -> Tuple[Tuple[str, ...], Dict[str, int]]:
return args, kwargs


@typechecked
def guarded_type_hint(x: "Imaginary") -> "Imaginary":
y: Imaginary = x
return y
4 changes: 4 additions & 0 deletions tests/test_instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,7 @@ def test_kwargs_fail(self, dummymodule):
r"instance of int",
):
dummymodule.typed_variable_args("foo", "bar", a="baz")


def test_guarded_type(dummymodule):
assert dummymodule.guarded_type_hint("foo") == "foo"

0 comments on commit 903b3f5

Please sign in to comment.