Skip to content

Commit

Permalink
[SPARK-36626][PYTHON] Support TimestampNTZ in createDataFrame/toPanda…
Browse files Browse the repository at this point in the history
…s and Python UDFs

### What changes were proposed in this pull request?

This PR proposes to implement `TimestampNTZType` support in PySpark's `SparkSession.createDataFrame`, `DataFrame.toPandas`, Python UDFs, and pandas UDFs with and without Arrow.

### Why are the changes needed?

To complete `TimestampNTZType` support.

### Does this PR introduce _any_ user-facing change?

Yes.

- Users now can use `TimestampNTZType` type in `SparkSession.createDataFrame`, `DataFrame.toPandas`, Python UDFs, and pandas UDFs with and without Arrow.

- If `spark.sql.timestampType` is configured to `TIMESTAMP_NTZ`, PySpark will infer the `datetime` without timezone as `TimestampNTZType`. If it has a timezone, it will be inferred as `TimestampType` in `SparkSession.createDataFrame`.

    - If `TimestampType` and `TimestampNTZType` conflict during merging inferred schema, `TimestampType` has a higher precedence.

- If the type is `TimestampNTZType`, treat this internally as an unknown timezone, and compute w/ UTC (same as JVM side), and avoid localization externally.

### How was this patch tested?

Manually tested and unittests were added.

Closes #33876 from HyukjinKwon/SPARK-36626.

Lead-authored-by: Hyukjin Kwon <gurwls223@apache.org>
Co-authored-by: Dominik Gehl <dog@open.ch>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon and dominikgehl committed Sep 2, 2021
1 parent e983ba8 commit 9c5bcac
Show file tree
Hide file tree
Showing 14 changed files with 185 additions and 41 deletions.
1 change: 1 addition & 0 deletions python/docs/source/reference/pyspark.sql.rst
Expand Up @@ -298,6 +298,7 @@ Data Types
StringType
StructField
StructType
TimestampNTZType
TimestampType


Expand Down
17 changes: 14 additions & 3 deletions python/pyspark/sql/pandas/conversion.py
Expand Up @@ -22,7 +22,7 @@
from pyspark.sql.pandas.serializers import ArrowCollectSerializer
from pyspark.sql.types import IntegralType
from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, FloatType, \
DoubleType, BooleanType, MapType, TimestampType, StructType, DataType
DoubleType, BooleanType, MapType, TimestampType, TimestampNTZType, StructType, DataType
from pyspark.traceback_utils import SCCallSiteSync


Expand Down Expand Up @@ -238,6 +238,8 @@ def _to_corrected_pandas_type(dt):
return np.bool
elif type(dt) == TimestampType:
return np.datetime64
elif type(dt) == TimestampNTZType:
return np.datetime64
else:
return None

Expand Down Expand Up @@ -354,6 +356,8 @@ def _convert_from_pandas(self, pdf, schema, timezone):

if timezone is not None:
from pyspark.sql.pandas.types import _check_series_convert_timestamps_tz_local
from pandas.core.dtypes.common import is_datetime64tz_dtype

copied = False
if isinstance(schema, StructType):
for field in schema:
Expand All @@ -368,8 +372,11 @@ def _convert_from_pandas(self, pdf, schema, timezone):
copied = True
pdf[field.name] = s
else:
should_localize = not self._is_timestamp_ntz_preferred()
for column, series in pdf.iteritems():
s = _check_series_convert_timestamps_tz_local(series, timezone)
s = series
if should_localize and is_datetime64tz_dtype(s.dtype) and s.dt.tz is not None:
s = _check_series_convert_timestamps_tz_local(series, timezone)
if s is not series:
if not copied:
# Copy once if the series is modified to prevent the original
Expand Down Expand Up @@ -448,8 +455,12 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
if isinstance(schema, (list, tuple)):
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
struct = StructType()
prefer_timestamp_ntz = self._is_timestamp_ntz_preferred()
for name, field in zip(schema, arrow_schema):
struct.add(name, from_arrow_type(field.type), nullable=field.nullable)
struct.add(
name,
from_arrow_type(field.type, prefer_timestamp_ntz),
nullable=field.nullable)
schema = struct

