From 33ec0bb052e6711efaa833d6c044680e4de30902 Mon Sep 17 00:00:00 2001 From: Brian Hulette Date: Fri, 13 Mar 2020 15:30:00 -0700 Subject: [PATCH] [BEAM-9477] RowCoder should be hashable and picklable (#11088) * Add (failing) test * implement RowCoder.__hash__ * Add tests that require RowCoder to be picklable * Fix pickling --- sdks/python/apache_beam/coders/row_coder.py | 18 ++++++++++++--- .../apache_beam/coders/row_coder_test.py | 23 ++++++++++++++++++- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/coders/row_coder.py b/sdks/python/apache_beam/coders/row_coder.py index 55f4019df7af0..7012a56cb3741 100644 --- a/sdks/python/apache_beam/coders/row_coder.py +++ b/sdks/python/apache_beam/coders/row_coder.py @@ -35,6 +35,7 @@ from apache_beam.portability.api import schema_pb2 from apache_beam.typehints.schemas import named_tuple_from_schema from apache_beam.typehints.schemas import named_tuple_to_schema +from apache_beam.utils import proto_utils __all__ = ["RowCoder"] @@ -69,7 +70,8 @@ def to_type_hint(self): def as_cloud_object(self, coders_context=None): raise NotImplementedError("as_cloud_object not supported for RowCoder") - __hash__ = None # type: ignore[assignment] + def __hash__(self): + return hash(self.schema.SerializeToString()) def __eq__(self, other): return type(self) == type(other) and self.schema == other.schema @@ -79,13 +81,18 @@ def to_runner_api_parameter(self, unused_context): @staticmethod @Coder.register_urn(common_urns.coders.ROW.urn, schema_pb2.Schema) - def from_runner_api_parameter(payload, components, unused_context): - return RowCoder(payload) + def from_runner_api_parameter(schema, components, unused_context): + return RowCoder(schema) @staticmethod def from_type_hint(named_tuple_type, registry): return RowCoder(named_tuple_to_schema(named_tuple_type)) + @staticmethod + def from_payload(payload): + # type: (bytes) -> RowCoder + return RowCoder(proto_utils.parse_Bytes(payload, schema_pb2.Schema)) + @staticmethod def coder_from_type(field_type): type_info = field_type.WhichOneof("type_info") @@ -106,6 +113,11 @@ def coder_from_type(field_type): "Encountered a type that is not currently supported by RowCoder: %s" % field_type) + def __reduce__(self): + # when pickling, use bytes representation of the schema. schema_pb2.Schema + # objects cannot be pickled. + return (RowCoder.from_payload, (self.schema.SerializeToString(), )) + class RowCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" diff --git a/sdks/python/apache_beam/coders/row_coder_test.py b/sdks/python/apache_beam/coders/row_coder_test.py index d75b0b4560913..0ffd98338190f 100644 --- a/sdks/python/apache_beam/coders/row_coder_test.py +++ b/sdks/python/apache_beam/coders/row_coder_test.py @@ -26,9 +26,14 @@ import numpy as np from past.builtins import unicode +import apache_beam as beam from apache_beam.coders import RowCoder from apache_beam.coders.typecoders import registry as coders_registry +from apache_beam.internal import pickler from apache_beam.portability.api import schema_pb2 +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to from apache_beam.typehints.schemas import typing_to_runner_api Person = typing.NamedTuple( @@ -44,8 +49,9 @@ class RowCoderTest(unittest.TestCase): + TEST_CASE = Person("Jon Snow", 23, None, ["crow", "wildling"]) TEST_CASES = [ - Person("Jon Snow", 23, None, ["crow", "wildling"]), + TEST_CASE, Person("Daenerys Targaryen", 25, "Westeros", ["Mother of Dragons"]), Person("Michael Bluth", 30, None, []) ] @@ -165,6 +171,21 @@ def test_schema_add_column_with_null_value(self): New(None, "baz", None), new_coder.decode(old_coder.encode(Old(None, "baz")))) + def test_row_coder_picklable(self): + # occasionally coders can get pickled, RowCoder should be able to handle it + coder = coders_registry.get_coder(Person) + roundtripped = pickler.loads(pickler.dumps(coder)) + + self.assertEqual(roundtripped, coder) + + def test_row_coder_in_pipeine(self): + with TestPipeline() as p: + res = ( + p + | beam.Create(self.TEST_CASES) + | beam.Filter(lambda person: person.name == "Jon Snow")) + assert_that(res, equal_to([self.TEST_CASE])) + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO)