Skip to content

Commit

Permalink
Enhancement: ignore existing annotations (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
tyrinwu authored and carljm committed Jan 19, 2018
1 parent 0119311 commit 79c7914
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 12 deletions.
14 changes: 13 additions & 1 deletion monkeytype/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ def get_stub(args: argparse.Namespace, stdout: IO, stderr: IO) -> Optional[Stub]
rewriter = args.config.type_rewriter()
if args.disable_type_rewriting:
rewriter = NoOpRewriter()
stubs = build_module_stubs_from_traces(traces, args.include_unparsable_defaults, rewriter)
stubs = build_module_stubs_from_traces(
traces,
include_unparsable_defaults=args.include_unparsable_defaults,
ignore_existing_annotations=args.ignore_existing_annotations,
rewriter=rewriter,
)
if args.sample_count:
display_sample_count(traces, stderr)
return stubs.get(module, None)
Expand All @@ -106,6 +111,7 @@ class HandlerError(Exception):


def apply_stub_handler(args: argparse.Namespace, stdout: IO, stderr: IO) -> None:
args.ignore_existing_annotations = False
stub = get_stub(args, stdout, stderr)
if stub is None:
print(f'No traces found', file=stderr)
Expand Down Expand Up @@ -270,6 +276,12 @@ def main(argv: List[str], stdout: IO, stderr: IO) -> int:
default=False,
help='Print to stderr the numbers of traces stubs are based on'
)
stub_parser.add_argument(
"--ignore-existing-annotations",
action='store_true',
default=False,
help='Ignore existing annotations and generate stubs only from traces.',
)
stub_parser.set_defaults(handler=print_stub_handler)

args = parser.parse_args(argv)
Expand Down
30 changes: 19 additions & 11 deletions monkeytype/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,21 @@ def get_imports_for_signature(sig: inspect.Signature) -> ImportMap:
return imports


