Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BatchConverter implementations for pandas types, use Batched DoFns in DataFrame convert utilities #22575

Merged
merged 10 commits into from
Aug 31, 2022
79 changes: 74 additions & 5 deletions sdks/python/apache_beam/dataframe/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
# limitations under the License.

import inspect
import warnings
import weakref
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import Iterable
from typing import Tuple
from typing import Union

Expand All @@ -28,8 +30,10 @@
from apache_beam import pvalue
from apache_beam.dataframe import expressions
from apache_beam.dataframe import frame_base
from apache_beam.dataframe import schemas
from apache_beam.dataframe import transforms
from apache_beam.dataframe.schemas import element_typehint_from_dataframe_proxy
from apache_beam.dataframe.schemas import generate_proxy
from apache_beam.typehints.pandas_type_compatibility import dtype_to_fieldtype

if TYPE_CHECKING:
# pylint: disable=ungrouped-imports
Expand Down Expand Up @@ -68,8 +72,17 @@ def to_dataframe(
# Attempt to come up with a reasonable, stable label by retrieving
# the name of these variables in the calling context.
label = 'BatchElements(%s)' % _var_name(pcoll, 2)
proxy = schemas.generate_proxy(pcoll.element_type)
pcoll = pcoll | label >> schemas.BatchRowsAsDataFrame(proxy=proxy)
proxy = generate_proxy(pcoll.element_type)

shim_dofn: beam.DoFn
if isinstance(proxy, pd.DataFrame):
shim_dofn = RowsToDataFrameFn()
elif isinstance(proxy, pd.Series):
shim_dofn = ElementsToSeriesFn()
else:
raise AssertionError("Unknown proxy type: %s" % proxy)

pcoll = pcoll | label >> beam.ParDo(shim_dofn)
return frame_base.DeferredFrame.wrap(
expressions.PlaceholderExpression(proxy, pcoll))

Expand All @@ -86,6 +99,18 @@ def to_dataframe(
) # type: weakref.WeakValueDictionary[str, pvalue.PCollection]


class RowsToDataFrameFn(beam.DoFn):
@beam.DoFn.yields_elements
def process_batch(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]:
yield batch


class ElementsToSeriesFn(beam.DoFn):
@beam.DoFn.yields_elements
def process_batch(self, batch: pd.Series) -> Iterable[pd.Series]:
yield batch


def _make_unbatched_pcoll(
pc: pvalue.PCollection, expr: expressions.Expression,
include_indexes: bool):
Expand All @@ -94,14 +119,58 @@ def _make_unbatched_pcoll(
label += " with indexes"

if label not in UNBATCHED_CACHE:
UNBATCHED_CACHE[label] = pc | label >> schemas.UnbatchPandas(
expr.proxy(), include_indexes=include_indexes)
proxy = expr.proxy()
shim_dofn: beam.DoFn
if isinstance(proxy, pd.DataFrame):
shim_dofn = DataFrameToRowsFn(proxy, include_indexes)
elif isinstance(proxy, pd.Series):
if include_indexes:
warnings.warn(
"Pipeline is converting a DeferredSeries to PCollection "
"with include_indexes=True. Note that this parameter is "
"_not_ respected for DeferredSeries conversion. To "
"include the index with your data, produce a"
"DeferredDataFrame instead.")

shim_dofn = SeriesToElementsFn(proxy)
else:
raise TypeError(f"Proxy '{proxy}' has unsupported type '{type(proxy)}'")

UNBATCHED_CACHE[label] = pc | label >> beam.ParDo(shim_dofn)

# Note unbatched cache is keyed by the expression id as well as parameters
# for the unbatching (i.e. include_indexes)
return UNBATCHED_CACHE[label]


class DataFrameToRowsFn(beam.DoFn):
def __init__(self, proxy, include_indexes):
self._proxy = proxy
self._include_indexes = include_indexes

@beam.DoFn.yields_batches
def process(self, element: pd.DataFrame) -> Iterable[pd.DataFrame]:
yield element

def infer_output_type(self, input_element_type):
return element_typehint_from_dataframe_proxy(
self._proxy, self._include_indexes)


class SeriesToElementsFn(beam.DoFn):
def __init__(self, proxy):
self._proxy = proxy

@beam.DoFn.yields_batches
def process(self, element: pd.Series) -> Iterable[pd.Series]:
yield element

def infer_output_type(self, input_element_type):
# Raise a TypeError if proxy has an unknown type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may have missed this, but where does the error get raised?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, this comment references behavior that was removed in 2b0597e

Now we will just shunt to Any in this case. I removed the comment. Thanks for raising this!

output_type = dtype_to_fieldtype(self._proxy.dtype)
return output_type


# TODO: Or should this be called from_dataframe?


Expand Down
Loading