Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-43473][PYTHON] Support struct type in createDataFrame from pandas DataFrame #41149

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 1 addition & 3 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,7 @@ def createDataFrame(
"spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely"
)

ser = ArrowStreamPandasSerializer(
cast(str, timezone), safecheck == "true", assign_cols_by_name=True
)
ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true")

_table = pa.Table.from_batches(
[ser._create_batch([(c, t) for (_, c), t in zip(data.items(), arrow_types)])]
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,8 +533,7 @@ def _create_from_pandas_with_arrow(
jsparkSession = self._jsparkSession

safecheck = self._jconf.arrowSafeTypeConversion()
col_by_name = True # col by name only applies to StructType columns, can't happen here
ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)
ser = ArrowStreamPandasSerializer(timezone, safecheck)

@no_type_check
def reader_func(temp_filename):
Expand Down
203 changes: 126 additions & 77 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details.
"""

from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.types import StringType, StructType, BinaryType, StructField, LongType
Expand Down Expand Up @@ -161,11 +162,10 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
If True, then Pandas DataFrames will get columns by name
"""

def __init__(self, timezone, safecheck, assign_cols_by_name):
def __init__(self, timezone, safecheck):
super(ArrowStreamPandasSerializer, self).__init__()
self._timezone = timezone
self._safecheck = safecheck
self._assign_cols_by_name = assign_cols_by_name

def arrow_to_pandas(self, arrow_column):
from pyspark.sql.pandas.types import (
Expand All @@ -186,6 +186,65 @@ def arrow_to_pandas(self, arrow_column):
else:
return s

def _create_array(self, series, arrow_type):
"""
Create an Arrow Array from the given pandas.Series and optional type.

Parameters
----------
series : pandas.Series
A single series
arrow_type : pyarrow.DataType, optional
If None, pyarrow's inferred type will be used

Returns
-------
pyarrow.Array
"""
import pyarrow as pa
from pyspark.sql.pandas.types import (
_check_series_convert_timestamps_internal,
_convert_dict_to_map_items,
)
from pandas.api.types import is_categorical_dtype

if hasattr(series.array, "__arrow_array__"):
mask = None
else:
mask = series.isnull()
# Ensure timestamp series are in expected form for Spark internal representation
if (
arrow_type is not None
and pa.types.is_timestamp(arrow_type)
and arrow_type.tz is not None
):
series = _check_series_convert_timestamps_internal(series, self._timezone)
elif arrow_type is not None and pa.types.is_map(arrow_type):
series = _convert_dict_to_map_items(series)
elif arrow_type is None and is_categorical_dtype(series.dtype):
series = series.astype(series.dtypes.categories.dtype)
try:
return pa.Array.from_pandas(series, mask=mask, type=arrow_type, safe=self._safecheck)
except TypeError as e:
error_msg = (
"Exception thrown when converting pandas.Series (%s) "
"with name '%s' to Arrow Array (%s)."
)
raise PySparkTypeError(error_msg % (series.dtype, series.name, arrow_type)) from e
except ValueError as e:
error_msg = (
"Exception thrown when converting pandas.Series (%s) "
"with name '%s' to Arrow Array (%s)."
)
if self._safecheck:
error_msg = error_msg + (
" It can be caused by overflows or other "
"unsafe conversions warned by Arrow. Arrow safe type check "
"can be disabled by using SQL config "
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
)
raise PySparkValueError(error_msg % (series.dtype, series.name, arrow_type)) from e

def _create_batch(self, series):
"""
Create an Arrow record batch from the given pandas.Series or list of Series,
Expand All @@ -201,13 +260,7 @@ def _create_batch(self, series):
pyarrow.RecordBatch
Arrow RecordBatch
"""
import pandas as pd
import pyarrow as pa
from pyspark.sql.pandas.types import (
_check_series_convert_timestamps_internal,
_convert_dict_to_map_items,
)
from pandas.api.types import is_categorical_dtype

# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or (
Expand All @@ -216,72 +269,7 @@ def _create_batch(self, series):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)

def create_array(s, t):
if hasattr(s.array, "__arrow_array__"):
mask = None
else:
mask = s.isnull()
# Ensure timestamp series are in expected form for Spark internal representation
if t is not None and pa.types.is_timestamp(t) and t.tz is not None:
s = _check_series_convert_timestamps_internal(s, self._timezone)
elif t is not None and pa.types.is_map(t):
s = _convert_dict_to_map_items(s)
elif t is None and is_categorical_dtype(s.dtype):
s = s.astype(s.dtypes.categories.dtype)
try:
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
except TypeError as e:
error_msg = (
"Exception thrown when converting pandas.Series (%s) "
"with name '%s' to Arrow Array (%s)."
)
raise TypeError(error_msg % (s.dtype, s.name, t)) from e
except ValueError as e:
error_msg = (
"Exception thrown when converting pandas.Series (%s) "
"with name '%s' to Arrow Array (%s)."
)
if self._safecheck:
error_msg = error_msg + (
" It can be caused by overflows or other "
"unsafe conversions warned by Arrow. Arrow safe type check "
"can be disabled by using SQL config "
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
)
raise ValueError(error_msg % (s.dtype, s.name, t)) from e
return array

arrs = []
for s, t in series:
if t is not None and pa.types.is_struct(t):
if not isinstance(s, pd.DataFrame):
raise ValueError(
"A field of type StructType expects a pandas.DataFrame, "
"but got: %s" % str(type(s))
)

# Input partition and result pandas.DataFrame empty, make empty Arrays with struct
if len(s) == 0 and len(s.columns) == 0:
arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
# Assign result columns by schema name if user labeled with strings
elif self._assign_cols_by_name and any(isinstance(name, str) for name in s.columns):
arrs_names = [
(create_array(s[field.name], field.type), field.name) for field in t
]
# Assign result columns by position
else:
arrs_names = [
# the selected series has name '1', so we rename it to field.name
# as the name is used by create_array to provide a meaningful error message
(create_array(s[s.columns[i]].rename(field.name), field.type), field.name)
for i, field in enumerate(t)
]

struct_arrs, struct_names = zip(*arrs_names)
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
else:
arrs.append(create_array(s, t))

arrs = [self._create_array(s, t) for s, t in series]
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])

