Skip to content

Commit

Permalink
Fix default type inference of CombinePerKey. (#16351)
Browse files Browse the repository at this point in the history
More typehints propagation of CombinePerKey, MapTuple, FlatMapTuple

get_type_hints returns an empty TypeHint object. Add with_default to
initialize the typehints from the fn when the hints are generated.
[default_type_hints() and in the map functions]

In MapTuple and FlatMapTuple, the input type hint is still ignored.
This is because
IOTypeHints cannot distinguish *args from an ordinary argument. I feel there
can be more complications even if this one is tracked.

For CombinePerKey, also override infer_output_types() to perform an inference
pass on the fn. This is more accurate than the default types. There
is some redundancy with default_type_hints()

Also fix strip_output_annotations (bug discovered while fixing the broken users
from the above changes)

It needs to inspect into the composites (Tuples etc), but wasn't doing so.
For now we return Any if any inner hint of a composite is annotated.
Expanded the comment.

This shall be filed / synced as #16351

BUG=36682198

Co-authored-by: Andy Ye <andyye333@gmail.com>
  • Loading branch information
rainwoodman and yeandy committed Jul 8, 2022
1 parent fbe6150 commit 66ffc0b
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 28 deletions.
82 changes: 54 additions & 28 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from apache_beam.typehints.trivial_inference import element_type
from apache_beam.typehints.typehints import TypeConstraint
from apache_beam.typehints.typehints import is_consistent_with
from apache_beam.typehints.typehints import visit_inner_types
from apache_beam.utils import urns
from apache_beam.utils.timestamp import Duration

Expand Down Expand Up @@ -721,7 +722,7 @@ def default_type_hints(self):
def infer_output_type(self, input_type):
# TODO(https://github.com/apache/beam/issues/19824): Side inputs types.
return trivial_inference.element_type(
self._strip_output_annotations(
_strip_output_annotations(
trivial_inference.infer_return_type(self.process, [input_type])))

@property
Expand Down Expand Up @@ -856,15 +857,6 @@ def get_output_batch_type(

return output_batch_type

def _strip_output_annotations(self, type_hint):
annotations = (TimestampedValue, WindowedValue, pvalue.TaggedOutput)
# TODO(robertwb): These should be parameterized types that the
# type inferencer understands.
if (type_hint in annotations or
trivial_inference.element_type(type_hint) in annotations):
return typehints.Any
return type_hint

def _process_argspec_fn(self):
"""Returns the Python callable that will eventually be invoked.
Expand Down Expand Up @@ -938,7 +930,7 @@ def default_type_hints(self):

def infer_output_type(self, input_type):
return trivial_inference.element_type(
self._strip_output_annotations(
_strip_output_annotations(
trivial_inference.infer_return_type(self._fn, [input_type])))

def _process_argspec_fn(self):
Expand Down Expand Up @@ -1223,12 +1215,13 @@ def extract_output(self, accumulator, *args, **kwargs):
return self._fn(accumulator, *args, **kwargs)

def default_type_hints(self):
fn_hints = get_type_hints(self._fn)
if fn_hints.input_types is None:
return fn_hints
fn_type_hints = typehints.decorators.IOTypeHints.from_callable(self._fn)
type_hints = get_type_hints(self._fn).with_defaults(fn_type_hints)
if type_hints.input_types is None:
return type_hints
else:
# fn(Iterable[V]) -> V becomes CombineFn(V) -> V
input_args, input_kwargs = fn_hints.input_types
input_args, input_kwargs = type_hints.input_types
if not input_args:
if len(input_kwargs) == 1:
input_args, input_kwargs = tuple(input_kwargs.values()), {}
Expand All @@ -1243,7 +1236,11 @@ def default_type_hints(self):
input_args[0])
input_args = (element_type(input_args[0]), ) + input_args[1:]
# TODO(robertwb): Assert output type is consistent with input type?
return fn_hints.with_input_types(*input_args, **input_kwargs)
return type_hints.with_input_types(*input_args, **input_kwargs)

def infer_output_type(self, input_type):
return _strip_output_annotations(
trivial_inference.infer_return_type(self._fn, [input_type]))

def for_input_type(self, input_type):
# Avoid circular imports.
Expand Down Expand Up @@ -1867,7 +1864,9 @@ def Map(fn, *args, **kwargs): # pylint: disable=invalid-name
wrapper)
output_hint = type_hints.simple_output_type(label)
if output_hint:
wrapper = with_output_types(typehints.Iterable[output_hint])(wrapper)
wrapper = with_output_types(
typehints.Iterable[_strip_output_annotations(output_hint)])(
wrapper)
# pylint: disable=protected-access
wrapper._argspec_fn = fn
# pylint: enable=protected-access
Expand Down Expand Up @@ -1928,14 +1927,17 @@ def MapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name

# Proxy the type-hint information from the original function to this new
# wrapped function.
type_hints = get_type_hints(fn)
type_hints = get_type_hints(fn).with_defaults(
typehints.decorators.IOTypeHints.from_callable(fn))
if type_hints.input_types is not None:
wrapper = with_input_types(
*type_hints.input_types[0], **type_hints.input_types[1])(
wrapper)
# TODO(BEAM-14052): ignore input hints, as we do not have enough
# information to infer the input type hint of the wrapper function.
pass
output_hint = type_hints.simple_output_type(label)
if output_hint:
wrapper = with_output_types(typehints.Iterable[output_hint])(wrapper)
wrapper = with_output_types(
typehints.Iterable[_strip_output_annotations(output_hint)])(
wrapper)

# Replace the first (args) component.
modified_arg_names = ['tuple_element'] + arg_names[-num_defaults:]
Expand Down Expand Up @@ -2003,14 +2005,15 @@ def FlatMapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name

# Proxy the type-hint information from the original function to this new
# wrapped function.
type_hints = get_type_hints(fn)
type_hints = get_type_hints(fn).with_defaults(
typehints.decorators.IOTypeHints.from_callable(fn))
if type_hints.input_types is not None:
wrapper = with_input_types(
*type_hints.input_types[0], **type_hints.input_types[1])(
wrapper)
# TODO(BEAM-14052): ignore input hints, as we do not have enough
# information to infer the input type hint of the wrapper function.
pass
output_hint = type_hints.simple_output_type(label)
if output_hint:
wrapper = with_output_types(output_hint)(wrapper)
wrapper = with_output_types(_strip_output_annotations(output_hint))(wrapper)

# Replace the first (args) component.
modified_arg_names = ['tuple_element'] + arg_names[-num_defaults:]
Expand Down Expand Up @@ -2258,7 +2261,9 @@ def Filter(fn, *args, **kwargs): # pylint: disable=invalid-name
get_type_hints(wrapper).input_types[0]):
output_hint = get_type_hints(wrapper).input_types[0][0]
if output_hint:
wrapper = with_output_types(typehints.Iterable[output_hint])(wrapper)
wrapper = with_output_types(
typehints.Iterable[_strip_output_annotations(output_hint)])(
wrapper)
# pylint: disable=protected-access
wrapper._argspec_fn = fn
# pylint: enable=protected-access
Expand Down Expand Up @@ -3493,3 +3498,24 @@ def to_runner_api_parameter(self, unused_context):
def from_runner_api_parameter(
unused_ptransform, unused_parameter, unused_context):
return Impulse()


def _strip_output_annotations(type_hint):
# TODO(robertwb): These should be parameterized types that the
# type inferencer understands.
# Then we can replace them with the correct element types instead of
# using Any. Refer to typehints.WindowedValue when doing this.
annotations = (TimestampedValue, WindowedValue, pvalue.TaggedOutput)

contains_annotation = False

def visitor(t, unused_args):
if t in annotations:
raise StopIteration

try:
visit_inner_types(type_hint, visitor, [])
except StopIteration:
contains_annotation = True

return typehints.Any if contains_annotation else type_hint
30 changes: 30 additions & 0 deletions sdks/python/apache_beam/transforms/ptransform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,6 +1498,36 @@ def half(b):
| 'ToBool' >> beam.Map(lambda x: bool(x)).with_input_types(
int).with_output_types(bool))

def test_pardo_like_inheriting_output_types_from_annotation(self):
def fn1(x: str) -> int:
return 1

def fn1_flat(x: str) -> typing.List[int]:
return [1]

def fn2(x: int, y: str) -> str:
return y

def fn2_flat(x: int, y: str) -> typing.List[str]:
return [y]

# We only need the args section of the hints.
def output_hints(transform):
return transform.default_type_hints().output_types[0][0]

self.assertEqual(int, output_hints(beam.Map(fn1)))
self.assertEqual(int, output_hints(beam.FlatMap(fn1_flat)))

self.assertEqual(str, output_hints(beam.MapTuple(fn2)))
self.assertEqual(str, output_hints(beam.FlatMapTuple(fn2_flat)))

def add(a: typing.Iterable[int]) -> int:
return sum(a)

self.assertCompatible(
typing.Tuple[typing.TypeVar('K'), int],
output_hints(beam.CombinePerKey(add)))

def test_group_by_key_only_output_type_deduction(self):
d = (
self.p
Expand Down
29 changes: 29 additions & 0 deletions sdks/python/apache_beam/typehints/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,35 @@ def visit(self, visitor, visitor_arg):
visitor(t, visitor_arg)


def visit_inner_types(type_constraint, visitor, visitor_arg):
"""Visitor pattern to visit all inner types of a type constraint.
Args:
type_constraint: A type constraint or a type.
visitor: A callable invoked for all nodes in the type tree comprising a
composite type. The visitor will be called with the node visited and the
visitor argument specified here.
visitor_arg: Visitor callback second argument.
Note:
Raise and capture a StopIteration to terminate the visit, e.g.
```
def visitor(type_constraint, visitor_arg):
if ...:
raise StopIteration
try:
visit_inner_types(type_constraint, visitor, visitor_arg)
except StopIteration:
pass
```
"""
if isinstance(type_constraint, TypeConstraint):
return type_constraint.visit(visitor, visitor_arg)
return visitor(type_constraint, visitor_arg)


def match_type_variables(type_constraint, concrete_type):
if isinstance(type_constraint, TypeConstraint):
return type_constraint.match_type_variables(concrete_type)
Expand Down
16 changes: 16 additions & 0 deletions sdks/python/apache_beam/typehints/typehints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from apache_beam.typehints.decorators import get_type_hints
from apache_beam.typehints.decorators import getcallargs_forhints
from apache_beam.typehints.typehints import is_consistent_with
from apache_beam.typehints.typehints import visit_inner_types


def check_or_interleave(hint, value, var):
Expand Down Expand Up @@ -307,6 +308,21 @@ def test_bind_type_variables(self):
}), typehints.Union[str, int])
self.assertEqual(hint.bind_type_variables({A: int, B: int}), int)

def test_visit_inner_types(self):
A = typehints.TypeVariable('A') # pylint: disable=invalid-name
B = typehints.TypeVariable('B') # pylint: disable=invalid-name
hint = typehints.Tuple[Tuple[A, A], B, int]

user_data = object()
nodes = []

def visitor(hint, arg):
self.assertIs(arg, user_data)
nodes.append(hint)

visit_inner_types(hint, visitor, user_data)
self.assertEqual(nodes, [hint, Tuple[A, A], A, A, B, int])


class OptionalHintTestCase(TypeHintTestCase):
def test_getitem_sequence_not_allowed(self):
Expand Down

0 comments on commit 66ffc0b

Please sign in to comment.