From 81030101a3f08d9cf129d939327faa0a614ead9c Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Mon, 18 Dec 2017 10:54:49 -0800 Subject: [PATCH] Fix getitem and list comprehension type inference. --- sdks/python/apache_beam/typehints/opcodes.py | 19 +++++++++++++----- .../typehints/trivial_inference_test.py | 20 +++++++++++++++++++ 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/typehints/opcodes.py b/sdks/python/apache_beam/typehints/opcodes.py index ccf0195a58f7..f0f81ff773f8 100644 --- a/sdks/python/apache_beam/typehints/opcodes.py +++ b/sdks/python/apache_beam/typehints/opcodes.py @@ -150,11 +150,19 @@ def binary_true_divide(state, unused_arg): def binary_subscr(state, unused_arg): - tos = state.stack.pop() - if tos in (str, six.text_type): - out = tos + index = state.stack.pop() + base = state.stack.pop() + if base in (str, six.text_type): + out = base + elif (isinstance(index, Const) and isinstance(index.value, int) + and isinstance(base, typehints.TupleHint.TupleConstraint)): + const_index = index.value + if -len(base.tuple_types) < const_index < len(base.tuple_types): + out = base.tuple_types[const_index] + else: + out = element_type(base) else: - out = element_type(tos) + out = element_type(base) state.stack.append(out) @@ -193,8 +201,9 @@ def store_subscr(unused_state, unused_args): # break_loop # continue_loop def list_append(state, arg): + new_element_type = Const.unwrap(state.stack.pop()) state.stack[-arg] = List[Union[element_type(state.stack[-arg]), - Const.unwrap(state.stack.pop())]] + new_element_type]] load_locals = push_value(Dict[str, Any]) diff --git a/sdks/python/apache_beam/typehints/trivial_inference_test.py b/sdks/python/apache_beam/typehints/trivial_inference_test.py index b017e8af497f..cd5b8c2f50b6 100644 --- a/sdks/python/apache_beam/typehints/trivial_inference_test.py +++ b/sdks/python/apache_beam/typehints/trivial_inference_test.py @@ -43,6 +43,18 @@ def testTuples(self): self.assertReturnType( typehints.Tuple[str, int, float], lambda x: (x, 0, 1.0), [str]) + def testGetItem(self): + def reverse(ab): + return ab[-1], ab[0] + self.assertReturnType( + typehints.Tuple[typehints.Any, typehints.Any], reverse, [typehints.Any]) + self.assertReturnType( + typehints.Tuple[int, float], reverse, [typehints.Tuple[float, int]]) + self.assertReturnType( + typehints.Tuple[int, str], reverse, [typehints.Tuple[str, float, int]]) + self.assertReturnType( + typehints.Tuple[int, int], reverse, [typehints.List[int]]) + def testUnpack(self): def reverse(a_b): (a, b) = a_b @@ -98,6 +110,14 @@ def testTupleListComprehension(self): typehints.List[typehints.Union[int, float]], lambda xs: [x for x in xs], [typehints.Tuple[int, float]]) + self.assertReturnType( + typehints.List[typehints.Tuple[str, int]], + lambda kvs: [(kvs[0], v) for v in kvs[1]], + [typehints.Tuple[str, typehints.Iterable[int]]]) + self.assertReturnType( + typehints.List[typehints.Tuple[str, typehints.Union[str, int], int]], + lambda L: [(a, a or b, b) for a, b in L], + [typehints.Iterable[typehints.Tuple[str, int]]]) def testGenerator(self):