Skip to content

Commit

Permalink
Rewrite Generator[..., None, None] to Iterator[...] (#110) (Fixes #4)
Browse files Browse the repository at this point in the history
  • Loading branch information
iyanuashiri authored and carljm committed Sep 19, 2018
1 parent a660a41 commit 6436198
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
11 changes: 11 additions & 0 deletions monkeytype/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,19 @@ def rewrite(self, typ):
return typ


class RewriteGenerator(TypeRewriter):
"""Returns an Iterator, if the send_type and return_type of a Generator is None"""

def rewrite_Generator(self, typ):
args = typ.__args__
if args[1] is NoneType and args[2] is NoneType:
return Iterator[args[0]]
return typ


DEFAULT_REWRITER = ChainedRewriter((
RemoveEmptyContainers(),
RewriteConfigDict(),
RewriteLargeUnion(),
RewriteGenerator(),
))
4 changes: 4 additions & 0 deletions monkeytype/typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,8 @@ class NoOpRewriter(TypeRewriter):
...


class RewriteGenerator(TypeRewriter):
...


DEFAULT_REWRITER: TypeRewriter = ...
22 changes: 22 additions & 0 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Tuple,
Type,
Union,
Generator,
)

import pytest
Expand All @@ -26,7 +27,9 @@
get_type,
get_type_str,
shrink_types,
RewriteGenerator,
)

from .util import Dummy


Expand Down Expand Up @@ -209,3 +212,22 @@ class G(C):
def test_rewrite(self, typ, expected):
rewritten = RewriteLargeUnion(2).rewrite(typ)
assert rewritten == expected


class TestRewriteGenerator:
@pytest.mark.parametrize(
'typ, expected',
[
# Should not rewrite
(Generator[int, None, int], Generator[int, None, int]),
# Should not rewrite
(Generator[int, int, None], Generator[int, int, None]),
# Should rewrite to Iterator[int]
(Generator[int, NoneType, NoneType], Iterator[int])
],
)
def test_rewrite(self, typ, expected):
rewritten = RewriteGenerator().rewrite(typ)
assert rewritten == expected

0 comments on commit 6436198

Please sign in to comment.