Skip to content

Commit

Permalink
Eliminated the need to pass the current function reference to check_a…
Browse files Browse the repository at this point in the history
…rgument_types()
  • Loading branch information
agronholm committed Jan 2, 2016
1 parent 9394c1e commit 5bfe1f6
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 46 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.rst
@@ -1,6 +1,13 @@
Version history
===============

1.1.0
-----

- Eliminated the need to pass a reference to the currently executing function to
``check_argument_types()``


1.0.2
-----

Expand Down
6 changes: 3 additions & 3 deletions README.rst
Expand Up @@ -15,7 +15,7 @@ Python 3) is supported. See below for details.

There are two principal ways to use type checking, each with its pros and cons:

#. calling ``check_type_arguments()`` from within the function body:
#. calling ``check_argument_types()`` from within the function body:
debugger friendly but cannot check the type of the return value
#. decorating the function with ``@typechecked``:
can check the type of the return value but adds an extra frame to the call stack for every call
Expand All @@ -30,14 +30,14 @@ Type checks can be fairly expensive so it is recommended to run Python in "optim
type checks in production. The optimized mode will disable the type checks, by virtue of removing
all ``assert`` statements and setting the ``__debug__`` constant to ``False``.

Using ``check_type_arguments()``:
Using ``check_argument_types()``:

.. code-block:: python
from typeguard import check_argument_types
def some_function(a: int, b: float, c: str, *args: str):
assert check_argument_types(some_function)
assert check_argument_types()
...
Using ``@typechecked``:
Expand Down
70 changes: 35 additions & 35 deletions tests/test_typeguard.py
Expand Up @@ -38,13 +38,13 @@ def foo(a: int):

def test_any_type(self):
def foo(a: Any):
assert check_argument_types(foo)
assert check_argument_types()

foo('aa')

def test_callable(self):
def foo(a: Callable[..., int]):
assert check_argument_types(foo)
assert check_argument_types()

def some_callable() -> int:
pass
Expand All @@ -53,7 +53,7 @@ def some_callable() -> int:

def test_callable_exact_arg_count(self):
def foo(a: Callable[[int, str], int]):
assert check_argument_types(foo)
assert check_argument_types()

def some_callable(x: int, y: str) -> int:
pass
Expand All @@ -62,14 +62,14 @@ def some_callable(x: int, y: str) -> int:

def test_callable_bad_type(self):
def foo(a: Callable[..., int]):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, 5)
assert str(exc.value) == 'argument a must be a callable'

def test_callable_too_few_arguments(self):
def foo(a: Callable[[int, str], int]):
assert check_argument_types(foo)
assert check_argument_types()

def some_callable(x: int) -> int:
pass
Expand All @@ -81,7 +81,7 @@ def some_callable(x: int) -> int:

def test_callable_too_many_arguments(self):
def foo(a: Callable[[int, str], int]):
assert check_argument_types(foo)
assert check_argument_types()

def some_callable(x: int, y: str, z: float) -> int:
pass
Expand All @@ -93,49 +93,49 @@ def some_callable(x: int, y: str, z: float) -> int:

def test_dict(self):
def foo(a: Dict[str, int]):
assert check_argument_types(foo)
assert check_argument_types()

foo({'x': 2})

def test_dict_bad_type(self):
def foo(a: Dict[str, int]):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, 5)
assert str(exc.value) == (
'type of argument a must be a dict; got int instead')

def test_dict_bad_key_type(self):
def foo(a: Dict[str, int]):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, {1: 2})
assert str(exc.value) == 'type of keys of argument a must be str; got int instead'

def test_dict_bad_value_type(self):
def foo(a: Dict[str, int]):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, {'x': 'a'})
assert str(exc.value) == "type of argument a['x'] must be int; got str instead"

def test_list(self):
def foo(a: List[int]):
assert check_argument_types(foo)
assert check_argument_types()

foo([1, 2])