# Determine arrow types to coerce data when creating batches
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/sql/pandas/conversion.pyi
Expand Up @@ -38,6 +38,7 @@ from pyspark.sql.types import ( # noqa: F401
ShortType as ShortType,
StructType as StructType,
TimestampType as TimestampType,
TimestampNTZType as TimestampNTZType,
)
from pyspark.traceback_utils import SCCallSiteSync as SCCallSiteSync # noqa: F401

Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/pandas/serializers.py
Expand Up @@ -126,7 +126,7 @@ def arrow_to_pandas(self, arrow_column):
# datetime64[ns] type handling.
s = arrow_column.to_pandas(date_as_object=True)

if pyarrow.types.is_timestamp(arrow_column.type):
if pyarrow.types.is_timestamp(arrow_column.type) and arrow_column.type.tz is not None:
return _check_series_localize_timestamps(s, self._timezone)
elif pyarrow.types.is_map(arrow_column.type):
return _convert_map_items_to_dict(s)
Expand Down Expand Up @@ -162,7 +162,7 @@ def _create_batch(self, series):
def create_array(s, t):
mask = s.isnull()
# Ensure timestamp series are in expected form for Spark internal representation
if t is not None and pa.types.is_timestamp(t):
if t is not None and pa.types.is_timestamp(t) and t.tz is not None:
s = _check_series_convert_timestamps_internal(s, self._timezone)
elif t is not None and pa.types.is_map(t):
s = _convert_dict_to_map_items(s)
Expand Down
8 changes: 6 additions & 2 deletions python/pyspark/sql/pandas/types.py
Expand Up @@ -22,7 +22,7 @@

from pyspark.sql.types import BooleanType, ByteType, ShortType, IntegerType, LongType, \
FloatType, DoubleType, DecimalType, StringType, BinaryType, DateType, TimestampType, \
ArrayType, MapType, StructType, StructField, NullType
TimestampNTZType, ArrayType, MapType, StructType, StructField, NullType


def to_arrow_type(dt):
Expand Down Expand Up @@ -55,6 +55,8 @@ def to_arrow_type(dt):
elif type(dt) == TimestampType:
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
arrow_type = pa.timestamp('us', tz='UTC')
elif type(dt) == TimestampNTZType:
arrow_type = pa.timestamp('us', tz=None)
elif type(dt) == ArrayType:
if type(dt.elementType) in [StructType, TimestampType]:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
Expand Down Expand Up @@ -88,7 +90,7 @@ def to_arrow_schema(schema):
return pa.schema(fields)


def from_arrow_type(at):
def from_arrow_type(at, prefer_timestamp_ntz=False):
""" Convert pyarrow type to Spark data type.
"""
from distutils.version import LooseVersion
Expand Down Expand Up @@ -116,6 +118,8 @@ def from_arrow_type(at):
spark_type = BinaryType()
elif types.is_date32(at):
spark_type = DateType()
elif types.is_timestamp(at) and prefer_timestamp_ntz and at.tz is None:
spark_type = TimestampNTZType()
elif types.is_timestamp(at):
spark_type = TimestampType()
elif types.is_list(at):
Expand Down
20 changes: 16 additions & 4 deletions python/pyspark/sql/session.py
Expand Up @@ -419,6 +419,9 @@ def range(self, start, end=None, step=1, numPartitions=None):

return DataFrame(jdf, self._wrapped)

def _is_timestamp_ntz_preferred(self):
return self._wrapped._conf.timestampType().typeName() == "timestamp_ntz"

