From ab91dcb26feaa44daf4c4c4ae7c7a5d91adf8021 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 19 Nov 2024 10:23:23 +0800 Subject: [PATCH 1/4] resolve conflicts --- python/pyspark/sql/connect/session.py | 59 +++++++++++++++++++-------- python/pyspark/sql/connect/utils.py | 28 +++++++++++++ 2 files changed, 69 insertions(+), 18 deletions(-) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 83b0496a8427..a609f2ceaef4 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -15,7 +15,6 @@ # limitations under the License. # from pyspark.sql.connect.utils import check_dependencies -from pyspark.sql.utils import is_timestamp_ntz_preferred check_dependencies(__name__) @@ -112,6 +111,7 @@ from pyspark.sql.connect.tvf import TableValuedFunction from pyspark.sql.connect.shell.progress import ProgressHandler from pyspark.sql.connect.datasource import DataSourceRegistration + from pyspark.sql.connect.utils import LazyConfigGetter try: import memory_profiler # noqa: F401 @@ -407,7 +407,10 @@ def clearProgressHandlers(self) -> None: clearProgressHandlers.__doc__ = PySparkSession.clearProgressHandlers.__doc__ def _inferSchemaFromList( - self, data: Iterable[Any], names: Optional[List[str]] = None + self, + data: Iterable[Any], + names: Optional[List[str]], + conf_getter: "LazyConfigGetter", ) -> StructType: """ Infer schema from list of Row, dict, or tuple. @@ -422,12 +425,12 @@ def _inferSchemaFromList( infer_dict_as_struct, infer_array_from_first_element, infer_map_from_first_pair, - prefer_timestamp_ntz, - ) = self._client.get_configs( - "spark.sql.pyspark.inferNestedDictAsStruct.enabled", - "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled", - "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled", - "spark.sql.timestampType", + prefer_timestamp, + ) = ( + conf_getter["spark.sql.pyspark.inferNestedDictAsStruct.enabled"], + conf_getter["spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled"], + conf_getter["spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled"], + conf_getter["spark.sql.timestampType"], ) return functools.reduce( _merge_type, @@ -438,7 +441,7 @@ def _inferSchemaFromList( infer_dict_as_struct=(infer_dict_as_struct == "true"), infer_array_from_first_element=(infer_array_from_first_element == "true"), infer_map_from_first_pair=(infer_map_from_first_pair == "true"), - prefer_timestamp_ntz=(prefer_timestamp_ntz == "TIMESTAMP_NTZ"), + prefer_timestamp_ntz=(prefer_timestamp == "TIMESTAMP_NTZ"), ) for row in data ), @@ -451,6 +454,22 @@ def createDataFrame( samplingRatio: Optional[float] = None, verifySchema: Optional[bool] = None, ) -> "ParentDataFrame": + from pyspark.sql.connect.utils import LazyConfigGetter + + conf_getter = LazyConfigGetter( + keys=[ + "spark.sql.timestampType", + "spark.sql.session.timeZone", + "spark.sql.session.localRelationCacheThreshold", + "spark.sql.execution.pandas.convertToArrowArraySafely", + "spark.sql.execution.pandas.inferPandasDictAsMap", + "spark.sql.pyspark.inferNestedDictAsStruct.enabled", + "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled", + "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled", + ], + session=self, + ) + assert data is not None if isinstance(data, DataFrame): raise PySparkTypeError( @@ -519,8 +538,7 @@ def createDataFrame( if schema is None: _cols = [str(x) if not isinstance(x, str) else x for x in data.columns] infer_pandas_dict_as_map = ( - str(self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")).lower() - == "true" + conf_getter["spark.sql.execution.pandas.inferPandasDictAsMap"] == "true" ) if infer_pandas_dict_as_map: struct = StructType() @@ -572,9 +590,8 @@ def createDataFrame( ] arrow_types = [to_arrow_type(dt) if dt is not None else None for dt in spark_types] - timezone, safecheck = self._client.get_configs( - "spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely" - ) + timezone = conf_getter["spark.sql.session.timeZone"] + safecheck = conf_getter["spark.sql.execution.pandas.convertToArrowArraySafely"] ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true") @@ -598,6 +615,10 @@ def createDataFrame( elif isinstance(data, pa.Table): prefer_timestamp_ntz = is_timestamp_ntz_preferred() + + timezone = conf_getter["spark.sql.session.timeZone"] + prefer_timestamp = conf_getter["spark.sql.timestampType"] + (timezone,) = self._client.get_configs("spark.sql.session.timeZone") # If no schema supplied by user then get the names of columns only @@ -609,7 +630,9 @@ def createDataFrame( _num_cols = len(_cols) if not isinstance(schema, StructType): - schema = from_arrow_schema(data.schema, prefer_timestamp_ntz=prefer_timestamp_ntz) + schema = from_arrow_schema( + data.schema, prefer_timestamp_ntz=prefer_timestamp == "TIMESTAMP_NTZ" + ) _table = ( _check_arrow_table_timestamps_localize(data, schema, True, timezone) @@ -671,7 +694,7 @@ def createDataFrame( if not isinstance(_schema, StructType): _schema = StructType().add("value", _schema) else: - _schema = self._inferSchemaFromList(_data, _cols) + _schema = self._inferSchemaFromList(_data, _cols, conf_getter) if _cols is not None and cast(int, _num_cols) < len(_cols): _num_cols = len(_cols) @@ -706,9 +729,9 @@ def createDataFrame( else: local_relation = LocalRelation(_table) - cache_threshold = self._client.get_configs("spark.sql.session.localRelationCacheThreshold") + cache_threshold = conf_getter["spark.sql.session.localRelationCacheThreshold"] plan: LogicalPlan = local_relation - if cache_threshold[0] is not None and int(cache_threshold[0]) <= _table.nbytes: + if cache_threshold is not None and int(cache_threshold) <= _table.nbytes: plan = CachedLocalRelation(self._cache_local_relation(local_relation)) df = DataFrame(plan, self) diff --git a/python/pyspark/sql/connect/utils.py b/python/pyspark/sql/connect/utils.py index a2511836816c..5d5eb86eb4ae 100644 --- a/python/pyspark/sql/connect/utils.py +++ b/python/pyspark/sql/connect/utils.py @@ -15,6 +15,10 @@ # limitations under the License. # import sys +from typing import Optional, Sequence, Dict, TYPE_CHECKING + +if TYPE_CHECKING: + from pyspark.sql.connect.session import SparkSession from pyspark.loose_version import LooseVersion from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version @@ -98,3 +102,27 @@ def require_minimum_googleapis_common_protos_version() -> None: def get_python_ver() -> str: return "%d.%d" % sys.version_info[:2] + + +class LazyConfigGetter: + def __init__( + self, + keys: Sequence[str], + session: "SparkSession", + ): + assert len(keys) > 0 and len(keys) == len(set(keys)) + assert all(isinstance(key, str) for key in keys) + assert session is not None + self._keys = keys + self._session = session + self._values: Dict[str, Optional[str]] = {} + + def __getitem__(self, key: str) -> Optional[str]: + assert key in self._keys + + if len(self._values) == 0: + values = self._session._client.get_configs(*self._keys) + for i, value in enumerate(values): + self._values[self._keys[i]] = value + + return self._values[key] From 885d3487e50394a838457c1494991b9dd665a81d Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 22 Nov 2024 11:00:53 +0800 Subject: [PATCH 2/4] simplify --- python/pyspark/sql/connect/client/core.py | 5 ++ python/pyspark/sql/connect/session.py | 57 ++++++++++------------- python/pyspark/sql/connect/utils.py | 28 ----------- 3 files changed, 29 insertions(+), 61 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 3de425505405..78d4e0fc1c4f 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -43,6 +43,7 @@ Dict, Set, NoReturn, + Mapping, cast, TYPE_CHECKING, Type, @@ -1576,6 +1577,10 @@ def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]: configs = dict(self.config(op).pairs) return tuple(configs.get(key) for key in keys) + def get_config_dict(self, *keys: str) -> Mapping[str, Optional[str]]: + op = pb2.ConfigRequest.Operation(get=pb2.ConfigRequest.Get(keys=keys)) + return dict(self.config(op).pairs) + def get_config_with_defaults( self, *pairs: Tuple[str, Optional[str]] ) -> Tuple[Optional[str], ...]: diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index a609f2ceaef4..b34831878dca 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -36,6 +36,7 @@ cast, overload, Iterable, + Mapping, TYPE_CHECKING, ClassVar, ) @@ -111,7 +112,6 @@ from pyspark.sql.connect.tvf import TableValuedFunction from pyspark.sql.connect.shell.progress import ProgressHandler from pyspark.sql.connect.datasource import DataSourceRegistration - from pyspark.sql.connect.utils import LazyConfigGetter try: import memory_profiler # noqa: F401 @@ -410,7 +410,7 @@ def _inferSchemaFromList( self, data: Iterable[Any], names: Optional[List[str]], - conf_getter: "LazyConfigGetter", + configs: Mapping[str, Optional[str]], ) -> StructType: """ Infer schema from list of Row, dict, or tuple. @@ -427,10 +427,10 @@ def _inferSchemaFromList( infer_map_from_first_pair, prefer_timestamp, ) = ( - conf_getter["spark.sql.pyspark.inferNestedDictAsStruct.enabled"], - conf_getter["spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled"], - conf_getter["spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled"], - conf_getter["spark.sql.timestampType"], + configs["spark.sql.pyspark.inferNestedDictAsStruct.enabled"], + configs["spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled"], + configs["spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled"], + configs["spark.sql.timestampType"], ) return functools.reduce( _merge_type, @@ -454,22 +454,6 @@ def createDataFrame( samplingRatio: Optional[float] = None, verifySchema: Optional[bool] = None, ) -> "ParentDataFrame": - from pyspark.sql.connect.utils import LazyConfigGetter - - conf_getter = LazyConfigGetter( - keys=[ - "spark.sql.timestampType", - "spark.sql.session.timeZone", - "spark.sql.session.localRelationCacheThreshold", - "spark.sql.execution.pandas.convertToArrowArraySafely", - "spark.sql.execution.pandas.inferPandasDictAsMap", - "spark.sql.pyspark.inferNestedDictAsStruct.enabled", - "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled", - "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled", - ], - session=self, - ) - assert data is not None if isinstance(data, DataFrame): raise PySparkTypeError( @@ -527,8 +511,21 @@ def createDataFrame( messageParameters={}, ) + # Get all related configs in a batch + configs = self._client.get_config_dict( + "spark.sql.timestampType", + "spark.sql.session.timeZone", + "spark.sql.session.localRelationCacheThreshold", + "spark.sql.execution.pandas.convertToArrowArraySafely", + "spark.sql.execution.pandas.inferPandasDictAsMap", + "spark.sql.pyspark.inferNestedDictAsStruct.enabled", + "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled", + "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled", + ) + _table: Optional[pa.Table] = None - timezone: Optional[str] = None + timezone: Optional[str] = configs["spark.sql.session.timeZone"] + prefer_timestamp = configs["spark.sql.timestampType"] if isinstance(data, pd.DataFrame): # Logic was borrowed from `_create_from_pandas_with_arrow` in @@ -538,7 +535,7 @@ def createDataFrame( if schema is None: _cols = [str(x) if not isinstance(x, str) else x for x in data.columns] infer_pandas_dict_as_map = ( - conf_getter["spark.sql.execution.pandas.inferPandasDictAsMap"] == "true" + configs["spark.sql.execution.pandas.inferPandasDictAsMap"] == "true" ) if infer_pandas_dict_as_map: struct = StructType() @@ -590,8 +587,7 @@ def createDataFrame( ] arrow_types = [to_arrow_type(dt) if dt is not None else None for dt in spark_types] - timezone = conf_getter["spark.sql.session.timeZone"] - safecheck = conf_getter["spark.sql.execution.pandas.convertToArrowArraySafely"] + safecheck = configs["spark.sql.execution.pandas.convertToArrowArraySafely"] ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true") @@ -616,11 +612,6 @@ def createDataFrame( prefer_timestamp_ntz = is_timestamp_ntz_preferred() - timezone = conf_getter["spark.sql.session.timeZone"] - prefer_timestamp = conf_getter["spark.sql.timestampType"] - - (timezone,) = self._client.get_configs("spark.sql.session.timeZone") - # If no schema supplied by user then get the names of columns only if schema is None: _cols = data.column_names @@ -694,7 +685,7 @@ def createDataFrame( if not isinstance(_schema, StructType): _schema = StructType().add("value", _schema) else: - _schema = self._inferSchemaFromList(_data, _cols, conf_getter) + _schema = self._inferSchemaFromList(_data, _cols, configs) if _cols is not None and cast(int, _num_cols) < len(_cols): _num_cols = len(_cols) @@ -729,7 +720,7 @@ def createDataFrame( else: local_relation = LocalRelation(_table) - cache_threshold = conf_getter["spark.sql.session.localRelationCacheThreshold"] + cache_threshold = configs["spark.sql.session.localRelationCacheThreshold"] plan: LogicalPlan = local_relation if cache_threshold is not None and int(cache_threshold) <= _table.nbytes: plan = CachedLocalRelation(self._cache_local_relation(local_relation)) diff --git a/python/pyspark/sql/connect/utils.py b/python/pyspark/sql/connect/utils.py index 5d5eb86eb4ae..a2511836816c 100644 --- a/python/pyspark/sql/connect/utils.py +++ b/python/pyspark/sql/connect/utils.py @@ -15,10 +15,6 @@ # limitations under the License. # import sys -from typing import Optional, Sequence, Dict, TYPE_CHECKING - -if TYPE_CHECKING: - from pyspark.sql.connect.session import SparkSession from pyspark.loose_version import LooseVersion from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version @@ -102,27 +98,3 @@ def require_minimum_googleapis_common_protos_version() -> None: def get_python_ver() -> str: return "%d.%d" % sys.version_info[:2] - - -class LazyConfigGetter: - def __init__( - self, - keys: Sequence[str], - session: "SparkSession", - ): - assert len(keys) > 0 and len(keys) == len(set(keys)) - assert all(isinstance(key, str) for key in keys) - assert session is not None - self._keys = keys - self._session = session - self._values: Dict[str, Optional[str]] = {} - - def __getitem__(self, key: str) -> Optional[str]: - assert key in self._keys - - if len(self._values) == 0: - values = self._session._client.get_configs(*self._keys) - for i, value in enumerate(values): - self._values[self._keys[i]] = value - - return self._values[key] From ad9686b0fb3bbb0086414aa706360d4e1ba5e642 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 22 Nov 2024 11:02:53 +0800 Subject: [PATCH 3/4] nit --- python/pyspark/sql/connect/session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index b34831878dca..ef56ea369d1d 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -522,10 +522,10 @@ def createDataFrame( "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled", "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled", ) + timezone = configs["spark.sql.session.timeZone"] + prefer_timestamp = configs["spark.sql.timestampType"] _table: Optional[pa.Table] = None - timezone: Optional[str] = configs["spark.sql.session.timeZone"] - prefer_timestamp = configs["spark.sql.timestampType"] if isinstance(data, pd.DataFrame): # Logic was borrowed from `_create_from_pandas_with_arrow` in From 1b02bf49c57265b7f22a81868dc51a5775eab7ee Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 22 Nov 2024 15:25:27 +0800 Subject: [PATCH 4/4] fix --- python/pyspark/sql/connect/session.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index ef56ea369d1d..bfd79092ccf4 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -609,9 +609,6 @@ def createDataFrame( ).cast(arrow_schema) elif isinstance(data, pa.Table): - prefer_timestamp_ntz = is_timestamp_ntz_preferred() - - # If no schema supplied by user then get the names of columns only if schema is None: _cols = data.column_names