Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BEAM-9477] RowCoder should be hashable and picklable #11088

Merged
merged 4 commits into from Mar 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 15 additions & 3 deletions sdks/python/apache_beam/coders/row_coder.py
Expand Up @@ -35,6 +35,7 @@
from apache_beam.portability.api import schema_pb2
from apache_beam.typehints.schemas import named_tuple_from_schema
from apache_beam.typehints.schemas import named_tuple_to_schema
from apache_beam.utils import proto_utils

__all__ = ["RowCoder"]

Expand Down Expand Up @@ -69,7 +70,8 @@ def to_type_hint(self):
def as_cloud_object(self, coders_context=None):
raise NotImplementedError("as_cloud_object not supported for RowCoder")

__hash__ = None # type: ignore[assignment]
def __hash__(self):
return hash(self.schema.SerializeToString())

def __eq__(self, other):
return type(self) == type(other) and self.schema == other.schema
Expand All @@ -79,13 +81,18 @@ def to_runner_api_parameter(self, unused_context):

@staticmethod
@Coder.register_urn(common_urns.coders.ROW.urn, schema_pb2.Schema)
def from_runner_api_parameter(payload, components, unused_context):
return RowCoder(payload)
def from_runner_api_parameter(schema, components, unused_context):
return RowCoder(schema)

@staticmethod
def from_type_hint(named_tuple_type, registry):
return RowCoder(named_tuple_to_schema(named_tuple_type))

@staticmethod
def from_payload(payload):
# type: (bytes) -> RowCoder
return RowCoder(proto_utils.parse_Bytes(payload, schema_pb2.Schema))

@staticmethod
def coder_from_type(field_type):
type_info = field_type.WhichOneof("type_info")
Expand All @@ -106,6 +113,11 @@ def coder_from_type(field_type):
"Encountered a type that is not currently supported by RowCoder: %s" %
field_type)

def __reduce__(self):
# when pickling, use bytes representation of the schema. schema_pb2.Schema
# objects cannot be pickled.
return (RowCoder.from_payload, (self.schema.SerializeToString(), ))


class RowCoderImpl(StreamCoderImpl):
"""For internal use only; no backwards-compatibility guarantees."""
Expand Down
23 changes: 22 additions & 1 deletion sdks/python/apache_beam/coders/row_coder_test.py
Expand Up @@ -26,9 +26,14 @@
import numpy as np
from past.builtins import unicode

import apache_beam as beam
from apache_beam.coders import RowCoder
from apache_beam.coders.typecoders import registry as coders_registry
from apache_beam.internal import pickler
from apache_beam.portability.api import schema_pb2
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.typehints.schemas import typing_to_runner_api

Person = typing.NamedTuple(
Expand All @@ -44,8 +49,9 @@


class RowCoderTest(unittest.TestCase):
TEST_CASE = Person("Jon Snow", 23, None, ["crow", "wildling"])
TEST_CASES = [
Person("Jon Snow", 23, None, ["crow", "wildling"]),
TEST_CASE,
Person("Daenerys Targaryen", 25, "Westeros", ["Mother of Dragons"]),
Person("Michael Bluth", 30, None, [])
]
Expand Down Expand Up @@ -165,6 +171,21 @@ def test_schema_add_column_with_null_value(self):
New(None, "baz", None),
new_coder.decode(old_coder.encode(Old(None, "baz"))))

def test_row_coder_picklable(self):
# occasionally coders can get pickled, RowCoder should be able to handle it
coder = coders_registry.get_coder(Person)
roundtripped = pickler.loads(pickler.dumps(coder))

self.assertEqual(roundtripped, coder)

def test_row_coder_in_pipeine(self):
with TestPipeline() as p:
res = (
p
| beam.Create(self.TEST_CASES)
| beam.Filter(lambda person: person.name == "Jon Snow"))
assert_that(res, equal_to([self.TEST_CASE]))


if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
Expand Down