def _inferSchemaFromList(self, data, names=None):
"""
Infer schema from list of Row, dict, or tuple.
Expand All @@ -437,7 +440,9 @@ def _inferSchemaFromList(self, data, names=None):
if not data:
raise ValueError("can not infer schema from empty dataset")
infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct()
schema = reduce(_merge_type, (_infer_schema(row, names, infer_dict_as_struct)
prefer_timestamp_ntz = self._is_timestamp_ntz_preferred()
schema = reduce(_merge_type, (
_infer_schema(row, names, infer_dict_as_struct, prefer_timestamp_ntz)
for row in data))
if _has_nulltype(schema):
raise ValueError("Some of types cannot be determined after inferring")
Expand Down Expand Up @@ -465,12 +470,18 @@ def _inferSchema(self, rdd, samplingRatio=None, names=None):
"can not infer schema")

infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct()
prefer_timestamp_ntz = self._is_timestamp_ntz_preferred()
if samplingRatio is None:
schema = _infer_schema(first, names=names, infer_dict_as_struct=infer_dict_as_struct)
schema = _infer_schema(
first,
names=names,
infer_dict_as_struct=infer_dict_as_struct,
prefer_timestamp_ntz=prefer_timestamp_ntz)
if _has_nulltype(schema):
for row in rdd.take(100)[1:]:
schema = _merge_type(schema, _infer_schema(
row, names=names, infer_dict_as_struct=infer_dict_as_struct))
row, names=names, infer_dict_as_struct=infer_dict_as_struct,
prefer_timestamp_ntz=prefer_timestamp_ntz))
if not _has_nulltype(schema):
break
else:
Expand All @@ -480,7 +491,8 @@ def _inferSchema(self, rdd, samplingRatio=None, names=None):
if samplingRatio < 0.99:
rdd = rdd.sample(False, float(samplingRatio))
schema = rdd.map(lambda row: _infer_schema(
row, names, infer_dict_as_struct=infer_dict_as_struct)).reduce(_merge_type)
row, names, infer_dict_as_struct=infer_dict_as_struct,
prefer_timestamp_ntz=prefer_timestamp_ntz)).reduce(_merge_type)
return schema

def _createFromRDD(self, rdd, schema, samplingRatio):
Expand Down
16 changes: 14 additions & 2 deletions python/pyspark/sql/tests/test_arrow.py
Expand Up @@ -27,8 +27,8 @@
from pyspark.sql import Row, SparkSession
from pyspark.sql.functions import rand, udf
from pyspark.sql.types import StructType, StringType, IntegerType, LongType, \
FloatType, DoubleType, DecimalType, DateType, TimestampType, BinaryType, StructField, \
ArrayType, NullType
FloatType, DoubleType, DecimalType, DateType, TimestampType, TimestampNTZType, \
BinaryType, StructField, ArrayType, NullType
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest
Expand Down Expand Up @@ -167,6 +167,18 @@ def test_toPandas_arrow_toggle(self):
assert_frame_equal(expected, pdf)
assert_frame_equal(expected, pdf_arrow)

def test_create_data_frame_to_pandas_timestamp_ntz(self):
# SPARK-36626: Test TimestampNTZ in createDataFrame and toPandas
with self.sql_conf({"spark.sql.session.timeZone": "America/Los_Angeles"}):
origin = pd.DataFrame({"a": [datetime.datetime(2012, 2, 2, 2, 2, 2)]})
df = self.spark.createDataFrame(
origin, schema=StructType([StructField("a", TimestampNTZType(), True)]))
df.selectExpr("assert_true('2012-02-02 02:02:02' == CAST(a AS STRING))").collect()

pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
assert_frame_equal(origin, pdf)
assert_frame_equal(pdf, pdf_arrow)

def test_toPandas_respect_session_timezone(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)

Expand Down
34 changes: 23 additions & 11 deletions python/pyspark/sql/tests/test_dataframe.py
Expand Up @@ -25,7 +25,7 @@
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import col, lit, count, sum, mean
from pyspark.sql.types import StringType, IntegerType, DoubleType, StructType, StructField, \
BooleanType, DateType, TimestampType, FloatType
BooleanType, DateType, TimestampType, TimestampNTZType, FloatType
from pyspark.sql.utils import AnalysisException, IllegalArgumentException
from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils, have_pyarrow, have_pandas, \
pandas_requirement_message, pyarrow_requirement_message
Expand Down Expand Up @@ -575,12 +575,16 @@ def _to_pandas(self):
from datetime import datetime, date
schema = StructType().add("a", IntegerType()).add("b", StringType())\
.add("c", BooleanType()).add("d", FloatType())\
.add("dt", DateType()).add("ts", TimestampType())
.add("dt", DateType()).add("ts", TimestampType())\
.add("ts_ntz", TimestampNTZType())
data = [
(1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
(2, "foo", True, 5.0, None, None),
(3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3)),
(4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4)),
(1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1),
datetime(1969, 1, 1, 1, 1, 1)),
(2, "foo", True, 5.0, None, None, None),
(3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3),
datetime(2012, 3, 3, 3, 3, 3)),
(4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4),
datetime(2100, 4, 4, 4, 4, 4)),
]
df = self.spark.createDataFrame(data, schema)
return df.toPandas()
Expand All @@ -596,6 +600,7 @@ def test_to_pandas(self):
self.assertEqual(types[3], np.float32)
self.assertEqual(types[4], np.object) # datetime.date
self.assertEqual(types[5], 'datetime64[ns]')
self.assertEqual(types[6], 'datetime64[ns]')

@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_to_pandas_with_duplicated_column_names(self):
Expand Down Expand Up @@ -662,7 +667,8 @@ def test_to_pandas_from_empty_dataframe(self):
CAST(0 AS DOUBLE) AS double,
CAST(1 AS BOOLEAN) AS boolean,
CAST('foo' AS STRING) AS string,
CAST('2019-01-01' AS TIMESTAMP) AS timestamp
CAST('2019-01-01' AS TIMESTAMP) AS timestamp,
CAST('2019-01-01' AS TIMESTAMP_NTZ) AS timestamp_ntz
"""
dtypes_when_nonempty_df = self.spark.sql(sql).toPandas().dtypes
dtypes_when_empty_df = self.spark.sql(sql).filter("False").toPandas().dtypes
Expand All @@ -682,7 +688,8 @@ def test_to_pandas_from_null_dataframe(self):
CAST(NULL AS DOUBLE) AS double,
CAST(NULL AS BOOLEAN) AS boolean,
CAST(NULL AS STRING) AS string,
CAST(NULL AS TIMESTAMP) AS timestamp
CAST(NULL AS TIMESTAMP) AS timestamp,
CAST(NULL AS TIMESTAMP_NTZ) AS timestamp_ntz
"""
pdf = self.spark.sql(sql).toPandas()
types = pdf.dtypes
Expand All @@ -695,6 +702,7 @@ def test_to_pandas_from_null_dataframe(self):
self.assertEqual(types[6], np.object)
self.assertEqual(types[7], np.object)
self.assertTrue(np.can_cast(np.datetime64, types[8]))
self.assertTrue(np.can_cast(np.datetime64, types[9]))

@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_to_pandas_from_mixed_dataframe(self):
Expand All @@ -710,9 +718,10 @@ def test_to_pandas_from_mixed_dataframe(self):
CAST(col6 AS DOUBLE) AS double,
CAST(col7 AS BOOLEAN) AS boolean,
CAST(col8 AS STRING) AS string,
timestamp_seconds(col9) AS timestamp
FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1),
(NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL)
timestamp_seconds(col9) AS timestamp,
timestamp_seconds(col10) AS timestamp_ntz
FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
(NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL)
"""
pdf_with_some_nulls = self.spark.sql(sql).toPandas()
pdf_with_only_nulls = self.spark.sql(sql).filter('tinyint is null').toPandas()
Expand All @@ -738,6 +747,9 @@ def test_create_dataframe_from_pandas_with_timestamp(self):
df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp")
self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp_ntz")
self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampNTZType))
self.assertTrue(isinstance(df.schema['d'].dataType, DateType))

