Skip to content

Commit

Permalink
Use pandas_type_compatibility BatchConverters for dataframe.schemas u…
Browse files Browse the repository at this point in the history
…tilities
  • Loading branch information
TheNeuralBit committed Aug 10, 2022
1 parent a5d8227 commit 1e2c800
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 105 deletions.
7 changes: 4 additions & 3 deletions sdks/python/apache_beam/dataframe/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
from apache_beam import pvalue
from apache_beam.dataframe import expressions
from apache_beam.dataframe import frame_base
from apache_beam.dataframe import schemas
from apache_beam.dataframe import transforms
from apache_beam.dataframe.schemas import element_typehint_from_dataframe_proxy
from apache_beam.dataframe.schemas import generate_proxy
from apache_beam.typehints.pandas_type_compatibility import dtype_to_fieldtype

if TYPE_CHECKING:
Expand Down Expand Up @@ -71,7 +72,7 @@ def to_dataframe(
# Attempt to come up with a reasonable, stable label by retrieving
# the name of these variables in the calling context.
label = 'BatchElements(%s)' % _var_name(pcoll, 2)
proxy = schemas.generate_proxy(pcoll.element_type)
proxy = generate_proxy(pcoll.element_type)

shim_dofn: beam.DoFn
if isinstance(proxy, pd.DataFrame):
Expand Down Expand Up @@ -152,7 +153,7 @@ def process(self, element: pd.DataFrame) -> Iterable[pd.DataFrame]:
yield element

def infer_output_type(self, input_element_type):
return schemas.element_typehint_from_dataframe_proxy(
return element_typehint_from_dataframe_proxy(
self._proxy, self._include_indexes)


Expand Down
132 changes: 41 additions & 91 deletions sdks/python/apache_beam/dataframe/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

# pytype: skip-file

import warnings
from typing import Any
from typing import Dict
from typing import NamedTuple
Expand All @@ -36,16 +37,14 @@

import apache_beam as beam
from apache_beam import typehints
from apache_beam.portability.api import schema_pb2
from apache_beam.transforms.util import BatchElements
from apache_beam.typehints.native_type_compatibility import _match_is_optional
from apache_beam.typehints.pandas_type_compatibility import INDEX_OPTION_NAME
from apache_beam.typehints.pandas_type_compatibility import create_pandas_batch_converter
from apache_beam.typehints.pandas_type_compatibility import dtype_from_typehint
from apache_beam.typehints.pandas_type_compatibility import dtype_to_fieldtype
from apache_beam.typehints.row_type import RowTypeConstraint
from apache_beam.typehints.schemas import named_fields_from_element_type
from apache_beam.typehints.schemas import named_tuple_from_schema
from apache_beam.typehints.schemas import named_tuple_to_schema
from apache_beam.utils import proto_utils
from apache_beam.typehints.typehints import normalize

__all__ = (
'BatchRowsAsDataFrame',
Expand All @@ -69,18 +68,21 @@ def __init__(self, *args, proxy=None, **kwargs):
self._proxy = proxy

def expand(self, pcoll):
proxy = generate_proxy(
pcoll.element_type) if self._proxy is None else self._proxy
if isinstance(proxy, pd.DataFrame):
columns = proxy.columns
construct = lambda batch: pd.DataFrame.from_records(
batch, columns=columns)
elif isinstance(proxy, pd.Series):
dtype = proxy.dtype
construct = lambda batch: pd.Series(batch, dtype=dtype)
if self._proxy is not None:
# Generate typehint
proxy = self._proxy
element_typehint = _element_typehint_from_proxy(proxy)
else:
raise NotImplementedError("Unknown proxy type: %s" % proxy)
return pcoll | self._batch_elements_transform | beam.Map(construct)
# Generate proxy
proxy = generate_proxy(pcoll.element_type)
element_typehint = pcoll.element_type

converter = create_pandas_batch_converter(
element_type=element_typehint, batch_type=type(proxy))

return (
pcoll | self._batch_elements_transform
| beam.Map(converter.produce_batch))


def generate_proxy(element_type):
Expand Down Expand Up @@ -117,6 +119,22 @@ def element_type_from_dataframe(proxy, include_indexes=False):
return element_typehint_from_dataframe_proxy(proxy, include_indexes).user_type


def _element_typehint_from_proxy(
proxy: pd.core.generic.NDFrame, include_indexes: bool = False):
if isinstance(proxy, pd.DataFrame):
return element_typehint_from_dataframe_proxy(
proxy, include_indexes=include_indexes)
elif isinstance(proxy, pd.Series):
if include_indexes:
warnings.warn(
"include_indexes=True for a Series input. Note that this "
"parameter is _not_ respected for DeferredSeries "
"conversion.")
return dtype_to_fieldtype(proxy.dtype)
else:
raise TypeError(f"Proxy '{proxy}' has unsupported type '{type(proxy)}'")


def element_typehint_from_dataframe_proxy(
proxy: pd.DataFrame, include_indexes: bool = False) -> RowTypeConstraint:

Expand Down Expand Up @@ -158,8 +176,7 @@ def element_typehint_from_dataframe_proxy(
field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]]
if include_indexes:
field_options = {
# TODO: Reference the constant in pandas_type_compatibility
index_name: [('beam:dataframe:index', None)]
index_name: [(INDEX_OPTION_NAME, None)]
for index_name in proxy.index.names
}
else:
Expand All @@ -168,82 +185,15 @@ def element_typehint_from_dataframe_proxy(
return RowTypeConstraint.from_fields(fields, field_options=field_options)


class _BaseDataframeUnbatchDoFn(beam.DoFn):
def __init__(self, namedtuple_ctor):
self._namedtuple_ctor = namedtuple_ctor

def _get_series(self, df):
raise NotImplementedError()

def process(self, df):
# TODO: Only do null checks for nullable types
def make_null_checking_generator(series):
nulls = pd.isnull(series)
return (None if isnull else value for isnull, value in zip(nulls, series))

all_series = self._get_series(df)
iterators = [
make_null_checking_generator(series) for series,
typehint in zip(all_series, self._namedtuple_ctor.__annotations__)
]

# TODO: Avoid materializing the rows. Produce an object that references the
# underlying dataframe
for values in zip(*iterators):
yield self._namedtuple_ctor(*values)

def infer_output_type(self, input_type):
return self._namedtuple_ctor

@classmethod
def _from_serialized_schema(cls, schema_str):
return cls(
named_tuple_from_schema(
proto_utils.parse_Bytes(schema_str, schema_pb2.Schema)))

def __reduce__(self):
# when pickling, use bytes representation of the schema.
return (
self._from_serialized_schema,
(named_tuple_to_schema(self._namedtuple_ctor).SerializeToString(), ))


class _UnbatchNoIndex(_BaseDataframeUnbatchDoFn):
def _get_series(self, df):
return [df[column] for column in df.columns]


class _UnbatchWithIndex(_BaseDataframeUnbatchDoFn):
def _get_series(self, df):
return [df.index.get_level_values(i) for i in range(len(df.index.names))
] + [df[column] for column in df.columns]


def _unbatch_transform(proxy, include_indexes):
if isinstance(proxy, pd.DataFrame):
ctor = element_type_from_dataframe(proxy, include_indexes=include_indexes)

return beam.ParDo(
_UnbatchWithIndex(ctor) if include_indexes else _UnbatchNoIndex(ctor))
elif isinstance(proxy, pd.Series):
# Raise a TypeError if proxy has an unknown type
output_type = dtype_to_fieldtype(proxy.dtype)
# TODO: Should the index ever be included for a Series?
if _match_is_optional(output_type):

def unbatch(series):
for isnull, value in zip(pd.isnull(series), series):
yield None if isnull else value
else:
element_typehint = normalize(
_element_typehint_from_proxy(proxy, include_indexes=include_indexes))

def unbatch(series):
yield from series
converter = create_pandas_batch_converter(
element_type=element_typehint, batch_type=type(proxy))

return beam.FlatMap(unbatch).with_output_types(output_type)
# TODO: What about scalar inputs?
else:
raise TypeError(
"Proxy '%s' has unsupported type '%s'" % (proxy, type(proxy)))
return beam.FlatMap(
converter.explode_batch).with_output_types(element_typehint)


@typehints.with_input_types(Union[pd.DataFrame, pd.Series])
Expand Down
29 changes: 23 additions & 6 deletions sdks/python/apache_beam/dataframe/schemas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.typehints import row_type
from apache_beam.typehints import typehints
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple

Expand All @@ -52,10 +53,7 @@ def check_df_pcoll_equal(actual):
drop=True)
sorted_expected = expected.sort_values(
by=list(expected.columns)).reset_index(drop=True)
if not sorted_actual.equals(sorted_expected):
raise AssertionError(
'Dataframes not equal: \n\nActual:\n%s\n\nExpected:\n%s' %
(sorted_actual, sorted_expected))
pd.testing.assert_frame_equal(sorted_actual, sorted_expected)

return check_df_pcoll_equal

Expand Down Expand Up @@ -145,6 +143,8 @@ def test_simple_df(self):
},
columns=['name', 'id', 'height'])

