Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-34761][SQL] Support add/subtract of a day-time interval to/from a timestamp #31855

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,12 @@ class Analyzer(override val catalogManager: CatalogManager)
case (YearMonthIntervalType, DateType) => DateAddYMInterval(r, l)
case (TimestampType, YearMonthIntervalType) => TimestampAddYMInterval(l, r)
case (YearMonthIntervalType, TimestampType) => TimestampAddYMInterval(r, l)
case (CalendarIntervalType, CalendarIntervalType) => a
case (CalendarIntervalType, CalendarIntervalType) |
(DayTimeIntervalType, DayTimeIntervalType) => a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we allow adding year-month interval as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The op is allowed already, see the test https://github.com/apache/spark/pull/31789/files#diff-bff10c8a3182aa943d8927135e0c14b02a338d9bcca94ddcd77670ee01fea0f3R2380 . It should still pass after the changes.

I have to add this cases because I reuse the TimeAdd expression, and associated rules here.

The year-month interval +/- year-month interval expr is handled by the default case case _ => a/case _ => s.

case (DateType, CalendarIntervalType) => DateAddInterval(l, r, ansiEnabled = f)
case (_, CalendarIntervalType) => Cast(TimeAdd(l, r), l.dataType)
case (_, CalendarIntervalType | DayTimeIntervalType) => Cast(TimeAdd(l, r), l.dataType)
case (CalendarIntervalType, DateType) => DateAddInterval(r, l, ansiEnabled = f)
case (CalendarIntervalType, _) => Cast(TimeAdd(r, l), r.dataType)
case (CalendarIntervalType | DayTimeIntervalType, _) => Cast(TimeAdd(r, l), r.dataType)
case (DateType, dt) if dt != StringType => DateAdd(l, r)
case (dt, DateType) if dt != StringType => DateAdd(r, l)
case _ => a
Expand All @@ -356,10 +357,11 @@ class Analyzer(override val catalogManager: CatalogManager)
DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, f)))
case (TimestampType, YearMonthIntervalType) =>
DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, f)))
case (CalendarIntervalType, CalendarIntervalType) => s
case (CalendarIntervalType, CalendarIntervalType) |
(DayTimeIntervalType, DayTimeIntervalType) => s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The YearMonthIntervalType is handled by the default case case _ => s. We need to handle CalendarIntervalType and DayTimeIntervalType especially otherwise they will be casted to some unexpected type in the base arithmetic ops +/-

case (DateType, CalendarIntervalType) =>
DatetimeSub(l, r, DateAddInterval(l, UnaryMinus(r, f), ansiEnabled = f))
case (_, CalendarIntervalType) =>
case (_, CalendarIntervalType | DayTimeIntervalType) =>
Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r, f))), l.dataType)
case (TimestampType, _) => SubtractTimestamps(l, r)
case (_, TimestampType) => SubtractTimestamps(l, r)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1264,25 +1264,33 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S

override def toString: String = s"$left + $right"
override def sql: String = s"${left.sql} + ${right.sql}"
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType)
override def inputTypes: Seq[AbstractDataType] =
Seq(TimestampType, TypeCollection(CalendarIntervalType, DayTimeIntervalType))

override def dataType: DataType = TimestampType

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

override def nullSafeEval(start: Any, interval: Any): Any = {
val itvl = interval.asInstanceOf[CalendarInterval]
DateTimeUtils.timestampAddInterval(
start.asInstanceOf[Long], itvl.months, itvl.days, itvl.microseconds, zoneId)
override def nullSafeEval(start: Any, interval: Any): Any = right.dataType match {
case DayTimeIntervalType =>
timestampAddDayTime(start.asInstanceOf[Long], interval.asInstanceOf[Long], zoneId)
case CalendarIntervalType =>
val i = interval.asInstanceOf[CalendarInterval]
timestampAddInterval(start.asInstanceOf[Long], i.months, i.days, i.microseconds, zoneId)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.timestampAddInterval($sd, $i.months, $i.days, $i.microseconds, $zid)"""
})
interval.dataType match {
case DayTimeIntervalType =>
defineCodeGen(ctx, ev, (sd, dt) => s"""$dtu.timestampAddDayTime($sd, $dt, $zid)""")
case CalendarIntervalType =>
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.timestampAddInterval($sd, $i.months, $i.days, $i.microseconds, $zid)"""
})
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,28 @@ object DateTimeUtils {
instantToMicros(microsToInstant(micros).atZone(zoneId).plusMonths(months).toInstant)
}

