Skip to content
Merged
107 changes: 84 additions & 23 deletions sdks/python/apache_beam/typehints/native_type_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# pytype: skip-file

import collections
import collections.abc
import logging
import sys
import types
Expand All @@ -45,7 +46,18 @@
frozenset: typing.FrozenSet,
}

_BUILTINS = [
dict,
list,
tuple,
set,
frozenset,
]

_CONVERTED_COLLECTIONS = [
collections.abc.Iterable,
collections.abc.Iterator,
collections.abc.Generator,
collections.abc.Set,
collections.abc.MutableSet,
collections.abc.Collection,
Expand Down Expand Up @@ -99,14 +111,25 @@ def _match_issubclass(match_against):
return lambda user_type: _safe_issubclass(user_type, match_against)


def _is_primitive(user_type, primitive):
# catch bare primitives
if user_type is primitive:
return True
return getattr(user_type, '__origin__', None) is primitive


def _match_is_primitive(match_against):
return lambda user_type: _is_primitive(user_type, match_against)


def _match_is_exactly_mapping(user_type):
# Avoid unintentionally catching all subtypes (e.g. strings and mappings).
expected_origin = collections.abc.Mapping
return getattr(user_type, '__origin__', None) is expected_origin


def _match_is_exactly_iterable(user_type):
if user_type is typing.Iterable:
if user_type is typing.Iterable or user_type is collections.abc.Iterable:
return True
# Avoid unintentionally catching all subtypes (e.g. strings and mappings).
expected_origin = collections.abc.Iterable
Expand Down Expand Up @@ -152,11 +175,13 @@ def _match_is_union(user_type):
return False


def match_is_set(user_type):
if _safe_issubclass(user_type, typing.Set):
def _match_is_set(user_type):
if _safe_issubclass(user_type, typing.Set) or _is_primitive(user_type, set):
return True
elif getattr(user_type, '__origin__', None) is not None:
return _safe_issubclass(user_type.__origin__, collections.abc.Set)
return _safe_issubclass(
user_type.__origin__, collections.abc.Set) or _safe_issubclass(
user_type.__origin__, collections.abc.MutableSet)
else:
return False

Expand Down Expand Up @@ -197,6 +222,36 @@ def convert_builtin_to_typing(typ):
return typ


def convert_typing_to_builtin(typ):
"""Converts a given typing collections type to its builtin counterpart.

Args:
typ: A typing type (e.g., typing.List[int]).

Returns:
type: The corresponding builtin type (e.g., list[int]).
"""
origin = getattr(typ, '__origin__', None)
args = getattr(typ, '__args__', None)
# Typing types return the primitive type as the origin from 3.9 on
if origin not in _BUILTINS:
return typ
# Early return for bare types
if not args:
return origin
if origin is list:
return list[convert_typing_to_builtin(args[0])]
elif origin is dict:
return dict[convert_typing_to_builtin(args[0]),
convert_typing_to_builtin(args[1])]
elif origin is tuple:
return tuple[tuple(convert_typing_to_builtin(args))]
elif origin is set:
return set[convert_typing_to_builtin(args)]
elif origin is frozenset:
return frozenset[convert_typing_to_builtin(args)]


def convert_collections_to_typing(typ):
"""Converts a given collections.abc type to a typing object.

Expand All @@ -216,6 +271,12 @@ def convert_collections_to_typing(typ):
return typ


def is_builtin(typ):
if typ in _BUILTINS:
return True
return getattr(typ, '__origin__', None) in _BUILTINS


def convert_to_beam_type(typ):
"""Convert a given typing type to a Beam type.

Expand All @@ -238,11 +299,8 @@ def convert_to_beam_type(typ):
sys.version_info.minor >= 10) and (isinstance(typ, types.UnionType)):
typ = typing.Union[typ]

if isinstance(typ, types.GenericAlias):
typ = convert_builtin_to_typing(typ)

if getattr(typ, '__module__', None) == 'collections.abc':
typ = convert_collections_to_typing(typ)
if getattr(typ, '__module__', None) == 'typing':
typ = convert_typing_to_builtin(typ)

typ_module = getattr(typ, '__module__', None)
if isinstance(typ, typing.TypeVar):
Expand All @@ -267,8 +325,16 @@ def convert_to_beam_type(typ):
# TODO(https://github.com/apache/beam/issues/20076): Currently unhandled.
_LOGGER.info('Converting NewType type hint to Any: "%s"', typ)
return typehints.Any
elif (typ_module != 'typing') and (typ_module != 'collections.abc'):
# Only translate types from the typing and collections.abc modules.
elif typ_module == 'apache_beam.typehints.native_type_compatibility' and \
getattr(typ, "__name__", typ.__origin__.__name__) == 'TypedWindowedValue':
# Need to pass through WindowedValue class so that it can be converted
# to the correct type constraint in Beam
# This is needed to fix https://github.com/apache/beam/issues/33356
pass

