Skip to content

Commit

Permalink
[SPARK-30919][SQL] Make interval multiply and divide's overflow behav…
Browse files Browse the repository at this point in the history
…ior consistent with other operations

### What changes were proposed in this pull request?

The current behavior of interval multiply and divide follows the ANSI SQL standard when overflow, it is compatible with other operations when `spark.sql.ansi.enabled` is true, but not compatible when `spark.sql.ansi.enabled` is false.

When `spark.sql.ansi.enabled` is false, as the factor is a double value, so it should use java's rounding or truncation behavior for casting double to integrals. when divided by zero, it returns `null`.  we also follow the natural rules for intervals as defined in the Gregorian calendar, so we do not add the month fraction to days but add days fraction to microseconds.

### Why are the changes needed?

Make interval multiply and divide's overflow behavior consistent with other interval operations

### Does this PR introduce any user-facing change?

no, these are new features in 3.0

### How was this patch tested?

add uts

Closes #27672 from yaooqinn/SPARK-30919.

Authored-by: Kent Yao <yaooqinn@hotmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
yaooqinn authored and cloud-fan committed Feb 25, 2020
1 parent e45f2c7 commit 761209c
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 58 deletions.
Expand Up @@ -22,6 +22,7 @@ import java.util.Locale
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.catalyst.util.IntervalUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

Expand Down Expand Up @@ -112,13 +113,14 @@ object ExtractIntervalPart {

abstract class IntervalNumOperation(
interval: Expression,
num: Expression,
operation: (CalendarInterval, Double) => CalendarInterval,
operationName: String)
num: Expression)
extends BinaryExpression with ImplicitCastInputTypes with Serializable {
override def left: Expression = interval
override def right: Expression = num

protected val operation: (CalendarInterval, Double) => CalendarInterval
protected def operationName: String

override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType, DoubleType)
override def dataType: DataType = CalendarIntervalType

Expand All @@ -136,11 +138,29 @@ abstract class IntervalNumOperation(
override def prettyName: String = operationName.stripSuffix("Exact") + "_interval"
}

case class MultiplyInterval(interval: Expression, num: Expression)
extends IntervalNumOperation(interval, num, multiplyExact, "multiplyExact")
case class MultiplyInterval(
interval: Expression,
num: Expression,
checkOverflow: Boolean = SQLConf.get.ansiEnabled)
extends IntervalNumOperation(interval, num) {

override protected val operation: (CalendarInterval, Double) => CalendarInterval =
if (checkOverflow) multiplyExact else multiply

override protected def operationName: String = if (checkOverflow) "multiplyExact" else "multiply"
}

case class DivideInterval(
interval: Expression,
num: Expression,
checkOverflow: Boolean = SQLConf.get.ansiEnabled)
extends IntervalNumOperation(interval, num) {

override protected val operation: (CalendarInterval, Double) => CalendarInterval =
if (checkOverflow) divideExact else divide

case class DivideInterval(interval: Expression, num: Expression)
extends IntervalNumOperation(interval, num, divideExact, "divideExact")
override protected def operationName: String = if (checkOverflow) "divideExact" else "divide"
}

