Skip to content

Commit

Permalink
Shrink non-total TypedDicts. (#168)
Browse files Browse the repository at this point in the history
Shrink TypedDicts by looking for fields that are present in all the TypedDicts and fields that are present only in some.

When all fields are optional, generate a non-total TypedDict class stub.
When there are both required and optional fields, generate a total base class and a non-total subclass.

Design choices:

+ If a key has different value types in the traced types, then we shrink those types to get its value type. This should not lead to large types because we shrink large union types to Any.

+ If the resulting TypedDict is larger than max_typed_dict_size, then fall back to `Dict[str, ...]`. However, preserve any nested anonymous TypedDicts.

+ Represent required and optional fields as nested TypedDicts.

+ Add `types_equal` to allow equality for `List[TypedDicts(...)]` on 3.6.

Co-authored-by: Pradeep Kumar Srinivasan <pradeepkumars@fb.com>
  • Loading branch information
pradeep90 and pradeep90 committed Mar 29, 2020
1 parent 3b77c55 commit 56a688e
Show file tree
Hide file tree
Showing 12 changed files with 682 additions and 201 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ master

* Remove ``stringcase`` dependency, just hand-roll ``pascal_case`` function.

* Shrink dictionary traces with required and optional keys to get non-total
TypedDict class declarations. Thanks Pradeep Kumar Srinivasan.


19.11.2
-------
Expand Down
1 change: 1 addition & 0 deletions monkeytype/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def get_stub(args: argparse.Namespace, stdout: IO, stderr: IO) -> Optional[Stub]
rewriter = NoOpRewriter()
stubs = build_module_stubs_from_traces(
traces,
args.config.max_typed_dict_size(),
existing_annotation_strategy=args.existing_annotation_strategy,
rewriter=rewriter,
)
Expand Down
70 changes: 70 additions & 0 deletions monkeytype/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any
from mypy_extensions import _TypedDictMeta # type: ignore


def is_typed_dict(typ: type) -> bool:
"""Test indirectly using _TypedDictMeta because TypedDict does not support `isinstance`."""
return isinstance(typ, _TypedDictMeta)


try:
# Python 3.7
Expand Down Expand Up @@ -37,6 +44,24 @@ def repr_forward_ref() -> str:
"""For checking the test output when ForwardRef is printed."""
return 'ForwardRef'

def __are_typed_dict_types_equal(type1: type, type2: type) -> bool:
"""Return true if the two TypedDicts are equal.
Doing this explicitly because
TypedDict('Foo', {'a': int}) != TypedDict('Foo', {'a': int})."""

if not is_typed_dict(type2):
return False
total1 = getattr(type1, "__total__", True)
total2 = getattr(type2, "__total__", True)
return (
type1.__name__ == type2.__name__
and total1 == total2
and type1.__annotations__ == type2.__annotations__
)

def types_equal(typ: type, other_type: type) -> bool:
return typ == other_type

except ImportError:
# Python 3.6
from typing import _Any, _Union, GenericMeta, _ForwardRef # type: ignore
Expand Down Expand Up @@ -68,3 +93,48 @@ def make_forward_ref(s: str) -> type:
def repr_forward_ref() -> str:
"""For checking the test output when ForwardRef is printed."""
return '_ForwardRef'

def __are_typed_dict_types_equal(type1: type, type2: type) -> bool:
"""Return true if the two TypedDicts are equal.
Doing this explicitly because
TypedDict('Foo', {'a': int}) != TypedDict('Foo', {'a': int})."""

if not is_typed_dict(type2):
return False
total1 = getattr(type1, "__total__", True)
total2 = getattr(type2, "__total__", True)
return (type1.__name__ == type2.__name__
and total1 == total2
and type1.__annotations__.keys() == type2.__annotations__.keys()
and all(types_equal(type1.__annotations__[key], type2.__annotations__[key])
for key in type1.__annotations__.keys()))

def types_equal(typ: type, other_type: type) -> bool:
# Types for which equality with inner TypedDict doesn't work correctly on 3.6.
special_container_types = ["List", "Dict", "Tuple", "Set", "DefaultDict"]

if (
(is_any(typ) and is_any(other_type))
or (is_union(typ) and is_union(other_type))
or (is_typed_dict(typ) and is_typed_dict(other_type))
):
pass
elif is_generic(typ) and is_generic(other_type):
if (
name_of_generic(typ) in special_container_types
and name_of_generic(other_type) in special_container_types
):
args = getattr(typ, "__args__", [])
other_args = getattr(other_type, "__args__", [])
return (
name_of_generic(typ) == name_of_generic(other_type)
and len(args) == len(other_args)
and all(types_equal(arg_type, other_arg_type)
for arg_type, other_arg_type in zip(args, other_args))
)
return typ == other_type


# HACK: MonkeyType monkey-patches _TypedDictMeta!
# We need this to compare TypedDicts recursively.
_TypedDictMeta.__eq__ = __are_typed_dict_types_equal
56 changes: 40 additions & 16 deletions monkeytype/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
is_forward_ref,
make_forward_ref,
)
from monkeytype.typing import is_anonymous_typed_dict
from monkeytype.typing import field_annotations
from monkeytype.util import pascal_case


