From 44a2ac55ae73401009d877a2a5d93165d1d8237b Mon Sep 17 00:00:00 2001 From: nehsyc <65132551+nehsyc@users.noreply.github.com> Date: Fri, 4 Dec 2020 14:35:20 -0800 Subject: [PATCH] [BEAM-10475] Add typehints for ShardedKeyCoder (#13474) --- sdks/python/apache_beam/coders/coders.py | 16 ++++ .../apache_beam/coders/coders_test_common.py | 17 ++++ .../apache_beam/typehints/sharded_key_type.py | 75 +++++++++++++++++ .../typehints/sharded_key_type_test.py | 80 +++++++++++++++++++ 4 files changed, 188 insertions(+) create mode 100644 sdks/python/apache_beam/typehints/sharded_key_type.py create mode 100644 sdks/python/apache_beam/typehints/sharded_key_type_test.py diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index 725a5d4632cc3..6058bf1f4b044 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -89,6 +89,7 @@ 'PickleCoder', 'ProtoCoder', 'SingletonCoder', + 'ShardedKeyCoder', 'StrUtf8Coder', 'TimestampCoder', 'TupleCoder', @@ -1485,6 +1486,21 @@ def as_cloud_object(self, coders_context=None): ], } + def to_type_hint(self): + from apache_beam.typehints import sharded_key_type + return sharded_key_type.ShardedKeyTypeConstraint( + self._key_coder.to_type_hint()) + + @staticmethod + def from_type_hint(typehint, registry): + from apache_beam.typehints import sharded_key_type + if isinstance(typehint, sharded_key_type.ShardedKeyTypeConstraint): + return ShardedKeyCoder(registry.get_coder(typehint.key_type)) + else: + raise ValueError(( + 'Expected an instance of ShardedKeyTypeConstraint' + ', but got a %s' % typehint)) + def __eq__(self, other): return type(self) == type(other) and self._key_coder == other._key_coder diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index 63ed0ba0f852e..7ff026e4aa14e 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -30,11 +30,14 @@ from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message from apache_beam.coders import coders +from apache_beam.coders import typecoders from apache_beam.internal import pickler from apache_beam.runners import pipeline_context from apache_beam.transforms import userstate from apache_beam.transforms import window from apache_beam.transforms.window import GlobalWindow +from apache_beam.typehints import sharded_key_type +from apache_beam.typehints import typehints from apache_beam.utils import timestamp from apache_beam.utils import windowed_value from apache_beam.utils.sharded_key import ShardedKey @@ -596,6 +599,20 @@ def test_sharded_key_coder(self): self.check_coder(coder, ShardedKey(key, b'')) self.check_coder(coder, ShardedKey(key, b'123')) + # Test type hints + self.assertTrue( + isinstance( + coder.to_type_hint(), sharded_key_type.ShardedKeyTypeConstraint)) + key_type = coder.to_type_hint().key_type + if isinstance(key_type, typehints.TupleConstraint): + self.assertEqual(key_type.tuple_types, (type(key[0]), type(key[1]))) + else: + self.assertEqual(key_type, type(key)) + self.assertEqual( + coders.ShardedKeyCoder.from_type_hint( + coder.to_type_hint(), typecoders.CoderRegistry()), + coder) + for other_key, _, other_key_coder in key_and_coders: other_coder = coders.ShardedKeyCoder(other_key_coder) # Test nested diff --git a/sdks/python/apache_beam/typehints/sharded_key_type.py b/sdks/python/apache_beam/typehints/sharded_key_type.py new file mode 100644 index 0000000000000..2c463ef704c2d --- /dev/null +++ b/sdks/python/apache_beam/typehints/sharded_key_type.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pytype: skip-file + +from __future__ import absolute_import + +from apache_beam.coders import typecoders +from apache_beam.coders.coders import ShardedKeyCoder +from apache_beam.typehints import typehints +from apache_beam.typehints.typehints import match_type_variables +from apache_beam.utils.sharded_key import ShardedKey + + +class ShardedKeyTypeConstraint(typehints.TypeConstraint): + def __init__(self, key_type): + self.key_type = typehints.normalize(key_type) + + def _inner_types(self): + yield self.key_type + + def _consistent_with_check_(self, sub): + return ( + isinstance(sub, self.__class__) and + typehints.is_consistent_with(sub.key_type, self.key_type)) + + def type_check(self, instance): + if not isinstance(instance, ShardedKey): + raise typehints.CompositeTypeHintError( + "ShardedKey type-constraint violated. Valid object instance " + "must be of type 'ShardedKey'. Instead, an instance of '%s' " + "was received." % (instance.__class__.__name__)) + + try: + typehints.check_constraint(self.key_type, instance.key) + except (typehints.CompositeTypeHintError, typehints.SimpleTypeHintError): + raise typehints.CompositeTypeHintError( + "%s type-constraint violated. The type of key in 'ShardedKey' " + "is incorrect. Expected an instance of type '%s', " + "instead received an instance of type '%s'." % ( + repr(self), + typehints._unified_repr(self.key_type), + instance.key.__class__.__name__)) + + def match_type_variables(self, concrete_type): + if isinstance(concrete_type, ShardedKeyTypeConstraint): + return match_type_variables(self.key_type, concrete_type.key_type) + return {} + + def __eq__(self, other): + return isinstance( + other, ShardedKeyTypeConstraint) and self.key_type == other.key_type + + def __hash__(self): + return hash(self.key_type) + + def __repr__(self): + return 'ShardedKey(%s)' % typehints._unified_repr(self.key_type) + + +typecoders.registry.register_coder(ShardedKeyTypeConstraint, ShardedKeyCoder) diff --git a/sdks/python/apache_beam/typehints/sharded_key_type_test.py b/sdks/python/apache_beam/typehints/sharded_key_type_test.py new file mode 100644 index 0000000000000..7bc5143b55d72 --- /dev/null +++ b/sdks/python/apache_beam/typehints/sharded_key_type_test.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for the ShardedKeyTypeConstraint.""" + +# pytype: skip-file + +from __future__ import absolute_import + +from apache_beam.typehints import Tuple +from apache_beam.typehints import typehints +from apache_beam.typehints.sharded_key_type import ShardedKeyTypeConstraint +from apache_beam.typehints.typehints_test import TypeHintTestCase +from apache_beam.utils.sharded_key import ShardedKey + + +class ShardedKeyTypeConstraintTest(TypeHintTestCase): + def test_compatibility(self): + constraint1 = ShardedKeyTypeConstraint(int) + constraint2 = ShardedKeyTypeConstraint(str) + + self.assertCompatible(constraint1, constraint1) + self.assertCompatible(constraint2, constraint2) + self.assertNotCompatible(constraint1, constraint2) + + def test_repr(self): + constraint = ShardedKeyTypeConstraint(int) + self.assertEqual('ShardedKey(int)', repr(constraint)) + + def test_type_check_not_sharded_key(self): + constraint = ShardedKeyTypeConstraint(int) + obj = 5 + with self.assertRaises(TypeError) as e: + constraint.type_check(obj) + self.assertEqual( + "ShardedKey type-constraint violated. Valid object instance must be of " + "type 'ShardedKey'. Instead, an instance of 'int' was received.", + e.exception.args[0]) + + def test_type_check_invalid_key_type(self): + constraint = ShardedKeyTypeConstraint(int) + obj = ShardedKey(key='abc', shard_id=b'123') + with self.assertRaises((TypeError, TypeError)) as e: + constraint.type_check(obj) + self.assertEqual( + "ShardedKey(int) type-constraint violated. The type of key in " + "'ShardedKey' is incorrect. Expected an instance of type 'int', " + "instead received an instance of type 'str'.", + e.exception.args[0]) + + def test_type_check_valid_simple_type(self): + constraint = ShardedKeyTypeConstraint(str) + obj = ShardedKey(key='abc', shard_id=b'123') + self.assertIsNone(constraint.type_check(obj)) + + def test_type_check_valid_composite_type(self): + constraint = ShardedKeyTypeConstraint(Tuple[int, str]) + obj = ShardedKey(key=(1, 'a'), shard_id=b'123') + self.assertIsNone(constraint.type_check(obj)) + + def test_match_type_variables(self): + K = typehints.TypeVariable('K') # pylint: disable=invalid-name + constraint = ShardedKeyTypeConstraint(K) + self.assertEqual({K: int}, + constraint.match_type_variables( + ShardedKeyTypeConstraint(int)))