def update_signature_args(sig: inspect.Signature, arg_types: Dict[str, type], has_self: bool) -> inspect.Signature:
def update_signature_args(
sig: inspect.Signature,
arg_types: Dict[str, type],
has_self: bool,
ignore_existing_annotations: bool = False) -> inspect.Signature:
"""Update argument annotations with the supplied types"""
params = []
for arg_idx, name in enumerate(sig.parameters):
param = sig.parameters[name]
typ = arg_types.get(name)
# Don't touch pre-existing annotations and leave self un-annotated
if (typ is not None) and \
(param.annotation is inspect.Parameter.empty) and \
((not has_self) or (arg_idx != 0)):
typ = inspect.Parameter.empty if typ is None else typ
is_self = (has_self and arg_idx == 0)
annotated = param.annotation is not inspect.Parameter.empty
# Don't touch existing annotations unless ignore_existing_annotations
if not is_self and (ignore_existing_annotations or not annotated):
param = param.replace(annotation=typ)
params.append(param)
return sig.replace(parameters=params)
Expand All @@ -199,11 +204,12 @@ def update_signature_args(sig: inspect.Signature, arg_types: Dict[str, type], ha
def update_signature_return(
sig: inspect.Signature,
return_type: type = None,
yield_type: type = None) -> inspect.Signature:
yield_type: type = None,
ignore_existing_annotations: bool = False) -> inspect.Signature:
"""Update return annotation with the supplied types"""
anno = sig.return_annotation
# Dont' touch pre-existing annotations
if anno is not inspect.Signature.empty:
# Don't touch pre-existing annotations unless ignore_existing_annotations
if not ignore_existing_annotations and anno is not inspect.Signature.empty:
return sig
# NB: We cannot distinguish between functions that explicitly only
# return None and those that do so implicitly. In the case of generator
Expand Down Expand Up @@ -239,6 +245,7 @@ def get_updated_definition(
func: Callable,
traces: Iterable[CallTrace],
rewriter: Optional[TypeRewriter] = None,
ignore_existing_annotations: bool = False,
) -> FunctionDefinition:
"""Update the definition for func using the types collected in traces."""
if rewriter is None:
Expand All @@ -251,8 +258,8 @@ def get_updated_definition(
if yield_type is not None:
yield_type = rewriter.rewrite(yield_type)
sig = defn.signature
sig = update_signature_args(sig, arg_types, defn.has_self)
sig = update_signature_return(sig, return_type, yield_type)
sig = update_signature_args(sig, arg_types, defn.has_self, ignore_existing_annotations)
sig = update_signature_return(sig, return_type, yield_type, ignore_existing_annotations)
return FunctionDefinition(defn.module, defn.qualname, defn.kind, sig, defn.is_async)


Expand Down Expand Up @@ -563,6 +570,7 @@ def build_module_stubs(entries: Iterable[FunctionDefinition]) -> Dict[str, Modul
def build_module_stubs_from_traces(
traces: Iterable[CallTrace],
include_unparsable_defaults: bool = False,
ignore_existing_annotations: bool = False,
rewriter: Optional[TypeRewriter] = None
) -> Dict[str, ModuleStub]:
"""Given an iterable of call traces, build the corresponding stubs."""
Expand All @@ -571,7 +579,7 @@ def build_module_stubs_from_traces(
index[trace.func].add(trace)
defns = []
for func, traces in index.items():
defn = get_updated_definition(func, traces, rewriter=rewriter)
defn = get_updated_definition(func, traces, rewriter, ignore_existing_annotations)
if has_unparsable_defaults(defn.signature) and not include_unparsable_defaults:
logger.warning(
"Omitting stub for function %s.%s; it contains unparsable default values." +
Expand Down
20 changes: 20 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def func2(a, b):
pass


def func_anno(a: int, b: str) -> None:
pass


class LoudContextConfig(DefaultConfig):
@contextmanager
def cli_context(self, command: str) -> Iterator[None]:
Expand Down Expand Up @@ -78,6 +82,22 @@ def func2(a: int, b: int) -> None: ...
assert ret == 0


def test_print_stub_ignore_existing_annotations(store_data, stdout, stderr):
store, db_file = store_data
traces = [
CallTrace(func_anno, {'a': int, 'b': int}, int),
]
store.add(traces)
with mock.patch.dict(os.environ, {DefaultConfig.DB_PATH_VAR: db_file.name}):
ret = cli.main(['stub', func.__module__, '--ignore-existing-annotations'],
stdout, stderr)
expected = """def func_anno(a: int, b: int) -> int: ...
"""
assert stdout.getvalue() == expected
assert stderr.getvalue() == ''
assert ret == 0


def test_no_traces(store_data, stdout, stderr):
store, db_file = store_data
with mock.patch.dict(os.environ, {DefaultConfig.DB_PATH_VAR: db_file.name}):
Expand Down
41 changes: 41 additions & 0 deletions tests/test_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,34 @@ def test_update_class(self):
expected = Signature(parameters=[Parameter('cls', Parameter.POSITIONAL_OR_KEYWORD)])
assert sig == expected

def test_update_arg_ignore_existing_anno(self):
"""Update stubs only bases on traces."""
sig = Signature.from_callable(UpdateSignatureHelper.has_annos)
sig = update_signature_args(sig, {'a': str, 'b': bool}, has_self=False, ignore_existing_annotations=True)
params = [
Parameter('a', Parameter.POSITIONAL_OR_KEYWORD, annotation=str),
Parameter('b', Parameter.POSITIONAL_OR_KEYWORD, annotation=bool),
]
assert sig == Signature(parameters=params, return_annotation=int)

def test_update_self_ignore_existing_anno(self):
"""Don't annotate first arg of instance methods with ignore_existing_annotations"""
sig = Signature.from_callable(UpdateSignatureHelper.an_instance_method)
sig = update_signature_args(sig, {'self': UpdateSignatureHelper}, has_self=True,
ignore_existing_annotations=True)
expected = Signature(parameters=[Parameter('self', Parameter.POSITIONAL_OR_KEYWORD)])
assert sig == expected

def test_update_arg_ignore_existing_anno_None(self):
"""Update arg annotations from types"""
sig = Signature.from_callable(UpdateSignatureHelper.has_annos)
sig = update_signature_args(sig, {'a': None, 'b': int}, has_self=False, ignore_existing_annotations=True)
params = [
Parameter('a', Parameter.POSITIONAL_OR_KEYWORD, annotation=inspect.Parameter.empty),
Parameter('b', Parameter.POSITIONAL_OR_KEYWORD, annotation=int),
]
assert sig == Signature(parameters=params, return_annotation=int)


class TestUpdateSignatureReturn:
def test_update_return(self):
Expand All @@ -447,6 +475,19 @@ def test_update_return_with_anno(self):
)
assert sig == expected

def test_update_return_with_anno_ignored(self):
"""Leave existing return annotations alone"""
sig = Signature.from_callable(UpdateSignatureHelper.has_annos)
sig = update_signature_return(sig, return_type=str, ignore_existing_annotations=True)
expected = Signature(
parameters=[
Parameter('a', Parameter.POSITIONAL_OR_KEYWORD, annotation=int),
Parameter('b', Parameter.POSITIONAL_OR_KEYWORD)
],
return_annotation=str
)
assert sig == expected

def test_update_yield(self):
sig = Signature.from_callable(UpdateSignatureHelper.a_class_method)
sig = update_signature_return(sig, yield_type=int)
Expand Down

0 comments on commit 79c7914

Please sign in to comment.