Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions sdks/python/apache_beam/coders/coder_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"""
# pytype: skip-file

import dataclasses
import decimal
import enum
import itertools
Expand Down Expand Up @@ -67,11 +68,6 @@
from apache_beam.utils.timestamp import MIN_TIMESTAMP
from apache_beam.utils.timestamp import Timestamp

try:
import dataclasses
except ImportError:
dataclasses = None # type: ignore

try:
import dill
except ImportError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import collections
import collections.abc
import dataclasses
import logging
import sys
import types
Expand Down Expand Up @@ -175,6 +176,10 @@ def match_is_named_tuple(user_type):
hasattr(user_type, '__annotations__') and hasattr(user_type, '_fields'))


def match_is_dataclass(user_type):
return dataclasses.is_dataclass(user_type) and isinstance(user_type, type)


def _match_is_optional(user_type):
return _match_is_union(user_type) and sum(
tp is type(None) for tp in _get_args(user_type)) == 1
Expand Down
47 changes: 24 additions & 23 deletions sdks/python/apache_beam/typehints/row_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@

from __future__ import annotations

import dataclasses
from typing import Any
from typing import Dict
from typing import Optional
from typing import Sequence
from typing import Tuple

from apache_beam.typehints import typehints
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
from apache_beam.typehints.schema_registry import SchemaTypeRegistry

Expand Down Expand Up @@ -56,18 +58,14 @@ def __init__(
for guidance on creating PCollections with inferred schemas.

Note RowTypeConstraint does not currently store arbitrary functions for
converting to/from the user type. Instead, we only support ``NamedTuple``
user types and make the follow assumptions:
converting to/from the user type. Instead, we support ``NamedTuple`` and
``dataclasses`` user types and make the follow assumptions:

- The user type can be constructed with field values as arguments in order
(i.e. ``constructor(*field_values)``).
- Field values can be accessed from instances of the user type by attribute
(i.e. with ``getattr(obj, field_name)``).

In the future we will add support for dataclasses
([#22085](https://github.com/apache/beam/issues/22085)) which also satisfy
these assumptions.

The RowTypeConstraint constructor should not be called directly (even
internally to Beam). Prefer static methods ``from_user_type`` or
``from_fields``.
Expand Down Expand Up @@ -107,27 +105,30 @@ def from_user_type(
if match_is_named_tuple(user_type):
fields = [(name, user_type.__annotations__[name])
for name in user_type._fields]

field_descriptions = getattr(user_type, '_field_descriptions', None)

if _user_type_is_generated(user_type):
return RowTypeConstraint.from_fields(
fields,
schema_id=getattr(user_type, _BEAM_SCHEMA_ID),
schema_options=schema_options,
field_options=field_options,
field_descriptions=field_descriptions)

# TODO(https://github.com/apache/beam/issues/22125): Add user API for
# specifying schema/field options
return RowTypeConstraint(
fields=fields,
user_type=user_type,
elif match_is_dataclass(user_type):
fields = [(field.name, field.type)
for field in dataclasses.fields(user_type)]
else:
return None

field_descriptions = getattr(user_type, '_field_descriptions', None)

if _user_type_is_generated(user_type):
return RowTypeConstraint.from_fields(
fields,
schema_id=getattr(user_type, _BEAM_SCHEMA_ID),
schema_options=schema_options,
field_options=field_options,
field_descriptions=field_descriptions)

return None
# TODO(https://github.com/apache/beam/issues/22125): Add user API for
# specifying schema/field options
return RowTypeConstraint(
fields=fields,
user_type=user_type,
schema_options=schema_options,
field_options=field_options,
field_descriptions=field_descriptions)

@staticmethod
def from_fields(
Expand Down
86 changes: 86 additions & 0 deletions sdks/python/apache_beam/typehints/row_type_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Unit tests for the Beam Row typing functionality."""

from dataclasses import dataclass
import typing
import unittest

import apache_beam as beam
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.typehints import row_type


class RowTypeTest(unittest.TestCase):
@staticmethod
def _check_key_type_and_count(x) -> int:
key_type = type(x[0])
if not row_type._user_type_is_generated(key_type):
raise RuntimeError("Expect type after GBK to be generated user type")

return len(x[1])

def test_group_by_key_namedtuple(self):
MyNamedTuple = typing.NamedTuple(
"MyNamedTuple", [("id", int), ("name", str)])

beam.coders.typecoders.registry.register_coder(
MyNamedTuple, beam.coders.RowCoder)

def generate(num: int):
for i in range(100):
yield (MyNamedTuple(i, 'a'), num)

pipeline = TestPipeline(is_integration_test=False)

with pipeline as p:
result = (
p
| 'Create' >> beam.Create([i for i in range(10)])
| 'Generate' >> beam.ParDo(generate).with_output_types(
tuple[MyNamedTuple, int])
| 'GBK' >> beam.GroupByKey()
| 'Count Elements' >> beam.Map(self._check_key_type_and_count))
assert_that(result, equal_to([10] * 100))

