From 795d01a1fd96fcbe80432d74be50382d350edf57 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Mon, 13 May 2024 15:09:22 -0400 Subject: [PATCH] Flatten conditional logic in connect/session.py --- python/pyspark/sql/connect/session.py | 175 ++++++++++++++------------ 1 file changed, 92 insertions(+), 83 deletions(-) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index f03a4e9938164..b2ac39983f1a1 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -475,97 +475,103 @@ def createDataFrame( _table: Optional[pa.Table] = None - if isinstance(data, pd.DataFrame) or isinstance(data, pa.Table): + if schema is None and isinstance(data, pd.DataFrame): # Logic was borrowed from `_create_from_pandas_with_arrow` in # `pyspark.sql.pandas.conversion.py`. Should ideally deduplicate the logics. # If no schema supplied by user then get the names of columns only - if schema is None: - if isinstance(data, pd.DataFrame): - _cols = [str(x) if not isinstance(x, str) else x for x in data.columns] - infer_pandas_dict_as_map = ( - str( - self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap") - ).lower() - == "true" - ) - if infer_pandas_dict_as_map: - struct = StructType() - pa_schema = pa.Schema.from_pandas(data) - spark_type: Union[MapType, DataType] - for field in pa_schema: - field_type = field.type - if isinstance(field_type, pa.StructType): - if len(field_type) == 0: - raise PySparkValueError( - error_class="CANNOT_INFER_EMPTY_SCHEMA", - message_parameters={}, - ) - arrow_type = field_type.field(0).type - spark_type = MapType(StringType(), from_arrow_type(arrow_type)) - else: - spark_type = from_arrow_type(field_type) - struct.add(field.name, spark_type, nullable=field.nullable) - schema = struct - elif isinstance(data, pa.Table): - schema = from_arrow_schema(data.schema, prefer_timestamp_ntz=True) - elif isinstance(schema, (list, tuple)) and cast(int, _num_cols) < len(data.columns): - assert isinstance(_cols, list) - _cols.extend([f"_{i + 1}" for i in range(cast(int, _num_cols), len(data.columns))]) - _num_cols = len(_cols) - - # Determine arrow types to coerce data when creating batches - arrow_schema: Optional[pa.Schema] = None - spark_types: List[Optional[DataType]] - arrow_types: List[Optional[pa.DataType]] - if isinstance(schema, StructType): - deduped_schema = cast(StructType, _deduplicate_field_names(schema)) - spark_types = [field.dataType for field in deduped_schema.fields] - arrow_schema = to_arrow_schema(deduped_schema) - arrow_types = [field.type for field in arrow_schema] - _cols = [str(x) if not isinstance(x, str) else x for x in schema.fieldNames()] - elif isinstance(schema, DataType): - raise PySparkTypeError( - error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW", - message_parameters={"data_type": str(schema)}, - ) - elif isinstance(data, pd.DataFrame): - # Any timestamps must be coerced to be compatible with Spark - spark_types = [ - TimestampType() - if is_datetime64_dtype(t) or isinstance(t, pd.DatetimeTZDtype) - else DayTimeIntervalType() - if is_timedelta64_dtype(t) - else None - for t in data.dtypes - ] - arrow_types = [to_arrow_type(dt) if dt is not None else None for dt in spark_types] + _cols = [str(x) if not isinstance(x, str) else x for x in data.columns] + infer_pandas_dict_as_map = ( + str(self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")).lower() + == "true" + ) + if infer_pandas_dict_as_map: + struct = StructType() + pa_schema = pa.Schema.from_pandas(data) + spark_type: Union[MapType, DataType] + for field in pa_schema: + field_type = field.type + if isinstance(field_type, pa.StructType): + if len(field_type) == 0: + raise PySparkValueError( + error_class="CANNOT_INFER_EMPTY_SCHEMA", + message_parameters={}, + ) + arrow_type = field_type.field(0).type + spark_type = MapType(StringType(), from_arrow_type(arrow_type)) + else: + spark_type = from_arrow_type(field_type) + struct.add(field.name, spark_type, nullable=field.nullable) + schema = struct + + elif schema is None and isinstance(data, pa.Table): + schema = from_arrow_schema(data.schema, prefer_timestamp_ntz=True) + + elif ( + isinstance(schema, (list, tuple)) + and isinstance(data, (pd.DataFrame, pa.Table)) + and cast(int, _num_cols) < len(data.columns) + ): + assert isinstance(_cols, list) + _cols.extend([f"_{i + 1}" for i in range(cast(int, _num_cols), len(data.columns))]) + _num_cols = len(_cols) + + # Determine arrow types to coerce data when creating batches + arrow_schema: Optional[pa.Schema] = None + spark_types: List[Optional[DataType]] + arrow_types: List[Optional[pa.DataType]] + + if isinstance(schema, StructType) and isinstance(data, (pd.DataFrame, pa.Table)): + deduped_schema = cast(StructType, _deduplicate_field_names(schema)) + spark_types = [field.dataType for field in deduped_schema.fields] + arrow_schema = to_arrow_schema(deduped_schema) + arrow_types = [field.type for field in arrow_schema] + _cols = [str(x) if not isinstance(x, str) else x for x in schema.fieldNames()] + + elif isinstance(schema, DataType) and isinstance(data, (pd.DataFrame, pa.Table)): + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW", + message_parameters={"data_type": str(schema)}, + ) + elif isinstance(data, pd.DataFrame): + # Any timestamps must be coerced to be compatible with Spark + spark_types = [ + TimestampType() + if is_datetime64_dtype(t) or isinstance(t, pd.DatetimeTZDtype) + else DayTimeIntervalType() + if is_timedelta64_dtype(t) + else None + for t in data.dtypes + ] + arrow_types = [to_arrow_type(dt) if dt is not None else None for dt in spark_types] + + if isinstance(data, pd.DataFrame): timezone, safecheck = self._client.get_configs( "spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely" ) ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true") - if isinstance(data, pd.DataFrame): - _table = pa.Table.from_batches( - [ - ser._create_batch( - [ - (c, at, st) - for (_, c), at, st in zip(data.items(), arrow_types, spark_types) - ] - ) - ] - ) - else: - _table = data + _table = pa.Table.from_batches( + [ + ser._create_batch( + [ + (c, at, st) + for (_, c), at, st in zip(data.items(), arrow_types, spark_types) + ] + ) + ] + ) - if isinstance(schema, StructType): - assert arrow_schema is not None - _table = _table.rename_columns( - cast(StructType, _deduplicate_field_names(schema)).names - ).cast(arrow_schema) + elif isinstance(data, pa.Table): + _table = data + + if isinstance(schema, StructType) and isinstance(data, (pd.DataFrame, pa.Table)): + assert arrow_schema is not None + _table = _table.rename_columns( # type: ignore[union-attr] + cast(StructType, _deduplicate_field_names(schema)).names + ).cast(arrow_schema) elif isinstance(data, np.ndarray): if _cols is None: @@ -602,7 +608,7 @@ def createDataFrame( # The _table should already have the proper column names. _cols = None - else: + elif not isinstance(data, (pd.DataFrame, pa.Table)): _data = list(data) if isinstance(_data[0], dict): @@ -642,12 +648,12 @@ def createDataFrame( # TODO: Beside the validation on number of columns, we should also check # whether the Arrow Schema is compatible with the user provided Schema. - if _num_cols is not None and _num_cols != _table.shape[1]: + if _num_cols is not None and _num_cols != _table.shape[1]: # type: ignore[union-attr] raise PySparkValueError( error_class="AXIS_LENGTH_MISMATCH", message_parameters={ "expected_length": str(_num_cols), - "actual_length": str(_table.shape[1]), + "actual_length": str(_table.shape[1]), # type: ignore[union-attr] }, ) @@ -658,7 +664,10 @@ def createDataFrame( cache_threshold = self._client.get_configs("spark.sql.session.localRelationCacheThreshold") plan: LogicalPlan = local_relation - if cache_threshold[0] is not None and int(cache_threshold[0]) <= _table.nbytes: + if ( + cache_threshold[0] is not None + and int(cache_threshold[0]) <= _table.nbytes # type: ignore[union-attr] + ): plan = CachedLocalRelation(self._cache_local_relation(local_relation)) df = DataFrame(plan, self)