diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1e42e6bf4df0e..928335ad414f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -342,11 +342,12 @@ class Analyzer(override val catalogManager: CatalogManager) case (YearMonthIntervalType, DateType) => DateAddYMInterval(r, l) case (TimestampType, YearMonthIntervalType) => TimestampAddYMInterval(l, r) case (YearMonthIntervalType, TimestampType) => TimestampAddYMInterval(r, l) - case (CalendarIntervalType, CalendarIntervalType) => a + case (CalendarIntervalType, CalendarIntervalType) | + (DayTimeIntervalType, DayTimeIntervalType) => a case (DateType, CalendarIntervalType) => DateAddInterval(l, r, ansiEnabled = f) - case (_, CalendarIntervalType) => Cast(TimeAdd(l, r), l.dataType) + case (_, CalendarIntervalType | DayTimeIntervalType) => Cast(TimeAdd(l, r), l.dataType) case (CalendarIntervalType, DateType) => DateAddInterval(r, l, ansiEnabled = f) - case (CalendarIntervalType, _) => Cast(TimeAdd(r, l), r.dataType) + case (CalendarIntervalType | DayTimeIntervalType, _) => Cast(TimeAdd(r, l), r.dataType) case (DateType, dt) if dt != StringType => DateAdd(l, r) case (dt, DateType) if dt != StringType => DateAdd(r, l) case _ => a @@ -356,10 +357,11 @@ class Analyzer(override val catalogManager: CatalogManager) DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, f))) case (TimestampType, YearMonthIntervalType) => DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, f))) - case (CalendarIntervalType, CalendarIntervalType) => s + case (CalendarIntervalType, CalendarIntervalType) | + (DayTimeIntervalType, DayTimeIntervalType) => s case (DateType, CalendarIntervalType) => DatetimeSub(l, r, DateAddInterval(l, UnaryMinus(r, f), ansiEnabled = f)) - case (_, CalendarIntervalType) => + case (_, CalendarIntervalType | DayTimeIntervalType) => Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r, f))), l.dataType) case (TimestampType, _) => SubtractTimestamps(l, r) case (_, TimestampType) => SubtractTimestamps(l, r) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 1e2a2b74aa67d..2aaa3aa68f764 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1264,25 +1264,33 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S override def toString: String = s"$left + $right" override def sql: String = s"${left.sql} + ${right.sql}" - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) + override def inputTypes: Seq[AbstractDataType] = + Seq(TimestampType, TypeCollection(CalendarIntervalType, DayTimeIntervalType)) override def dataType: DataType = TimestampType override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - override def nullSafeEval(start: Any, interval: Any): Any = { - val itvl = interval.asInstanceOf[CalendarInterval] - DateTimeUtils.timestampAddInterval( - start.asInstanceOf[Long], itvl.months, itvl.days, itvl.microseconds, zoneId) + override def nullSafeEval(start: Any, interval: Any): Any = right.dataType match { + case DayTimeIntervalType => + timestampAddDayTime(start.asInstanceOf[Long], interval.asInstanceOf[Long], zoneId) + case CalendarIntervalType => + val i = interval.asInstanceOf[CalendarInterval] + timestampAddInterval(start.asInstanceOf[Long], i.months, i.days, i.microseconds, zoneId) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (sd, i) => { - s"""$dtu.timestampAddInterval($sd, $i.months, $i.days, $i.microseconds, $zid)""" - }) + interval.dataType match { + case DayTimeIntervalType => + defineCodeGen(ctx, ev, (sd, dt) => s"""$dtu.timestampAddDayTime($sd, $dt, $zid)""") + case CalendarIntervalType => + defineCodeGen(ctx, ev, (sd, i) => { + s"""$dtu.timestampAddInterval($sd, $i.months, $i.days, $i.microseconds, $zid)""" + }) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index b55b82a07b00c..2139467ac266e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -591,6 +591,28 @@ object DateTimeUtils { instantToMicros(microsToInstant(micros).atZone(zoneId).plusMonths(months).toInstant) } + /** + * Adds a day-time interval expressed in microseconds to a timestamp at the given time zone. + * It converts the input timestamp to a local timestamp, and adds the interval by: + * - Splitting the interval to days and microsecond adjustment in a day, and + * - First of all, it adds days and then the time part. + * The resulted local timestamp is converted back to an instant at the given time zone. + * + * @param micros The input timestamp value, expressed in microseconds since 1970-01-01 00:00:00Z. + * @param dayTime The amount of microseconds to add. It can be positive or negative. + * @param zoneId The time zone ID at which the operation is performed. + * @return A timestamp value, expressed in microseconds since 1970-01-01 00:00:00Z. + */ + def timestampAddDayTime(micros: Long, dayTime: Long, zoneId: ZoneId): Long = { + val days = dayTime / MICROS_PER_DAY + val microseconds = dayTime - days * MICROS_PER_DAY + val resultTimestamp = microsToInstant(micros) + .atZone(zoneId) + .plusDays(days) + .plus(microseconds, ChronoUnit.MICROS) + instantToMicros(resultTimestamp.toInstant) + } + /** * Adds a full interval (months, days, microseconds) a timestamp represented as the number of * microseconds since 1970-01-01 00:00:00Z. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 65f54c165786e..786c3021fcaad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.text.{ParseException, SimpleDateFormat} -import java.time.{DateTimeException, Instant, LocalDate, Period, ZoneId} +import java.time.{DateTimeException, Duration, Instant, LocalDate, Period, ZoneId} import java.time.format.DateTimeParseException +import java.time.temporal.ChronoUnit import java.util.{Calendar, Locale, TimeZone} import java.util.concurrent.TimeUnit._ @@ -1538,4 +1539,59 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { TimestampType, YearMonthIntervalType) } } + + test("SPARK-34761: add a day-time interval to a timestamp") { + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS", Locale.US) + for (zid <- outstandingZoneIds) { + val timeZoneId = Option(zid.getId) + sdf.setTimeZone(TimeZone.getTimeZone(zid)) + checkEvaluation( + TimeAdd( + Literal(new Timestamp(sdf.parse("2021-01-01 00:00:00.123").getTime)), + Literal(Duration.ofDays(10).plusMinutes(10).plusMillis(321)), + timeZoneId), + DateTimeUtils.fromJavaTimestamp( + new Timestamp(sdf.parse("2021-01-11 00:10:00.444").getTime))) + checkEvaluation( + TimeAdd( + Literal(new Timestamp(sdf.parse("2021-01-01 00:10:00.123").getTime)), + Literal(Duration.ofDays(-10).minusMinutes(9).minusMillis(120)), + timeZoneId), + DateTimeUtils.fromJavaTimestamp( + new Timestamp(sdf.parse("2020-12-22 00:01:00.003").getTime))) + + val e = intercept[Exception] { + checkEvaluation( + TimeAdd( + Literal(new Timestamp(sdf.parse("2021-01-01 00:00:00.123").getTime)), + Literal(Duration.of(Long.MaxValue, ChronoUnit.MICROS)), + timeZoneId), + null) + }.getCause + assert(e.isInstanceOf[ArithmeticException]) + assert(e.getMessage.contains("long overflow")) + + checkEvaluation( + TimeAdd( + Literal.create(null, TimestampType), + Literal(Duration.ofDays(1)), + timeZoneId), + null) + checkEvaluation( + TimeAdd( + Literal(new Timestamp(sdf.parse("2021-01-01 00:00:00.123").getTime)), + Literal.create(null, DayTimeIntervalType), + timeZoneId), + null) + checkEvaluation( + TimeAdd( + Literal.create(null, TimestampType), + Literal.create(null, DayTimeIntervalType), + timeZoneId), + null) + checkConsistencyBetweenInterpretedAndCodegen( + (ts: Expression, interval: Expression) => TimeAdd(ts, interval, timeZoneId), + TimestampType, DayTimeIntervalType) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala index 26ff847d7c6c6..35ddc1d7bd9b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala @@ -164,13 +164,10 @@ object LiteralGenerator { for { i <- Gen.choose(-100, 100) } yield Literal.create(i, IntegerType) lazy val dayTimeIntervalLiteralGen: Gen[Literal] = { - for { - seconds <- Gen.choose( - Duration.ofDays(-106751990).getSeconds, - Duration.ofDays(106751990).getSeconds) - nanoAdjustment <- Gen.choose(-999999000, 999999000) - } yield { - Literal.create(Duration.ofSeconds(seconds, nanoAdjustment), DayTimeIntervalType) + calendarIntervalLiterGen.map { calendarIntervalLiteral => + Literal.create( + calendarIntervalLiteral.value.asInstanceOf[CalendarInterval].extractAsDuration(), + DayTimeIntervalType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 9c2a43a145007..9802f656ba4cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -711,4 +711,39 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { intercept[IllegalArgumentException](getDayOfWeekFromString(UTF8String.fromString("xx"))) intercept[IllegalArgumentException](getDayOfWeekFromString(UTF8String.fromString("\"quote"))) } + + test("SPARK-34761: timestamp add day-time interval") { + // transit from Pacific Standard Time to Pacific Daylight Time + assert(timestampAddDayTime( + // 2019-3-9 is the end of Pacific Standard Time + date(2019, 3, 9, 12, 0, 0, 123000, LA), + MICROS_PER_DAY, LA) === + // 2019-3-10 is the start of Pacific Daylight Time + date(2019, 3, 10, 12, 0, 0, 123000, LA)) + // just normal days + outstandingZoneIds.foreach { zid => + assert(timestampAddDayTime( + date(2021, 3, 18, 19, 44, 1, 100000, zid), 0, zid) === + date(2021, 3, 18, 19, 44, 1, 100000, zid)) + assert(timestampAddDayTime( + date(2021, 1, 19, 0, 0, 0, 0, zid), -18 * MICROS_PER_DAY, zid) === + date(2021, 1, 1, 0, 0, 0, 0, zid)) + assert(timestampAddDayTime( + date(2021, 3, 18, 19, 44, 1, 999999, zid), 10 * MICROS_PER_MINUTE, zid) === + date(2021, 3, 18, 19, 54, 1, 999999, zid)) + assert(timestampAddDayTime( + date(2021, 3, 18, 19, 44, 1, 1, zid), -MICROS_PER_DAY - 1, zid) === + date(2021, 3, 17, 19, 44, 1, 0, zid)) + assert(timestampAddDayTime( + date(2019, 5, 9, 12, 0, 0, 123456, zid), 2 * MICROS_PER_DAY + 1, zid) === + date(2019, 5, 11, 12, 0, 0, 123457, zid)) + } + // transit from Pacific Daylight Time to Pacific Standard Time + assert(timestampAddDayTime( + // 2019-11-2 is the end of Pacific Daylight Time + date(2019, 11, 2, 12, 0, 0, 123000, LA), + MICROS_PER_DAY, LA) === + // 2019-11-3 is the start of Pacific Standard Time + date(2019, 11, 3, 12, 0, 0, 123000, LA)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 00fa346b908d5..6ed64ece9b296 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} +import java.time.temporal.ChronoUnit import java.util.Locale import org.apache.hadoop.io.{LongWritable, Text} @@ -2526,4 +2527,51 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { assert(e.getMessage.contains("long overflow")) } } + + test("SPARK-34761: add/subtract a day-time interval to/from a timestamp") { + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { + outstandingZoneIds.foreach { zid => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> zid.getId) { + Seq( + (LocalDateTime.of(1900, 1, 1, 0, 0, 0, 123456000), Duration.ofDays(0)) -> + LocalDateTime.of(1900, 1, 1, 0, 0, 0, 123456000), + (LocalDateTime.of(1970, 1, 1, 0, 0, 0, 100000000), Duration.ofDays(-1)) -> + LocalDateTime.of(1969, 12, 31, 0, 0, 0, 100000000), + (LocalDateTime.of(2021, 3, 14, 1, 2, 3), Duration.ofDays(1)) -> + LocalDateTime.of(2021, 3, 15, 1, 2, 3), + (LocalDateTime.of(2020, 12, 31, 23, 59, 59, 999000000), + Duration.ofDays(2 * 30).plusMillis(1)) -> LocalDateTime.of(2021, 3, 2, 0, 0, 0), + (LocalDateTime.of(2020, 3, 16, 0, 0, 0, 1000), Duration.of(-1, ChronoUnit.MICROS)) -> + LocalDateTime.of(2020, 3, 16, 0, 0, 0), + (LocalDateTime.of(2020, 2, 29, 12, 13, 14), Duration.ofDays(365)) -> + LocalDateTime.of(2021, 2, 28, 12, 13, 14), + (LocalDateTime.of(1582, 10, 4, 1, 2, 3, 40000000), + Duration.ofDays(10).plusMillis(60)) -> + LocalDateTime.of(1582, 10, 14, 1, 2, 3, 100000000) + ).foreach { case ((ldt, duration), expected) => + val ts = ldt.atZone(zid).toInstant + val result = expected.atZone(zid).toInstant + val df = Seq((ts, duration, result)).toDF("ts", "interval", "result") + checkAnswer( + df.select($"ts" + $"interval", $"interval" + $"ts", $"result" - $"interval"), + Row(result, result, ts)) + } + } + } + + Seq( + "2021-03-16T18:56:00Z" -> "ts + i", + "1900-03-16T18:56:00Z" -> "ts - i").foreach { case (instant, op) => + val e = intercept[SparkException] { + Seq( + (Instant.parse(instant), Duration.of(Long.MaxValue, ChronoUnit.MICROS))) + .toDF("ts", "i") + .selectExpr(op) + .collect() + }.getCause + assert(e.isInstanceOf[ArithmeticException]) + assert(e.getMessage.contains("long overflow")) + } + } + } }