Skip to content

Commit

Permalink
Fixed type checking of generators annotated as Iterable/AsyncIterable
Browse files Browse the repository at this point in the history
Closes #76.
  • Loading branch information
agronholm committed Aug 28, 2019
1 parent 40d78cb commit 4761514
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 33 deletions.
5 changes: 5 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ Version history

This library adheres to `Semantic Versioning 2.0 <https://semver.org/#semantic-versioning-200>`_.

**UNRELEASED**

- Fixed incompatibility between annotated ``Iterable`` or ``AsyncIterable`` return types and
generator/async generator functions

**2.5.0** (2019-08-26)

- Added yield type checking via ``TypeChecker`` for regular generators
Expand Down
18 changes: 14 additions & 4 deletions tests/test_typeguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,9 +886,15 @@ def method(self) -> int:
pytest.raises(TypeError, Foo.classmethod).match(pattern)
pytest.raises(TypeError, Foo().method).match(pattern)

def test_generator(self):
@pytest.mark.parametrize('annotation', [
Generator[int, str, List[str]],
Generator,
Iterable[int],
Iterable
], ids=['generator', 'bare_generator', 'iterable', 'bare_iterable'])
def test_generator(self, annotation):
@typechecked
def genfunc() -> Generator[int, str, List[str]]:
def genfunc() -> annotation:
val1 = yield 2
val2 = yield 3
val3 = yield 4
Expand All @@ -903,9 +909,13 @@ def genfunc() -> Generator[int, str, List[str]]:

assert exc.value.value == ['2', '3', '4']

def test_generator_bad_yield(self):
@pytest.mark.parametrize('annotation', [
Generator[int, str, None],
Iterable[int],
], ids=['generator', 'iterable'])
def test_generator_bad_yield(self, annotation):
@typechecked
def genfunc() -> Generator[int, str, None]:
def genfunc() -> annotation:
yield 'foo'

gen = genfunc()
Expand Down
36 changes: 25 additions & 11 deletions tests/test_typeguard_py36.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from typing import AsyncGenerator
from typing import AsyncGenerator, AsyncIterable

import pytest

from typeguard import TypeChecker, typechecked


class TestTypeChecked:
def test_async_generator(self):
@pytest.mark.parametrize('annotation', [
AsyncGenerator[int, str],
AsyncIterable[int]
], ids=['generator', 'iterable'])
def test_async_generator(self, annotation):
async def run_generator():
@typechecked
async def genfunc() -> AsyncGenerator[int, str]:
async def genfunc() -> annotation:
values.append((yield 2))
values.append((yield 3))
values.append((yield 4))
Expand All @@ -32,9 +36,13 @@ async def genfunc() -> AsyncGenerator[int, str]:

assert values == ['2', '3', '4']

def test_async_generator_bad_yield(self):
@pytest.mark.parametrize('annotation', [
AsyncGenerator[int, str],
AsyncIterable[int]
], ids=['generator', 'iterable'])
def test_async_generator_bad_yield(self, annotation):
@typechecked
async def genfunc() -> AsyncGenerator[int, str]:
async def genfunc() -> annotation:
yield 'foo'

gen = genfunc()
Expand All @@ -57,18 +65,24 @@ async def genfunc() -> AsyncGenerator[int, str]:
exc.match('type of value sent to generator must be str; got int instead')


class TestTypeChecker:
@staticmethod
async def asyncgenfunc() -> AsyncGenerator[int, None]:
yield 1
async def asyncgenfunc() -> AsyncGenerator[int, None]:
yield 1


async def asyncgeniterablefunc() -> AsyncIterable[int]:
yield 1


class TestTypeChecker:
@pytest.fixture
def checker(self):
return TypeChecker(__name__)

def test_async_generator(self, checker):
@pytest.mark.parametrize('func', [asyncgenfunc, asyncgeniterablefunc],
ids=['generator', 'iterable'])
def test_async_generator(self, checker, func):
"""Make sure that the type checker does not complain about the None return value."""
with checker, pytest.warns(None) as record:
self.asyncgenfunc()
func()

