Skip to content

Commit

Permalink
Fixed placement of module level typeguard imports
Browse files Browse the repository at this point in the history
Fixes #385.
  • Loading branch information
agronholm committed Aug 26, 2023
1 parent 068edfd commit c11057a
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ This library adheres to `Semantic Versioning 2.0 <https://semver.org/#semantic-v
- Fixed ``@typechecked`` optimization causing compilation of instrumented code to fail
when any block was left empty by the AST transformer (eg `if` or `try` / `except` blocks)
(`#352 <https://github.com/agronholm/typeguard/issues/352>`_)
- Fixed placement of injected typeguard imports with respect to ``__future__`` imports and module
docstrings (`#385 <https://github.com/agronholm/typeguard/issues/385>`_)

**4.1.2** (2023-08-18)

Expand Down
3 changes: 2 additions & 1 deletion src/typeguard/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,9 @@ def visit_Name(self, node: Name) -> Name:
return node

def visit_Module(self, node: Module) -> Module:
self._module_memo = self._memo = TransformMemo(node, None, ())
self.generic_visit(node)
self._memo.insert_imports(node)
self._module_memo.insert_imports(node)

fix_missing_locations(node)
return node
Expand Down
31 changes: 31 additions & 0 deletions tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,3 +1624,34 @@ def foo() -> int:
'''
).strip()
)


def test_respect_future_import() -> None:
# Regression test for #385
node = parse(
dedent(
'''
"""module docstring"""
from __future__ import annotations
def foo() -> int:
return 1
'''
)
)
TypeguardTransformer().visit(node)
assert (
unparse(node)
== dedent(
'''
"""module docstring"""
from __future__ import annotations
from typeguard import TypeCheckMemo
from typeguard._functions import check_return_type
def foo() -> int:
memo = TypeCheckMemo(globals(), locals())
return check_return_type('foo', 1, int, memo)
'''
).strip()
)

0 comments on commit c11057a

Please sign in to comment.