Skip to content

Commit

Permalink
[SPARK-43473][PYTHON] Support struct type in createDataFrame from pan…
Browse files Browse the repository at this point in the history
…das DataFrame

### What changes were proposed in this pull request?

Supports struct type in `createDataFrame` from pandas DataFrame.

With Arrow optimization, it works without the fallback:

```py
>>> import pandas as pd
>>> from pyspark.sql.types import Row
>>>
>>> pdf = pd.DataFrame(
...     {"a": [Row(1, "a"), Row(2, "b")], "b": [{"s": 3, "t": "x"}, {"s": 4, "t": "y"}]}
... )
>>> schema = "a struct<x int, y string>, b struct<s int, t string>"
>>>
>>> df = spark.createDataFrame(pdf, schema)
>>> df.show()
+------+------+
|     a|     b|
+------+------+
|{1, a}|{3, x}|
|{2, b}|{4, y}|
+------+------+
```

and Spark Connect also works.

### Why are the changes needed?

In vanilla PySpark without Arrow optimization, `Row` object or `dict` can be handled as struct type if the schema is provided:

```py
>>> import pandas as pd
>>> from pyspark.sql.types import *
>>>
>>> pdf = pd.DataFrame(
...     {"a": [Row(1, "a"), Row(2, "b")], "b": [{"s": 3, "t": "x"}, {"s": 4, "t": "y"}]}
... )
>>> schema = "a struct<x int, y string>, b struct<s int, t string>"
>>>
>>> df = spark.createDataFrame(pdf, schema)
>>> df.show()
+------+------+
|     a|     b|
+------+------+
|{1, a}|{3, x}|
|{2, b}|{4, y}|
+------+------+
```

Whereas with Arrow, it uses a fallback to make it:

```py
>>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', True)
>>> spark.createDataFrame(pdf, schema).show()
/.../pyspark/sql/pandas/conversion.py:329: UserWarning: createDataFrame attempted Arrow optimization because 'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, failed by the reason below:
  A field of type StructType expects a pandas.DataFrame, but got: <class 'pandas.core.series.Series'>
Attempting non-optimization as 'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to true.
  warn(msg)
+------+------+
|     a|     b|
+------+------+
|{1, a}|{3, x}|
|{2, b}|{4, y}|
+------+------+
```

and Spark Connect fails:

```py
>>> df = spark.createDataFrame(pdf, schema)
Traceback (most recent call last):
...
ValueError: A field of type StructType expects a pandas.DataFrame, but got: <class 'pandas.core.series.Series'>
```

### Does this PR introduce _any_ user-facing change?

`Row` object or `dict` in pandas DataFrame works as struct type when `createDataFrame` if the schema is provided.

### How was this patch tested?

Added the related test.

Closes #41149 from ueshin/issues/SPARK-43473/rows.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
ueshin authored and HyukjinKwon committed May 16, 2023
1 parent d53ddbe commit 6221995
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 82 deletions.
4 changes: 1 addition & 3 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,7 @@ def createDataFrame(
"spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely"
)

ser = ArrowStreamPandasSerializer(
cast(str, timezone), safecheck == "true", assign_cols_by_name=True
)
ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true")

