Skip to content

Commit

Permalink
[SPARK-34739][SQL] Support add/subtract of a year-month interval to/f…
Browse files Browse the repository at this point in the history
…rom a timestamp

### What changes were proposed in this pull request?
Support `timestamp +/- year-month interval`. In the PR, I propose to introduce new binary expression `TimestampAddYMInterval` similarly to `DateAddYMInterval`. It invokes new method `timestampAddMonths` from `DateTimeUtils` by passing a timestamp as an offset in microseconds since the epoch, amount of months from the giveb year-month interval, and the time zone ID in which the operation is performed. The `timestampAddMonths()` method converts the input microseconds to a local timestamp, adds months to it, and converts the results back to an instant in microseconds at the given time zone.

### Why are the changes needed?
To conform the ANSI SQL standard which requires to support such operation over timestamps and intervals:
<img width="811" alt="Screenshot 2021-03-12 at 11 36 14" src="https://user-images.githubusercontent.com/1580697/111081674-865d4900-8515-11eb-86c8-3538ecaf4804.png">

### Does this PR introduce _any_ user-facing change?
Should not since new intervals have not been released yet.

### How was this patch tested?
By running new tests:
```
$ build/sbt "test:testOnly *DateTimeUtilsSuite"
$ build/sbt "test:testOnly *DateExpressionsSuite"
$ build/sbt "test:testOnly *ColumnExpressionSuite"
```

Closes #31832 from MaxGekk/timestamp-add-year-month-interval.

Authored-by: Max Gekk <max.gekk@gmail.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
MaxGekk committed Mar 15, 2021
1 parent 363a7f0 commit 9809a2f
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ class Analyzer(override val catalogManager: CatalogManager)
case a @ Add(l, r, f) if a.childrenResolved => (l.dataType, r.dataType) match {
case (DateType, YearMonthIntervalType) => DateAddYMInterval(l, r)
case (YearMonthIntervalType, DateType) => DateAddYMInterval(r, l)
case (TimestampType, YearMonthIntervalType) => TimestampAddYMInterval(l, r)
case (YearMonthIntervalType, TimestampType) => TimestampAddYMInterval(r, l)
case (CalendarIntervalType, CalendarIntervalType) => a
case (DateType, CalendarIntervalType) => DateAddInterval(l, r, ansiEnabled = f)
case (_, CalendarIntervalType) => Cast(TimeAdd(l, r), l.dataType)
Expand All @@ -352,6 +354,8 @@ class Analyzer(override val catalogManager: CatalogManager)
case s @ Subtract(l, r, f) if s.childrenResolved => (l.dataType, r.dataType) match {
case (DateType, YearMonthIntervalType) =>
DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, f)))
case (TimestampType, YearMonthIntervalType) =>
DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, f)))
case (CalendarIntervalType, CalendarIntervalType) => s
case (DateType, CalendarIntervalType) =>
DatetimeSub(l, r, DateAddInterval(l, UnaryMinus(r, f), ansiEnabled = f))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1522,6 +1522,40 @@ case class DateAddYMInterval(date: Expression, interval: Expression) extends Add
override def sql: String = s"${left.sql} + ${right.sql}"
}

// Adds the year-month interval to the timestamp
case class TimestampAddYMInterval(
timestamp: Expression,
interval: Expression,
timeZoneId: Option[String] = None)
extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes with NullIntolerant {

def this(timestamp: Expression, interval: Expression) = this(timestamp, interval, None)

override def left: Expression = timestamp
override def right: Expression = interval

override def toString: String = s"$left + $right"
override def sql: String = s"${left.sql} + ${right.sql}"
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, YearMonthIntervalType)

override def dataType: DataType = TimestampType

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

override def nullSafeEval(micros: Any, months: Any): Any = {
timestampAddMonths(micros.asInstanceOf[Long], months.asInstanceOf[Int], zoneId)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (micros, months) => {
s"""$dtu.timestampAddMonths($micros, $months, $zid)"""
})
}
}

/**
* Returns number of months between times `timestamp1` and `timestamp2`.
* If `timestamp1` is later than `timestamp2`, then the result is positive.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,20 @@ object DateTimeUtils {
localDateToDays(daysToLocalDate(days).plusMonths(months))
}

/**
* Adds months to a timestamp at the given time zone. It converts the input timestamp to a local
* timestamp at the given time zone, adds months, and converts the resulted local timestamp
* back to a timestamp, expressed in microseconds since 1970-01-01 00:00:00Z.
*
* @param micros The input timestamp value, expressed in microseconds since 1970-01-01 00:00:00Z
* @param months The amount of months to add. It can be positive or negative.
* @param zoneId The time zone ID at which the operation is performed.
* @return A timestamp value, expressed in microseconds since 1970-01-01 00:00:00Z.
*/
def timestampAddMonths(micros: Long, months: Int, zoneId: ZoneId): Long = {
instantToMicros(microsToInstant(micros).atZone(zoneId).plusMonths(months).toInstant)
}