def test_list_bad_type(self):
def foo(a: List[int]):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, 5)
assert str(exc.value) == (
'type of argument a must be a list; got int instead')

def test_list_bad_element(self):
def foo(a: List[int]):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, [1, 2, 'bb'])
assert str(exc.value) == (
Expand All @@ -145,21 +145,21 @@ def foo(a: List[int]):
ids=['tuple', 'list', 'str'])
def test_sequence(self, value):
def foo(a: Sequence[str]):
assert check_argument_types(foo)
assert check_argument_types()

foo(value)

def test_sequence_bad_type(self):
def foo(a: Sequence[int]):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, 5)
assert str(exc.value) == (
'type of argument a must be a sequence; got int instead')

def test_sequence_bad_element(self):
def foo(a: Sequence[int]):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, [1, 2, 'bb'])
assert str(exc.value) == (
Expand All @@ -168,50 +168,50 @@ def foo(a: Sequence[int]):
@pytest.mark.parametrize('value', [set(), {6}])
def test_set(self, value):
def foo(a: Set[int]):
assert check_argument_types(foo)
assert check_argument_types()

foo(value)

def test_set_bad_type(self):
def foo(a: Set[int]):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, 5)
assert str(exc.value) == 'type of argument a must be a set; got int instead'

def test_set_bad_element(self):
def foo(a: Set[int]):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, {1, 2, 'bb'})
assert str(exc.value) == (
'type of elements of argument a must be int; got str instead')

def test_tuple(self):
def foo(a: Tuple[int, int]):
assert check_argument_types(foo)
assert check_argument_types()

foo((1, 2))

def test_tuple_bad_type(self):
def foo(a: Tuple[int]):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, 5)
assert str(exc.value) == (
'type of argument a must be a tuple; got int instead')

def test_tuple_wrong_number_of_elements(self):
def foo(a: Tuple[int, str]):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, (1, 'aa', 2))
assert str(exc.value) == ('argument a has wrong number of elements (expected 2, got 3 '
'instead)')

def test_tuple_bad_element(self):
def foo(a: Tuple[int, str]):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, (1, 2))
assert str(exc.value) == (
Expand All @@ -220,14 +220,14 @@ def foo(a: Tuple[int, str]):
@pytest.mark.parametrize('value', [6, 'aa'])
def test_union(self, value):
def foo(a: Union[str, int]):
assert check_argument_types(foo)
assert check_argument_types()

foo(value)

@pytest.mark.parametrize('value', [6.5, b'aa'])
def test_union_fail(self, value):
def foo(a: Union[str, int]):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, value)
assert str(exc.value) == (
Expand All @@ -242,15 +242,15 @@ def test_typevar_constraints(self, values):
T = TypeVar('T', int, str)

def foo(a: T, b: T):
assert check_argument_types(foo)
assert check_argument_types()

foo(*values)

def test_typevar_constraints_fail(self):
T = TypeVar('T', int, str)

def foo(a: T, b: T):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, 2.5, 'aa')
assert str(exc.value) == 'type of argument a must be one of (int, str); got float instead'
Expand All @@ -259,15 +259,15 @@ def test_typevar_bound(self):
T = TypeVar('T', bound=Parent)

def foo(a: T, b: T):
assert check_argument_types(foo)
assert check_argument_types()

foo(Child(), Child())

def test_typevar_bound_fail(self):
T = TypeVar('T', bound=Child)

def foo(a: T, b: T):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, Parent(), Parent())
assert str(exc.value) == ('argument a must be an instance of test_typeguard.Child; got '
Expand All @@ -277,7 +277,7 @@ def test_typevar_invariant_fail(self):
T = TypeVar('T', int, str)

def foo(a: T, b: T):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, 2, 3.6)
assert str(exc.value) == 'type of argument b must be exactly int; got float instead'
Expand All @@ -286,15 +286,15 @@ def test_typevar_covariant(self):
T = TypeVar('T', covariant=True)

