Skip to content

Commit

Permalink
Handle TypedDicts nested within generic types. (#162)
Browse files Browse the repository at this point in the history
Fixes #159.

Earlier, we didn't pass in `max_typed_dict_size` recursively for all calls to `get_type`. This meant that `[{"a": 1}]` would produce the TypedDict class name even when `max_typed_dict_size` was set to 0. And since we didn't gather class stubs from nested generic types, the class definition stub wouldn't show up either.

* Pass `max_typed_dict_size` in all calls of `get_type`.

* Gather class stubs recursively using a `TypeRewriter` visitor for parameter types, return types, and yield types. Add `rewrite_TypedDict` method to `TypeRewriter`.

* Special-case `{}` so that we don't return an empty TypedDict regardless of `max_typed_dict_size`.
  • Loading branch information
pradeep90 committed Feb 11, 2020
1 parent a586173 commit 33af5bc
Show file tree
Hide file tree
Showing 6 changed files with 344 additions and 59 deletions.
6 changes: 5 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ Changelog
master
------

* Generate stubs for TypedDicts nested within generic types. Disable
TypedDicts completely when the max size is zero. Thanks Pradeep Kumar
Srinivasan. Fixes #159.


19.11.2
-------
Expand Down Expand Up @@ -47,7 +51,7 @@ master
19.5.0
------

* Mark ``monkeytype`` package as typed per PEP 561. Thanks Vasily Zakharov for
* Mark ``monkeytype`` package as typed per PEP 561. Thanks Vasily Zakharov for
the report.
* Add ``-v`` option; don't display individual traces that fail to decode unless
it is given.
Expand Down
91 changes: 63 additions & 28 deletions monkeytype/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,22 +503,59 @@ def __repr__(self) -> str:
tuple(self.function_stubs.values()),
tuple(self.attribute_stubs))

@staticmethod
def stubs_from_typed_dict(type_dict: type, class_name: str) -> List['ClassStub']:
"""Return a list of class stubs for all TypedDicts found within `type_dict`."""
assert is_anonymous_typed_dict(type_dict)
class_stubs = []

class ReplaceTypedDictsWithStubs(TypeRewriter):
"""Replace TypedDicts in a generic type with class stubs and store all the stubs."""

def __init__(self, class_name_hint: str) -> None:
self._class_name_hint = class_name_hint
self.stubs: List[ClassStub] = []

def _rewrite_container(self, cls: type, container: type) -> type:
"""Rewrite while using the index of the inner type as a class name hint.
Otherwise, Tuple[TypedDict(...), TypedDict(...)] would give the same
name for both the generated classes."""
if container.__module__ != "typing":
return container
args = getattr(container, '__args__', None)
if args is None:
return container
elif args == ((),): # special case of empty tuple `Tuple[()]`
elems: Tuple[Any, ...] = ()
else:
# Avoid adding a suffix for the first one so that
# single-element containers don't have a numeric suffix.
elems, stub_lists = zip(*[
self.rewrite_and_get_stubs(
elem,
class_name_hint=self._class_name_hint + (
'' if index == 0 else str(index + 1)))
for index, elem in enumerate(args)])
for stubs in stub_lists:
self.stubs.extend(stubs)
# Value of type "type" is not indexable.
return cls[elems] # type: ignore

def rewrite_TypedDict(self, typed_dict: type) -> type:
if not is_anonymous_typed_dict(typed_dict):
return super().rewrite_TypedDict(typed_dict)
class_name = get_typed_dict_class_name(self._class_name_hint)
attribute_stubs = []
for name, typ in type_dict.__annotations__.items():
if is_anonymous_typed_dict(typ):
_class_name = get_typed_dict_class_name(name)
class_stubs.extend(ClassStub.stubs_from_typed_dict(typ, _class_name))
typ = make_forward_ref(_class_name)
attribute_stubs.append(AttributeStub(name, typ))
class_stubs.append(ClassStub(name=f'{class_name}(TypedDict)',
function_stubs=[],
attribute_stubs=attribute_stubs))
return class_stubs
for name, typ in typed_dict.__annotations__.items():
rewritten_type, stubs = self.rewrite_and_get_stubs(typ, class_name_hint=name)
attribute_stubs.append(AttributeStub(name, rewritten_type))
self.stubs.extend(stubs)
self.stubs.append(ClassStub(name=f'{class_name}(TypedDict)',
function_stubs=[],
attribute_stubs=attribute_stubs))
return make_forward_ref(class_name)

@staticmethod
def rewrite_and_get_stubs(typ: type, class_name_hint: str) -> Tuple[type, List[ClassStub]]:
rewriter = ReplaceTypedDictsWithStubs(class_name_hint)
rewritten_type = rewriter.rewrite(typ)
return rewritten_type, rewriter.stubs


