Skip to content

Commit

Permalink
Merge pull request #7534 Native type hints in Beam type hints
Browse files Browse the repository at this point in the history
[BEAM-2713] Allow native type hints nested in Beam type hints.
  • Loading branch information
robertwb committed Jan 23, 2019
2 parents 4e6fb5e + 595c4cf commit a843a08
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ def test_convert_to_beam_type(self):
native_type_compatibility.convert_to_beam_type(typing_type),
beam_type, description)

@unittest.skipIf(sys.version_info[0] == 3 and
os.environ.get('RUN_SKIPPED_PY3_TESTS') != '1',
'This test still needs to be fixed on Python 3.')
def test_convert_nested_to_beam_type(self):
self.assertEqual(
typehints.List[typing.Any],
typehints.List[typehints.Any])
self.assertEqual(
typehints.List[typing.Dict[int, str]],
typehints.List[typehints.Dict[int, str]])

@unittest.skipIf(sys.version_info[0] == 3 and
os.environ.get('RUN_SKIPPED_PY3_TESTS') != '1',
'This test still needs to be fixed on Python 3.')
Expand Down
58 changes: 38 additions & 20 deletions sdks/python/apache_beam/typehints/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ class SequenceTypeConstraint(IndexableTypeConstraint):
"""

def __init__(self, inner_type, sequence_type):
self.inner_type = inner_type
self.inner_type = normalize(inner_type)
self._sequence_type = sequence_type

def __eq__(self, other):
Expand Down Expand Up @@ -349,7 +349,8 @@ def validate_composite_type_param(type_param, error_msg_prefix):
possible_classes.append(types.__dict__["ClassType"])
is_not_type_constraint = (
not isinstance(type_param, tuple(possible_classes))
and type_param is not None)
and type_param is not None
and getattr(type_param, '__module__', None) != 'typing')
is_forbidden_type = (isinstance(type_param, type) and
type_param in DISALLOWED_PRIMITIVE_TYPES)

Expand Down Expand Up @@ -472,7 +473,7 @@ class UnionHint(CompositeTypeHint):
class UnionConstraint(TypeConstraint):

def __init__(self, union_types):
self.union_types = set(union_types)
self.union_types = set(normalize(t) for t in union_types)

def __eq__(self, other):
return (isinstance(other, UnionHint.UnionConstraint)
Expand Down Expand Up @@ -595,7 +596,7 @@ def _consistent_with_check_(self, sub):
class TupleConstraint(IndexableTypeConstraint):

def __init__(self, type_params):
self.tuple_types = tuple(type_params)
self.tuple_types = tuple(normalize(t) for t in type_params)

def __eq__(self, other):
return (isinstance(other, TupleHint.TupleConstraint)
Expand Down Expand Up @@ -774,8 +775,8 @@ class DictHint(CompositeTypeHint):
class DictConstraint(TypeConstraint):

def __init__(self, key_type, value_type):
self.key_type = key_type
self.value_type = value_type
self.key_type = normalize(key_type)
self.value_type = normalize(value_type)

def __repr__(self):
return 'Dict[%s, %s]' % (_unified_repr(self.key_type),
Expand Down Expand Up @@ -970,7 +971,7 @@ class IteratorHint(CompositeTypeHint):
class IteratorTypeConstraint(TypeConstraint):

def __init__(self, t):
self.yielded_type = t
self.yielded_type = normalize(t)

def __repr__(self):
return 'Iterator[%s]' % _unified_repr(self.yielded_type)
Expand Down Expand Up @@ -1020,7 +1021,7 @@ class WindowedTypeConstraint(with_metaclass(GetitemConstructor,
"""

def __init__(self, inner_type):
self.inner_type = inner_type
self.inner_type = normalize(inner_type)

def __eq__(self, other):
return (isinstance(other, WindowedTypeConstraint)
Expand Down Expand Up @@ -1075,22 +1076,39 @@ class GeneratorHint(IteratorHint):
WindowedValue = WindowedTypeConstraint


_KNOWN_PRIMITIVE_TYPES = {
dict: Dict[Any, Any],
list: List[Any],
tuple: Tuple[Any, ...],
set: Set[Any],
# Using None for the NoneType is a common convention.
None: type(None),
}
# There is a circular dependency between defining this mapping
# and using it in normalize(). Initialize it here and populate
# it below.
_KNOWN_PRIMITIVE_TYPES = {}


def normalize(x):
if x in _KNOWN_PRIMITIVE_TYPES:
def normalize(x, none_as_type=False):
# None is inconsistantly used for Any, unknown, or NoneType.
if none_as_type and x is None:
return type(None)
elif x in _KNOWN_PRIMITIVE_TYPES:
return _KNOWN_PRIMITIVE_TYPES[x]
elif getattr(x, '__module__', None) == 'typing':
# Avoid circular imports
from apache_beam.typehints import native_type_compatibility
beam_type = native_type_compatibility.convert_to_beam_type(x)
if beam_type != x:
# We were able to do the conversion.
return beam_type
else:
# It might be a compatible type we don't understand.
return Any
return x


_KNOWN_PRIMITIVE_TYPES.update({
dict: Dict[Any, Any],
list: List[Any],
tuple: Tuple[Any, ...],
set: Set[Any],
})


def is_consistent_with(sub, base):
"""Returns whether the type a is consistent with b.
Expand All @@ -1105,8 +1123,8 @@ def is_consistent_with(sub, base):
return True
if isinstance(sub, AnyTypeConstraint) or isinstance(base, AnyTypeConstraint):
return True
sub = normalize(sub)
base = normalize(base)
sub = normalize(sub, none_as_type=True)
base = normalize(base, none_as_type=True)
if isinstance(base, TypeConstraint):
if isinstance(sub, UnionConstraint):
return all(is_consistent_with(c, base) for c in sub.union_types)
Expand Down

0 comments on commit a843a08

Please sign in to comment.