diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index c23b6c5d11a9f..ee3ab1702c43e 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -353,9 +353,7 @@ def createDataFrame( "spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely" ) - ser = ArrowStreamPandasSerializer( - cast(str, timezone), safecheck == "true", assign_cols_by_name=True - ) + ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true") _table = pa.Table.from_batches( [ser._create_batch([(c, t) for (_, c), t in zip(data.items(), arrow_types)])] diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 59b7a8524e5ed..089fd10a45d44 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -533,8 +533,7 @@ def _create_from_pandas_with_arrow( jsparkSession = self._jsparkSession safecheck = self._jconf.arrowSafeTypeConversion() - col_by_name = True # col by name only applies to StructType columns, can't happen here - ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name) + ser = ArrowStreamPandasSerializer(timezone, safecheck) @no_type_check def reader_func(temp_filename): diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 30c65ce42ba50..e86004884bcda 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -19,6 +19,7 @@ Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details. """ +from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import StringType, StructType, BinaryType, StructField, LongType @@ -161,11 +162,10 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer): If True, then Pandas DataFrames will get columns by name """ - def __init__(self, timezone, safecheck, assign_cols_by_name): + def __init__(self, timezone, safecheck): super(ArrowStreamPandasSerializer, self).__init__() self._timezone = timezone self._safecheck = safecheck - self._assign_cols_by_name = assign_cols_by_name def arrow_to_pandas(self, arrow_column): from pyspark.sql.pandas.types import ( @@ -186,6 +186,65 @@ def arrow_to_pandas(self, arrow_column): else: return s + def _create_array(self, series, arrow_type): + """ + Create an Arrow Array from the given pandas.Series and optional type. + + Parameters + ---------- + series : pandas.Series + A single series + arrow_type : pyarrow.DataType, optional + If None, pyarrow's inferred type will be used + + Returns + ------- + pyarrow.Array + """ + import pyarrow as pa + from pyspark.sql.pandas.types import ( + _check_series_convert_timestamps_internal, + _convert_dict_to_map_items, + ) + from pandas.api.types import is_categorical_dtype + + if hasattr(series.array, "__arrow_array__"): + mask = None + else: + mask = series.isnull() + # Ensure timestamp series are in expected form for Spark internal representation + if ( + arrow_type is not None + and pa.types.is_timestamp(arrow_type) + and arrow_type.tz is not None + ): + series = _check_series_convert_timestamps_internal(series, self._timezone) + elif arrow_type is not None and pa.types.is_map(arrow_type): + series = _convert_dict_to_map_items(series) + elif arrow_type is None and is_categorical_dtype(series.dtype): + series = series.astype(series.dtypes.categories.dtype) + try: + return pa.Array.from_pandas(series, mask=mask, type=arrow_type, safe=self._safecheck) + except TypeError as e: + error_msg = ( + "Exception thrown when converting pandas.Series (%s) " + "with name '%s' to Arrow Array (%s)." + ) + raise PySparkTypeError(error_msg % (series.dtype, series.name, arrow_type)) from e + except ValueError as e: + error_msg = ( + "Exception thrown when converting pandas.Series (%s) " + "with name '%s' to Arrow Array (%s)." + ) + if self._safecheck: + error_msg = error_msg + ( + " It can be caused by overflows or other " + "unsafe conversions warned by Arrow. Arrow safe type check " + "can be disabled by using SQL config " + "`spark.sql.execution.pandas.convertToArrowArraySafely`." + ) + raise PySparkValueError(error_msg % (series.dtype, series.name, arrow_type)) from e + def _create_batch(self, series): """ Create an Arrow record batch from the given pandas.Series or list of Series, @@ -201,13 +260,7 @@ def _create_batch(self, series): pyarrow.RecordBatch Arrow RecordBatch """ - import pandas as pd import pyarrow as pa - from pyspark.sql.pandas.types import ( - _check_series_convert_timestamps_internal, - _convert_dict_to_map_items, - ) - from pandas.api.types import is_categorical_dtype # Make input conform to [(series1, type1), (series2, type2), ...] if not isinstance(series, (list, tuple)) or ( @@ -216,72 +269,7 @@ def _create_batch(self, series): series = [series] series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - def create_array(s, t): - if hasattr(s.array, "__arrow_array__"): - mask = None - else: - mask = s.isnull() - # Ensure timestamp series are in expected form for Spark internal representation - if t is not None and pa.types.is_timestamp(t) and t.tz is not None: - s = _check_series_convert_timestamps_internal(s, self._timezone) - elif t is not None and pa.types.is_map(t): - s = _convert_dict_to_map_items(s) - elif t is None and is_categorical_dtype(s.dtype): - s = s.astype(s.dtypes.categories.dtype) - try: - array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck) - except TypeError as e: - error_msg = ( - "Exception thrown when converting pandas.Series (%s) " - "with name '%s' to Arrow Array (%s)." - ) - raise TypeError(error_msg % (s.dtype, s.name, t)) from e - except ValueError as e: - error_msg = ( - "Exception thrown when converting pandas.Series (%s) " - "with name '%s' to Arrow Array (%s)." - ) - if self._safecheck: - error_msg = error_msg + ( - " It can be caused by overflows or other " - "unsafe conversions warned by Arrow. Arrow safe type check " - "can be disabled by using SQL config " - "`spark.sql.execution.pandas.convertToArrowArraySafely`." - ) - raise ValueError(error_msg % (s.dtype, s.name, t)) from e - return array - - arrs = [] - for s, t in series: - if t is not None and pa.types.is_struct(t): - if not isinstance(s, pd.DataFrame): - raise ValueError( - "A field of type StructType expects a pandas.DataFrame, " - "but got: %s" % str(type(s)) - ) - - # Input partition and result pandas.DataFrame empty, make empty Arrays with struct - if len(s) == 0 and len(s.columns) == 0: - arrs_names = [(pa.array([], type=field.type), field.name) for field in t] - # Assign result columns by schema name if user labeled with strings - elif self._assign_cols_by_name and any(isinstance(name, str) for name in s.columns): - arrs_names = [ - (create_array(s[field.name], field.type), field.name) for field in t - ] - # Assign result columns by position - else: - arrs_names = [ - # the selected series has name '1', so we rename it to field.name - # as the name is used by create_array to provide a meaningful error message - (create_array(s[s.columns[i]].rename(field.name), field.type), field.name) - for i, field in enumerate(t) - ] - - struct_arrs, struct_names = zip(*arrs_names) - arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names)) - else: - arrs.append(create_array(s, t)) - + arrs = [self._create_array(s, t) for s, t in series] return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) def dump_stream(self, iterator, stream): @@ -312,9 +300,8 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): """ def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False): - super(ArrowStreamPandasUDFSerializer, self).__init__( - timezone, safecheck, assign_cols_by_name - ) + super(ArrowStreamPandasUDFSerializer, self).__init__(timezone, safecheck) + self._assign_cols_by_name = assign_cols_by_name self._df_for_struct = df_for_struct def arrow_to_pandas(self, arrow_column): @@ -334,6 +321,68 @@ def arrow_to_pandas(self, arrow_column): s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column) return s + def _create_batch(self, series): + """ + Create an Arrow record batch from the given pandas.Series pandas.DataFrame + or list of Series or DataFrame, with optional type. + + Parameters + ---------- + series : pandas.Series or pandas.DataFrame or list + A single series or dataframe, list of series or dataframe, + or list of (series or dataframe, arrow_type) + + Returns + ------- + pyarrow.RecordBatch + Arrow RecordBatch + """ + import pandas as pd + import pyarrow as pa + + # Make input conform to [(series1, type1), (series2, type2), ...] + if not isinstance(series, (list, tuple)) or ( + len(series) == 2 and isinstance(series[1], pa.DataType) + ): + series = [series] + series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) + + arrs = [] + for s, t in series: + if t is not None and pa.types.is_struct(t): + if not isinstance(s, pd.DataFrame): + raise PySparkValueError( + "A field of type StructType expects a pandas.DataFrame, " + "but got: %s" % str(type(s)) + ) + + # Input partition and result pandas.DataFrame empty, make empty Arrays with struct + if len(s) == 0 and len(s.columns) == 0: + arrs_names = [(pa.array([], type=field.type), field.name) for field in t] + # Assign result columns by schema name if user labeled with strings + elif self._assign_cols_by_name and any(isinstance(name, str) for name in s.columns): + arrs_names = [ + (self._create_array(s[field.name], field.type), field.name) for field in t + ] + # Assign result columns by position + else: + arrs_names = [ + # the selected series has name '1', so we rename it to field.name + # as the name is used by _create_array to provide a meaningful error message + ( + self._create_array(s[s.columns[i]].rename(field.name), field.type), + field.name, + ) + for i, field in enumerate(t) + ] + + struct_arrs, struct_names = zip(*arrs_names) + arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names)) + else: + arrs.append(self._create_array(s, t)) + + return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) + def dump_stream(self, iterator, stream): """ Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 37e2ffe5bdaf9..09dc744dd58b2 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -654,6 +654,32 @@ def check_createDataFrame_with_map_type(self, arrow_enabled): i, m = row self.assertEqual(m, map_data[i]) + def test_createDataFrame_with_struct_type(self): + for arrow_enabled in [True, False]: + with self.subTest(arrow_enabled=arrow_enabled): + self.check_createDataFrame_with_struct_type(arrow_enabled) + + def check_createDataFrame_with_struct_type(self, arrow_enabled): + pdf = pd.DataFrame( + {"a": [Row(1, "a"), Row(2, "b")], "b": [{"s": 3, "t": "x"}, {"s": 4, "t": "y"}]} + ) + for schema in ( + "a struct, b struct", + StructType() + .add("a", StructType().add("x", LongType()).add("y", StringType())) + .add("b", StructType().add("s", LongType()).add("t", StringType())), + ): + with self.subTest(schema=schema): + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}): + df = self.spark.createDataFrame(pdf, schema) + result = df.collect() + expected = [(rec[0], Row(**rec[1])) for rec in pdf.to_records(index=False)] + for r in range(len(expected)): + for e in range(len(expected[r])): + self.assertTrue( + expected[r][e] == result[r][e], f"{expected[r][e]} == {result[r][e]}" + ) + def test_createDataFrame_with_string_dtype(self): # SPARK-34521: spark.createDataFrame does not support Pandas StringDtype extension type with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}):