Skip to content

Commit

Permalink
[SPARK-37279][PYTHON][SQL] Support DayTimeIntervalType in createDataF…
Browse files Browse the repository at this point in the history
…rame, collect and Python UDF

### What changes were proposed in this pull request?

This PR implements `DayTimeIntervalType` in PySpark's `DataFrame.collect()`, `SparkSession.createDataFrame()` and `functions.udf`.
This type is mapped to [`datetime.timedelta`](https://docs.python.org/3/library/datetime.html#timedelta-objects).

Arrow code path will be separately implemented at SPARK-37277, and Py4J support will be done at SPARK-37281.

### Why are the changes needed?

- In order to support `datetime.timedelta` out of the box via PySpark.
- To seamlessly support ANSI standard types

Semantically [`datetime.timedelta`](https://docs.python.org/3/library/datetime.html#timedelta-objects) is mapped to `DayTimeIntervalType`. Python's timedelta does not support months, years, etc.

### Does this PR introduce _any_ user-facing change?

Yes, users will be able to use `datetime.timedelta` in PySpark with `DayTimeIntervalType` at `DataFrame.collect()`, `SparkSession.createDataFrame()` and `functions.udf`:

```python
>>> import datetime
>>> df = spark.createDataFrame([(datetime.timedelta(days=1),)])
>>> df.collect()
[Row(_1=datetime.timedelta(days=1))]
```

```python
>>> from pyspark.sql.functions import udf
>>> df.select(udf(lambda x: x, "interval day to second")("_1")).show()
+--------------------+
|        <lambda>(_1)|
+--------------------+
|INTERVAL '1 00:00...|
+--------------------+
```

### How was this patch tested?

Unittests were added, and the

Closes #34614 from HyukjinKwon/SPARK-37277.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon committed Nov 18, 2021
1 parent 9553ed7 commit e2e1e42
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 10 deletions.
1 change: 1 addition & 0 deletions python/docs/source/reference/pyspark.sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ Data Types
StructType
TimestampNTZType
TimestampType
DayTimeIntervalType


Observation
Expand Down
78 changes: 77 additions & 1 deletion python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
FloatType,
DateType,
TimestampType,
DayTimeIntervalType,
MapType,
StringType,
StructType,
Expand Down Expand Up @@ -144,6 +145,7 @@ def __init__(self):
"a",
datetime.date(1970, 1, 1),
datetime.datetime(1970, 1, 1, 0, 0),
datetime.timedelta(microseconds=123456678),
1.0,
array.array("d", [1]),
[1],
Expand All @@ -165,6 +167,7 @@ def __init__(self):
"string",
"date",
"timestamp",
"interval day to second",
"double",
"array<double>",
"array<bigint>",
Expand All @@ -186,6 +189,7 @@ def __init__(self):
"a",
datetime.date(1970, 1, 1),
datetime.datetime(1970, 1, 1, 0, 0),
datetime.timedelta(microseconds=123456678),
1.0,
[1.0],
[1],
Expand Down Expand Up @@ -290,7 +294,7 @@ def test_create_dataframe_from_objects(self):
self.assertEqual(df.first(), Row(key=1, value="1"))

def test_apply_schema(self):
from datetime import date, datetime
from datetime import date, datetime, timedelta

rdd = self.sc.parallelize(
[
Expand All @@ -303,6 +307,7 @@ def test_apply_schema(self):
1.0,
date(2010, 1, 1),
datetime(2010, 1, 1, 1, 1, 1),
timedelta(days=1),
{"a": 1},
(2,),
[1, 2, 3],
Expand All @@ -320,6 +325,7 @@ def test_apply_schema(self):
StructField("float1", FloatType(), False),
StructField("date1", DateType(), False),
StructField("time1", TimestampType(), False),
StructField("daytime1", DayTimeIntervalType(), False),
StructField("map1", MapType(StringType(), IntegerType(), False), False),
StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
StructField("list1", ArrayType(ByteType(), False), False),
Expand All @@ -337,6 +343,7 @@ def test_apply_schema(self):
x.float1,
x.date1,
x.time1,
x.daytime1,
x.map1["a"],
x.struct1.b,
x.list1,
Expand All @@ -352,6 +359,7 @@ def test_apply_schema(self):
1.0,
date(2010, 1, 1),
datetime(2010, 1, 1, 1, 1, 1),
timedelta(days=1),
1,
2,
[1, 2, 3],
Expand Down Expand Up @@ -929,6 +937,74 @@ def assertCollectSuccess(typecode, value):
a = array.array(t)
self.spark.createDataFrame([Row(myarray=a)]).collect()

def test_daytime_interval_type_constructor(self):
# SPARK-37277: Test constructors in day time interval.
self.assertEqual(DayTimeIntervalType().simpleString(), "interval day to second")
self.assertEqual(
DayTimeIntervalType(DayTimeIntervalType.DAY).simpleString(), "interval day"
)
self.assertEqual(
DayTimeIntervalType(
DayTimeIntervalType.HOUR, DayTimeIntervalType.SECOND
).simpleString(),
"interval hour to second",
)

with self.assertRaisesRegex(RuntimeError, "interval None to 3 is invalid"):
DayTimeIntervalType(endField=DayTimeIntervalType.SECOND)

with self.assertRaisesRegex(RuntimeError, "interval 123 to 123 is invalid"):
DayTimeIntervalType(123)

with self.assertRaisesRegex(RuntimeError, "interval 0 to 321 is invalid"):
DayTimeIntervalType(DayTimeIntervalType.DAY, 321)

def test_daytime_interval_type(self):
# SPARK-37277: Support DayTimeIntervalType in createDataFrame
timedetlas = [
(datetime.timedelta(microseconds=123),),
(
datetime.timedelta(
days=1, seconds=23, microseconds=123, milliseconds=4, minutes=5, hours=11
),
),
(datetime.timedelta(microseconds=-123),),
(datetime.timedelta(days=-1),),
]
df = self.spark.createDataFrame(timedetlas, schema="td interval day to second")
self.assertEqual(set(r.td for r in df.collect()), set(set(r[0] for r in timedetlas)))

exprs = [
"INTERVAL '1 02:03:04' DAY TO SECOND AS a",
"INTERVAL '1 02:03' DAY TO MINUTE AS b",
"INTERVAL '1 02' DAY TO HOUR AS c",
"INTERVAL '1' DAY AS d",
"INTERVAL '26:03:04' HOUR TO SECOND AS e",
"INTERVAL '26:03' HOUR TO MINUTE AS f",
"INTERVAL '26' HOUR AS g",
"INTERVAL '1563:04' MINUTE TO SECOND AS h",
"INTERVAL '1563' MINUTE AS i",
"INTERVAL '93784' SECOND AS j",
]
df = self.spark.range(1).selectExpr(exprs)

actual = list(df.first())
expected = [
datetime.timedelta(days=1, hours=2, minutes=3, seconds=4),
datetime.timedelta(days=1, hours=2, minutes=3),
datetime.timedelta(days=1, hours=2),
datetime.timedelta(days=1),
datetime.timedelta(hours=26, minutes=3, seconds=4),
datetime.timedelta(hours=26, minutes=3),
datetime.timedelta(hours=26),
datetime.timedelta(minutes=1563, seconds=4),
datetime.timedelta(minutes=1563),
datetime.timedelta(seconds=93784),
]

for n, (a, e) in enumerate(zip(actual, expected)):
self.assertEqual(a, e, "%s does not match with %s" % (exprs[n], expected[n]))


class DataTypeTests(unittest.TestCase):
# regression test for SPARK-6055
Expand Down
20 changes: 19 additions & 1 deletion python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from pyspark import SparkContext
from pyspark.sql import SparkSession, Column, Row
from pyspark.sql.functions import udf
from pyspark.sql.functions import udf, assert_true, lit
from pyspark.sql.udf import UserDefinedFunction
from pyspark.sql.types import (
StringType,
Expand All @@ -36,6 +36,7 @@
StructType,
StructField,
TimestampNTZType,
DayTimeIntervalType,
)
from pyspark.sql.utils import AnalysisException
from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message
Expand Down Expand Up @@ -607,6 +608,23 @@ def noop(x):
self.assertEqual(df.schema[0].dataType.typeName(), "timestamp_ntz")
self.assertEqual(df.first()[0], datetime.datetime(1970, 1, 1, 0, 0))

def test_udf_daytime_interval(self):
# SPARK-37277: Support DayTimeIntervalType in Python UDF
@udf(DayTimeIntervalType(DayTimeIntervalType.DAY, DayTimeIntervalType.SECOND))
def noop(x):
assert x == datetime.timedelta(microseconds=123)
return x

df = self.spark.createDataFrame(
[(datetime.timedelta(microseconds=123),)], schema="td interval day to second"
).select(noop("td").alias("td"))

df.select(
assert_true(lit("INTERVAL '0 00:00:00.000123' DAY TO SECOND") == df.td.cast("string"))
).collect()
self.assertEqual(df.schema[0].dataType.simpleString(), "interval day to second")
self.assertEqual(df.first()[0], datetime.timedelta(microseconds=123))

def test_nonparam_udf_with_aggregate(self):
import pyspark.sql.functions as f

Expand Down
76 changes: 74 additions & 2 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sys
import decimal
import time
import math
import datetime
import calendar
import json
Expand Down Expand Up @@ -65,6 +66,7 @@
"ByteType",
"IntegerType",
"LongType",
"DayTimeIntervalType",
"Row",
"ShortType",
"ArrayType",
Expand Down Expand Up @@ -317,6 +319,65 @@ def simpleString(self) -> str:
return "bigint"


class DayTimeIntervalType(AtomicType):
"""DayTimeIntervalType (datetime.timedelta)."""

DAY = 0
HOUR = 1
MINUTE = 2
SECOND = 3

_fields = {
DAY: "day",
HOUR: "hour",
MINUTE: "minute",
SECOND: "second",
}

_inverted_fields = dict(zip(_fields.values(), _fields.keys()))

def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None):
if startField is None and endField is None:
# Default matched to scala side.
startField = DayTimeIntervalType.DAY
endField = DayTimeIntervalType.SECOND
elif startField is not None and endField is None:
endField = startField

fields = DayTimeIntervalType._fields
if startField not in fields.keys() or endField not in fields.keys():
raise RuntimeError("interval %s to %s is invalid" % (startField, endField))
self.startField = cast(int, startField)
self.endField = cast(int, endField)

def _str_repr(self) -> str:
fields = DayTimeIntervalType._fields
start_field_name = fields[self.startField]
end_field_name = fields[self.endField]
if start_field_name == end_field_name:
return "interval %s" % start_field_name
else:
return "interval %s to %s" % (start_field_name, end_field_name)

simpleString = _str_repr

jsonValue = _str_repr

def __repr__(self) -> str:
return "%s(%d,%d)" % (type(self).__name__, self.startField, self.endField)

def needConversion(self) -> bool:
return True

def toInternal(self, dt: datetime.timedelta) -> Optional[int]:
if dt is not None:
return (math.floor(dt.total_seconds()) * 1000000) + dt.microseconds

def fromInternal(self, micros: int) -> Optional[datetime.timedelta]:
if micros is not None:
return datetime.timedelta(microseconds=micros)


class ShortType(IntegralType):
"""Short data type, i.e. a signed 16-bit integer."""

Expand Down Expand Up @@ -905,6 +966,7 @@ def __eq__(self, other: Any) -> bool:


_FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)")
_INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?")


def _parse_datatype_string(s: str) -> DataType:
Expand Down Expand Up @@ -1034,11 +1096,17 @@ def _parse_datatype_json_value(json_value: Union[dict, str]) -> DataType:
return _all_atomic_types[json_value]()
elif json_value == "decimal":
return DecimalType()
elif json_value == "timestamp_ntz":
return TimestampNTZType()
elif _FIXED_DECIMAL.match(json_value):
m = _FIXED_DECIMAL.match(json_value)
return DecimalType(int(m.group(1)), int(m.group(2))) # type: ignore[union-attr]
elif _INTERVAL_DAYTIME.match(json_value):
m = _INTERVAL_DAYTIME.match(json_value)
inverted_fields = DayTimeIntervalType._inverted_fields
first_field = inverted_fields.get(m.group(1)) # type: ignore[union-attr]
second_field = inverted_fields.get(m.group(3)) # type: ignore[union-attr]
if first_field is not None and second_field is None:
return DayTimeIntervalType(first_field)
return DayTimeIntervalType(first_field, second_field)
else:
raise ValueError("Could not parse datatype: %s" % json_value)
else:
Expand All @@ -1063,6 +1131,7 @@ def _parse_datatype_json_value(json_value: Union[dict, str]) -> DataType:
datetime.date: DateType,
datetime.datetime: TimestampType, # can be TimestampNTZType
datetime.time: TimestampType, # can be TimestampNTZType
datetime.timedelta: DayTimeIntervalType,
bytes: BinaryType,
}

Expand Down Expand Up @@ -1163,6 +1232,8 @@ def _infer_type(
return DecimalType(38, 18)
if dataType is TimestampType and prefer_timestamp_ntz and obj.tzinfo is None:
return TimestampNTZType()
if dataType is DayTimeIntervalType:
return DayTimeIntervalType()
elif dataType is not None:
return dataType()

Expand Down Expand Up @@ -1409,6 +1480,7 @@ def convert_struct(obj: Any) -> Optional[Tuple]:
DateType: (datetime.date, datetime.datetime),
TimestampType: (datetime.datetime,),
TimestampNTZType: (datetime.datetime,),
DayTimeIntervalType: (datetime.timedelta,),
ArrayType: (list, tuple, array),
MapType: (dict,),
StructType: (tuple, list, dict),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.unsafe.types.UTF8String
object EvaluatePython {

def needConversionInPython(dt: DataType): Boolean = dt match {
case DateType | TimestampType | TimestampNTZType => true
case DateType | TimestampType | TimestampNTZType | _: DayTimeIntervalType => true
case _: StructType => true
case _: UserDefinedType[_] => true
case ArrayType(elementType, _) => needConversionInPython(elementType)
Expand Down Expand Up @@ -137,11 +137,12 @@ object EvaluatePython {
case c: Int => c
}

case TimestampType | TimestampNTZType => (obj: Any) => nullSafeConvert(obj) {
case c: Long => c
// Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
case c: Int => c.toLong
}
case TimestampType | TimestampNTZType | _: DayTimeIntervalType => (obj: Any) =>
nullSafeConvert(obj) {
case c: Long => c
// Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
case c: Int => c.toLong
}

case StringType => (obj: Any) => nullSafeConvert(obj) {
case _ => UTF8String.fromString(obj.toString)
Expand Down

0 comments on commit e2e1e42

Please sign in to comment.