Skip to content

Commit

Permalink
[SPARK-31527][SQL] date add/subtract interval only allow those day pr…
Browse files Browse the repository at this point in the history
…ecision in ansi mode

### What changes were proposed in this pull request?

To follow ANSI,the expressions - `date + interval`, `interval + date` and `date - interval` should only accept intervals which the `microseconds` part is 0.

### Why are the changes needed?

Better ANSI compliance

### Does this PR introduce any user-facing change?

No, this PR should target 3.0.0 in which this feature is newly added.

### How was this patch tested?

add more unit tests

Closes #28310 from yaooqinn/SPARK-31527.

Authored-by: Kent Yao <yaooqinn@hotmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
yaooqinn authored and cloud-fan committed Apr 27, 2020
1 parent a911287 commit ebc8fa5
Show file tree
Hide file tree
Showing 12 changed files with 888 additions and 21 deletions.
Expand Up @@ -246,7 +246,7 @@ class Analyzer(
ResolveLambdaVariables(conf) ::
ResolveTimeZone(conf) ::
ResolveRandomSeed ::
ResolveBinaryArithmetic(conf) ::
ResolveBinaryArithmetic ::
TypeCoercion.typeCoercionRules(conf) ++
extendedResolutionRules : _*),
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
Expand All @@ -268,17 +268,21 @@ class Analyzer(
/**
* For [[Add]]:
* 1. if both side are interval, stays the same;
* 2. else if one side is interval, turns it to [[TimeAdd]];
* 3. else if one side is date, turns it to [[DateAdd]] ;
* 4. else stays the same.
* 2. else if one side is date and the other is interval,
* turns it to [[DateAddInterval]];
* 3. else if one side is interval, turns it to [[TimeAdd]];
* 4. else if one side is date, turns it to [[DateAdd]] ;
* 5. else stays the same.
*
* For [[Subtract]]:
* 1. if both side are interval, stays the same;
* 2. else if the right side is an interval, turns it to [[TimeSub]];
* 3. else if one side is timestamp, turns it to [[SubtractTimestamps]];
* 4. else if the right side is date, turns it to [[DateDiff]]/[[SubtractDates]];
* 5. else if the left side is date, turns it to [[DateSub]];
* 6. else turns it to stays the same.
* 2. else if the left side is date and the right side is interval,
* turns it to [[DateAddInterval(l, -r)]];
* 3. else if the right side is an interval, turns it to [[TimeSub]];
* 4. else if one side is timestamp, turns it to [[SubtractTimestamps]];
* 5. else if the right side is date, turns it to [[DateDiff]]/[[SubtractDates]];
* 6. else if the left side is date, turns it to [[DateSub]];
* 7. else turns it to stays the same.
*
* For [[Multiply]]:
* 1. If one side is interval, turns it to [[MultiplyInterval]];
Expand All @@ -288,19 +292,22 @@ class Analyzer(
* 1. If the left side is interval, turns it to [[DivideInterval]];
* 2. otherwise, stays the same.
*/
case class ResolveBinaryArithmetic(conf: SQLConf) extends Rule[LogicalPlan] {
object ResolveBinaryArithmetic extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p: LogicalPlan => p.transformExpressionsUp {
case a @ Add(l, r) if a.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, CalendarIntervalType) => a
case (DateType, CalendarIntervalType) => DateAddInterval(l, r)
case (_, CalendarIntervalType) => Cast(TimeAdd(l, r), l.dataType)
case (CalendarIntervalType, DateType) => DateAddInterval(r, l)
case (CalendarIntervalType, _) => Cast(TimeAdd(r, l), r.dataType)
case (DateType, dt) if dt != StringType => DateAdd(l, r)
case (dt, DateType) if dt != StringType => DateAdd(r, l)
case _ => a
}
case s @ Subtract(l, r) if s.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, CalendarIntervalType) => s
case (DateType, CalendarIntervalType) => DateAddInterval(l, UnaryMinus(r))
case (_, CalendarIntervalType) => Cast(TimeSub(l, r), l.dataType)
case (TimestampType, _) => SubtractTimestamps(l, r)
case (_, TimestampType) => SubtractTimestamps(l, r)
Expand Down
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.util.{DateTimeUtils, LegacyDateFormats, Tim
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.catalyst.util.LegacyDateFormats.SIMPLE_DATE_FORMAT
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

Expand Down Expand Up @@ -1196,6 +1197,68 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S
}
}