/**
* Adds a day-time interval expressed in microseconds to a timestamp at the given time zone.
* It converts the input timestamp to a local timestamp, and adds the interval by:
* - Splitting the interval to days and microsecond adjustment in a day, and
* - First of all, it adds days and then the time part.
* The resulted local timestamp is converted back to an instant at the given time zone.
*
* @param micros The input timestamp value, expressed in microseconds since 1970-01-01 00:00:00Z.
* @param dayTime The amount of microseconds to add. It can be positive or negative.
* @param zoneId The time zone ID at which the operation is performed.
* @return A timestamp value, expressed in microseconds since 1970-01-01 00:00:00Z.
*/
def timestampAddDayTime(micros: Long, dayTime: Long, zoneId: ZoneId): Long = {
val days = dayTime / MICROS_PER_DAY
val microseconds = dayTime - days * MICROS_PER_DAY
val resultTimestamp = microsToInstant(micros)
.atZone(zoneId)
.plusDays(days)
.plus(microseconds, ChronoUnit.MICROS)
instantToMicros(resultTimestamp.toInstant)
}

/**
* Adds a full interval (months, days, microseconds) a timestamp represented as the number of
* microseconds since 1970-01-01 00:00:00Z.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.expressions

import java.sql.{Date, Timestamp}
import java.text.{ParseException, SimpleDateFormat}
import java.time.{DateTimeException, Instant, LocalDate, Period, ZoneId}
import java.time.{DateTimeException, Duration, Instant, LocalDate, Period, ZoneId}
import java.time.format.DateTimeParseException
import java.time.temporal.ChronoUnit
import java.util.{Calendar, Locale, TimeZone}
import java.util.concurrent.TimeUnit._

Expand Down Expand Up @@ -1538,4 +1539,59 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
TimestampType, YearMonthIntervalType)
}
}

test("SPARK-34761: add a day-time interval to a timestamp") {
val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS", Locale.US)
for (zid <- outstandingZoneIds) {
val timeZoneId = Option(zid.getId)
sdf.setTimeZone(TimeZone.getTimeZone(zid))
checkEvaluation(
TimeAdd(
Literal(new Timestamp(sdf.parse("2021-01-01 00:00:00.123").getTime)),
Literal(Duration.ofDays(10).plusMinutes(10).plusMillis(321)),
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
timeZoneId),
DateTimeUtils.fromJavaTimestamp(
new Timestamp(sdf.parse("2021-01-11 00:10:00.444").getTime)))
checkEvaluation(
TimeAdd(
Literal(new Timestamp(sdf.parse("2021-01-01 00:10:00.123").getTime)),
Literal(Duration.ofDays(-10).minusMinutes(9).minusMillis(120)),
timeZoneId),
DateTimeUtils.fromJavaTimestamp(
new Timestamp(sdf.parse("2020-12-22 00:01:00.003").getTime)))

val e = intercept[Exception] {
checkEvaluation(
TimeAdd(
Literal(new Timestamp(sdf.parse("2021-01-01 00:00:00.123").getTime)),
Literal(Duration.of(Long.MaxValue, ChronoUnit.MICROS)),
timeZoneId),
null)
}.getCause
assert(e.isInstanceOf[ArithmeticException])
assert(e.getMessage.contains("long overflow"))

checkEvaluation(
TimeAdd(
Literal.create(null, TimestampType),
Literal(Duration.ofDays(1)),
timeZoneId),
null)
checkEvaluation(
TimeAdd(
Literal(new Timestamp(sdf.parse("2021-01-01 00:00:00.123").getTime)),
Literal.create(null, DayTimeIntervalType),
timeZoneId),
null)
checkEvaluation(
TimeAdd(
Literal.create(null, TimestampType),
Literal.create(null, DayTimeIntervalType),
timeZoneId),
null)
checkConsistencyBetweenInterpretedAndCodegen(
(ts: Expression, interval: Expression) => TimeAdd(ts, interval, timeZoneId),
TimestampType, DayTimeIntervalType)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,10 @@ object LiteralGenerator {
for { i <- Gen.choose(-100, 100) } yield Literal.create(i, IntegerType)

lazy val dayTimeIntervalLiteralGen: Gen[Literal] = {
for {
seconds <- Gen.choose(
Duration.ofDays(-106751990).getSeconds,
Duration.ofDays(106751990).getSeconds)
nanoAdjustment <- Gen.choose(-999999000, 999999000)
} yield {
Literal.create(Duration.ofSeconds(seconds, nanoAdjustment), DayTimeIntervalType)
calendarIntervalLiterGen.map { calendarIntervalLiteral =>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generator is used to test expressions like random timestamp + random day-time interval. I have to adjust the generator to have reasonable intervals, and prevent overflows.

Literal.create(
calendarIntervalLiteral.value.asInstanceOf[CalendarInterval].extractAsDuration(),
DayTimeIntervalType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -711,4 +711,39 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
intercept[IllegalArgumentException](getDayOfWeekFromString(UTF8String.fromString("xx")))
intercept[IllegalArgumentException](getDayOfWeekFromString(UTF8String.fromString("\"quote")))
}

test("SPARK-34761: timestamp add day-time interval") {
// transit from Pacific Standard Time to Pacific Daylight Time
assert(timestampAddDayTime(
// 2019-3-9 is the end of Pacific Standard Time
date(2019, 3, 9, 12, 0, 0, 123000, LA),
MICROS_PER_DAY, LA) ===
// 2019-3-10 is the start of Pacific Daylight Time
date(2019, 3, 10, 12, 0, 0, 123000, LA))
// just normal days
outstandingZoneIds.foreach { zid =>
assert(timestampAddDayTime(
date(2021, 3, 18, 19, 44, 1, 100000, zid), 0, zid) ===
date(2021, 3, 18, 19, 44, 1, 100000, zid))
assert(timestampAddDayTime(
date(2021, 1, 19, 0, 0, 0, 0, zid), -18 * MICROS_PER_DAY, zid) ===
date(2021, 1, 1, 0, 0, 0, 0, zid))
assert(timestampAddDayTime(
date(2021, 3, 18, 19, 44, 1, 999999, zid), 10 * MICROS_PER_MINUTE, zid) ===
date(2021, 3, 18, 19, 54, 1, 999999, zid))
assert(timestampAddDayTime(
date(2021, 3, 18, 19, 44, 1, 1, zid), -MICROS_PER_DAY - 1, zid) ===
date(2021, 3, 17, 19, 44, 1, 0, zid))
assert(timestampAddDayTime(
date(2019, 5, 9, 12, 0, 0, 123456, zid), 2 * MICROS_PER_DAY + 1, zid) ===
date(2019, 5, 11, 12, 0, 0, 123457, zid))
}
// transit from Pacific Daylight Time to Pacific Standard Time
assert(timestampAddDayTime(
// 2019-11-2 is the end of Pacific Daylight Time
date(2019, 11, 2, 12, 0, 0, 123000, LA),
MICROS_PER_DAY, LA) ===
// 2019-11-3 is the start of Pacific Standard Time
date(2019, 11, 3, 12, 0, 0, 123000, LA))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import java.sql.{Date, Timestamp}
import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
import java.time.temporal.ChronoUnit
import java.util.Locale

import org.apache.hadoop.io.{LongWritable, Text}
Expand Down Expand Up @@ -2526,4 +2527,51 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
assert(e.getMessage.contains("long overflow"))
}
}

test("SPARK-34761: add/subtract a day-time interval to/from a timestamp") {
withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") {
outstandingZoneIds.foreach { zid =>
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> zid.getId) {
Seq(
(LocalDateTime.of(1900, 1, 1, 0, 0, 0, 123456000), Duration.ofDays(0)) ->
LocalDateTime.of(1900, 1, 1, 0, 0, 0, 123456000),
(LocalDateTime.of(1970, 1, 1, 0, 0, 0, 100000000), Duration.ofDays(-1)) ->
LocalDateTime.of(1969, 12, 31, 0, 0, 0, 100000000),
(LocalDateTime.of(2021, 3, 14, 1, 2, 3), Duration.ofDays(1)) ->
LocalDateTime.of(2021, 3, 15, 1, 2, 3),
(LocalDateTime.of(2020, 12, 31, 23, 59, 59, 999000000),
Duration.ofDays(2 * 30).plusMillis(1)) -> LocalDateTime.of(2021, 3, 2, 0, 0, 0),
(LocalDateTime.of(2020, 3, 16, 0, 0, 0, 1000), Duration.of(-1, ChronoUnit.MICROS)) ->
LocalDateTime.of(2020, 3, 16, 0, 0, 0),
(LocalDateTime.of(2020, 2, 29, 12, 13, 14), Duration.ofDays(365)) ->
LocalDateTime.of(2021, 2, 28, 12, 13, 14),
(LocalDateTime.of(1582, 10, 4, 1, 2, 3, 40000000),
Duration.ofDays(10).plusMillis(60)) ->
LocalDateTime.of(1582, 10, 14, 1, 2, 3, 100000000)
).foreach { case ((ldt, duration), expected) =>
val ts = ldt.atZone(zid).toInstant
val result = expected.atZone(zid).toInstant
val df = Seq((ts, duration, result)).toDF("ts", "interval", "result")
checkAnswer(
df.select($"ts" + $"interval", $"interval" + $"ts", $"result" - $"interval"),
Row(result, result, ts))
}
}
}

Seq(
"2021-03-16T18:56:00Z" -> "ts + i",
"1900-03-16T18:56:00Z" -> "ts - i").foreach { case (instant, op) =>
val e = intercept[SparkException] {
Seq(
(Instant.parse(instant), Duration.of(Long.MaxValue, ChronoUnit.MICROS)))
.toDF("ts", "i")
.selectExpr(op)
.collect()
}.getCause
assert(e.isInstanceOf[ArithmeticException])
assert(e.getMessage.contains("long overflow"))
}
}
}
}