class ModuleStub(Stub):
Expand Down Expand Up @@ -604,23 +641,21 @@ def from_callable_and_traced_types(
typed_dict_class_stubs: List[ClassStub] = []
new_arg_types = {}
for name, typ in arg_types.items():
if is_anonymous_typed_dict(typ):
class_name = get_typed_dict_class_name(name)
typed_dict_class_stubs.extend(ClassStub.stubs_from_typed_dict(typ, class_name))
typ = make_forward_ref(class_name)
new_arg_types[name] = typ
rewritten_type, stubs = ReplaceTypedDictsWithStubs.rewrite_and_get_stubs(typ, class_name_hint=name)
new_arg_types[name] = rewritten_type
typed_dict_class_stubs.extend(stubs)

if return_type and is_anonymous_typed_dict(return_type):
if return_type:
# Replace the dot in a qualified name.
class_name = get_typed_dict_class_name(func.__qualname__.replace('.', '_'))
typed_dict_class_stubs.extend(ClassStub.stubs_from_typed_dict(return_type, class_name))
return_type = make_forward_ref(class_name)
class_name_hint = func.__qualname__.replace('.', '_')
return_type, stubs = ReplaceTypedDictsWithStubs.rewrite_and_get_stubs(return_type, class_name_hint)
typed_dict_class_stubs.extend(stubs)

if yield_type and is_anonymous_typed_dict(yield_type):
if yield_type:
# Replace the dot in a qualified name.
class_name = get_typed_dict_class_name(func.__qualname__.replace('.', '_') + 'Yield')
typed_dict_class_stubs.extend(ClassStub.stubs_from_typed_dict(yield_type, class_name))
yield_type = make_forward_ref(class_name)
class_name_hint = func.__qualname__.replace('.', '_') + 'Yield'
yield_type, stubs = ReplaceTypedDictsWithStubs.rewrite_and_get_stubs(yield_type, class_name_hint)
typed_dict_class_stubs.extend(stubs)

function = FunctionDefinition.from_callable(func)
signature = function.signature
Expand Down
33 changes: 23 additions & 10 deletions monkeytype/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,24 @@ def make_generator(yield_typ, send_typ, return_typ):
)


def get_dict_type(dct, max_typed_dict_size=None):
def get_dict_type(dct, max_typed_dict_size):
"""Return a TypedDict for `dct` if all the keys are strings.
Else, default to the union of the keys and of the values."""
if len(dct) == 0:
# Special-case this because returning an empty TypedDict is
# unintuitive, especially when you've "disabled" TypedDict generation
# by setting `max_typed_dict_size` to 0.
return Dict[Any, Any]
if (all(isinstance(k, str) for k in dct.keys())
and (max_typed_dict_size is None or len(dct) <= max_typed_dict_size)):
return TypedDict(DUMMY_TYPED_DICT_NAME, {k: get_type(v) for k, v in dct.items()})
return TypedDict(DUMMY_TYPED_DICT_NAME, {k: get_type(v, max_typed_dict_size) for k, v in dct.items()})
else:
key_type = shrink_types(get_type(k) for k in dct.keys())
val_type = shrink_types(get_type(v) for v in dct.values())
key_type = shrink_types(get_type(k, max_typed_dict_size) for k in dct.keys())
val_type = shrink_types(get_type(v, max_typed_dict_size) for v in dct.values())
return Dict[key_type, val_type]


def get_type(obj, max_typed_dict_size=None):
def get_type(obj, max_typed_dict_size):
"""Return the static type that would be used in a type hint"""
if isinstance(obj, type):
return Type[obj]
Expand All @@ -138,19 +143,19 @@ def get_type(obj, max_typed_dict_size=None):
return Iterator[Any]
typ = type(obj)
if typ is list:
elem_type = shrink_types(get_type(e) for e in obj)
elem_type = shrink_types(get_type(e, max_typed_dict_size) for e in obj)
return List[elem_type]
elif typ is set:
elem_type = shrink_types(get_type(e) for e in obj)
elem_type = shrink_types(get_type(e, max_typed_dict_size) for e in obj)
return Set[elem_type]
elif typ is dict:
return get_dict_type(obj, max_typed_dict_size)
elif typ is defaultdict:
key_type = shrink_types(get_type(k) for k in obj.keys())
val_type = shrink_types(get_type(v) for v in obj.values())
key_type = shrink_types(get_type(k, max_typed_dict_size) for k in obj.keys())
val_type = shrink_types(get_type(v, max_typed_dict_size) for v in obj.values())
return DefaultDict[key_type, val_type]
elif typ is tuple:
return Tuple[tuple(get_type(e) for e in obj)]
return Tuple[tuple(get_type(e, max_typed_dict_size) for e in obj)]
return typ


Expand Down Expand Up @@ -186,6 +191,12 @@ def rewrite_Set(self, st):
def rewrite_Tuple(self, tup):
return self._rewrite_container(Tuple, tup)

def rewrite_TypedDict(self, typed_dict):
return TypedDict(typed_dict.__name__,
{name: self.rewrite(typ)
for name, typ in typed_dict.__annotations__.items()},
total=typed_dict.__total__)

def rewrite_Union(self, union):
return self._rewrite_container(Union, union)

Expand All @@ -198,6 +209,8 @@ def rewrite(self, typ):
typname = 'Any'
elif is_union(typ):
typname = 'Union'
elif is_typed_dict(typ):
typname = 'TypedDict'
elif is_generic(typ):
typname = name_of_generic(typ)
else:
Expand Down
4 changes: 3 additions & 1 deletion monkeytype/typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def is_anonymous_typed_dict(typ: type) -> bool: ...
def shrink_types(types: Iterable[type]) -> type: ...


def get_dict_type(dct: Any) -> type: ...
def get_dict_type(dct: Any, max_typed_dict_size: Optional[int]) -> type: ...


def get_type(obj: Any, max_typed_dict_size: Optional[int]) -> type: ...
Expand All @@ -59,6 +59,8 @@ class TypeRewriter:

def rewrite_Union(self, union: _Union) -> type: ...

def rewrite_TypedDict(self, typed_dict: type) -> type: ...

def rewrite_Tuple(self, tup: Tuple) -> type: ...

def generic_rewrite(self, typ: type) -> type: ...
Expand Down

0 comments on commit 33af5bc

Please sign in to comment.