Skip to content

Commit

Permalink
[SPARK-43545][SQL][PYTHON] Support nested timestamp type
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Supports nested timestamp type in `spark.createDataFrame()` with pandas DataFrame and `df.toPandas()`, and makes them return correct results.

For the following schema and pandas DataFrame:

```py
schema = (
    StructType()
    .add("ts", TimestampType())
    .add("ts_ntz", TimestampNTZType())
    .add(
        "struct", StructType().add("ts", TimestampType()).add("ts_ntz", TimestampNTZType())
    )
    .add("array", ArrayType(TimestampType()))
    .add("array_ntz", ArrayType(TimestampNTZType()))
    .add("map", MapType(StringType(), TimestampType()))
    .add("map_ntz", MapType(StringType(), TimestampNTZType()))
)

data = [
    Row(
        datetime.datetime(2023, 1, 1, 0, 0, 0),
        datetime.datetime(2023, 1, 1, 0, 0, 0),
        Row(
            datetime.datetime(2023, 1, 1, 0, 0, 0),
            datetime.datetime(2023, 1, 1, 0, 0, 0),
        ),
        [datetime.datetime(2023, 1, 1, 0, 0, 0)],
        [datetime.datetime(2023, 1, 1, 0, 0, 0)],
        dict(ts=datetime.datetime(2023, 1, 1, 0, 0, 0)),
        dict(ts_ntz=datetime.datetime(2023, 1, 1, 0, 0, 0)),
    )
]

pdf = pd.DataFrame.from_records(data, columns=schema.names)
```

##### `spark.createDataFrame()`

For all, return the same results:

```py
>>> spark.conf.set("spark.sql.session.timeZone", "America/New_York")
>>> spark.createDataFrame(pdf, schema).show(truncate=False)
+-------------------+-------------------+------------------------------------------+---------------------+---------------------+---------------------------+-------------------------------+
|ts                 |ts_ntz             |struct                                    |array                |array_ntz            |map                        |map_ntz                        |
+-------------------+-------------------+------------------------------------------+---------------------+---------------------+---------------------------+-------------------------------+
|2023-01-01 00:00:00|2023-01-01 00:00:00|{2023-01-01 00:00:00, 2023-01-01 00:00:00}|[2023-01-01 00:00:00]|[2023-01-01 00:00:00]|{ts -> 2023-01-01 00:00:00}|{ts_ntz -> 2023-01-01 00:00:00}|
+-------------------+-------------------+------------------------------------------+---------------------+---------------------+---------------------------+-------------------------------+
```

##### `df.toPandas()`

```py
>>> spark.conf.set("spark.sql.session.timeZone", "America/New_York")
>>> df.toPandas()
                   ts     ts_ntz                                      struct                  array              array_ntz                          map                          map_ntz
0 2023-01-01 03:00:00 2023-01-01  (2023-01-01 03:00:00, 2023-01-01 00:00:00)  [2023-01-01 03:00:00]  [2023-01-01 00:00:00]  {'ts': 2023-01-01 03:00:00}  {'ts_ntz': 2023-01-01 00:00:00}
```

### Why are the changes needed?

Currently nested timestamps in `spark.createDataFrame()` with pandas DataFrame and `df.toPandas()` are not supported with `ArrayType` and `MapType`, or return different results from the top-level timestamps with `StructType`.

For the following schema and pandas DataFrame:

```py
schema = (
    StructType()
    .add("ts", TimestampType())
    .add("ts_ntz", TimestampNTZType())
    .add(
        "struct", StructType().add("ts", TimestampType()).add("ts_ntz", TimestampNTZType())
    )
)

data = [
    Row(
        datetime.datetime(2023, 1, 1, 0, 0, 0),
        datetime.datetime(2023, 1, 1, 0, 0, 0),
        Row(
            datetime.datetime(2023, 1, 1, 0, 0, 0),
            datetime.datetime(2023, 1, 1, 0, 0, 0),
        ),
    )
]

pdf = pd.DataFrame.from_records(data, columns=schema.names)
```

##### `spark.createDataFrame()`

- Without Arrow

```py
>>> spark.conf.set("spark.sql.session.timeZone", "America/New_York")
>>> spark.createDataFrame(pdf, schema).show(truncate=False)
+-------------------+-------------------+------------------------------------------+
|ts                 |ts_ntz             |struct                                    |
+-------------------+-------------------+------------------------------------------+
|2023-01-01 00:00:00|2023-01-01 00:00:00|{2023-01-01 03:00:00, 2023-01-01 00:00:00}|
+-------------------+-------------------+------------------------------------------+
```

- With Arrow or Spark Connect:

