Skip to content

Commit

Permalink
Enhencement: stub --diff (#59)
Browse files Browse the repository at this point in the history
* Add test_stub_diff

* Add cmd option --diff to stub_parser

* restructure test cases

* Add get_diff function

* Follow 3-line break convention

* Add test case : longer function

* rewrite get_diff to support diff on longer functions

* Refactor print_stub_handler
  • Loading branch information
tyrinwu authored and mpage committed Jan 26, 2018
1 parent 178e889 commit a768d0e
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 5 deletions.
39 changes: 34 additions & 5 deletions monkeytype/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import argparse
import collections
import difflib
import importlib
import inspect
import os.path
Expand Down Expand Up @@ -137,12 +138,34 @@ def apply_stub_handler(args: argparse.Namespace, stdout: IO, stderr: IO) -> None
raise HandlerError(f"Failed applying stub with retype:\n{cpe.stdout.decode('utf-8')}")


def print_stub_handler(args: argparse.Namespace, stdout: IO, stderr: IO) -> None:
def get_diff(args: argparse.Namespace, stdout: IO, stderr: IO) -> Optional[str]:
args.ignore_existing_annotations = False
stub = get_stub(args, stdout, stderr)
if stub is None:
print(f'No traces found', file=stderr)
return
print(stub.render(), file=stdout)
args.ignore_existing_annotations = True
stub_ignore_anno = get_stub(args, stdout, stderr)
if stub is None or stub_ignore_anno is None:
return None
diff = []
seq1 = (s + "\n" for s in stub.render().split("\n\n\n"))
seq2 = (s + "\n" for s in stub_ignore_anno.render().split("\n\n\n"))
for stub1, stub2 in zip(seq1, seq2):
if stub1 != stub2:
stub_diff = "".join(difflib.ndiff(stub1.splitlines(keepends=True), stub2.splitlines(keepends=True)))
diff.append(stub_diff[:-1])
return "\n\n\n".join(diff)


def print_stub_handler(args: argparse.Namespace, stdout: IO, stderr: IO) -> None:
output, file = None, stdout
if args.diff:
output = get_diff(args, stdout, stderr)
else:
stub = get_stub(args, stdout, stderr)
if stub is not None:
output = stub.render()
if output is None:
output, file = 'No traces found', stderr
print(output, file=file)


def run_handler(args: argparse.Namespace, stdout: IO, stderr: IO) -> None:
Expand Down Expand Up @@ -282,6 +305,12 @@ def main(argv: List[str], stdout: IO, stderr: IO) -> int:
default=False,
help='Ignore existing annotations and generate stubs only from traces.',
)
stub_parser.add_argument(
"--diff",
action='store_true',
default=False,
help='Compare stubs generated with and without considering existing annotations.',
)
stub_parser.set_defaults(handler=print_stub_handler)

args = parser.parse_args(argv)
Expand Down
73 changes: 73 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ def func_anno(a: int, b: str) -> None:
pass


def func_anno2(a: str, b: str) -> None:
pass


def super_long_function_with_long_params(
long_param1: str,
long_param2: str,
long_param3: str,
long_param4: str,
long_param5: str,
) -> None:
pass


class LoudContextConfig(DefaultConfig):
@contextmanager
def cli_context(self, command: str) -> Iterator[None]:
Expand Down Expand Up @@ -98,6 +112,65 @@ def test_print_stub_ignore_existing_annotations(store_data, stdout, stderr):
assert ret == 0


def test_get_diff(store_data, stdout, stderr):
store, db_file = store_data
traces = [
CallTrace(func_anno, {'a': int, 'b': int}, int),
CallTrace(func_anno2, {'a': str, 'b': str}, None),
]
store.add(traces)
with mock.patch.dict(os.environ, {DefaultConfig.DB_PATH_VAR: db_file.name}):
ret = cli.main(['stub', func.__module__, '--diff'], stdout, stderr)
expected = """- def func_anno(a: int, b: str) -> None: ...
? ^ - ^^ ^
+ def func_anno(a: int, b: int) -> int: ...
? ^^ ^ ^
"""
assert stdout.getvalue() == expected
assert stderr.getvalue() == ''
assert ret == 0


def test_get_diff2(store_data, stdout, stderr):
store, db_file = store_data
traces = [
CallTrace(super_long_function_with_long_params, {
'long_param1': str,
'long_param2': str,
'long_param3': int,
'long_param4': str,
'long_param5': int,
}, None),
CallTrace(func_anno, {'a': int, 'b': int}, int),
]
store.add(traces)
with mock.patch.dict(os.environ, {DefaultConfig.DB_PATH_VAR: db_file.name}):
ret = cli.main(['stub', func.__module__, '--diff'], stdout, stderr)
expected = """- def func_anno(a: int, b: str) -> None: ...
? ^ - ^^ ^
+ def func_anno(a: int, b: int) -> int: ...
? ^^ ^ ^
def super_long_function_with_long_params(
long_param1: str,
long_param2: str,
- long_param3: str,
? ^ -
+ long_param3: int,
? ^^
long_param4: str,
- long_param5: str
? ^ -
+ long_param5: int
? ^^
) -> None: ...
"""
assert stdout.getvalue() == expected
assert stderr.getvalue() == ''
assert ret == 0


def test_no_traces(store_data, stdout, stderr):
store, db_file = store_data
with mock.patch.dict(os.environ, {DefaultConfig.DB_PATH_VAR: db_file.name}):
Expand Down

0 comments on commit a768d0e

Please sign in to comment.