Skip to content

Commit

Permalink
Allow apply to overwrite existing annotations using CLI flag. (#183)
Browse files Browse the repository at this point in the history
It should be safe to remove the `args.existing_annotation_strategy = ExistingAnnotationStrategy.REPLICATE` line because that is now the default value in `args`.

This needed the latest ApplyTypeAnnotationsVisitor from libcst. Set the minimum libcst version to 0.3.5.

Note: I ran `pipenv update -d` to update the packages. This ended up upgrading pyflakes, which now errors on `List[make_forward_ref('Foo')]` saying "undefined name 'Foo'", probably because it's a literal non-class string within a container. So, I've temporarily silenced these errors in .flake8.

Fixed f-strings that didn't have expressions inside.

Co-authored-by: Pradeep Kumar Srinivasan <pradeepkumars@fb.com>
  • Loading branch information
pradeep90 and pradeep90 committed May 16, 2020
1 parent bfd7278 commit 47f4237
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 170 deletions.
3 changes: 3 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
[flake8]
max-line-length = 120
per-file-ignores =
tests/test_stubs.py:F821
tests/test_typing.py:F821
291 changes: 132 additions & 159 deletions Pipfile.lock

Large diffs are not rendered by default.

25 changes: 21 additions & 4 deletions monkeytype/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,18 @@ class HandlerError(Exception):
pass


def apply_stub_using_libcst(stub: str, source: str) -> str:
def apply_stub_using_libcst(
stub: str, source: str, overwrite_existing_annotations: bool
) -> str:
try:
stub_module = parse_module(stub)
source_module = parse_module(source)
context = CodemodContext()
ApplyTypeAnnotationsVisitor.add_stub_to_context(context, stub_module)
ApplyTypeAnnotationsVisitor.store_stub_in_context(
context,
stub_module,
overwrite_existing_annotations,
)
transformer = ApplyTypeAnnotationsVisitor(context)
transformed_source_module = transformer.transform_module(source_module)
except Exception as exception:
Expand All @@ -152,15 +158,18 @@ def apply_stub_using_libcst(stub: str, source: str) -> str:


def apply_stub_handler(args: argparse.Namespace, stdout: IO, stderr: IO) -> None:
args.existing_annotation_strategy = ExistingAnnotationStrategy.REPLICATE
stub = get_stub(args, stdout, stderr)
if stub is None:
complain_about_no_traces(args, stderr)
return
module = args.module_path[0]
mod = importlib.import_module(module)
source_path = Path(inspect.getfile(mod))
source_with_types = apply_stub_using_libcst(stub=stub.render(), source=source_path.read_text())
source_with_types = apply_stub_using_libcst(
stub=stub.render(),
source=source_path.read_text(),
overwrite_existing_annotations=args.existing_annotation_strategy == ExistingAnnotationStrategy.IGNORE,
)
source_path.write_text(source_with_types)
print(source_with_types, file=stdout)

Expand Down Expand Up @@ -297,6 +306,14 @@ def main(argv: List[str], stdout: IO, stderr: IO) -> int:
default=False,
help='Print to stderr the numbers of traces stubs are based on'
)
apply_parser.add_argument(
"--ignore-existing-annotations",
action="store_const",
dest="existing_annotation_strategy",
default=ExistingAnnotationStrategy.REPLICATE,
const=ExistingAnnotationStrategy.IGNORE,
help="Ignore existing annotations when applying stubs from traces.",
)
apply_parser.set_defaults(handler=apply_stub_handler)

stub_parser = subparsers.add_parser(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_version(root_path):
]
},
python_requires='>=3.6',
install_requires=['mypy_extensions', 'libcst>=0.3.4'],
install_requires=['mypy_extensions', 'libcst>=0.3.5'],
classifiers=[
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
Expand Down
34 changes: 32 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,11 @@ def no_stub(a):
def uses_union(d: Union[int, bool]) -> None:
return None
"""
assert cli.apply_stub_using_libcst(textwrap.dedent(stub), textwrap.dedent(source)) == textwrap.dedent(expected)
assert cli.apply_stub_using_libcst(
textwrap.dedent(stub),
textwrap.dedent(source),
overwrite_existing_annotations=False,
) == textwrap.dedent(expected)


def test_apply_stub_using_libcst__exception(stdout, stderr):
Expand All @@ -419,4 +423,30 @@ def my_test_function(
def my_test_function(a: int, b: str) -> bool: ...
"""
with pytest.raises(cli.HandlerError):
cli.apply_stub_using_libcst(textwrap.dedent(stub), textwrap.dedent(erroneous_source))
cli.apply_stub_using_libcst(
textwrap.dedent(stub),
textwrap.dedent(erroneous_source),
overwrite_existing_annotations=False,
)


def test_apply_stub_using_libcst__overwrite_existing_annotations():
source = """
def has_annotations(x: int) -> str:
return 1 in x
"""
stub = """
from typing import List
def has_annotations(x: List[int]) -> bool: ...
"""
expected = """
from typing import List
def has_annotations(x: List[int]) -> bool:
return 1 in x
"""
assert cli.apply_stub_using_libcst(
textwrap.dedent(stub),
textwrap.dedent(source),
overwrite_existing_annotations=True,
) == textwrap.dedent(expected)
8 changes: 4 additions & 4 deletions tests/test_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ class TestReplaceTypedDictsWithStubs:
AttributeStub(name='b', typ=str),
]),
ClassStub(
name=f'FooBarTypedDict__RENAME_ME__NonTotal(FooBarTypedDict__RENAME_ME__, total=False)',
name='FooBarTypedDict__RENAME_ME__NonTotal(FooBarTypedDict__RENAME_ME__, total=False)',
function_stubs=[],
attribute_stubs=[
AttributeStub(name='c', typ=int),
Expand Down Expand Up @@ -723,7 +723,7 @@ def test_render_typed_dict_in_list(self):
'',
'',
'class Dummy:',
f' def an_instance_method(self, foo: List[\'FooTypedDict__RENAME_ME__\'], bar: int) -> int: ...'])
' def an_instance_method(self, foo: List[\'FooTypedDict__RENAME_ME__\'], bar: int) -> int: ...'])
self.maxDiff = None
assert build_module_stubs(entries)['tests.util'].render() == expected

Expand All @@ -747,12 +747,12 @@ def test_render_typed_dict_base_and_subclass(self):
' a: int',
'',
'',
f'class FooTypedDict__RENAME_ME__NonTotal(FooTypedDict__RENAME_ME__, total=False):',
'class FooTypedDict__RENAME_ME__NonTotal(FooTypedDict__RENAME_ME__, total=False):',
' b: str',
'',
'',
'class Dummy:',
f' def an_instance_method(self, foo: \'FooTypedDict__RENAME_ME__NonTotal\', bar: int) -> int: ...'])
' def an_instance_method(self, foo: \'FooTypedDict__RENAME_ME__NonTotal\', bar: int) -> int: ...'])
assert build_module_stubs(entries)['tests.util'].render() == expected


Expand Down

0 comments on commit 47f4237

Please sign in to comment.