Skip to content

Commit

Permalink
Add BatchRowsAsDataframe and generate_proxy, integrated into DataFram…
Browse files Browse the repository at this point in the history
…eTransform
  • Loading branch information
TheNeuralBit committed Aug 7, 2020
1 parent 1671ce5 commit 55c4920
Show file tree
Hide file tree
Showing 7 changed files with 317 additions and 37 deletions.
15 changes: 7 additions & 8 deletions sdks/python/apache_beam/coders/row_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from apache_beam.typehints.schemas import named_fields_to_schema
from apache_beam.typehints.schemas import named_tuple_from_schema
from apache_beam.typehints.schemas import named_tuple_to_schema
from apache_beam.typehints.schemas import schema_from_element_type
from apache_beam.utils import proto_utils

__all__ = ["RowCoder"]
Expand Down Expand Up @@ -90,14 +91,12 @@ def from_runner_api_parameter(schema, components, unused_context):

@staticmethod
def from_type_hint(type_hint, registry):
if isinstance(type_hint, row_type.RowTypeConstraint):
try:
schema = named_fields_to_schema(type_hint._fields)
except ValueError:
# TODO(BEAM-10570): Consider a pythonsdk logical type.
return typecoders.registry.get_coder(object)
else:
schema = named_tuple_to_schema(type_hint)
try:
schema = schema_from_element_type(type_hint)
except ValueError:
# TODO(BEAM-10570): Consider a pythonsdk logical type.
return typecoders.registry.get_coder(object)

return RowCoder(schema)

@staticmethod
Expand Down
8 changes: 7 additions & 1 deletion sdks/python/apache_beam/dataframe/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from apache_beam.dataframe import expressions
from apache_beam.dataframe import frame_base
from apache_beam.dataframe import transforms
from apache_beam.dataframe import schemas

if TYPE_CHECKING:
# pylint: disable=ungrouped-imports
Expand All @@ -36,7 +37,7 @@
# TODO: Or should this be called as_dataframe?
def to_dataframe(
pcoll, # type: pvalue.PCollection
proxy, # type: pandas.core.generic.NDFrame
proxy=None, # type: pandas.core.generic.NDFrame
):
# type: (...) -> frame_base.DeferredFrame