def dump_stream(self, iterator, stream):
Expand Down Expand Up @@ -312,9 +300,8 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
"""

def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False):
super(ArrowStreamPandasUDFSerializer, self).__init__(
timezone, safecheck, assign_cols_by_name
)
super(ArrowStreamPandasUDFSerializer, self).__init__(timezone, safecheck)
self._assign_cols_by_name = assign_cols_by_name
self._df_for_struct = df_for_struct

def arrow_to_pandas(self, arrow_column):
Expand All @@ -334,6 +321,68 @@ def arrow_to_pandas(self, arrow_column):
s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column)
return s

def _create_batch(self, series):
"""
Create an Arrow record batch from the given pandas.Series pandas.DataFrame
or list of Series or DataFrame, with optional type.

Parameters
----------
series : pandas.Series or pandas.DataFrame or list
A single series or dataframe, list of series or dataframe,
or list of (series or dataframe, arrow_type)

Returns
-------
pyarrow.RecordBatch
Arrow RecordBatch
"""
import pandas as pd
import pyarrow as pa

# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or (
len(series) == 2 and isinstance(series[1], pa.DataType)
):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)

arrs = []
for s, t in series:
if t is not None and pa.types.is_struct(t):
if not isinstance(s, pd.DataFrame):
raise PySparkValueError(
"A field of type StructType expects a pandas.DataFrame, "
"but got: %s" % str(type(s))
)

# Input partition and result pandas.DataFrame empty, make empty Arrays with struct
if len(s) == 0 and len(s.columns) == 0:
arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
# Assign result columns by schema name if user labeled with strings
elif self._assign_cols_by_name and any(isinstance(name, str) for name in s.columns):
arrs_names = [
(self._create_array(s[field.name], field.type), field.name) for field in t
]
# Assign result columns by position
else:
arrs_names = [
# the selected series has name '1', so we rename it to field.name
# as the name is used by _create_array to provide a meaningful error message
(
self._create_array(s[s.columns[i]].rename(field.name), field.type),
field.name,
)
for i, field in enumerate(t)
]

struct_arrs, struct_names = zip(*arrs_names)
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
else:
arrs.append(self._create_array(s, t))

return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])

def dump_stream(self, iterator, stream):
"""
Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
Expand Down
26 changes: 26 additions & 0 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,32 @@ def check_createDataFrame_with_map_type(self, arrow_enabled):
i, m = row
self.assertEqual(m, map_data[i])

def test_createDataFrame_with_struct_type(self):
for arrow_enabled in [True, False]:
with self.subTest(arrow_enabled=arrow_enabled):
self.check_createDataFrame_with_struct_type(arrow_enabled)

def check_createDataFrame_with_struct_type(self, arrow_enabled):
pdf = pd.DataFrame(
{"a": [Row(1, "a"), Row(2, "b")], "b": [{"s": 3, "t": "x"}, {"s": 4, "t": "y"}]}
)
for schema in (
"a struct<x int, y string>, b struct<s int, t string>",
StructType()
.add("a", StructType().add("x", LongType()).add("y", StringType()))
.add("b", StructType().add("s", LongType()).add("t", StringType())),
):
with self.subTest(schema=schema):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
df = self.spark.createDataFrame(pdf, schema)
result = df.collect()
expected = [(rec[0], Row(**rec[1])) for rec in pdf.to_records(index=False)]
for r in range(len(expected)):
for e in range(len(expected[r])):
self.assertTrue(
expected[r][e] == result[r][e], f"{expected[r][e]} == {result[r][e]}"
)

def test_createDataFrame_with_string_dtype(self):
# SPARK-34521: spark.createDataFrame does not support Pandas StringDtype extension type
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}):
Expand Down