Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 53 additions & 16 deletions python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
_create_row,
StringType,
)
from pyspark.sql.utils import is_timestamp_ntz_preferred
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.errors import PySparkTypeError, PySparkValueError

Expand Down Expand Up @@ -395,7 +394,28 @@ def createDataFrame( # type: ignore[misc]

assert isinstance(self, SparkSession)

timezone = self._jconf.sessionLocalTimeZone()
(
timestampType,
sessionLocalTimeZone,
arrowPySparkEnabled,
arrowUseLargeVarTypes,
arrowPySparkFallbackEnabled,
arrowMaxRecordsPerBatch,
) = self._jconf.getConfs(
[
"spark.sql.timestampType",
"spark.sql.session.timeZone",
"spark.sql.execution.arrow.pyspark.enabled",
"spark.sql.execution.arrow.useLargeVarTypes",
"spark.sql.execution.arrow.pyspark.fallback.enabled",
"spark.sql.execution.arrow.maxRecordsPerBatch",
]
)

prefer_timestamp_ntz = timestampType == "TIMESTAMP_NTZ"
prefers_large_var_types = arrowUseLargeVarTypes == "true"
timezone = sessionLocalTimeZone
arrow_batch_size = int(arrowMaxRecordsPerBatch)

if type(data).__name__ == "Table":
# `data` is a PyArrow Table
Expand All @@ -411,7 +431,7 @@ def createDataFrame( # type: ignore[misc]
if schema is None:
schema = data.schema.names

return self._create_from_arrow_table(data, schema, timezone)
return self._create_from_arrow_table(data, schema, timezone, prefer_timestamp_ntz)

# `data` is a PandasDataFrameLike object
from pyspark.sql.pandas.utils import require_minimum_pandas_version
Expand All @@ -422,11 +442,18 @@ def createDataFrame( # type: ignore[misc]
if schema is None:
schema = [str(x) if not isinstance(x, str) else x for x in data.columns]

if self._jconf.arrowPySparkEnabled() and len(data) > 0:
if arrowPySparkEnabled == "true" and len(data) > 0:
try:
return self._create_from_pandas_with_arrow(data, schema, timezone)
return self._create_from_pandas_with_arrow(
data,
schema,
timezone,
prefer_timestamp_ntz,
prefers_large_var_types,
arrow_batch_size,
)
except Exception as e:
if self._jconf.arrowPySparkFallbackEnabled():
if arrowPySparkFallbackEnabled == "true":
msg = (
"createDataFrame attempted Arrow optimization because "
"'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, "
Expand All @@ -446,11 +473,15 @@ def createDataFrame( # type: ignore[misc]
)
warn(msg)
raise
converted_data = self._convert_from_pandas(data, schema, timezone)
converted_data = self._convert_from_pandas(data, schema, timezone, prefer_timestamp_ntz)
return self._create_dataframe(converted_data, schema, samplingRatio, verifySchema)

def _convert_from_pandas(
self, pdf: "PandasDataFrameLike", schema: Union[StructType, str, List[str]], timezone: str
self,
pdf: "PandasDataFrameLike",
schema: Union[StructType, str, List[str]],
timezone: str,
prefer_timestamp_ntz: bool,
) -> List:
"""
Convert a pandas.DataFrame to list of records that can be used to make a DataFrame
Expand Down Expand Up @@ -566,7 +597,7 @@ def convert_timestamp(value: Any) -> Any:
)
copied = True
else:
should_localize = not is_timestamp_ntz_preferred()
should_localize = not prefer_timestamp_ntz
for column, series in pdf.items():
s = series
if (
Expand Down Expand Up @@ -643,7 +674,13 @@ def _get_numpy_record_dtype(self, rec: "np.recarray") -> Optional["np.dtype"]:
return np.dtype(record_type_list) if has_rec_fix else None

def _create_from_pandas_with_arrow(
self, pdf: "PandasDataFrameLike", schema: Union[StructType, List[str]], timezone: str
self,
pdf: "PandasDataFrameLike",
schema: Union[StructType, List[str]],
timezone: str,
prefer_timestamp_ntz: bool,
prefers_large_var_types: bool,
arrow_batch_size: int,
) -> "DataFrame":
"""
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
Expand Down Expand Up @@ -683,7 +720,6 @@ def _create_from_pandas_with_arrow(
# Create the Spark schema from list of names passed in with Arrow types
if isinstance(schema, (list, tuple)):
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
prefer_timestamp_ntz = is_timestamp_ntz_preferred()
struct = StructType()
if infer_pandas_dict_as_map:
spark_type: Union[MapType, DataType]
Expand Down Expand Up @@ -729,12 +765,11 @@ def _create_from_pandas_with_arrow(
]

# Slice the DataFrame to be batched
step = self._jconf.arrowMaxRecordsPerBatch()
step = arrow_batch_size
step = step if step > 0 else len(pdf)
pdf_slices = (pdf.iloc[start : start + step] for start in range(0, len(pdf), step))

# Create list of Arrow (columns, arrow_type, spark_type) for serializer dump_stream
prefers_large_var_types = self._jconf.arrowUseLargeVarTypes()
arrow_data = [
[
(
Expand Down Expand Up @@ -771,7 +806,11 @@ def create_iter_server():
return df

def _create_from_arrow_table(
self, table: "pa.Table", schema: Union[StructType, List[str]], timezone: str
self,
table: "pa.Table",
schema: Union[StructType, List[str]],
timezone: str,
prefer_timestamp_ntz: bool,
) -> "DataFrame":
"""
Create a DataFrame from a given pyarrow.Table by slicing it into partitions then
Expand All @@ -793,8 +832,6 @@ def _create_from_arrow_table(

require_minimum_pyarrow_version()

prefer_timestamp_ntz = is_timestamp_ntz_preferred()

# Create the Spark schema from list of names passed in with Arrow types
if isinstance(schema, (list, tuple)):
table = table.rename_columns(schema)
Expand Down