diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index a4ccf4da6e8a..e494b14eda62 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -40,7 +40,6 @@ _create_row, StringType, ) -from pyspark.sql.utils import is_timestamp_ntz_preferred from pyspark.traceback_utils import SCCallSiteSync from pyspark.errors import PySparkTypeError, PySparkValueError @@ -395,7 +394,28 @@ def createDataFrame( # type: ignore[misc] assert isinstance(self, SparkSession) - timezone = self._jconf.sessionLocalTimeZone() + ( + timestampType, + sessionLocalTimeZone, + arrowPySparkEnabled, + arrowUseLargeVarTypes, + arrowPySparkFallbackEnabled, + arrowMaxRecordsPerBatch, + ) = self._jconf.getConfs( + [ + "spark.sql.timestampType", + "spark.sql.session.timeZone", + "spark.sql.execution.arrow.pyspark.enabled", + "spark.sql.execution.arrow.useLargeVarTypes", + "spark.sql.execution.arrow.pyspark.fallback.enabled", + "spark.sql.execution.arrow.maxRecordsPerBatch", + ] + ) + + prefer_timestamp_ntz = timestampType == "TIMESTAMP_NTZ" + prefers_large_var_types = arrowUseLargeVarTypes == "true" + timezone = sessionLocalTimeZone + arrow_batch_size = int(arrowMaxRecordsPerBatch) if type(data).__name__ == "Table": # `data` is a PyArrow Table @@ -411,7 +431,7 @@ def createDataFrame( # type: ignore[misc] if schema is None: schema = data.schema.names - return self._create_from_arrow_table(data, schema, timezone) + return self._create_from_arrow_table(data, schema, timezone, prefer_timestamp_ntz) # `data` is a PandasDataFrameLike object from pyspark.sql.pandas.utils import require_minimum_pandas_version @@ -422,11 +442,18 @@ def createDataFrame( # type: ignore[misc] if schema is None: schema = [str(x) if not isinstance(x, str) else x for x in data.columns] - if self._jconf.arrowPySparkEnabled() and len(data) > 0: + if arrowPySparkEnabled == "true" and len(data) > 0: try: - return self._create_from_pandas_with_arrow(data, schema, timezone) + return self._create_from_pandas_with_arrow( + data, + schema, + timezone, + prefer_timestamp_ntz, + prefers_large_var_types, + arrow_batch_size, + ) except Exception as e: - if self._jconf.arrowPySparkFallbackEnabled(): + if arrowPySparkFallbackEnabled == "true": msg = ( "createDataFrame attempted Arrow optimization because " "'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, " @@ -446,11 +473,15 @@ def createDataFrame( # type: ignore[misc] ) warn(msg) raise - converted_data = self._convert_from_pandas(data, schema, timezone) + converted_data = self._convert_from_pandas(data, schema, timezone, prefer_timestamp_ntz) return self._create_dataframe(converted_data, schema, samplingRatio, verifySchema) def _convert_from_pandas( - self, pdf: "PandasDataFrameLike", schema: Union[StructType, str, List[str]], timezone: str + self, + pdf: "PandasDataFrameLike", + schema: Union[StructType, str, List[str]], + timezone: str, + prefer_timestamp_ntz: bool, ) -> List: """ Convert a pandas.DataFrame to list of records that can be used to make a DataFrame @@ -566,7 +597,7 @@ def convert_timestamp(value: Any) -> Any: ) copied = True else: - should_localize = not is_timestamp_ntz_preferred() + should_localize = not prefer_timestamp_ntz for column, series in pdf.items(): s = series if ( @@ -643,7 +674,13 @@ def _get_numpy_record_dtype(self, rec: "np.recarray") -> Optional["np.dtype"]: return np.dtype(record_type_list) if has_rec_fix else None def _create_from_pandas_with_arrow( - self, pdf: "PandasDataFrameLike", schema: Union[StructType, List[str]], timezone: str + self, + pdf: "PandasDataFrameLike", + schema: Union[StructType, List[str]], + timezone: str, + prefer_timestamp_ntz: bool, + prefers_large_var_types: bool, + arrow_batch_size: int, ) -> "DataFrame": """ Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting @@ -683,7 +720,6 @@ def _create_from_pandas_with_arrow( # Create the Spark schema from list of names passed in with Arrow types if isinstance(schema, (list, tuple)): arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False) - prefer_timestamp_ntz = is_timestamp_ntz_preferred() struct = StructType() if infer_pandas_dict_as_map: spark_type: Union[MapType, DataType] @@ -729,12 +765,11 @@ def _create_from_pandas_with_arrow( ] # Slice the DataFrame to be batched - step = self._jconf.arrowMaxRecordsPerBatch() + step = arrow_batch_size step = step if step > 0 else len(pdf) pdf_slices = (pdf.iloc[start : start + step] for start in range(0, len(pdf), step)) # Create list of Arrow (columns, arrow_type, spark_type) for serializer dump_stream - prefers_large_var_types = self._jconf.arrowUseLargeVarTypes() arrow_data = [ [ ( @@ -771,7 +806,11 @@ def create_iter_server(): return df def _create_from_arrow_table( - self, table: "pa.Table", schema: Union[StructType, List[str]], timezone: str + self, + table: "pa.Table", + schema: Union[StructType, List[str]], + timezone: str, + prefer_timestamp_ntz: bool, ) -> "DataFrame": """ Create a DataFrame from a given pyarrow.Table by slicing it into partitions then @@ -793,8 +832,6 @@ def _create_from_arrow_table( require_minimum_pyarrow_version() - prefer_timestamp_ntz = is_timestamp_ntz_preferred() - # Create the Spark schema from list of names passed in with Arrow types if isinstance(schema, (list, tuple)): table = table.rename_columns(schema)