@unittest.skipIf(have_pandas, "Required Pandas was found.")
def test_create_dataframe_required_pandas_not_found(self):
Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/sql/tests/test_pandas_udf.py
Expand Up @@ -16,6 +16,7 @@
#

import unittest
import datetime

from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
from pyspark.sql.types import DoubleType, StructType, StructField, LongType
Expand Down Expand Up @@ -239,6 +240,23 @@ def udf(column):
with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
df.withColumn('udf', udf('id')).collect()

def test_pandas_udf_timestamp_ntz(self):
# SPARK-36626: Test TimestampNTZ in pandas UDF
@pandas_udf(returnType="timestamp_ntz")
def noop(s):
assert s.iloc[0] == datetime.datetime(1970, 1, 1, 0, 0)
return s

with self.sql_conf({"spark.sql.session.timeZone": "Asia/Hong_Kong"}):
df = (self.spark
.createDataFrame(
[(datetime.datetime(1970, 1, 1, 0, 0),)], schema="dt timestamp_ntz")
.select(noop("dt").alias("dt")))

df.selectExpr("assert_true('1970-01-01 00:00:00' == CAST(dt AS STRING))").collect()
self.assertEqual(df.schema[0].dataType.typeName(), "timestamp_ntz")
self.assertEqual(df.first()[0], datetime.datetime(1970, 1, 1, 0, 0))