assert len(record) == 0
43 changes: 25 additions & 18 deletions typeguard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from types import CodeType, FunctionType
from typing import (
Callable, Any, Union, Dict, List, TypeVar, Tuple, Set, Sequence, get_type_hints, TextIO,
Optional, IO, BinaryIO, Type, Generator, overload)
Optional, IO, BinaryIO, Type, Generator, overload, Iterable, AsyncIterable)
from warnings import warn
from weakref import WeakKeyDictionary, WeakValueDictionary

Expand All @@ -30,8 +30,11 @@
AsyncGenerator = None

try:
from inspect import isasyncgenfunction
from inspect import isasyncgenfunction, isasyncgen
except ImportError:
def isasyncgen(obj):
return False

def isasyncgenfunction(func):
return False

Expand Down Expand Up @@ -593,10 +596,12 @@ def check_argument_types(memo: Optional[_CallMemo] = None) -> bool:

class TypeCheckedGenerator:
def __init__(self, wrapped: Generator, memo: _CallMemo):
rtype_args = memo.type_hints['return'].__args__
self.__wrapped = wrapped
self.__memo = memo
self.__yield_type, self.__send_type, self.__return_type = \
memo.type_hints['return'].__args__
self.__yield_type = rtype_args[0]
self.__send_type = rtype_args[1] if len(rtype_args) > 1 else Any
self.__return_type = rtype_args[2] if len(rtype_args) > 2 else Any
self.__initialized = False

def __iter__(self):
Expand Down Expand Up @@ -626,9 +631,11 @@ def send(self, obj):

class TypeCheckedAsyncGenerator:
def __init__(self, wrapped: AsyncGenerator, memo: _CallMemo):
rtype_args = memo.type_hints['return'].__args__
self.__wrapped = wrapped
self.__memo = memo
self.__yield_type, self.__send_type = memo.type_hints['return'].__args__
self.__yield_type = rtype_args[0]
self.__send_type = rtype_args[1] if len(rtype_args) > 1 else Any
self.__initialized = False

async def __aiter__(self):
Expand Down Expand Up @@ -702,10 +709,19 @@ def wrapper(*args, **kwargs):
check_argument_types(memo)
retval = func(*args, **kwargs)
check_return_type(retval, memo)
if inspect.isgenerator(retval):
return TypeCheckedGenerator(retval, memo)
else:
return retval

# If a generator is returned, wrap it if its yield/send/return types can be checked
if inspect.isgenerator(retval) or isasyncgen(retval):
return_type = memo.type_hints.get('return')
origin = getattr(return_type, '__origin__')
if origin in (Generator, collections.abc.Generator,
Iterable, collections.abc.Iterable):
return TypeCheckedGenerator(retval, memo)
elif origin is not None and origin in (AsyncGenerator, collections.abc.AsyncGenerator,
AsyncIterable, collections.abc.AsyncIterable):
return TypeCheckedAsyncGenerator(retval, memo)

return retval

async def async_wrapper(*args, **kwargs):
memo = _CallMemo(func, args=args, kwargs=kwargs)
Expand All @@ -714,18 +730,9 @@ async def async_wrapper(*args, **kwargs):
check_return_type(retval, memo)
return retval

def asyncgen_wrapper(*args, **kwargs):
memo = _CallMemo(func, args=args, kwargs=kwargs)
check_argument_types(memo)
retval = func(*args, **kwargs)
return TypeCheckedAsyncGenerator(retval, memo)

if inspect.iscoroutinefunction(func):
if func.__code__ is not async_wrapper.__code__:
return wraps(func)(async_wrapper)
elif isasyncgenfunction(func):
if func.__code__ is not asyncgen_wrapper.__code__:
return wraps(func)(asyncgen_wrapper)
else:
if func.__code__ is not wrapper.__code__:
return wraps(func)(wrapper)
Expand Down

0 comments on commit 4761514

Please sign in to comment.