Skip to content

Commit

Permalink
[BEAM-10475] Add typehints for ShardedKeyCoder (#13474)
Browse files Browse the repository at this point in the history
  • Loading branch information
nehsyc committed Dec 4, 2020
1 parent edc087e commit 44a2ac5
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 0 deletions.
16 changes: 16 additions & 0 deletions sdks/python/apache_beam/coders/coders.py
Expand Up @@ -89,6 +89,7 @@
'PickleCoder',
'ProtoCoder',
'SingletonCoder',
'ShardedKeyCoder',
'StrUtf8Coder',
'TimestampCoder',
'TupleCoder',
Expand Down Expand Up @@ -1485,6 +1486,21 @@ def as_cloud_object(self, coders_context=None):
],
}

def to_type_hint(self):
from apache_beam.typehints import sharded_key_type
return sharded_key_type.ShardedKeyTypeConstraint(
self._key_coder.to_type_hint())

@staticmethod
def from_type_hint(typehint, registry):
from apache_beam.typehints import sharded_key_type
if isinstance(typehint, sharded_key_type.ShardedKeyTypeConstraint):
return ShardedKeyCoder(registry.get_coder(typehint.key_type))
else:
raise ValueError((
'Expected an instance of ShardedKeyTypeConstraint'
', but got a %s' % typehint))

def __eq__(self, other):
return type(self) == type(other) and self._key_coder == other._key_coder

Expand Down
17 changes: 17 additions & 0 deletions sdks/python/apache_beam/coders/coders_test_common.py
Expand Up @@ -30,11 +30,14 @@

from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message
from apache_beam.coders import coders
from apache_beam.coders import typecoders
from apache_beam.internal import pickler
from apache_beam.runners import pipeline_context
from apache_beam.transforms import userstate
from apache_beam.transforms import window
from apache_beam.transforms.window import GlobalWindow
from apache_beam.typehints import sharded_key_type
from apache_beam.typehints import typehints
from apache_beam.utils import timestamp
from apache_beam.utils import windowed_value
from apache_beam.utils.sharded_key import ShardedKey
Expand Down Expand Up @@ -596,6 +599,20 @@ def test_sharded_key_coder(self):
self.check_coder(coder, ShardedKey(key, b''))
self.check_coder(coder, ShardedKey(key, b'123'))

# Test type hints
self.assertTrue(
isinstance(
coder.to_type_hint(), sharded_key_type.ShardedKeyTypeConstraint))
key_type = coder.to_type_hint().key_type
if isinstance(key_type, typehints.TupleConstraint):
self.assertEqual(key_type.tuple_types, (type(key[0]), type(key[1])))
else:
self.assertEqual(key_type, type(key))
self.assertEqual(
coders.ShardedKeyCoder.from_type_hint(
coder.to_type_hint(), typecoders.CoderRegistry()),
coder)

for other_key, _, other_key_coder in key_and_coders:
other_coder = coders.ShardedKeyCoder(other_key_coder)
# Test nested
Expand Down
75 changes: 75 additions & 0 deletions sdks/python/apache_beam/typehints/sharded_key_type.py
@@ -0,0 +1,75 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# pytype: skip-file

from __future__ import absolute_import

from apache_beam.coders import typecoders
from apache_beam.coders.coders import ShardedKeyCoder
from apache_beam.typehints import typehints
from apache_beam.typehints.typehints import match_type_variables
from apache_beam.utils.sharded_key import ShardedKey


class ShardedKeyTypeConstraint(typehints.TypeConstraint):
def __init__(self, key_type):
self.key_type = typehints.normalize(key_type)

def _inner_types(self):
yield self.key_type

def _consistent_with_check_(self, sub):
return (
isinstance(sub, self.__class__) and
typehints.is_consistent_with(sub.key_type, self.key_type))