expected.name = expected.name.astype(pd.StringDtype())

with TestPipeline() as p:
res = (
p
Expand All @@ -160,6 +160,7 @@ def test_simple_df_with_beam_row(self):
'height': list(float(i) for i in range(5))
},
columns=['name', 'id', 'height'])
expected.name = expected.name.astype(pd.StringDtype())

with TestPipeline() as p:
res = (
Expand Down Expand Up @@ -235,8 +236,14 @@ def test_batch_with_df_transform(self):
assert_that(res, equal_to([('Falcon', 375.), ('Parrot', 25.)]))

def assert_typehints_equal(self, left, right):
left = typehints.normalize(left)
right = typehints.normalize(right)
def maybe_drop_rowtypeconstraint(typehint):
if isinstance(typehint, row_type.RowTypeConstraint):
return typehint.user_type
else:
return typehint

left = maybe_drop_rowtypeconstraint(typehints.normalize(left))
right = maybe_drop_rowtypeconstraint(typehints.normalize(right))

if match_is_named_tuple(left):
self.assertTrue(match_is_named_tuple(right))
Expand Down Expand Up @@ -273,6 +280,16 @@ def test_unbatch_with_index(self, df_or_series, rows, _):

assert_that(res, equal_to(rows))

@parameterized.expand(SERIES_TESTS, name_func=test_name_func)
def test_unbatch_series_with_index_warns(
self, series, unused_rows, unused_type):
proxy = series[:0]

with TestPipeline() as p:
input_pc = p | beam.Create([series[::2], series[1::2]])
with self.assertWarns(UserWarning):
_ = input_pc | schemas.UnbatchPandas(proxy, include_indexes=True)

def test_unbatch_include_index_unnamed_index_raises(self):
df = pd.DataFrame({'foo': [1, 2, 3, 4]})
proxy = df[:0]
Expand Down
18 changes: 13 additions & 5 deletions sdks/python/apache_beam/typehints/pandas_type_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,19 @@ def dtype_to_fieldtype(dtype):
return Any


@BatchConverter.register
def create_pandas_batch_converter(
element_type: type, batch_type: type) -> BatchConverter:
if batch_type == pd.DataFrame:
return DataFrameBatchConverter.from_typehints(
element_type=element_type, batch_type=batch_type)
elif batch_type == pd.Series:
return SeriesBatchConverter.from_typehints(
element_type=element_type, batch_type=batch_type)

return None


class DataFrameBatchConverter(BatchConverter):
def __init__(
self,
Expand All @@ -145,7 +158,6 @@ def __init__(
self._columns = [name for name, _ in element_type._fields]

@staticmethod
@BatchConverter.register
def from_typehints(element_type,
batch_type) -> Optional['DataFrameBatchConverter']:
if not batch_type == pd.DataFrame:
Expand Down Expand Up @@ -261,16 +273,12 @@ def unbatch(series):
self.explode_batch = unbatch

@staticmethod
@BatchConverter.register
def from_typehints(element_type,
batch_type) -> Optional['SeriesBatchConverter']:
if not batch_type == pd.Series:
return None

dtype = dtype_from_typehint(element_type)
if dtype == np.object:
# Don't create Any <-> Series[np.object] mapping
return None

return SeriesBatchConverter(element_type, dtype)

Expand Down

0 comments on commit 1e2c800

Please sign in to comment.