Skip to content

Commit

Permalink
Fixed forward references not resolving correctly in certain cases
Browse files Browse the repository at this point in the history
When decorating a function directly, local namespace in which the function was declared is now used to resolve forward references.
When decorating classes, the class dictionary is now used in the same manner.
This fix was inspired by Sjuul Janssen's PR (#39).

Fixes #74. Closes #39.
  • Loading branch information
agronholm committed Nov 10, 2019
1 parent e08e2a7 commit 51c6ab4
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ This library adheres to `Semantic Versioning 2.0 <https://semver.org/#semantic-v
- Fixed ``AttributeError`` when using ``@typechecked`` on a metaclass
- Fixed ``@typechecked`` compatibility with built-in function wrappers
- Fixed type checking generator wrappers not being recognized as generators
- Fixed resolution of forward references in certain cases (inner classes, function-local classes)

**2.6.0** (2019-11-06)

Expand Down
26 changes: 26 additions & 0 deletions tests/dummymodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,29 @@ def type_checked_classmethod(cls, x: int, y: int) -> int:
@staticmethod
def type_checked_staticmethod(x: int, y: int) -> int:
return x * y


def outer():
class Inner:
pass

def create_inner() -> 'Inner':
return Inner()

return create_inner


class Outer:
class Inner:
pass

def create_inner(self) -> 'Inner':
return Outer.Inner()

@classmethod
def create_inner_classmethod(cls) -> 'Inner':
return Outer.Inner()

@staticmethod
def create_inner_staticmethod() -> 'Inner':
return Outer.Inner()
21 changes: 21 additions & 0 deletions tests/test_importhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,24 @@ def test_dynamic_type_checking_func(dummymodule, argtype, returntype, error):
exc.match(error)
else:
assert dummymodule.dynamic_type_checking_func(4, argtype, returntype) == '4'


def test_class_in_function(dummymodule):
create_inner = dummymodule.outer()
retval = create_inner()
assert retval.__class__.__qualname__ == 'outer.<locals>.Inner'


def test_inner_class_method(dummymodule):
retval = dummymodule.Outer().create_inner()
assert retval.__class__.__qualname__ == 'Outer.Inner'


def test_inner_class_classmethod(dummymodule):
retval = dummymodule.Outer.create_inner_classmethod()
assert retval.__class__.__qualname__ == 'Outer.Inner'


def test_inner_class_staticmethod(dummymodule):
retval = dummymodule.Outer.create_inner_staticmethod()
assert retval.__class__.__qualname__ == 'Outer.Inner'
12 changes: 12 additions & 0 deletions tests/test_typeguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,18 @@ def func(x: int) -> None:
cache_info = func.__wrapped__.cache_info()
assert cache_info.hits == 1

def test_local_class(self):
@typechecked
class LocalClass:
class Inner:
pass

def create_inner(self) -> 'Inner':
return self.Inner()

retval = LocalClass().create_inner()
assert isinstance(retval, LocalClass.Inner)


class TestTypeChecker:
@pytest.fixture
Expand Down
29 changes: 17 additions & 12 deletions typeguard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ class _CallMemo:
__slots__ = ('func', 'func_name', 'signature', 'typevars', 'arguments', 'type_hints',
'is_generator')

def __init__(self, func: Callable, frame=None, args: tuple = None,
kwargs: Dict[str, Any] = None, forward_refs_policy=ForwardRefPolicy.ERROR):
def __init__(self, func: Callable, frame_locals: Optional[Dict[str, Any]] = None,
args: tuple = None, kwargs: Dict[str, Any] = None,
forward_refs_policy=ForwardRefPolicy.ERROR):
self.func = func
self.func_name = function_name(func)
self.signature = inspect.signature(func)
Expand All @@ -79,14 +80,14 @@ def __init__(self, func: Callable, frame=None, args: tuple = None,
if args is not None and kwargs is not None:
self.arguments = self.signature.bind(*args, **kwargs).arguments
else:
assert frame, 'frame must be specified if args or kwargs is None'
self.arguments = frame.f_locals.copy()
assert frame_locals is not None, 'frame must be specified if args or kwargs is None'
self.arguments = frame_locals

self.type_hints = _type_hints_map.get(func)
if self.type_hints is None:
while True:
try:
hints = get_type_hints(func)
hints = get_type_hints(func, localns=frame_locals)
except NameError as exc:
if forward_refs_policy is ForwardRefPolicy.ERROR:
raise
Expand Down Expand Up @@ -572,7 +573,7 @@ def check_return_type(retval, memo: Optional[_CallMemo] = None) -> bool:
except LookupError:
return True # This can happen with the Pydev/PyCharm debugger extension installed

memo = _CallMemo(func, frame)
memo = _CallMemo(func, frame.f_locals)

if 'return' in memo.type_hints:
try:
Expand Down Expand Up @@ -604,7 +605,7 @@ def check_argument_types(memo: Optional[_CallMemo] = None) -> bool:
except LookupError:
return True # This can happen with the Pydev/PyCharm debugger extension installed

memo = _CallMemo(func, frame)
memo = _CallMemo(func, frame.f_locals)

for argname, expected_type in memo.type_hints.items():
if argname != 'return' and argname in memo.arguments:
Expand Down Expand Up @@ -704,7 +705,7 @@ def typechecked(func: T_CallableOrType, *, always: bool = False) -> T_CallableOr
...


def typechecked(func=None, *, always=False):
def typechecked(func=None, *, always=False, _localns: Optional[Dict[str, Any]] = None):
"""
Perform runtime type checking on the arguments that are passed to the wrapped function.
Expand All @@ -724,18 +725,22 @@ class with this decorator.
return func

if func is None:
return partial(typechecked, always=always)
return partial(typechecked, always=always, _localns=_localns)

if isclass(func):
prefix = func.__qualname__ + '.'
for key in dir(func):
attr = getattr(func, key, None)
if callable(attr) and attr.__qualname__.startswith(prefix):
if getattr(attr, '__annotations__', None):
setattr(func, key, typechecked(attr, always=always))
setattr(func, key, typechecked(attr, always=always, _localns=func.__dict__))

return func

# Find the frame in which the function was declared, for resolving forward references later
if _localns is None:
_localns = sys._getframe(1).f_locals

# Find either the first Python wrapper or the actual function
python_func = inspect.unwrap(func, stop=lambda f: hasattr(f, '__code__'))

Expand All @@ -744,7 +749,7 @@ class with this decorator.
return func

def wrapper(*args, **kwargs):
memo = _CallMemo(python_func, args=args, kwargs=kwargs)
memo = _CallMemo(python_func, _localns, args=args, kwargs=kwargs)
check_argument_types(memo)
retval = func(*args, **kwargs)
check_return_type(retval, memo)
Expand Down Expand Up @@ -929,7 +934,7 @@ def __call__(self, frame, event: str, arg) -> None: # pragma: no cover

if func is not None and self.should_check_type(func):
memo = self._call_memos[frame] = _CallMemo(
func, frame, forward_refs_policy=self.annotation_policy)
func, frame.f_locals, forward_refs_policy=self.annotation_policy)
if memo.is_generator:
return_type_hint = memo.type_hints['return']
if return_type_hint is not None:
Expand Down

0 comments on commit 51c6ab4

Please sign in to comment.