diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 3de425505405..78d4e0fc1c4f 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -43,6 +43,7 @@ Dict, Set, NoReturn, + Mapping, cast, TYPE_CHECKING, Type, @@ -1576,6 +1577,10 @@ def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]: configs = dict(self.config(op).pairs) return tuple(configs.get(key) for key in keys) + def get_config_dict(self, *keys: str) -> Mapping[str, Optional[str]]: + op = pb2.ConfigRequest.Operation(get=pb2.ConfigRequest.Get(keys=keys)) + return dict(self.config(op).pairs) + def get_config_with_defaults( self, *pairs: Tuple[str, Optional[str]] ) -> Tuple[Optional[str], ...]: diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 83b0496a8427..bfd79092ccf4 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -15,7 +15,6 @@ # limitations under the License. # from pyspark.sql.connect.utils import check_dependencies -from pyspark.sql.utils import is_timestamp_ntz_preferred check_dependencies(__name__) @@ -37,6 +36,7 @@ cast, overload, Iterable, + Mapping, TYPE_CHECKING, ClassVar, ) @@ -407,7 +407,10 @@ def clearProgressHandlers(self) -> None: clearProgressHandlers.__doc__ = PySparkSession.clearProgressHandlers.__doc__ def _inferSchemaFromList( - self, data: Iterable[Any], names: Optional[List[str]] = None + self, + data: Iterable[Any], + names: Optional[List[str]], + configs: Mapping[str, Optional[str]], ) -> StructType: """ Infer schema from list of Row, dict, or tuple. @@ -422,12 +425,12 @@ def _inferSchemaFromList( infer_dict_as_struct, infer_array_from_first_element, infer_map_from_first_pair, - prefer_timestamp_ntz, - ) = self._client.get_configs( - "spark.sql.pyspark.inferNestedDictAsStruct.enabled", - "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled", - "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled", - "spark.sql.timestampType", + prefer_timestamp, + ) = ( + configs["spark.sql.pyspark.inferNestedDictAsStruct.enabled"], + configs["spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled"], + configs["spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled"], + configs["spark.sql.timestampType"], ) return functools.reduce( _merge_type, @@ -438,7 +441,7 @@ def _inferSchemaFromList( infer_dict_as_struct=(infer_dict_as_struct == "true"), infer_array_from_first_element=(infer_array_from_first_element == "true"), infer_map_from_first_pair=(infer_map_from_first_pair == "true"), - prefer_timestamp_ntz=(prefer_timestamp_ntz == "TIMESTAMP_NTZ"), + prefer_timestamp_ntz=(prefer_timestamp == "TIMESTAMP_NTZ"), ) for row in data ), @@ -508,8 +511,21 @@ def createDataFrame( messageParameters={}, ) + # Get all related configs in a batch + configs = self._client.get_config_dict( + "spark.sql.timestampType", + "spark.sql.session.timeZone", + "spark.sql.session.localRelationCacheThreshold", + "spark.sql.execution.pandas.convertToArrowArraySafely", + "spark.sql.execution.pandas.inferPandasDictAsMap", + "spark.sql.pyspark.inferNestedDictAsStruct.enabled", + "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled", + "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled", + ) + timezone = configs["spark.sql.session.timeZone"] + prefer_timestamp = configs["spark.sql.timestampType"] + _table: Optional[pa.Table] = None - timezone: Optional[str] = None if isinstance(data, pd.DataFrame): # Logic was borrowed from `_create_from_pandas_with_arrow` in @@ -519,8 +535,7 @@ def createDataFrame( if schema is None: _cols = [str(x) if not isinstance(x, str) else x for x in data.columns] infer_pandas_dict_as_map = ( - str(self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")).lower() - == "true" + configs["spark.sql.execution.pandas.inferPandasDictAsMap"] == "true" ) if infer_pandas_dict_as_map: struct = StructType() @@ -572,9 +587,7 @@ def createDataFrame( ] arrow_types = [to_arrow_type(dt) if dt is not None else None for dt in spark_types] - timezone, safecheck = self._client.get_configs( - "spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely" - ) + safecheck = configs["spark.sql.execution.pandas.convertToArrowArraySafely"] ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true") @@ -596,10 +609,6 @@ def createDataFrame( ).cast(arrow_schema) elif isinstance(data, pa.Table): - prefer_timestamp_ntz = is_timestamp_ntz_preferred() - - (timezone,) = self._client.get_configs("spark.sql.session.timeZone") - # If no schema supplied by user then get the names of columns only if schema is None: _cols = data.column_names @@ -609,7 +618,9 @@ def createDataFrame( _num_cols = len(_cols) if not isinstance(schema, StructType): - schema = from_arrow_schema(data.schema, prefer_timestamp_ntz=prefer_timestamp_ntz) + schema = from_arrow_schema( + data.schema, prefer_timestamp_ntz=prefer_timestamp == "TIMESTAMP_NTZ" + ) _table = ( _check_arrow_table_timestamps_localize(data, schema, True, timezone) @@ -671,7 +682,7 @@ def createDataFrame( if not isinstance(_schema, StructType): _schema = StructType().add("value", _schema) else: - _schema = self._inferSchemaFromList(_data, _cols) + _schema = self._inferSchemaFromList(_data, _cols, configs) if _cols is not None and cast(int, _num_cols) < len(_cols): _num_cols = len(_cols) @@ -706,9 +717,9 @@ def createDataFrame( else: local_relation = LocalRelation(_table) - cache_threshold = self._client.get_configs("spark.sql.session.localRelationCacheThreshold") + cache_threshold = configs["spark.sql.session.localRelationCacheThreshold"] plan: LogicalPlan = local_relation - if cache_threshold[0] is not None and int(cache_threshold[0]) <= _table.nbytes: + if cache_threshold is not None and int(cache_threshold) <= _table.nbytes: plan = CachedLocalRelation(self._cache_local_relation(local_relation)) df = DataFrame(plan, self)