```py
>>> spark.createDataFrame(pdf, schema).show(truncate=False)
+-------------------+-------------------+------------------------------------------+
|ts                 |ts_ntz             |struct                                    |
+-------------------+-------------------+------------------------------------------+
|2023-01-01 00:00:00|2023-01-01 00:00:00|{2022-12-31 19:00:00, 2023-01-01 00:00:00}|
+-------------------+-------------------+------------------------------------------+
```

##### `df.toPandas()`

For the following DataFrame:

```py
>>> spark.conf.unset("spark.sql.session.timeZone")
>>> df = spark.createDataFrame(data, schema)
>>>
>>> df.show(truncate=False)
+-------------------+-------------------+------------------------------------------+
|ts                 |ts_ntz             |struct                                    |
+-------------------+-------------------+------------------------------------------+
|2023-01-01 00:00:00|2023-01-01 00:00:00|{2023-01-01 00:00:00, 2023-01-01 00:00:00}|
+-------------------+-------------------+------------------------------------------+

>>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row')
```

- Without Arrow

```py
>>> spark.conf.set("spark.sql.session.timeZone", "America/New_York")
>>> df.toPandas()
                   ts     ts_ntz                                      struct
0 2023-01-01 03:00:00 2023-01-01  (2023-01-01 00:00:00, 2023-01-01 00:00:00)
```

- With Arrow or Spark Connect:

```py
>>> df.toPandas()
                   ts     ts_ntz                                      struct
0 2023-01-01 03:00:00 2023-01-01  (2023-01-01 08:00:00, 2023-01-01 00:00:00)
```

### Does this PR introduce _any_ user-facing change?

Users will be able to use nested timestamps.

### How was this patch tested?

Added/updated the related tests.

Closes #41240 from ueshin/issues/SPARK-43545/ts.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
ueshin authored and zhengruifeng committed May 25, 2023
1 parent 5ec1385 commit 46949e6
Show file tree
Hide file tree
Showing 16 changed files with 455 additions and 141 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/pandas/typedef/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def as_spark_type(
elif tpe in (str, np.unicode_, "str", "U"):
return types.StringType()
# TimestampType or TimestampNTZType if timezone is not specified.
elif tpe in (datetime.datetime, np.datetime64, "datetime64[ns]", "M"):
elif tpe in (datetime.datetime, np.datetime64, "datetime64[ns]", "M", pd.Timestamp):
return types.TimestampNTZType() if prefer_timestamp_ntz else types.TimestampType()

# DayTimeIntervalType
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ def to_pandas(self, plan: pb2.Plan) -> "pd.DataFrame":
table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(req)
assert table is not None

schema = schema or types.from_arrow_schema(table.schema)
schema = schema or types.from_arrow_schema(table.schema, prefer_timestamp_ntz=True)
assert schema is not None and isinstance(schema, StructType)

# Rename columns to avoid duplicated column names.
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,7 +1624,7 @@ def collect(self) -> List[Row]:
query = self._plan.to_proto(self._session.client)
table, schema = self._session.client.to_table(query)

schema = schema or from_arrow_schema(table.schema)
schema = schema or from_arrow_schema(table.schema, prefer_timestamp_ntz=True)

assert schema is not None and isinstance(schema, StructType)

Expand Down Expand Up @@ -1902,7 +1902,7 @@ def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]:
assert isinstance(schema_or_table, pa.Table)
table = schema_or_table
if schema is None:
schema = from_arrow_schema(table.schema)
schema = from_arrow_schema(table.schema, prefer_timestamp_ntz=True)
yield from ArrowTableToRowsConversion.convert(table, schema)

toLocalIterator.__doc__ = PySparkDataFrame.toLocalIterator.__doc__
Expand Down
21 changes: 16 additions & 5 deletions python/pyspark/sql/connect/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,13 +407,20 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da
elif types.is_duration(at):
spark_type = DayTimeIntervalType()
elif types.is_list(at):
spark_type = ArrayType(from_arrow_type(at.value_type))
spark_type = ArrayType(from_arrow_type(at.value_type, prefer_timestamp_ntz))
elif types.is_map(at):
spark_type = MapType(from_arrow_type(at.key_type), from_arrow_type(at.item_type))
spark_type = MapType(
from_arrow_type(at.key_type, prefer_timestamp_ntz),
from_arrow_type(at.item_type, prefer_timestamp_ntz),
)
elif types.is_struct(at):
return StructType(
[
StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
StructField(
field.name,
from_arrow_type(field.type, prefer_timestamp_ntz),
nullable=field.nullable,
)
for field in at
]
)
Expand All @@ -424,11 +431,15 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da
return spark_type