def test_group_by_key_dataclass(self):
@dataclass
class MyDataClass:
id: int
name: str

beam.coders.typecoders.registry.register_coder(
MyDataClass, beam.coders.RowCoder)

def generate(num: int):
for i in range(100):
yield (MyDataClass(i, 'a'), num)

pipeline = TestPipeline(is_integration_test=False)

with pipeline as p:
result = (
p
| 'Create' >> beam.Create([i for i in range(10)])
| 'Generate' >> beam.ParDo(generate).with_output_types(
tuple[MyDataClass, int])
| 'GBK' >> beam.GroupByKey()
| 'Count Elements' >> beam.Map(self._check_key_type_and_count))
assert_that(result, equal_to([10] * 100))
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/typehints/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
from apache_beam.typehints.native_type_compatibility import _safe_issubclass
from apache_beam.typehints.native_type_compatibility import convert_to_python_type
from apache_beam.typehints.native_type_compatibility import extract_optional_type
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
Expand Down Expand Up @@ -629,7 +630,7 @@ def schema_from_element_type(element_type: type) -> schema_pb2.Schema:
Returns schema as a list of (name, python_type) tuples"""
if isinstance(element_type, row_type.RowTypeConstraint):
return named_fields_to_schema(element_type._fields)
elif match_is_named_tuple(element_type):
elif match_is_named_tuple(element_type) or match_is_dataclass(element_type):
if hasattr(element_type, row_type._BEAM_SCHEMA_ID):
# if the named tuple's schema is in registry, we just use it instead of
# regenerating one.
Expand Down
61 changes: 61 additions & 0 deletions sdks/python/apache_beam/typehints/schemas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# pytype: skip-file

import dataclasses
import itertools
import pickle
import unittest
Expand Down Expand Up @@ -388,6 +389,24 @@ def test_namedtuple_roundtrip(self, user_type):
self.assertIsInstance(roundtripped, row_type.RowTypeConstraint)
self.assert_namedtuple_equivalent(roundtripped.user_type, user_type)

def test_dataclass_roundtrip(self):
@dataclasses.dataclass
class SimpleDataclass:
id: np.int64
name: str

roundtripped = typing_from_runner_api(
typing_to_runner_api(
SimpleDataclass, schema_registry=SchemaTypeRegistry()),
schema_registry=SchemaTypeRegistry())

self.assertIsInstance(roundtripped, row_type.RowTypeConstraint)
# The roundtripped user_type is generated as a NamedTuple, so we can't test
# equivalence directly with the dataclass.
# Instead, let's verify annotations.
self.assertEqual(
roundtripped.user_type.__annotations__, SimpleDataclass.__annotations__)

def test_row_type_constraint_to_schema(self):
result_type = typing_to_runner_api(
row_type.RowTypeConstraint.from_fields([
Expand Down Expand Up @@ -646,6 +665,48 @@ def test_trivial_example(self):
expected.row_type.schema.fields,
typing_to_runner_api(MyCuteClass).row_type.schema.fields)

def test_trivial_example_dataclass(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about adding a test that does beam.Create([MyCuteDataclass()]).wit_output_types(beam.Row) -> Reshuffle (the results should be named tuples)?

And also a test that beam.Create([MyCuteDataclass()]).wit_output_types(beam.Row) -> Reshuffle the resulting type is still the original dataclass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, added unit tests for both named tuple and dataclasses (they are expected to resulting in same generated user type passing through GBK)

Noticing current schemas_test unit tests does not involve pipeline tests. Created a row_type_test to settle these tests

@dataclasses.dataclass
class MyCuteDataclass:
name: str
age: Optional[int]
interests: List[str]
height: float
blob: ByteString

expected = schema_pb2.FieldType(
row_type=schema_pb2.RowType(
schema=schema_pb2.Schema(
fields=[
schema_pb2.Field(
name='name',
type=schema_pb2.FieldType(
atomic_type=schema_pb2.STRING),
),
schema_pb2.Field(
name='age',
type=schema_pb2.FieldType(
nullable=True, atomic_type=schema_pb2.INT64)),
schema_pb2.Field(
name='interests',
type=schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(
element_type=schema_pb2.FieldType(
atomic_type=schema_pb2.STRING)))),
schema_pb2.Field(
name='height',
type=schema_pb2.FieldType(
atomic_type=schema_pb2.DOUBLE)),
schema_pb2.Field(
name='blob',
type=schema_pb2.FieldType(
atomic_type=schema_pb2.BYTES)),
])))

self.assertEqual(
expected.row_type.schema.fields,
typing_to_runner_api(MyCuteDataclass).row_type.schema.fields)

def test_user_type_annotated_with_id_after_conversion(self):
MyCuteClass = NamedTuple('MyCuteClass', [
('name', str),
Expand Down
Loading