/**
* Adds a full interval (months, days, microseconds) a timestamp represented as the number of
* microseconds since 1970-01-01 00:00:00Z.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1500,4 +1500,42 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
}
}

test("SPARK-34739: add a year-month interval to a timestamp") {
val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS", Locale.US)
for (zid <- outstandingZoneIds) {
val timeZoneId = Option(zid.getId)
sdf.setTimeZone(TimeZone.getTimeZone(zid))

checkEvaluation(
TimestampAddYMInterval(
Literal(new Timestamp(sdf.parse("2016-01-29 10:11:12.123").getTime)),
Literal(Period.ofMonths(2)),
timeZoneId),
DateTimeUtils.fromJavaTimestamp(
new Timestamp(sdf.parse("2016-03-29 10:11:12.123").getTime)))

checkEvaluation(
TimestampAddYMInterval(
Literal.create(null, TimestampType),
Literal(Period.ofMonths(1)),
timeZoneId),
null)
checkEvaluation(
TimestampAddYMInterval(
Literal(new Timestamp(sdf.parse("2016-01-29 10:00:00.000").getTime)),
Literal.create(null, YearMonthIntervalType),
timeZoneId),
null)
checkEvaluation(
TimestampAddYMInterval(
Literal.create(null, TimestampType),
Literal.create(null, YearMonthIntervalType),
timeZoneId),
null)
checkConsistencyBetweenInterpretedAndCodegen(
(ts: Expression, interval: Expression) => TimestampAddYMInterval(ts, interval, timeZoneId),
TimestampType, YearMonthIntervalType)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,22 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
assert(dateAddMonths(input, -13) === days(1996, 1, 28))
}

test("SPARK-34739: timestamp add months") {
outstandingZoneIds.foreach { zid =>
Seq(
(date(2021, 3, 13, 21, 28, 13, 123456, zid), 0, date(2021, 3, 13, 21, 28, 13, 123456, zid)),
(date(2021, 3, 31, 0, 0, 0, 123, zid), -1, date(2021, 2, 28, 0, 0, 0, 123, zid)),
(date(2020, 2, 29, 1, 2, 3, 4, zid), 12, date(2021, 2, 28, 1, 2, 3, 4, zid)),
(date(1, 1, 1, 0, 0, 0, 1, zid), 2020 * 12, date(2021, 1, 1, 0, 0, 0, 1, zid)),
(date(1581, 10, 7, 23, 59, 59, 999, zid), 12, date(1582, 10, 7, 23, 59, 59, 999, zid)),
(date(9999, 12, 31, 23, 59, 59, 999999, zid), -11,
date(9999, 1, 31, 23, 59, 59, 999999, zid))
).foreach { case (timestamp, months, expected) =>
assert(timestampAddMonths(timestamp, months, zid) === expected)
}
}
}

test("date add interval with day precision") {
val input = days(1997, 2, 28)
assert(dateAddInterval(input, new CalendarInterval(36, 0, 0)) === days(2000, 2, 28))
Expand All @@ -401,7 +417,7 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
intercept[IllegalArgumentException](dateAddInterval(input, new CalendarInterval(36, 47, 1)))
}

test("timestamp add months") {
test("timestamp add interval") {
val ts1 = date(1997, 2, 28, 10, 30, 0)
val ts2 = date(2000, 2, 28, 10, 30, 0, 123000)
assert(timestampAddInterval(ts1, 36, 0, 123000, defaultZoneId) === ts2)
Expand Down
132 changes: 109 additions & 23 deletions sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql

import java.sql.{Date, Timestamp}
import java.time.{Duration, LocalDate, Period}
import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
import java.util.Locale

import org.apache.hadoop.io.{LongWritable, Text}
Expand All @@ -28,6 +28,7 @@ import org.scalatest.matchers.should.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.sql.UpdateFieldsBenchmark._
import org.apache.spark.sql.catalyst.expressions.{InSet, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{outstandingTimezonesIds, outstandingZoneIds}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -2391,17 +2392,23 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {

test("SPARK-34721: add a year-month interval to a date") {
withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") {
Seq(
(LocalDate.of(1900, 10, 1), Period.ofMonths(0)) -> LocalDate.of(1900, 10, 1),
(LocalDate.of(1970, 1, 1), Period.ofMonths(-1)) -> LocalDate.of(1969, 12, 1),
(LocalDate.of(2021, 3, 11), Period.ofMonths(1)) -> LocalDate.of(2021, 4, 11),
(LocalDate.of(2020, 12, 31), Period.ofMonths(2)) -> LocalDate.of(2021, 2, 28),
(LocalDate.of(2021, 5, 31), Period.ofMonths(-3)) -> LocalDate.of(2021, 2, 28),
(LocalDate.of(2020, 2, 29), Period.ofYears(1)) -> LocalDate.of(2021, 2, 28),
(LocalDate.of(1, 1, 1), Period.ofYears(2020)) -> LocalDate.of(2021, 1, 1)
).foreach { case ((date, period), result) =>
val df = Seq((date, period)).toDF("date", "interval")
checkAnswer(df.select($"date" + $"interval", $"interval" + $"date"), Row(result, result))
outstandingTimezonesIds.foreach { zid =>
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> zid) {
Seq(
(LocalDate.of(1900, 10, 1), Period.ofMonths(0)) -> LocalDate.of(1900, 10, 1),
(LocalDate.of(1970, 1, 1), Period.ofMonths(-1)) -> LocalDate.of(1969, 12, 1),
(LocalDate.of(2021, 3, 11), Period.ofMonths(1)) -> LocalDate.of(2021, 4, 11),
(LocalDate.of(2020, 12, 31), Period.ofMonths(2)) -> LocalDate.of(2021, 2, 28),
(LocalDate.of(2021, 5, 31), Period.ofMonths(-3)) -> LocalDate.of(2021, 2, 28),
(LocalDate.of(2020, 2, 29), Period.ofYears(1)) -> LocalDate.of(2021, 2, 28),
(LocalDate.of(1, 1, 1), Period.ofYears(2020)) -> LocalDate.of(2021, 1, 1)
).foreach { case ((date, period), result) =>
val df = Seq((date, period)).toDF("date", "interval")
checkAnswer(
df.select($"date" + $"interval", $"interval" + $"date"),
Row(result, result))
}
}
}

val e = intercept[SparkException] {
Expand All @@ -2417,17 +2424,21 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {

test("SPARK-34721: subtract a year-month interval from a date") {
withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") {
Seq(
(LocalDate.of(1582, 10, 4), Period.ofMonths(0)) -> LocalDate.of(1582, 10, 4),
(LocalDate.of(1582, 10, 15), Period.ofMonths(1)) -> LocalDate.of(1582, 9, 15),
(LocalDate.of(1, 1, 1), Period.ofMonths(-1)) -> LocalDate.of(1, 2, 1),
(LocalDate.of(9999, 10, 31), Period.ofMonths(-2)) -> LocalDate.of(9999, 12, 31),
(LocalDate.of(2021, 5, 31), Period.ofMonths(3)) -> LocalDate.of(2021, 2, 28),
(LocalDate.of(2021, 2, 28), Period.ofYears(1)) -> LocalDate.of(2020, 2, 28),
(LocalDate.of(2020, 2, 29), Period.ofYears(4)) -> LocalDate.of(2016, 2, 29)
).foreach { case ((date, period), result) =>
val df = Seq((date, period)).toDF("date", "interval")
checkAnswer(df.select($"date" - $"interval"), Row(result))
outstandingTimezonesIds.foreach { zid =>
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> zid) {
Seq(
(LocalDate.of(1582, 10, 4), Period.ofMonths(0)) -> LocalDate.of(1582, 10, 4),
(LocalDate.of(1582, 10, 15), Period.ofMonths(1)) -> LocalDate.of(1582, 9, 15),
(LocalDate.of(1, 1, 1), Period.ofMonths(-1)) -> LocalDate.of(1, 2, 1),
(LocalDate.of(9999, 10, 31), Period.ofMonths(-2)) -> LocalDate.of(9999, 12, 31),
(LocalDate.of(2021, 5, 31), Period.ofMonths(3)) -> LocalDate.of(2021, 2, 28),
(LocalDate.of(2021, 2, 28), Period.ofYears(1)) -> LocalDate.of(2020, 2, 28),
(LocalDate.of(2020, 2, 29), Period.ofYears(4)) -> LocalDate.of(2016, 2, 29)
).foreach { case ((date, period), result) =>
val df = Seq((date, period)).toDF("date", "interval")
checkAnswer(df.select($"date" - $"interval"), Row(result))
}
}
}

val e = intercept[SparkException] {
Expand All @@ -2440,4 +2451,79 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
assert(e.getMessage.contains("integer overflow"))
}
}

test("SPARK-34739: add a year-month interval to a timestamp") {
withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") {
outstandingZoneIds.foreach { zid =>
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> zid.getId) {
Seq(
(LocalDateTime.of(1900, 1, 1, 0, 0, 0, 123456000), Period.ofMonths(0)) ->
LocalDateTime.of(1900, 1, 1, 0, 0, 0, 123456000),
(LocalDateTime.of(1970, 1, 1, 0, 0, 0, 1000), Period.ofMonths(-1)) ->
LocalDateTime.of(1969, 12, 1, 0, 0, 0, 1000),
(LocalDateTime.of(2021, 3, 14, 1, 2, 3, 0), Period.ofMonths(1)) ->
LocalDateTime.of(2021, 4, 14, 1, 2, 3, 0),
(LocalDateTime.of(2020, 12, 31, 23, 59, 59, 999999000), Period.ofMonths(2)) ->
LocalDateTime.of(2021, 2, 28, 23, 59, 59, 999999000),
(LocalDateTime.of(2021, 5, 31, 0, 0, 1, 0), Period.ofMonths(-3)) ->
LocalDateTime.of(2021, 2, 28, 0, 0, 1, 0),
(LocalDateTime.of(2020, 2, 29, 12, 13, 14), Period.ofYears(1)) ->
LocalDateTime.of(2021, 2, 28, 12, 13, 14),
(LocalDateTime.of(1, 1, 1, 1, 1, 1, 1000), Period.ofYears(2020)) ->
LocalDateTime.of(2021, 1, 1, 1, 1, 1, 1000)
).foreach { case ((ldt, period), expected) =>
val df = Seq((ldt.atZone(zid).toInstant, period)).toDF("ts", "interval")
val result = expected.atZone(zid).toInstant
checkAnswer(df.select($"ts" + $"interval", $"interval" + $"ts"), Row(result, result))
}
}
}

val e = intercept[SparkException] {
Seq((Instant.parse("2021-03-14T18:55:00Z"), Period.ofMonths(Int.MaxValue)))
.toDF("ts", "interval")
.select($"ts" + $"interval")
.collect()
}.getCause
assert(e.isInstanceOf[ArithmeticException])
assert(e.getMessage.contains("long overflow"))
}
}

test("SPARK-34739: subtract a year-month interval from a timestamp") {
withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") {
outstandingZoneIds.foreach { zid =>
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> zid.getId) {
Seq(
(LocalDateTime.of(1582, 10, 4, 0, 0, 0), Period.ofMonths(0)) ->
LocalDateTime.of(1582, 10, 4, 0, 0, 0),
(LocalDateTime.of(1582, 10, 15, 23, 59, 59, 999999000), Period.ofMonths(1)) ->
LocalDateTime.of(1582, 9, 15, 23, 59, 59, 999999000),
(LocalDateTime.of(1, 1, 1, 1, 1, 1, 1000), Period.ofMonths(-1)) ->
LocalDateTime.of(1, 2, 1, 1, 1, 1, 1000),
(LocalDateTime.of(9999, 10, 31, 23, 59, 59, 999000000), Period.ofMonths(-2)) ->
LocalDateTime.of(9999, 12, 31, 23, 59, 59, 999000000),
(LocalDateTime.of(2021, 5, 31, 0, 0, 0, 1000), Period.ofMonths(3)) ->
LocalDateTime.of(2021, 2, 28, 0, 0, 0, 1000),
(LocalDateTime.of(2021, 2, 28, 11, 12, 13, 123456000), Period.ofYears(1)) ->
LocalDateTime.of(2020, 2, 28, 11, 12, 13, 123456000),
(LocalDateTime.of(2020, 2, 29, 1, 2, 3, 5000), Period.ofYears(4)) ->
LocalDateTime.of(2016, 2, 29, 1, 2, 3, 5000)
).foreach { case ((ldt, period), expected) =>
val df = Seq((ldt.atZone(zid).toInstant, period)).toDF("ts", "interval")
checkAnswer(df.select($"ts" - $"interval"), Row(expected.atZone(zid).toInstant))
}
}
}

val e = intercept[SparkException] {
Seq((Instant.parse("2021-03-14T18:55:00Z"), Period.ofMonths(Int.MaxValue)))
.toDF("ts", "interval")
.select($"ts" - $"interval")
.collect()
}.getCause
assert(e.isInstanceOf[ArithmeticException])
assert(e.getMessage.contains("long overflow"))
}
}
}

0 comments on commit 9809a2f

Please sign in to comment.