Expand Down Expand Up @@ -217,7 +217,10 @@ def update_signature_return(
return sig.replace(return_annotation=anno)


def shrink_traced_types(traces: Iterable[CallTrace]) -> Tuple[Dict[str, type], Optional[type], Optional[type]]:
def shrink_traced_types(
traces: Iterable[CallTrace],
max_typed_dict_size: int,
) -> Tuple[Dict[str, type], Optional[type], Optional[type]]:
"""Merges the traced types and returns the minimally equivalent types"""
arg_types: DefaultDict[str, Set[type]] = collections.defaultdict(set)
return_types: Set[type] = set()
Expand All @@ -229,9 +232,9 @@ def shrink_traced_types(traces: Iterable[CallTrace]) -> Tuple[Dict[str, type], O
return_types.add(t.return_type)
if t.yield_type is not None:
yield_types.add(t.yield_type)
shrunken_arg_types = {name: shrink_types(ts) for name, ts in arg_types.items()}
return_type = shrink_types(return_types) if return_types else None
yield_type = shrink_types(yield_types) if yield_types else None
shrunken_arg_types = {name: shrink_types(ts, max_typed_dict_size) for name, ts in arg_types.items()}
return_type = shrink_types(return_types, max_typed_dict_size) if return_types else None
yield_type = shrink_types(yield_types, max_typed_dict_size) if yield_types else None
return (shrunken_arg_types, return_type, yield_type)


Expand Down Expand Up @@ -539,18 +542,35 @@ def _rewrite_container(self, cls: type, container: type) -> type:
# Value of type "type" is not indexable.
return cls[elems] # type: ignore

def rewrite_TypedDict(self, typed_dict: type) -> type:
if not is_anonymous_typed_dict(typed_dict):
return super().rewrite_TypedDict(typed_dict)
class_name = get_typed_dict_class_name(self._class_name_hint)
def _add_typed_dict_class_stub(self, fields: Dict[str, type], class_name: str,
base_class_name: str = 'TypedDict', total: bool = True) -> None:
attribute_stubs = []
for name, typ in typed_dict.__annotations__.items():
for name, typ in fields.items():
rewritten_type, stubs = self.rewrite_and_get_stubs(typ, class_name_hint=name)
attribute_stubs.append(AttributeStub(name, rewritten_type))
self.stubs.extend(stubs)
self.stubs.append(ClassStub(name=f'{class_name}(TypedDict)',
total_flag = '' if total else ', total=False'
self.stubs.append(ClassStub(name=f'{class_name}({base_class_name}{total_flag})',
function_stubs=[],
attribute_stubs=attribute_stubs))

def rewrite_anonymous_TypedDict(self, typed_dict: type) -> type:
class_name = get_typed_dict_class_name(self._class_name_hint)
required_fields, optional_fields = field_annotations(typed_dict)
has_required_fields = len(required_fields) != 0
has_optional_fields = len(optional_fields) != 0
if not has_required_fields and not has_optional_fields:
raise Exception("Expected empty TypedDicts to be shrunk as Dict[Any, Any]"
" but got an empty TypedDict anyway")
elif has_required_fields and not has_optional_fields:
self._add_typed_dict_class_stub(required_fields, class_name)
elif not has_required_fields and has_optional_fields:
self._add_typed_dict_class_stub(optional_fields, class_name, total=False)
else:
self._add_typed_dict_class_stub(required_fields, class_name)
base_class_name = class_name
class_name = get_typed_dict_class_name(self._class_name_hint) + 'NonTotal'
self._add_typed_dict_class_stub(optional_fields, class_name, base_class_name, total=False)
return make_forward_ref(class_name)

@staticmethod
Expand Down Expand Up @@ -688,13 +708,14 @@ def __repr__(self) -> str:
def get_updated_definition(
func: Callable,
traces: Iterable[CallTrace],
max_typed_dict_size: int,
rewriter: Optional[TypeRewriter] = None,
existing_annotation_strategy: ExistingAnnotationStrategy = ExistingAnnotationStrategy.REPLICATE,
) -> FunctionDefinition:
"""Update the definition for func using the types collected in traces."""
if rewriter is None:
rewriter = NoOpRewriter()
arg_types, return_type, yield_type = shrink_traced_types(traces)
arg_types, return_type, yield_type = shrink_traced_types(traces, max_typed_dict_size)
arg_types = {name: rewriter.rewrite(typ) for name, typ in arg_types.items()}
if return_type is not None:
return_type = rewriter.rewrite(return_type)
Expand Down Expand Up @@ -741,32 +762,35 @@ def build_module_stubs(entries: Iterable[FunctionDefinition]) -> Dict[str, Modul

def build_module_stubs_from_traces(
traces: Iterable[CallTrace],
max_typed_dict_size: int,
existing_annotation_strategy: ExistingAnnotationStrategy = ExistingAnnotationStrategy.REPLICATE,
rewriter: Optional[TypeRewriter] = None
rewriter: Optional[TypeRewriter] = None,
) -> Dict[str, ModuleStub]:
"""Given an iterable of call traces, build the corresponding stubs."""
index: DefaultDict[Callable, Set[CallTrace]] = collections.defaultdict(set)
for trace in traces:
index[trace.func].add(trace)
defns = []
for func, traces in index.items():
defn = get_updated_definition(func, traces, rewriter, existing_annotation_strategy)
defn = get_updated_definition(func, traces, max_typed_dict_size, rewriter, existing_annotation_strategy)
defns.append(defn)
return build_module_stubs(defns)


class StubIndexBuilder(CallTraceLogger):
"""Builds type stub index directly from collected call traces."""

def __init__(self, module_re: str) -> None:
def __init__(self, module_re: str, max_typed_dict_size: int) -> None:
self.re = re.compile(module_re)
self.index: DefaultDict[Callable, Set[CallTrace]] = collections.defaultdict(set)
self.max_typed_dict_size = max_typed_dict_size

def log(self, trace: CallTrace) -> None:
if not self.re.match(trace.funcname):
return
self.index[trace.func].add(trace)

def get_stubs(self) -> Dict[str, ModuleStub]:
defs = (get_updated_definition(func, traces) for func, traces in self.index.items())
defs = (get_updated_definition(func, traces, self.max_typed_dict_size)
for func, traces in self.index.items())
return build_module_stubs(defs)
6 changes: 3 additions & 3 deletions monkeytype/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ class CallTracer:
def __init__(
self,
logger: CallTraceLogger,
max_typed_dict_size: int,
code_filter: Optional[CodeFilter] = None,
sample_rate: Optional[int] = None,
max_typed_dict_size: Optional[int] = None,
) -> None:
self.logger = logger
self.traces: Dict[FrameType, CallTrace] = {}
Expand Down Expand Up @@ -269,13 +269,13 @@ def __call__(self, frame: FrameType, event: str, arg: Any) -> 'CallTracer':
@contextmanager
def trace_calls(
logger: CallTraceLogger,
max_typed_dict_size: int,
code_filter: Optional[CodeFilter] = None,
sample_rate: Optional[int] = None,
max_typed_dict_size: Optional[int] = None,
) -> Iterator[None]:
"""Enable call tracing for a block of code"""
old_trace = sys.getprofile()
sys.setprofile(CallTracer(logger, code_filter, sample_rate, max_typed_dict_size))
sys.setprofile(CallTracer(logger, max_typed_dict_size, code_filter, sample_rate))
try:
yield
finally:
Expand Down

0 comments on commit 56a688e

Please sign in to comment.