Skip to content

Commit

Permalink
Merge cc926b8 into 068edfd
Browse files Browse the repository at this point in the history
  • Loading branch information
vthemelis committed Aug 26, 2023
2 parents 068edfd + cc926b8 commit 259a984
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 15 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ repos:
hooks:
- id: ruff
args: [--fix, --show-fixes]
exclude: "^tests/dummymodule.py"

- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black
exclude: "^tests/mypy/negative.py"
exclude: "^tests/dummymodule.py|^tests/mypy/negative.py"

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
Expand Down
36 changes: 22 additions & 14 deletions src/typeguard/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
fix_missing_locations,
iter_fields,
keyword,
stmt,
walk,
)
from collections import defaultdict
Expand Down Expand Up @@ -153,19 +154,9 @@ def __post_init__(self) -> None:

# Figure out where to insert instrumentation code
if self.node:
for index, child in enumerate(self.node.body):
if isinstance(child, ImportFrom) and child.module == "__future__":
# (module only) __future__ imports must come first
continue
elif (
isinstance(child, Expr)
and isinstance(child.value, Constant)
and isinstance(child.value.value, str)
):
continue # docstring

self.code_inject_index = index
break
self.code_inject_index = TransformMemo._get_code_inject_index(
self.node.body
)

def get_unused_name(self, name: str) -> str:
memo: TransformMemo | None = self
Expand Down Expand Up @@ -232,7 +223,8 @@ def insert_imports(self, node: Module | FunctionDef | AsyncFunctionDef) -> None:
alias(orig_name, new_name.id if orig_name != new_name.id else None)
for orig_name, new_name in sorted(names.items())
]
node.body.insert(self.code_inject_index, ImportFrom(modulename, aliases, 0))
inject_index = TransformMemo._get_code_inject_index(node.body)
node.body.insert(inject_index, ImportFrom(modulename, aliases, 0))

def name_matches(self, expression: expr | Expr | None, *names: str) -> bool:
if expression is None:
Expand Down Expand Up @@ -280,6 +272,22 @@ def get_config_keywords(self) -> list[keyword]:
overrides.update(self.configuration_overrides)
return [keyword(key, value) for key, value in overrides.items()]

@staticmethod
def _get_code_inject_index(stmts: list[stmt]) -> int:
for index, child in enumerate(stmts):
if isinstance(child, ImportFrom) and child.module == "__future__":
# (module only) __future__ imports must come first
continue
elif (
isinstance(child, Expr)
and isinstance(child.value, Constant)
and isinstance(child.value.value, str)
):
continue # docstring

return index
return len(stmts)


class NameCollector(NodeVisitor):
def __init__(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions tests/dummymodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Module docstring."""
from __future__ import annotations

import sys
from contextlib import contextmanager
from typing import (
Expand Down
29 changes: 29 additions & 0 deletions tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,6 +1487,35 @@ def foo() -> int:
)


def test_respect_future_import() -> None:
# Regression test for #385
node = parse(
dedent(
"""
from __future__ import annotations
def foo() -> int:
return 1
"""
)
)
TypeguardTransformer().visit(node)
assert (
unparse(node)
== dedent(
"""
from __future__ import annotations
from typeguard import TypeCheckMemo
from typeguard._functions import check_return_type
def foo() -> int:
memo = TypeCheckMemo(globals(), locals())
return check_return_type('foo', 1, int, memo)
"""
).strip()
)


def test_dont_leave_empty_ast_container_nodes() -> None:
# Regression test for #352
node = parse(
Expand Down

0 comments on commit 259a984

Please sign in to comment.