def foo(a: T, b: T):
assert check_argument_types(foo)
assert check_argument_types()

foo(Parent(), Child())

def test_typevar_covariant_fail(self):
T = TypeVar('T', covariant=True)

def foo(a: T, b: T):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, Child(), Parent())
assert str(exc.value) == ('argument b must be an instance of test_typeguard.Child; got '
Expand All @@ -304,15 +304,15 @@ def test_typevar_contravariant(self):
T = TypeVar('T', contravariant=True)

def foo(a: T, b: T):
assert check_argument_types(foo)
assert check_argument_types()

foo(Child(), Parent())

def test_typevar_contravariant_fail(self):
T = TypeVar('T', contravariant=True)

def foo(a: T, b: T):
assert check_argument_types(foo)
assert check_argument_types()

exc = pytest.raises(TypeError, foo, Parent(), Child())
assert str(exc.value) == ('type of argument b must be test_typeguard.Parent or one of its '
Expand Down Expand Up @@ -344,7 +344,7 @@ def test_default_argument_type(self, values):
"""
def foo(a: str=1, b: float='x', c: str=None):
assert check_argument_types(foo)
assert check_argument_types()

foo(*values)

Expand Down
23 changes: 15 additions & 8 deletions typeguard.py
@@ -1,10 +1,11 @@
from typing import Callable, Any, Union, Dict, List, TypeVar, Tuple, Set, Sequence, get_type_hints
from collections import OrderedDict
from warnings import warn
from functools import partial, wraps
from weakref import WeakKeyDictionary
from typing import Callable, Any, Union, Dict, List, TypeVar, Tuple, Set, Sequence, get_type_hints
import collections
import inspect
import gc

__all__ = ('typechecked', 'check_argument_types')

Expand Down Expand Up @@ -206,7 +207,7 @@ def check_type(argname: str, value, expected_type, typevars_memo: Dict[TypeVar,
format(argname, qualified_name(expected_type), qualified_name(type(value))))


def check_argument_types(func: Callable, args: tuple=None, kwargs: Dict[str, Any]=None,
def check_argument_types(func: Callable=None, args: tuple=None, kwargs: Dict[str, Any]=None,
typevars_memo: Dict[TypeVar, type]=None) -> bool:
"""
Check that the argument values match the annotated types.
Expand All @@ -221,13 +222,21 @@ def check_argument_types(func: Callable, args: tuple=None, kwargs: Dict[str, Any
:raises TypeError: if there is an argument type mismatch
"""
# Unwrap the function
while hasattr(func, '__wrapped__'):
func = func.__wrapped__
frame = inspect.currentframe().f_back
if func:
# Unwrap the function
while hasattr(func, '__wrapped__'):
func = func.__wrapped__
else:
# No callable provided, so fish it out of the garbage collector
for obj in gc.get_referrers(frame.f_code):
if inspect.isfunction(obj):
func = obj
break

spec = inspect.getfullargspec(func)
type_hints = _type_hints_map.get(func)
if type_hints is None:
spec = inspect.getfullargspec(func)
hints = get_type_hints(func)
type_hints = _type_hints_map[func] = OrderedDict(
(arg, hints[arg]) for arg in spec.args + ['return'] if arg in hints)
Expand All @@ -239,11 +248,9 @@ def check_argument_types(func: Callable, args: tuple=None, kwargs: Dict[str, Any
type_hints[argname] = Union[hints[argname], type(default_value)]

if args is None or kwargs is None:
frame = inspect.currentframe().f_back
argvalues = frame.f_locals
elif isinstance(args, tuple) and isinstance(kwargs, dict):
argvalues = kwargs.copy()
spec = inspect.getfullargspec(func)
pos_values = dict(zip(spec.args, args))
argvalues.update(pos_values)
else:
Expand Down

0 comments on commit 5bfe1f6

Please sign in to comment.