elif (typ_module != 'typing') and (typ_module !=
'collections.abc') and not is_builtin(typ):
# Only translate primitives and types from collections.abc and typing.
return typ
if (typ_module == 'collections.abc' and
typ.__origin__ not in _CONVERTED_COLLECTIONS):
Expand All @@ -285,39 +351,34 @@ def convert_to_beam_type(typ):
_TypeMapEntry(match=is_forward_ref, arity=0, beam_type=typehints.Any),
_TypeMapEntry(match=is_any, arity=0, beam_type=typehints.Any),
_TypeMapEntry(
match=_match_issubclass(typing.Dict),
arity=2,
beam_type=typehints.Dict),
match=_match_is_primitive(dict), arity=2, beam_type=typehints.Dict),
_TypeMapEntry(
match=_match_is_exactly_iterable,
arity=1,
beam_type=typehints.Iterable),
_TypeMapEntry(
match=_match_issubclass(typing.List),
arity=1,
beam_type=typehints.List),
match=_match_is_primitive(list), arity=1, beam_type=typehints.List),
# FrozenSets are a specific instance of a set, so we check this first.
_TypeMapEntry(
match=_match_issubclass(typing.FrozenSet),
match=_match_is_primitive(frozenset),
arity=1,
beam_type=typehints.FrozenSet),
_TypeMapEntry(match=match_is_set, arity=1, beam_type=typehints.Set),
_TypeMapEntry(match=_match_is_set, arity=1, beam_type=typehints.Set),
# NamedTuple is a subclass of Tuple, but it needs special handling.
# We just convert it to Any for now.
# This MUST appear before the entry for the normal Tuple.
_TypeMapEntry(
match=match_is_named_tuple, arity=0, beam_type=typehints.Any),
_TypeMapEntry(
match=_match_issubclass(typing.Tuple),
arity=-1,
match=_match_is_primitive(tuple), arity=-1,
beam_type=typehints.Tuple),
_TypeMapEntry(match=_match_is_union, arity=-1, beam_type=typehints.Union),
_TypeMapEntry(
match=_match_issubclass(typing.Generator),
match=_match_issubclass(collections.abc.Generator),
arity=3,
beam_type=typehints.Generator),
_TypeMapEntry(
match=_match_issubclass(typing.Iterator),
match=_match_issubclass(collections.abc.Iterator),
arity=1,
beam_type=typehints.Iterator),
_TypeMapEntry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from apache_beam.typehints.native_type_compatibility import convert_to_beam_types
from apache_beam.typehints.native_type_compatibility import convert_to_typing_type
from apache_beam.typehints.native_type_compatibility import convert_to_typing_types
from apache_beam.typehints.native_type_compatibility import convert_typing_to_builtin
from apache_beam.typehints.native_type_compatibility import is_any

_TestNamedTuple = typing.NamedTuple(
Expand All @@ -43,6 +44,7 @@ class _TestClass(object):


T = typing.TypeVar('T')
U = typing.TypeVar('U')


class _TestGeneric(typing.Generic[T]):
Expand Down Expand Up @@ -140,7 +142,7 @@ def test_convert_to_beam_type_with_builtin_types(self):
(
'builtin nested tuple',
tuple[str, list],
typehints.Tuple[str, typehints.List[typehints.Any]],
typehints.Tuple[str, typehints.List[typehints.TypeVariable('T')]],
)
]

Expand All @@ -159,7 +161,7 @@ def test_convert_to_beam_type_with_collections_types(self):
typehints.Iterable[int]),
(
'collection generator',
collections.abc.Generator[int],
collections.abc.Generator[int, None, None],
typehints.Generator[int]),
(
'collection iterator',
Expand All @@ -177,9 +179,8 @@ def test_convert_to_beam_type_with_collections_types(self):
'mapping not caught',
collections.abc.Mapping[str, int],
collections.abc.Mapping[str, int]),
('set', collections.abc.Set[str], typehints.Set[str]),
('set', collections.abc.Set[int], typehints.Set[int]),
('mutable set', collections.abc.MutableSet[int], typehints.Set[int]),
('enum set', collections.abc.Set[_TestEnum], typehints.Set[_TestEnum]),
(
'enum mutable set',
collections.abc.MutableSet[_TestEnum],
Expand Down Expand Up @@ -337,6 +338,24 @@ def test_is_any(self):
for expected, typ in test_cases:
self.assertEqual(expected, is_any(typ), msg='%s' % typ)

def test_convert_typing_to_builtin(self):
test_cases = [
('list', typing.List[int],
list[int]), ('dict', typing.Dict[str, int], dict[str, int]),
('tuple', typing.Tuple[str, int], tuple[str, int]),
('set', typing.Set[str], set[str]),
('frozenset', typing.FrozenSet[int], frozenset[int]),
(
'nested',
typing.List[typing.Dict[str, typing.Tuple[int]]],
list[dict[str, tuple[int]]]), ('typevar', typing.List[T], list[T]),
('nested_typevar', typing.Dict[T, typing.List[U]], dict[T, list[U]])
]

for description, typing_type, expected_builtin_type in test_cases:
builtin_type = convert_typing_to_builtin(typing_type)
self.assertEqual(builtin_type, expected_builtin_type, description)


if __name__ == '__main__':
unittest.main()
Loading