Skip to content

Commit

Permalink
Add BatchConverter implementations for pandas types, use Batched DoFn…
Browse files Browse the repository at this point in the history
…s in DataFrame convert utilities (#22575)

Add BatchConverter implementations for pandas types, use Batched DoFns in DataFrame convert utilities
  • Loading branch information
TheNeuralBit committed Aug 31, 2022
2 parents 149ed07 + c088431 commit a6329a5
Show file tree
Hide file tree
Showing 8 changed files with 701 additions and 202 deletions.
77 changes: 72 additions & 5 deletions sdks/python/apache_beam/dataframe/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
# limitations under the License.

import inspect
import warnings
import weakref
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import Iterable
from typing import Tuple
from typing import Union

Expand All @@ -28,8 +30,10 @@
from apache_beam import pvalue
from apache_beam.dataframe import expressions
from apache_beam.dataframe import frame_base
from apache_beam.dataframe import schemas
from apache_beam.dataframe import transforms
from apache_beam.dataframe.schemas import element_typehint_from_dataframe_proxy
from apache_beam.dataframe.schemas import generate_proxy
from apache_beam.typehints.pandas_type_compatibility import dtype_to_fieldtype

if TYPE_CHECKING:
# pylint: disable=ungrouped-imports
Expand Down Expand Up @@ -68,8 +72,17 @@ def to_dataframe(
# Attempt to come up with a reasonable, stable label by retrieving
# the name of these variables in the calling context.
label = 'BatchElements(%s)' % _var_name(pcoll, 2)
proxy = schemas.generate_proxy(pcoll.element_type)
pcoll = pcoll | label >> schemas.BatchRowsAsDataFrame(proxy=proxy)
proxy = generate_proxy(pcoll.element_type)

shim_dofn: beam.DoFn
if isinstance(proxy, pd.DataFrame):
shim_dofn = RowsToDataFrameFn()
elif isinstance(proxy, pd.Series):
shim_dofn = ElementsToSeriesFn()
else:
raise AssertionError("Unknown proxy type: %s" % proxy)

pcoll = pcoll | label >> beam.ParDo(shim_dofn)
return frame_base.DeferredFrame.wrap(
expressions.PlaceholderExpression(proxy, pcoll))

Expand All @@ -86,6 +99,18 @@ def to_dataframe(
) # type: weakref.WeakValueDictionary[str, pvalue.PCollection]


class RowsToDataFrameFn(beam.DoFn):
@beam.DoFn.yields_elements
def process_batch(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]:
yield batch


class ElementsToSeriesFn(beam.DoFn):
@beam.DoFn.yields_elements
def process_batch(self, batch: pd.Series) -> Iterable[pd.Series]:
yield batch


def _make_unbatched_pcoll(
pc: pvalue.PCollection, expr: expressions.Expression,
include_indexes: bool):
Expand All @@ -94,14 +119,56 @@ def _make_unbatched_pcoll(
label += " with indexes"

if label not in UNBATCHED_CACHE:
UNBATCHED_CACHE[label] = pc | label >> schemas.UnbatchPandas(
expr.proxy(), include_indexes=include_indexes)
proxy = expr.proxy()
shim_dofn: beam.DoFn
if isinstance(proxy, pd.DataFrame):
shim_dofn = DataFrameToRowsFn(proxy, include_indexes)
elif isinstance(proxy, pd.Series):
if include_indexes:
warnings.warn(
"Pipeline is converting a DeferredSeries to PCollection "
"with include_indexes=True. Note that this parameter is "
"_not_ respected for DeferredSeries conversion. To "
"include the index with your data, produce a"
"DeferredDataFrame instead.")

shim_dofn = SeriesToElementsFn(proxy)
else:
raise TypeError(f"Proxy '{proxy}' has unsupported type '{type(proxy)}'")

UNBATCHED_CACHE[label] = pc | label >> beam.ParDo(shim_dofn)

# Note unbatched cache is keyed by the expression id as well as parameters
# for the unbatching (i.e. include_indexes)
return UNBATCHED_CACHE[label]


class DataFrameToRowsFn(beam.DoFn):
def __init__(self, proxy, include_indexes):
self._proxy = proxy
self._include_indexes = include_indexes

@beam.DoFn.yields_batches
def process(self, element: pd.DataFrame) -> Iterable[pd.DataFrame]:
yield element

def infer_output_type(self, input_element_type):
return element_typehint_from_dataframe_proxy(
self._proxy, self._include_indexes)


class SeriesToElementsFn(beam.DoFn):
def __init__(self, proxy):
self._proxy = proxy

@beam.DoFn.yields_batches
def process(self, element: pd.Series) -> Iterable[pd.Series]:
yield element

def infer_output_type(self, input_element_type):
return dtype_to_fieldtype(self._proxy.dtype)


# TODO: Or should this be called from_dataframe?


Expand Down
Loading

0 comments on commit a6329a5

Please sign in to comment.