if __name__ == "__main__":
from pyspark.sql.tests.test_pandas_udf import * # noqa: F401
Expand Down
16 changes: 14 additions & 2 deletions python/pyspark/sql/tests/test_types.py
Expand Up @@ -29,8 +29,8 @@
from pyspark.sql.udf import UserDefinedFunction
from pyspark.sql.utils import AnalysisException
from pyspark.sql.types import ByteType, ShortType, IntegerType, FloatType, DateType, \
TimestampType, MapType, StringType, StructType, StructField, ArrayType, DoubleType, LongType, \
DecimalType, BinaryType, BooleanType, NullType
TimestampType, MapType, StringType, StructType, StructField,\
ArrayType, DoubleType, LongType, DecimalType, BinaryType, BooleanType, NullType
from pyspark.sql.types import ( # type: ignore
_array_signed_int_typecode_ctype_mappings, _array_type_mappings,
_array_unsigned_int_typecode_ctype_mappings, _infer_type, _make_type_verifier, _merge_type
Expand Down Expand Up @@ -175,6 +175,18 @@ def __init__(self):
]
self.assertEqual(actual, expected)

with self.sql_conf({"spark.sql.timestampType": "TIMESTAMP_NTZ"}):
with self.sql_conf({"spark.sql.session.timeZone": "America/Sao_Paulo"}):
df = self.spark.createDataFrame([(datetime.datetime(1970, 1, 1, 0, 0),)])
self.assertEqual(list(df.schema)[0].dataType.simpleString(), "timestamp_ntz")
self.assertEqual(df.first()[0], datetime.datetime(1970, 1, 1, 0, 0))

df = self.spark.createDataFrame([
(datetime.datetime(1970, 1, 1, 0, 0),),
(datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc),)
])
self.assertEqual(list(df.schema)[0].dataType.simpleString(), "timestamp")

def test_infer_schema_not_enough_names(self):
df = self.spark.createDataFrame([["a", "b"]], ["col1"])
self.assertEqual(df.columns, ['col1', '_2'])
Expand Down
20 changes: 19 additions & 1 deletion python/pyspark/sql/tests/test_udf.py
Expand Up @@ -20,13 +20,14 @@
import shutil
import tempfile
import unittest
import datetime

from pyspark import SparkContext
from pyspark.sql import SparkSession, Column, Row
from pyspark.sql.functions import udf
from pyspark.sql.udf import UserDefinedFunction
from pyspark.sql.types import StringType, IntegerType, BooleanType, DoubleType, LongType, \
ArrayType, StructType, StructField
ArrayType, StructType, StructField, TimestampNTZType
from pyspark.sql.utils import AnalysisException
from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message
from pyspark.testing.utils import QuietTest
Expand Down Expand Up @@ -552,6 +553,23 @@ def __call__(self, x):
self.assertEqual(f, f_.func)
self.assertEqual(return_type, f_.returnType)

def test_udf_timestamp_ntz(self):
# SPARK-36626: Test TimestampNTZ in Python UDF
@udf(TimestampNTZType())
def noop(x):
assert x == datetime.datetime(1970, 1, 1, 0, 0)
return x

with self.sql_conf({"spark.sql.session.timeZone": "Pacific/Honolulu"}):
df = (self.spark
.createDataFrame(
[(datetime.datetime(1970, 1, 1, 0, 0),)], schema="dt timestamp_ntz")
.select(noop("dt").alias("dt")))

df.selectExpr("assert_true('1970-01-01 00:00:00' == CAST(dt AS STRING))").collect()
self.assertEqual(df.schema[0].dataType.typeName(), "timestamp_ntz")
self.assertEqual(df.first()[0], datetime.datetime(1970, 1, 1, 0, 0))

def test_nonparam_udf_with_aggregate(self):
import pyspark.sql.functions as f

Expand Down

0 comments on commit 9c5bcac

Please sign in to comment.