Expand All @@ -52,6 +53,11 @@ def to_dataframe(
A proxy object must be given if the schema for the PCollection is not known.
"""
if proxy is None:
# If no proxy is given, assume this is an element-wise schema-aware
# PCollection that needs to be batched.
proxy = schemas.generate_proxy(pcoll.element_type)
pcoll = pcoll | 'BatchElements' >> schemas.BatchRowsAsDataFrame()
return frame_base.DeferredFrame.wrap(
expressions.PlaceholderExpression(proxy, pcoll))

Expand Down
78 changes: 78 additions & 0 deletions sdks/python/apache_beam/dataframe/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Utilities for relating schema-aware PCollections and dataframe transforms.
"""

from typing import NamedTuple
from typing import TypeVar

import pandas as pd

from apache_beam import typehints
from apache_beam.typehints.schemas import named_fields_from_element_type
from apache_beam.transforms.core import DoFn
from apache_beam.transforms.core import ParDo
from apache_beam.transforms.util import BatchElements

__all__ = ('BatchRowsAsDataFrame', 'generate_proxy')

T = TypeVar('T', bound=NamedTuple)


@typehints.with_input_types(T)
@typehints.with_output_types(pd.DataFrame)
class BatchRowsAsDataFrame(BatchElements):
"""A transform that batches schema-aware PCollection elements into DataFrames
Batching parameters are inherited from
:class:`~apache_beam.transforms.util.BatchElements`.
"""
def __init__(self, *args, **kwargs):
super(BatchRowsAsDataFrame, self).__init__(*args, **kwargs)
self._batch_elements_transform = BatchElements(*args, **kwargs)

def expand(self, pcoll):
return super(BatchRowsAsDataFrame, self).expand(pcoll) | ParDo(
_RowBatchToDataFrameDoFn(pcoll.element_type))


class _RowBatchToDataFrameDoFn(DoFn):
def __init__(self, element_type):
self._columns = [
name for name, _ in named_fields_from_element_type(element_type)
]

def process(self, element):
result = pd.DataFrame.from_records(element, columns=self._columns)
yield result


def _make_empty_series(name, typ):
try:
return pd.Series(name=name, dtype=typ)
except TypeError:
raise TypeError("Unable to convert type '%s' for field '%s'" % (name, typ))


def generate_proxy(element_type):
# type: (type) -> pd.DataFrame
return pd.DataFrame({
name: _make_empty_series(name, typ)
for name,
typ in named_fields_from_element_type(element_type)
})
82 changes: 82 additions & 0 deletions sdks/python/apache_beam/dataframe/schemas_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Tests for schemas."""

import unittest
from typing import NamedTuple

import future.tests.base # pylint: disable=unused-import
# patches unittest.testcase to be python3 compatible
import pandas as pd

import apache_beam as beam
from apache_beam.coders import RowCoder
from apache_beam.coders.typecoders import registry as coders_registry
from apache_beam.dataframe import schemas
from apache_beam.dataframe import transforms
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that

Simple = NamedTuple('Simple', [('name', str), ('id', int), ('height', float)])
coders_registry.register_coder(Simple, RowCoder)
Animal = NamedTuple('Animal', [('animal', str), ('max_speed', float)])
coders_registry.register_coder(Animal, RowCoder)


def matches_df(expected):
def check_df_pcoll_equal(actual):
sorted_actual = pd.concat(actual).sort_index()
sorted_expected = expected.sort_index()
if not sorted_actual.equals(sorted_expected):
raise AssertionError(
'Dataframes not equal: \n\nActual:\n%s\n\nExpected:\n%s' %
(sorted_actual, sorted_expected))

return check_df_pcoll_equal


class SchemasTest(unittest.TestCase):
def test_simple_df(self):
expected = pd.DataFrame({
'name': list(map(str, range(5))),
'id': list(range(5)),
'height': list(map(float, range(5)))
})

with TestPipeline() as p:
res = (
p
| beam.Create(
[Simple(name=str(i), id=i, height=float(i)) for i in range(5)])
| schemas.BatchRowsAsDataFrame(min_batch_size=10, max_batch_size=10))
assert_that(res, matches_df(expected))

def test_generate_proxy(self):
expected = pd.DataFrame({
'animal': pd.Series(dtype=str), 'max_speed': pd.Series(dtype=float)
})

self.assertTrue(schemas.generate_proxy(Animal).equals(expected))

def test_batch_with_df_transform(self):
with TestPipeline() as p:
res = (
p
| beam.Create([
Animal('Falcon', 380.0),
Animal('Falcon', 370.0),
Animal('Parrot', 24.0),
Animal('Parrot', 26.0)
])
| schemas.BatchRowsAsDataFrame()
| transforms.DataframeTransform(
lambda df: df.groupby('animal').mean(),
proxy=schemas.generate_proxy(Animal)))
assert_that(
res,
matches_df(
pd.DataFrame({'max_speed': [375.0, 25.0]},
index=pd.Index(
data=['Falcon', 'Parrot'], name='animal'))))


if __name__ == '__main__':
unittest.main()
7 changes: 5 additions & 2 deletions sdks/python/apache_beam/dataframe/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class DataframeTransform(transforms.PTransform):
passed to the callable as positional arguments, or a dictionary of
PCollections, in which case they will be passed as keyword arguments.
"""
def __init__(self, func, proxy):
def __init__(self, func, proxy=None):
self._func = func
self._proxy = proxy

Expand All @@ -62,7 +62,10 @@ def expand(self, input_pcolls):

# Convert inputs to a flat dict.
input_dict = _flatten(input_pcolls) # type: Dict[Any, PCollection]
proxies = _flatten(self._proxy)
proxies = _flatten(self._proxy) if self._proxy is not None else {
tag: None
for tag in input_dict.keys()
}
input_frames = {
k: convert.to_dataframe(pc, proxies[k])
for k, pc in input_dict.items()
Expand Down

0 comments on commit 55c4920

Please sign in to comment.