_table = pa.Table.from_batches(
[ser._create_batch([(c, t) for (_, c), t in zip(data.items(), arrow_types)])]
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,8 +533,7 @@ def _create_from_pandas_with_arrow(
jsparkSession = self._jsparkSession

safecheck = self._jconf.arrowSafeTypeConversion()
col_by_name = True # col by name only applies to StructType columns, can't happen here
ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)
ser = ArrowStreamPandasSerializer(timezone, safecheck)

@no_type_check
def reader_func(temp_filename):
Expand Down
203 changes: 126 additions & 77 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details.
"""

from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.types import StringType, StructType, BinaryType, StructField, LongType
Expand Down Expand Up @@ -161,11 +162,10 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
If True, then Pandas DataFrames will get columns by name
"""

def __init__(self, timezone, safecheck, assign_cols_by_name):
def __init__(self, timezone, safecheck):
super(ArrowStreamPandasSerializer, self).__init__()
self._timezone = timezone
self._safecheck = safecheck
self._assign_cols_by_name = assign_cols_by_name

def arrow_to_pandas(self, arrow_column):
from pyspark.sql.pandas.types import (
Expand All @@ -186,6 +186,65 @@ def arrow_to_pandas(self, arrow_column):
else:
return s

def _create_array(self, series, arrow_type):
"""
Create an Arrow Array from the given pandas.Series and optional type.
Parameters
----------
series : pandas.Series
A single series
arrow_type : pyarrow.DataType, optional
If None, pyarrow's inferred type will be used
Returns
-------
pyarrow.Array
"""
import pyarrow as pa
from pyspark.sql.pandas.types import (
_check_series_convert_timestamps_internal,
_convert_dict_to_map_items,
)
from pandas.api.types import is_categorical_dtype

if hasattr(series.array, "__arrow_array__"):
mask = None
else:
mask = series.isnull()
# Ensure timestamp series are in expected form for Spark internal representation
if (
arrow_type is not None
and pa.types.is_timestamp(arrow_type)
and arrow_type.tz is not None
):
series = _check_series_convert_timestamps_internal(series, self._timezone)
elif arrow_type is not None and pa.types.is_map(arrow_type):
series = _convert_dict_to_map_items(series)
elif arrow_type is None and is_categorical_dtype(series.dtype):
series = series.astype(series.dtypes.categories.dtype)
try:
return pa.Array.from_pandas(series, mask=mask, type=arrow_type, safe=self._safecheck)
except TypeError as e:
error_msg = (
"Exception thrown when converting pandas.Series (%s) "
"with name '%s' to Arrow Array (%s)."
)
raise PySparkTypeError(error_msg % (series.dtype, series.name, arrow_type)) from e
except ValueError as e:
error_msg = (
"Exception thrown when converting pandas.Series (%s) "
"with name '%s' to Arrow Array (%s)."
)
if self._safecheck:
error_msg = error_msg + (
" It can be caused by overflows or other "
"unsafe conversions warned by Arrow. Arrow safe type check "
"can be disabled by using SQL config "
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
)
raise PySparkValueError(error_msg % (series.dtype, series.name, arrow_type)) from e

def _create_batch(self, series):
"""
Create an Arrow record batch from the given pandas.Series or list of Series,
Expand All @@ -201,13 +260,7 @@ def _create_batch(self, series):
pyarrow.RecordBatch
Arrow RecordBatch
"""
import pandas as pd
import pyarrow as pa
from pyspark.sql.pandas.types import (
_check_series_convert_timestamps_internal,
_convert_dict_to_map_items,
)
from pandas.api.types import is_categorical_dtype

# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or (
Expand All @@ -216,72 +269,7 @@ def _create_batch(self, series):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)

def create_array(s, t):
if hasattr(s.array, "__arrow_array__"):
mask = None
else:
mask = s.isnull()
# Ensure timestamp series are in expected form for Spark internal representation
if t is not None and pa.types.is_timestamp(t) and t.tz is not None:
s = _check_series_convert_timestamps_internal(s, self._timezone)
elif t is not None and pa.types.is_map(t):
s = _convert_dict_to_map_items(s)
elif t is None and is_categorical_dtype(s.dtype):
s = s.astype(s.dtypes.categories.dtype)
try:
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
except TypeError as e:
error_msg = (
"Exception thrown when converting pandas.Series (%s) "
"with name '%s' to Arrow Array (%s)."
)
raise TypeError(error_msg % (s.dtype, s.name, t)) from e
except ValueError as e:
error_msg = (
"Exception thrown when converting pandas.Series (%s) "
"with name '%s' to Arrow Array (%s)."
)
if self._safecheck:
error_msg = error_msg + (
" It can be caused by overflows or other "
"unsafe conversions warned by Arrow. Arrow safe type check "
"can be disabled by using SQL config "
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
)
raise ValueError(error_msg % (s.dtype, s.name, t)) from e
return array

arrs = []
for s, t in series:
if t is not None and pa.types.is_struct(t):
if not isinstance(s, pd.DataFrame):
raise ValueError(
"A field of type StructType expects a pandas.DataFrame, "
"but got: %s" % str(type(s))
)

# Input partition and result pandas.DataFrame empty, make empty Arrays with struct
if len(s) == 0 and len(s.columns) == 0:
arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
# Assign result columns by schema name if user labeled with strings
elif self._assign_cols_by_name and any(isinstance(name, str) for name in s.columns):
arrs_names = [
(create_array(s[field.name], field.type), field.name) for field in t
]
# Assign result columns by position
else:
arrs_names = [
# the selected series has name '1', so we rename it to field.name
# as the name is used by create_array to provide a meaningful error message
(create_array(s[s.columns[i]].rename(field.name), field.type), field.name)
for i, field in enumerate(t)
]

struct_arrs, struct_names = zip(*arrs_names)
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
else:
arrs.append(create_array(s, t))

arrs = [self._create_array(s, t) for s, t in series]
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])