// scalastyle:off line.size.limit
@ExpressionDescription(
Expand Down
Expand Up @@ -405,22 +405,38 @@ object IntervalUtils {
}

/**
* Makes an interval from months, days and micros with the fractional part by
* adding the month fraction to days and the days fraction to micros.
*
* @throws ArithmeticException if the result overflows any field value
* Makes an interval from months, days and micros with the fractional part.
* The overflow style here follows the way of ansi sql standard and the natural rules for
* intervals as defined in the Gregorian calendar. Thus, the days fraction will be added
* to microseconds but the months fraction will not be added to days, and it will throw
* exception if any part overflows.
*/
private def fromDoubles(
monthsWithFraction: Double,
daysWithFraction: Double,
microsWithFraction: Double): CalendarInterval = {
val truncatedMonths = Math.toIntExact(monthsWithFraction.toLong)
val days = daysWithFraction + DAYS_PER_MONTH * (monthsWithFraction - truncatedMonths)
val truncatedDays = Math.toIntExact(days.toLong)
val micros = microsWithFraction + MICROS_PER_DAY * (days - truncatedDays)
val truncatedDays = Math.toIntExact(daysWithFraction.toLong)
val micros = microsWithFraction + MICROS_PER_DAY * (daysWithFraction - truncatedDays)
new CalendarInterval(truncatedMonths, truncatedDays, micros.round)
}

/**
* Makes an interval from months, days and micros with the fractional part.
* The overflow style here follows the way of casting [[java.lang.Double]] to integrals and the
* natural rules for intervals as defined in the Gregorian calendar. Thus, the days fraction
* will be added to microseconds but the months fraction will not be added to days, and there may
* be rounding or truncation in months(or day and microseconds) part.
*/
private def safeFromDoubles(
monthsWithFraction: Double,
daysWithFraction: Double,
microsWithFraction: Double): CalendarInterval = {
val truncatedDays = daysWithFraction.toInt
val micros = microsWithFraction + MICROS_PER_DAY * (daysWithFraction - truncatedDays)
new CalendarInterval(monthsWithFraction.toInt, truncatedDays, micros.round)
}

/**
* Unary minus, return the negated the calendar interval value.
*
Expand Down Expand Up @@ -484,6 +500,13 @@ object IntervalUtils {
new CalendarInterval(months, days, microseconds)
}

/**
* Return a new calendar interval instance of the left interval times a multiplier.
*/
def multiply(interval: CalendarInterval, num: Double): CalendarInterval = {
safeFromDoubles(num * interval.months, num * interval.days, num * interval.microseconds)
}

/**
* Return a new calendar interval instance of the left interval times a multiplier.
*
Expand All @@ -493,6 +516,14 @@ object IntervalUtils {
fromDoubles(num * interval.months, num * interval.days, num * interval.microseconds)
}

/**
* Return a new calendar interval instance of the left interval divides by a dividend.
*/
def divide(interval: CalendarInterval, num: Double): CalendarInterval = {
if (num == 0) return null
safeFromDoubles(interval.months / num, interval.days / num, interval.microseconds / num)
}

/**
* Return a new calendar interval instance of the left interval divides by a dividend.
*
Expand Down
Expand Up @@ -199,11 +199,16 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("multiply") {
def check(interval: String, num: Double, expected: String): Unit = {
val expr = MultiplyInterval(Literal(stringToInterval(interval)), Literal(num))
def check(
interval: String,
num: Double,
expected: String,
isAnsi: Option[Boolean] = None): Unit = {
val expectedRes = safeStringToInterval(expected)
Seq("true", "false").foreach { v =>
val configs = if (isAnsi.isEmpty) Seq("true", "false") else isAnsi.map(_.toString).toSeq
configs.foreach { v =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> v) {
val expr = MultiplyInterval(Literal(stringToInterval(interval)), Literal(num))
if (expectedRes == null) {
checkExceptionInExpression[ArithmeticException](expr, expected)
} else {
Expand All @@ -220,17 +225,23 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
check("1 year 1 second", 0.5, "6 months 500 milliseconds")
check("-100 years -1 millisecond", 0.5, "-50 years -500 microseconds")
check("2 months 4 seconds", -0.5, "-1 months -2 seconds")
check("1 month 2 microseconds", 1.5, "1 months 15 days 3 microseconds")
check("2 months", Int.MaxValue, "integer overflow")
check("1 month 2 microseconds", 1.5, "1 months 3 microseconds")
check("2 months", Int.MaxValue, "integer overflow", Some(true))
check("2 months", Int.MaxValue, Int.MaxValue + " months", Some(false))
}

test("divide") {
def check(interval: String, num: Double, expected: String): Unit = {
val expr = DivideInterval(Literal(stringToInterval(interval)), Literal(num))
def check(
interval: String,
num: Double,
expected: String,
isAnsi: Option[Boolean] = None): Unit = {
val expectedRes = safeStringToInterval(expected)
Seq("true", "false").foreach { v =>
val configs = if (isAnsi.isEmpty) Seq("true", "false") else isAnsi.map(_.toString).toSeq
configs.foreach { v =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> v) {
if (expectedRes == null) {
val expr = DivideInterval(Literal(stringToInterval(interval)), Literal(num))
if (expected != null && expectedRes == null) {
checkExceptionInExpression[ArithmeticException](expr, expected)
} else {
checkEvaluation(expr, expectedRes)
Expand All @@ -245,9 +256,11 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
check("6 years -7 seconds", 3, "2 years -2.333333 seconds")
check("2 years -8 seconds", 0.5, "4 years -16 seconds")
check("-1 month 2 microseconds", -0.25, "4 months -8 microseconds")
check("1 month 3 microsecond", 1.5, "20 days 2 microseconds")
check("1 second", 0, "divide by zero")
check(s"${Int.MaxValue} months", 0.9, "integer overflow")
check("1 month 3 microsecond", 1.5, "2 microseconds")
check("1 second", 0, "divide by zero", Some(true))
check("1 second", 0, null, Some(false))
check(s"${Int.MaxValue} months", 0.9, "integer overflow", Some(true))
check(s"${Int.MaxValue} months", 0.9, Int.MaxValue + " months", Some(false))
}

test("make interval") {
Expand Down
Expand Up @@ -268,37 +268,56 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper {
}

test("multiply by num") {
var interval = new CalendarInterval(0, 0, 0)
assert(interval === multiplyExact(interval, 0))
interval = new CalendarInterval(123, 456, 789)
assert(new CalendarInterval(123 * 42, 456 * 42, 789 * 42) === multiplyExact(interval, 42))
interval = new CalendarInterval(-123, -456, -789)
assert(new CalendarInterval(-123 * 42, -456 * 42, -789 * 42) === multiplyExact(interval, 42))
assert(new CalendarInterval(1, 22, 12 * MICROS_PER_HOUR) ===
multiplyExact(new CalendarInterval(1, 5, 0), 1.5))
assert(new CalendarInterval(2, 14, 12 * MICROS_PER_HOUR) ===
multiplyExact(new CalendarInterval(2, 2, 2 * MICROS_PER_HOUR), 1.2))
Seq[(CalendarInterval, Double) => CalendarInterval](multiply, multiplyExact).foreach { func =>
var interval = new CalendarInterval(0, 0, 0)
assert(interval === func(interval, 0))
interval = new CalendarInterval(123, 456, 789)
assert(new CalendarInterval(123 * 42, 456 * 42, 789 * 42) === func(interval, 42))
interval = new CalendarInterval(-123, -456, -789)
assert(new CalendarInterval(-123 * 42, -456 * 42, -789 * 42) === func(interval, 42))
interval = new CalendarInterval(1, 5, 0)
assert(new CalendarInterval(1, 7, 12 * MICROS_PER_HOUR) === func(interval, 1.5))
interval = new CalendarInterval(2, 2, 2 * MICROS_PER_HOUR)
assert(new CalendarInterval(2, 2, 12 * MICROS_PER_HOUR) === func(interval, 1.2))
}

val interval = new CalendarInterval(2, 0, 0)
assert(multiply(interval, Integer.MAX_VALUE) === new CalendarInterval(Int.MaxValue, 0, 0))
try {
multiplyExact(new CalendarInterval(2, 0, 0), Integer.MAX_VALUE)
multiplyExact(interval, Integer.MAX_VALUE)
fail("Expected to throw an exception on months overflow")
} catch {
case e: ArithmeticException => assert(e.getMessage.contains("overflow"))
}
}

test("divide by num") {
var interval = new CalendarInterval(0, 0, 0)
assert(interval === divideExact(interval, 10))
interval = new CalendarInterval(1, 3, 30 * MICROS_PER_SECOND)
assert(new CalendarInterval(0, 16, 12 * MICROS_PER_HOUR + 15 * MICROS_PER_SECOND) ===
divideExact(interval, 2))
assert(new CalendarInterval(2, 6, MICROS_PER_MINUTE) === divideExact(interval, 0.5))
interval = new CalendarInterval(-1, 0, -30 * MICROS_PER_SECOND)
assert(new CalendarInterval(0, -15, -15 * MICROS_PER_SECOND) === divideExact(interval, 2))
assert(new CalendarInterval(-2, 0, -1 * MICROS_PER_MINUTE) === divideExact(interval, 0.5))
Seq[(CalendarInterval, Double) => CalendarInterval](divide, divideExact).foreach { func =>
var interval = new CalendarInterval(0, 0, 0)
assert(interval === func(interval, 10))
interval = new CalendarInterval(1, 3, 30 * MICROS_PER_SECOND)
assert(new CalendarInterval(0, 1, 12 * MICROS_PER_HOUR + 15 * MICROS_PER_SECOND) ===
func(interval, 2))
assert(new CalendarInterval(2, 6, MICROS_PER_MINUTE) === func(interval, 0.5))
interval = new CalendarInterval(-1, 0, -30 * MICROS_PER_SECOND)
assert(new CalendarInterval(0, 0, -15 * MICROS_PER_SECOND) === func(interval, 2))
assert(new CalendarInterval(-2, 0, -MICROS_PER_MINUTE) === func(interval, 0.5))
}

var interval = new CalendarInterval(Int.MaxValue, Int.MaxValue, 0)
assert(divide(interval, 0.9) === new CalendarInterval(Int.MaxValue, Int.MaxValue,
((Int.MaxValue / 9.0) * MICROS_PER_DAY).round))
try {
divideExact(interval, 0.9)
fail("Expected to throw an exception on integer overflow")
} catch {
case e: ArithmeticException => assert(e.getMessage.contains("integer overflow"))
}

interval = new CalendarInterval(123, 456, 789)
assert(divide(interval, 0) === null)
try {
divideExact(new CalendarInterval(123, 456, 789), 0)
divideExact(interval, 0)
fail("Expected to throw an exception on divide by zero")
} catch {
case e: ArithmeticException => assert(e.getMessage.contains("divide by zero"))
Expand Down
4 changes: 4 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/interval.sql
Expand Up @@ -4,6 +4,10 @@
select 3 * (timestamp'2019-10-15 10:11:12.001002' - date'2019-10-15');
select interval 4 month 2 weeks 3 microseconds * 1.5;
select (timestamp'2019-10-15' - timestamp'2019-10-14') / 1.5;
select interval 2147483647 month * 2;
select interval 2147483647 month / 0.5;
select interval 2147483647 day * 2;
select interval 2147483647 day / 0.5;

-- interval operation with null and zero case
select interval '2 seconds' / 0;
Expand Down
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 81
-- Number of queries: 85


-- !query
Expand All @@ -26,6 +26,42 @@ struct<divide_interval(subtracttimestamps(TIMESTAMP '2019-10-15 00:00:00', TIMES
16 hours


-- !query
select interval 2147483647 month * 2
-- !query schema
struct<>
-- !query output
java.lang.ArithmeticException
integer overflow


-- !query
select interval 2147483647 month / 0.5
-- !query schema
struct<>
-- !query output
java.lang.ArithmeticException
integer overflow


-- !query
select interval 2147483647 day * 2
-- !query schema
struct<>
-- !query output
java.lang.ArithmeticException
integer overflow


-- !query
select interval 2147483647 day / 0.5
-- !query schema
struct<>
-- !query output
java.lang.ArithmeticException
integer overflow


-- !query
select interval '2 seconds' / 0
-- !query schema
Expand Down

0 comments on commit 761209c

Please sign in to comment.