/**
* Adds date and an interval.
*
* When ansi mode is on, the microseconds part of interval needs to be 0, otherwise a runtime
* [[IllegalArgumentException]] will be raised.
* When ansi mode is off, if the microseconds part of interval is 0, we perform date + interval
* for better performance. if the microseconds part is not 0, then the date will be converted to a
* timestamp to add with the whole interval parts.
*/
case class DateAddInterval(
start: Expression,
interval: Expression,
timeZoneId: Option[String] = None,
ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression {

override def left: Expression = start
override def right: Expression = interval

override def toString: String = s"$left + $right"
override def sql: String = s"${left.sql} + ${right.sql}"
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, CalendarIntervalType)

override def dataType: DataType = DateType

override def nullSafeEval(start: Any, interval: Any): Any = {
val itvl = interval.asInstanceOf[CalendarInterval]
if (ansiEnabled || itvl.microseconds == 0) {
DateTimeUtils.dateAddInterval(start.asInstanceOf[Int], itvl)
} else {
val startTs = DateTimeUtils.epochDaysToMicros(start.asInstanceOf[Int], zoneId)
val resultTs = DateTimeUtils.timestampAddInterval(
startTs, itvl.months, itvl.days, itvl.microseconds, zoneId)
DateTimeUtils.microsToEpochDays(resultTs, zoneId)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
nullSafeCodeGen(ctx, ev, (sd, i) => if (ansiEnabled) {
s"""${ev.value} = $dtu.dateAddInterval($sd, $i);"""
} else {
val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
val startTs = ctx.freshName("startTs")
val resultTs = ctx.freshName("resultTs")
s"""
|if ($i.microseconds == 0) {
| ${ev.value} = $dtu.dateAddInterval($sd, $i);
|} else {
| long $startTs = $dtu.epochDaysToMicros($sd, $zid);
| long $resultTs =
| $dtu.timestampAddInterval($startTs, $i.months, $i.days, $i.microseconds, $zid);
| ${ev.value} = $dtu.microsToEpochDays($resultTs, $zid);
|}
|""".stripMargin
})
}

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
}

/**
* This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function
* takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in UTC, and
Expand Down
Expand Up @@ -618,6 +618,22 @@ object DateTimeUtils {
instantToMicros(resultTimestamp.toInstant)
}

/**
* Add the date and the interval's months and days.
* Returns a date value, expressed in days since 1.1.1970.
*
* @throws DateTimeException if the result exceeds the supported date range
* @throws IllegalArgumentException if the interval has `microseconds` part
*/
def dateAddInterval(
start: SQLDate,
interval: CalendarInterval): SQLDate = {
require(interval.microseconds == 0,
"Cannot add hours, minutes or seconds, milliseconds, microseconds to a date")
val ld = LocalDate.ofEpochDay(start).plusMonths(interval.months).plusDays(interval.days)
localDateToDays(ld)
}

/**
* Returns number of months between time1 and time2. time1 and time2 are expressed in
* microseconds since 1.1.1970. If time1 is later than time2, the result is positive.
Expand Down
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.{SparkFunSuite, SparkUpgradeException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -358,6 +359,40 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(DateAdd, DateType, IntegerType)
}

test("date add interval") {
val d = Date.valueOf("2016-02-28")
Seq("true", "false") foreach { flag =>
withSQLConf((SQLConf.ANSI_ENABLED.key, flag)) {
checkEvaluation(
DateAddInterval(Literal(d), Literal(new CalendarInterval(0, 1, 0))),
DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29")))
checkEvaluation(
DateAddInterval(Literal(d), Literal(new CalendarInterval(1, 1, 0))),
DateTimeUtils.fromJavaDate(Date.valueOf("2016-03-29")))
checkEvaluation(DateAddInterval(Literal(d), Literal.create(null, CalendarIntervalType)),
null)
checkEvaluation(DateAddInterval(Literal.create(null, DateType),
Literal(new CalendarInterval(1, 1, 0))),
null)
}
}

withSQLConf((SQLConf.ANSI_ENABLED.key, "true")) {
checkExceptionInExpression[IllegalArgumentException](
DateAddInterval(Literal(d), Literal(new CalendarInterval(1, 1, 25 * MICROS_PER_HOUR))),
"Cannot add hours, minutes or seconds, milliseconds, microseconds to a date")
}

withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) {
checkEvaluation(
DateAddInterval(Literal(d), Literal(new CalendarInterval(1, 1, 25))),
DateTimeUtils.fromJavaDate(Date.valueOf("2016-03-29")))
checkEvaluation(
DateAddInterval(Literal(d), Literal(new CalendarInterval(1, 1, 25 * MICROS_PER_HOUR))),
DateTimeUtils.fromJavaDate(Date.valueOf("2016-03-30")))
}
}

test("date_sub") {
checkEvaluation(
DateSub(Literal(Date.valueOf("2015-01-01")), Literal(1.toByte)),
Expand Down
Expand Up @@ -176,5 +176,9 @@ class ExpressionSQLBuilderSuite extends SparkFunSuite {
TimeSub('a, interval),
"`a` - INTERVAL '1 hours'"
)
checkSQL(
DateAddInterval('a, interval),
"`a` + INTERVAL '1 hours'"
)
}
}
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {

Expand Down Expand Up @@ -391,6 +391,14 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
assert(dateAddMonths(input, -13) === days(1996, 1, 28))
}

test("date add interval with day precision") {
val input = days(1997, 2, 28, 10, 30)
assert(dateAddInterval(input, new CalendarInterval(36, 0, 0)) === days(2000, 2, 28))
assert(dateAddInterval(input, new CalendarInterval(36, 47, 0)) === days(2000, 4, 15))
assert(dateAddInterval(input, new CalendarInterval(-13, 0, 0)) === days(1996, 1, 28))
intercept[IllegalArgumentException](dateAddInterval(input, new CalendarInterval(36, 47, 1)))
}

test("timestamp add months") {
val ts1 = date(1997, 2, 28, 10, 30, 0)
val ts2 = date(2000, 2, 28, 10, 30, 0, 123000)
Expand Down
@@ -0,0 +1 @@
--IMPORT datetime.sql

0 comments on commit ebc8fa5

Please sign in to comment.