def dump_stream(self, iterator, stream):
Expand Down Expand Up @@ -312,9 +300,8 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
"""

def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False):
super(ArrowStreamPandasUDFSerializer, self).__init__(
timezone, safecheck, assign_cols_by_name
)
super(ArrowStreamPandasUDFSerializer, self).__init__(timezone, safecheck)
self._assign_cols_by_name = assign_cols_by_name
self._df_for_struct = df_for_struct

def arrow_to_pandas(self, arrow_column):
Expand All @@ -334,6 +321,68 @@ def arrow_to_pandas(self, arrow_column):
s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column)
return s

def _create_batch(self, series):
"""
Create an Arrow record batch from the given pandas.Series pandas.DataFrame
or list of Series or DataFrame, with optional type.
Parameters
----------
series : pandas.Series or pandas.DataFrame or list
A single series or dataframe, list of series or dataframe,
or list of (series or dataframe, arrow_type)
Returns
-------
pyarrow.RecordBatch
Arrow RecordBatch
"""
import pandas as pd
import pyarrow as pa

# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or (
len(series) == 2 and isinstance(series[1], pa.DataType)
):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)

arrs = []
for s, t in series:
if t is not None and pa.types.is_struct(t):
if not isinstance(s, pd.DataFrame):
raise PySparkValueError(
"A field of type StructType expects a pandas.DataFrame, "
"but got: %s" % str(type(s))
)

# Input partition and result pandas.DataFrame empty, make empty Arrays with struct
if len(s) == 0 and len(s.columns) == 0:
arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
# Assign result columns by schema name if user labeled with strings
elif self._assign_cols_by_name and any(isinstance(name, str) for name in s.columns):
arrs_names = [
(self._create_array(s[field.name], field.type), field.name) for field in t
]
# Assign result columns by position
else:
arrs_names = [
# the selected series has name '1', so we rename it to field.name
# as the name is used by _create_array to provide a meaningful error message
(
self._create_array(s[s.columns[i]].rename(field.name), field.type),
field.name,
)
for i, field in enumerate(t)
]

struct_arrs, struct_names = zip(*arrs_names)
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
else:
arrs.append(self._create_array(s, t))

return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])

def dump_stream(self, iterator, stream):
"""
Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
Expand Down
26 changes: 26 additions & 0 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,32 @@ def check_createDataFrame_with_map_type(self, arrow_enabled):
i, m = row
self.assertEqual(m, map_data[i])

def test_createDataFrame_with_struct_type(self):
for arrow_enabled in [True, False]:
with self.subTest(arrow_enabled=arrow_enabled):
self.check_createDataFrame_with_struct_type(arrow_enabled)

def check_createDataFrame_with_struct_type(self, arrow_enabled):
pdf = pd.DataFrame(
{"a": [Row(1, "a"), Row(2, "b")], "b": [{"s": 3, "t": "x"}, {"s": 4, "t": "y"}]}
)
for schema in (
"a struct<x int, y string>, b struct<s int, t string>",
StructType()
.add("a", StructType().add("x", LongType()).add("y", StringType()))
.add("b", StructType().add("s", LongType()).add("t", StringType())),
):
with self.subTest(schema=schema):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
df = self.spark.createDataFrame(pdf, schema)
result = df.collect()
expected = [(rec[0], Row(**rec[1])) for rec in pdf.to_records(index=False)]
for r in range(len(expected)):
for e in range(len(expected[r])):
self.assertTrue(
expected[r][e] == result[r][e], f"{expected[r][e]} == {result[r][e]}"
)

def test_createDataFrame_with_string_dtype(self):
# SPARK-34521: spark.createDataFrame does not support Pandas StringDtype extension type
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}):
Expand Down

0 comments on commit 6221995

Please sign in to comment.