Skip to content

Commit

Permalink
Flatten conditional logic in connect/session.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ianmcook committed May 13, 2024
1 parent 494693a commit 795d01a
Showing 1 changed file with 92 additions and 83 deletions.
175 changes: 92 additions & 83 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,97 +475,103 @@ def createDataFrame(

_table: Optional[pa.Table] = None

if isinstance(data, pd.DataFrame) or isinstance(data, pa.Table):
if schema is None and isinstance(data, pd.DataFrame):
# Logic was borrowed from `_create_from_pandas_with_arrow` in
# `pyspark.sql.pandas.conversion.py`. Should ideally deduplicate the logics.

# If no schema supplied by user then get the names of columns only
if schema is None:
if isinstance(data, pd.DataFrame):
_cols = [str(x) if not isinstance(x, str) else x for x in data.columns]
infer_pandas_dict_as_map = (
str(
self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")
).lower()
== "true"
)
if infer_pandas_dict_as_map:
struct = StructType()
pa_schema = pa.Schema.from_pandas(data)
spark_type: Union[MapType, DataType]
for field in pa_schema:
field_type = field.type
if isinstance(field_type, pa.StructType):
if len(field_type) == 0:
raise PySparkValueError(
error_class="CANNOT_INFER_EMPTY_SCHEMA",
message_parameters={},
)
arrow_type = field_type.field(0).type
spark_type = MapType(StringType(), from_arrow_type(arrow_type))
else:
spark_type = from_arrow_type(field_type)
struct.add(field.name, spark_type, nullable=field.nullable)
schema = struct
elif isinstance(data, pa.Table):
schema = from_arrow_schema(data.schema, prefer_timestamp_ntz=True)
elif isinstance(schema, (list, tuple)) and cast(int, _num_cols) < len(data.columns):
assert isinstance(_cols, list)
_cols.extend([f"_{i + 1}" for i in range(cast(int, _num_cols), len(data.columns))])
_num_cols = len(_cols)

# Determine arrow types to coerce data when creating batches
arrow_schema: Optional[pa.Schema] = None
spark_types: List[Optional[DataType]]
arrow_types: List[Optional[pa.DataType]]
if isinstance(schema, StructType):
deduped_schema = cast(StructType, _deduplicate_field_names(schema))
spark_types = [field.dataType for field in deduped_schema.fields]
arrow_schema = to_arrow_schema(deduped_schema)
arrow_types = [field.type for field in arrow_schema]
_cols = [str(x) if not isinstance(x, str) else x for x in schema.fieldNames()]
elif isinstance(schema, DataType):
raise PySparkTypeError(
error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW",
message_parameters={"data_type": str(schema)},
)
elif isinstance(data, pd.DataFrame):
# Any timestamps must be coerced to be compatible with Spark
spark_types = [
TimestampType()
if is_datetime64_dtype(t) or isinstance(t, pd.DatetimeTZDtype)
else DayTimeIntervalType()
if is_timedelta64_dtype(t)
else None
for t in data.dtypes
]
arrow_types = [to_arrow_type(dt) if dt is not None else None for dt in spark_types]
_cols = [str(x) if not isinstance(x, str) else x for x in data.columns]
infer_pandas_dict_as_map = (
str(self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")).lower()
== "true"
)
if infer_pandas_dict_as_map:
struct = StructType()
pa_schema = pa.Schema.from_pandas(data)
spark_type: Union[MapType, DataType]
for field in pa_schema:
field_type = field.type
if isinstance(field_type, pa.StructType):
if len(field_type) == 0:
raise PySparkValueError(
error_class="CANNOT_INFER_EMPTY_SCHEMA",
message_parameters={},
)
arrow_type = field_type.field(0).type
spark_type = MapType(StringType(), from_arrow_type(arrow_type))
else:
spark_type = from_arrow_type(field_type)
struct.add(field.name, spark_type, nullable=field.nullable)
schema = struct

elif schema is None and isinstance(data, pa.Table):
schema = from_arrow_schema(data.schema, prefer_timestamp_ntz=True)

elif (
isinstance(schema, (list, tuple))
and isinstance(data, (pd.DataFrame, pa.Table))
and cast(int, _num_cols) < len(data.columns)
):
assert isinstance(_cols, list)
_cols.extend([f"_{i + 1}" for i in range(cast(int, _num_cols), len(data.columns))])
_num_cols = len(_cols)

# Determine arrow types to coerce data when creating batches
arrow_schema: Optional[pa.Schema] = None
spark_types: List[Optional[DataType]]
arrow_types: List[Optional[pa.DataType]]

if isinstance(schema, StructType) and isinstance(data, (pd.DataFrame, pa.Table)):
deduped_schema = cast(StructType, _deduplicate_field_names(schema))
spark_types = [field.dataType for field in deduped_schema.fields]
arrow_schema = to_arrow_schema(deduped_schema)
arrow_types = [field.type for field in arrow_schema]
_cols = [str(x) if not isinstance(x, str) else x for x in schema.fieldNames()]

elif isinstance(schema, DataType) and isinstance(data, (pd.DataFrame, pa.Table)):
raise PySparkTypeError(
error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW",
message_parameters={"data_type": str(schema)},
)

elif isinstance(data, pd.DataFrame):
# Any timestamps must be coerced to be compatible with Spark
spark_types = [
TimestampType()
if is_datetime64_dtype(t) or isinstance(t, pd.DatetimeTZDtype)
else DayTimeIntervalType()
if is_timedelta64_dtype(t)
else None
for t in data.dtypes
]
arrow_types = [to_arrow_type(dt) if dt is not None else None for dt in spark_types]

if isinstance(data, pd.DataFrame):
timezone, safecheck = self._client.get_configs(
"spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely"
)

ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true")

if isinstance(data, pd.DataFrame):
_table = pa.Table.from_batches(
[
ser._create_batch(
[
(c, at, st)
for (_, c), at, st in zip(data.items(), arrow_types, spark_types)
]
)
]
)
else:
_table = data
_table = pa.Table.from_batches(
[
ser._create_batch(
[
(c, at, st)
for (_, c), at, st in zip(data.items(), arrow_types, spark_types)
]
)
]
)

if isinstance(schema, StructType):
assert arrow_schema is not None
_table = _table.rename_columns(
cast(StructType, _deduplicate_field_names(schema)).names
).cast(arrow_schema)
elif isinstance(data, pa.Table):
_table = data

if isinstance(schema, StructType) and isinstance(data, (pd.DataFrame, pa.Table)):
assert arrow_schema is not None
_table = _table.rename_columns( # type: ignore[union-attr]
cast(StructType, _deduplicate_field_names(schema)).names
).cast(arrow_schema)

elif isinstance(data, np.ndarray):
if _cols is None:
Expand Down Expand Up @@ -602,7 +608,7 @@ def createDataFrame(
# The _table should already have the proper column names.
_cols = None

else:
elif not isinstance(data, (pd.DataFrame, pa.Table)):
_data = list(data)

if isinstance(_data[0], dict):
Expand Down Expand Up @@ -642,12 +648,12 @@ def createDataFrame(

# TODO: Beside the validation on number of columns, we should also check
# whether the Arrow Schema is compatible with the user provided Schema.
if _num_cols is not None and _num_cols != _table.shape[1]:
if _num_cols is not None and _num_cols != _table.shape[1]: # type: ignore[union-attr]
raise PySparkValueError(
error_class="AXIS_LENGTH_MISMATCH",
message_parameters={
"expected_length": str(_num_cols),
"actual_length": str(_table.shape[1]),
"actual_length": str(_table.shape[1]), # type: ignore[union-attr]
},
)

Expand All @@ -658,7 +664,10 @@ def createDataFrame(

cache_threshold = self._client.get_configs("spark.sql.session.localRelationCacheThreshold")
plan: LogicalPlan = local_relation
if cache_threshold[0] is not None and int(cache_threshold[0]) <= _table.nbytes:
if (
cache_threshold[0] is not None
and int(cache_threshold[0]) <= _table.nbytes # type: ignore[union-attr]
):
plan = CachedLocalRelation(self._cache_local_relation(local_relation))

df = DataFrame(plan, self)
Expand Down

0 comments on commit 795d01a

Please sign in to comment.