def from_arrow_schema(arrow_schema: "pa.Schema") -> StructType:
def from_arrow_schema(arrow_schema: "pa.Schema", prefer_timestamp_ntz: bool = False) -> StructType:
"""Convert schema from Arrow to Spark."""
return StructType(
[
StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
StructField(
field.name,
from_arrow_type(field.type, prefer_timestamp_ntz),
nullable=field.nullable,
)
for field in arrow_schema
]
)
113 changes: 100 additions & 13 deletions python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#
import sys
from typing import (
Any,
Callable,
List,
Optional,
Union,
Expand All @@ -28,7 +30,8 @@
from pyspark.errors.exceptions.captured import unwrap_spark_exception
from pyspark.rdd import _load_from_socket
from pyspark.sql.pandas.serializers import ArrowCollectSerializer
from pyspark.sql.types import TimestampType, StructType, DataType
from pyspark.sql.pandas.types import _dedup_names
from pyspark.sql.types import ArrayType, MapType, TimestampType, StructType, DataType, _create_row
from pyspark.sql.utils import is_timestamp_ntz_preferred
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.errors import PySparkTypeError
Expand Down Expand Up @@ -218,6 +221,7 @@ def toPandas(self) -> "PandasDataFrameLike":
"row" if struct_in_pandas == "legacy" else struct_in_pandas
),
error_on_duplicated_field_names=False,
timestamp_utc_localized=False,
)(pser)
for (_, pser), field in zip(pdf.items(), self.schema.fields)
],
Expand Down Expand Up @@ -375,22 +379,105 @@ def _convert_from_pandas(
assert isinstance(self, SparkSession)

if timezone is not None:
from pyspark.sql.pandas.types import _check_series_convert_timestamps_tz_local
from pyspark.sql.pandas.types import (
_check_series_convert_timestamps_tz_local,
_get_local_timezone,
)
from pandas.core.dtypes.common import is_datetime64tz_dtype, is_timedelta64_dtype

copied = False
if isinstance(schema, StructType):
for field in schema:
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if isinstance(field.dataType, TimestampType):
s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone)
if s is not pdf[field.name]:
if not copied:
# Copy once if the series is modified to prevent the original
# Pandas DataFrame from being updated
pdf = pdf.copy()
copied = True
pdf[field.name] = s

def _create_converter(data_type: DataType) -> Callable[[pd.Series], pd.Series]:
if isinstance(data_type, TimestampType):

def correct_timestamp(pser: pd.Series) -> pd.Series:
return _check_series_convert_timestamps_tz_local(pser, timezone)

return correct_timestamp

def _converter(dt: DataType) -> Optional[Callable[[Any], Any]]:
if isinstance(dt, ArrayType):
element_conv = _converter(dt.elementType) or (lambda x: x)

def convert_array(value: Any) -> Any:
if value is None:
return None
else:
return [element_conv(v) for v in value]

return convert_array

elif isinstance(dt, MapType):
key_conv = _converter(dt.keyType) or (lambda x: x)
value_conv = _converter(dt.valueType) or (lambda x: x)

def convert_map(value: Any) -> Any:
if value is None:
return None
else:
return {key_conv(k): value_conv(v) for k, v in value.items()}

return convert_map

elif isinstance(dt, StructType):
field_names = dt.names
dedup_field_names = _dedup_names(field_names)
field_convs = [
_converter(f.dataType) or (lambda x: x) for f in dt.fields
]

def convert_struct(value: Any) -> Any:
if value is None:
return None
elif isinstance(value, dict):
_values = [
field_convs[i](value.get(name, None))
for i, name in enumerate(dedup_field_names)
]
return _create_row(field_names, _values)
else:
_values = [
field_convs[i](value[i]) for i, name in enumerate(value)
]
return _create_row(field_names, _values)

return convert_struct

elif isinstance(dt, TimestampType):

def convert_timestamp(value: Any) -> Any:
if value is None:
return None
else:
return (
pd.Timestamp(value)
.tz_localize(timezone, ambiguous=False) # type: ignore
.tz_convert(_get_local_timezone())
.tz_localize(None)
.to_pydatetime()
)

return convert_timestamp

else:
return None

conv = _converter(data_type)
if conv is not None:
return lambda pser: pser.apply(conv) # type: ignore[return-value]
else:
return lambda pser: pser

if len(pdf.columns) > 0:
pdf = pd.concat(
[
_create_converter(field.dataType)(pser)
for (_, pser), field in zip(pdf.items(), schema.fields)
],
axis="columns",
)
copied = True
else:
should_localize = not is_timestamp_ntz_preferred()
for column, series in pdf.items():
Expand Down

0 comments on commit 46949e6

Please sign in to comment.