def type_check(self, instance):
if not isinstance(instance, ShardedKey):
raise typehints.CompositeTypeHintError(
"ShardedKey type-constraint violated. Valid object instance "
"must be of type 'ShardedKey'. Instead, an instance of '%s' "
"was received." % (instance.__class__.__name__))

try:
typehints.check_constraint(self.key_type, instance.key)
except (typehints.CompositeTypeHintError, typehints.SimpleTypeHintError):
raise typehints.CompositeTypeHintError(
"%s type-constraint violated. The type of key in 'ShardedKey' "
"is incorrect. Expected an instance of type '%s', "
"instead received an instance of type '%s'." % (
repr(self),
typehints._unified_repr(self.key_type),
instance.key.__class__.__name__))

def match_type_variables(self, concrete_type):
if isinstance(concrete_type, ShardedKeyTypeConstraint):
return match_type_variables(self.key_type, concrete_type.key_type)
return {}

def __eq__(self, other):
return isinstance(
other, ShardedKeyTypeConstraint) and self.key_type == other.key_type

def __hash__(self):
return hash(self.key_type)

def __repr__(self):
return 'ShardedKey(%s)' % typehints._unified_repr(self.key_type)


typecoders.registry.register_coder(ShardedKeyTypeConstraint, ShardedKeyCoder)
80 changes: 80 additions & 0 deletions sdks/python/apache_beam/typehints/sharded_key_type_test.py
@@ -0,0 +1,80 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Unit tests for the ShardedKeyTypeConstraint."""

# pytype: skip-file

from __future__ import absolute_import

from apache_beam.typehints import Tuple
from apache_beam.typehints import typehints
from apache_beam.typehints.sharded_key_type import ShardedKeyTypeConstraint
from apache_beam.typehints.typehints_test import TypeHintTestCase
from apache_beam.utils.sharded_key import ShardedKey


class ShardedKeyTypeConstraintTest(TypeHintTestCase):
def test_compatibility(self):
constraint1 = ShardedKeyTypeConstraint(int)
constraint2 = ShardedKeyTypeConstraint(str)

self.assertCompatible(constraint1, constraint1)
self.assertCompatible(constraint2, constraint2)
self.assertNotCompatible(constraint1, constraint2)

def test_repr(self):
constraint = ShardedKeyTypeConstraint(int)
self.assertEqual('ShardedKey(int)', repr(constraint))

def test_type_check_not_sharded_key(self):
constraint = ShardedKeyTypeConstraint(int)
obj = 5
with self.assertRaises(TypeError) as e:
constraint.type_check(obj)
self.assertEqual(
"ShardedKey type-constraint violated. Valid object instance must be of "
"type 'ShardedKey'. Instead, an instance of 'int' was received.",
e.exception.args[0])

def test_type_check_invalid_key_type(self):
constraint = ShardedKeyTypeConstraint(int)
obj = ShardedKey(key='abc', shard_id=b'123')
with self.assertRaises((TypeError, TypeError)) as e:
constraint.type_check(obj)
self.assertEqual(
"ShardedKey(int) type-constraint violated. The type of key in "
"'ShardedKey' is incorrect. Expected an instance of type 'int', "
"instead received an instance of type 'str'.",
e.exception.args[0])

def test_type_check_valid_simple_type(self):
constraint = ShardedKeyTypeConstraint(str)
obj = ShardedKey(key='abc', shard_id=b'123')
self.assertIsNone(constraint.type_check(obj))

def test_type_check_valid_composite_type(self):
constraint = ShardedKeyTypeConstraint(Tuple[int, str])
obj = ShardedKey(key=(1, 'a'), shard_id=b'123')
self.assertIsNone(constraint.type_check(obj))

def test_match_type_variables(self):
K = typehints.TypeVariable('K') # pylint: disable=invalid-name
constraint = ShardedKeyTypeConstraint(K)
self.assertEqual({K: int},
constraint.match_type_variables(
ShardedKeyTypeConstraint(int)))

0 comments on commit 44a2ac5

Please sign in to comment.