Skip to content

Commit

Permalink
[SPARK-34850][SQL] Support multiply a day-time interval by a numeric
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
1. Add new expression `MultiplyDTInterval` which multiplies a `DayTimeIntervalType` expression by a `NumericType` expression including ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType.
2. Extend binary arithmetic rules to support `numeric * day-time interval` and `day-time interval * numeric`.
3. Invoke `DoubleMath.roundToInt` in `double/float * year-month interval`.

### Why are the changes needed?
To conform the ANSI SQL standard which requires such operation over day-time intervals:
<img width="667" alt="Screenshot 2021-03-22 at 16 33 16" src="https://user-images.githubusercontent.com/1580697/111997810-77d1eb80-8b2c-11eb-951d-e43911d9c5db.png">

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
By running new tests:
```
$ build/sbt "test:testOnly *IntervalExpressionsSuite"
$ build/sbt "test:testOnly *ColumnExpressionSuite"
```

Closes #31951 from MaxGekk/mul-day-time-interval.

Authored-by: Max Gekk <max.gekk@gmail.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
MaxGekk committed Mar 25, 2021
1 parent 8c6748f commit a68d7ca
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 14 deletions.
Expand Up @@ -374,6 +374,8 @@ class Analyzer(override val catalogManager: CatalogManager)
case (_, CalendarIntervalType) => MultiplyInterval(r, l, f)
case (YearMonthIntervalType, _) => MultiplyYMInterval(l, r)
case (_, YearMonthIntervalType) => MultiplyYMInterval(r, l)
case (DayTimeIntervalType, _) => MultiplyDTInterval(l, r)
case (_, DayTimeIntervalType) => MultiplyDTInterval(r, l)
case _ => m
}
case d @ Divide(l, r, f) if d.childrenResolved => (l.dataType, r.dataType) match {
Expand Down
Expand Up @@ -17,8 +17,11 @@

package org.apache.spark.sql.catalyst.expressions

import java.math.RoundingMode
import java.util.Locale

import com.google.common.math.DoubleMath

import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.catalyst.util.IntervalUtils._
Expand Down Expand Up @@ -236,12 +239,12 @@ case class MakeInterval(
nullSafeCodeGen(ctx, ev, (year, month, week, day, hour, min, sec) => {
val iu = IntervalUtils.getClass.getName.stripSuffix("$")
val secFrac = sec.getOrElse("0")
val faileOnErrorBranch = if (failOnError) "throw e;" else s"${ev.isNull} = true;"
val failOnErrorBranch = if (failOnError) "throw e;" else s"${ev.isNull} = true;"
s"""
try {
${ev.value} = $iu.makeInterval($year, $month, $week, $day, $hour, $min, $secFrac);
} catch (java.lang.ArithmeticException e) {
$faileOnErrorBranch
$failOnErrorBranch
}
"""
})
Expand All @@ -268,7 +271,7 @@ case class MultiplyYMInterval(
case LongType => (months: Int, num) =>
Math.toIntExact(Math.multiplyExact(months, num.asInstanceOf[Long]))
case FloatType | DoubleType => (months: Int, num) =>
Math.toIntExact(Math.round(months * num.asInstanceOf[Number].doubleValue()))
DoubleMath.roundToInt(months * num.asInstanceOf[Number].doubleValue(), RoundingMode.HALF_UP)
case _: DecimalType => (months: Int, num) =>
val decimalRes = ((new Decimal).set(months) * num.asInstanceOf[Decimal]).toJavaBigDecimal
decimalRes.setScale(0, java.math.RoundingMode.HALF_UP).intValueExact()
Expand All @@ -285,8 +288,9 @@ case class MultiplyYMInterval(
val jlm = classOf[Math].getName
defineCodeGen(ctx, ev, (m, n) => s"$jlm.toIntExact($jlm.multiplyExact($m, $n))")
case FloatType | DoubleType =>
val jlm = classOf[Math].getName
defineCodeGen(ctx, ev, (m, n) => s"$jlm.toIntExact($jlm.round($m * (double)$n))")
val dm = classOf[DoubleMath].getName
defineCodeGen(ctx, ev, (m, n) =>
s"$dm.roundToInt($m * (double)$n, java.math.RoundingMode.HALF_UP)")
case _: DecimalType =>
defineCodeGen(ctx, ev, (m, n) =>
s"((new Decimal()).set($m).$$times($n)).toJavaBigDecimal()" +
Expand All @@ -295,3 +299,45 @@ case class MultiplyYMInterval(

override def toString: String = s"($left * $right)"
}

// Multiply a day-time interval by a numeric
case class MultiplyDTInterval(
interval: Expression,
num: Expression)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
override def left: Expression = interval
override def right: Expression = num

override def inputTypes: Seq[AbstractDataType] = Seq(DayTimeIntervalType, NumericType)
override def dataType: DataType = DayTimeIntervalType

@transient
private lazy val evalFunc: (Long, Any) => Any = right.dataType match {
case _: IntegralType => (micros: Long, num) =>
Math.multiplyExact(micros, num.asInstanceOf[Number].longValue())
case _: DecimalType => (micros: Long, num) =>
val decimalRes = ((new Decimal).set(micros) * num.asInstanceOf[Decimal]).toJavaBigDecimal
decimalRes.setScale(0, RoundingMode.HALF_UP).longValueExact()
case _: FractionalType => (micros: Long, num) =>
DoubleMath.roundToLong(micros * num.asInstanceOf[Number].doubleValue(), RoundingMode.HALF_UP)
}

override def nullSafeEval(interval: Any, num: Any): Any = {
evalFunc(interval.asInstanceOf[Long], num)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = right.dataType match {
case _: IntegralType =>
defineCodeGen(ctx, ev, (m, n) => s"java.lang.Math.multiplyExact($m, $n)")
case _: DecimalType =>
defineCodeGen(ctx, ev, (m, n) =>
s"((new Decimal()).set($m).$$times($n)).toJavaBigDecimal()" +
".setScale(0, java.math.RoundingMode.HALF_UP).longValueExact()")
case _: FractionalType =>
val dm = classOf[DoubleMath].getName
defineCodeGen(ctx, ev, (m, n) =>
s"$dm.roundToLong($m * (double)$n, java.math.RoundingMode.HALF_UP)")
}

override def toString: String = s"($left * $right)"
}
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import java.time.Period
import java.time.{Duration, Period}
import java.time.temporal.ChronoUnit

import scala.language.implicitConversions

Expand All @@ -26,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.catalyst.util.IntervalUtils.{safeStringToInterval, stringToInterval}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{Decimal, DecimalType, YearMonthIntervalType}
import org.apache.spark.sql.types.{DayTimeIntervalType, Decimal, DecimalType, YearMonthIntervalType}
import org.apache.spark.sql.types.DataTypeTestUtils.numericTypes
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

Expand Down Expand Up @@ -280,7 +281,6 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("SPARK-34824: multiply year-month interval by numeric") {
Seq(
(Period.ofMonths(-100), Float.NaN) -> Period.ofMonths(0),
(Period.ofYears(-123), Literal(null, DecimalType.USER_DEFAULT)) -> null,
(Period.ofMonths(0), 10) -> Period.ofMonths(0),
(Period.ofMonths(10), 0L) -> Period.ofMonths(0),
Expand All @@ -294,13 +294,15 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

Seq(
(Period.ofMonths(2), Int.MaxValue),
(Period.ofMonths(Int.MinValue), 10d),
(Period.ofMonths(200), Double.PositiveInfinity),
(Period.ofMonths(-200), Float.NegativeInfinity)
).foreach { case (period, num) =>
(Period.ofMonths(2), Int.MaxValue) -> "overflow",
(Period.ofMonths(Int.MinValue), 10d) -> "not in range",
(Period.ofMonths(-100), Float.NaN) -> "input is infinite or NaN",
(Period.ofMonths(200), Double.PositiveInfinity) -> "input is infinite or NaN",
(Period.ofMonths(-200), Float.NegativeInfinity) -> "input is infinite or NaN"
).foreach { case ((period, num), expectedErrMsg) =>
checkExceptionInExpression[ArithmeticException](
MultiplyYMInterval(Literal(period), Literal(num)), "overflow")
MultiplyYMInterval(Literal(period), Literal(num)),
expectedErrMsg)
}

numericTypes.foreach { numType =>
Expand All @@ -309,4 +311,37 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
YearMonthIntervalType, numType)
}
}

test("SPARK-34850: multiply day-time interval by numeric") {
Seq(
(Duration.ofHours(-123), Literal(null, DecimalType.USER_DEFAULT)) -> null,
(Duration.ofMinutes(0), 10) -> Duration.ofMinutes(0),
(Duration.ofSeconds(10), 0L) -> Duration.ofSeconds(0),
(Duration.ofMillis(100), -1.toByte) -> Duration.ofMillis(-100),
(Duration.ofDays(12), 0.3d) -> Duration.ofDays(12).multipliedBy(3).dividedBy(10),
(Duration.of(-1000, ChronoUnit.MICROS), 0.3f) -> Duration.of(-300, ChronoUnit.MICROS),
(Duration.ofDays(9999), 0.0001d) -> Duration.ofDays(9999).dividedBy(10000),
(Duration.ofDays(9999), BigDecimal(0.0001)) -> Duration.ofDays(9999).dividedBy(10000)
).foreach { case ((duration, num), expected) =>
checkEvaluation(MultiplyDTInterval(Literal(duration), Literal(num)), expected)
}

Seq(
(Duration.ofDays(-100), Float.NaN) -> "input is infinite or NaN",
(Duration.ofDays(2), Int.MaxValue) -> "overflow",
(Duration.ofHours(Int.MinValue), Short.MinValue) -> "overflow",
(Duration.ofDays(10), BigDecimal(Long.MinValue)) -> "Overflow",
(Duration.ofDays(200), Double.PositiveInfinity) -> "input is infinite or NaN",
(Duration.ofDays(-200), Float.NegativeInfinity) -> "input is infinite or NaN"
).foreach { case ((duration, num), expectedErrMsg) =>
checkExceptionInExpression[ArithmeticException](
MultiplyDTInterval(Literal(duration), Literal(num)), expectedErrMsg)
}

numericTypes.foreach { numType =>
checkConsistencyBetweenInterpretedAndCodegenAllowingException(
(interval: Expression, num: Expression) => MultiplyDTInterval(interval, num),
DayTimeIntervalType, numType)
}
}
}
Expand Up @@ -2613,4 +2613,38 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
assert(e.isInstanceOf[ArithmeticException])
assert(e.getMessage.contains("overflow"))
}

test("SPARK-34850: multiply day-time interval by numeric") {
checkAnswer(
Seq((Duration.ofDays(0), 0)).toDF("i", "n").select($"i" * $"n"),
Row(Duration.ofDays(0)))
checkAnswer(
Seq((Duration.ofDays(0), 10.toByte)).toDF("i", "n").select($"i" * $"n"),
Row(Duration.ofDays(0)))
checkAnswer(
Seq((Duration.ofHours(12), 3.toShort)).toDF("i", "n").select($"n" * $"i"),
Row(Duration.ofDays(1).plusHours(12)))
checkAnswer(
Seq((Duration.ofMinutes(1000), "2")).toDF("i", "n").select($"i" * $"n"),
Row(Duration.ofMinutes(2000)))
checkAnswer(
Seq((Duration.ofSeconds(1), 60L)).toDF("i", "n").select($"n" * $"i"),
Row(Duration.ofMinutes(1)))
checkAnswer(
Seq((Duration.of(-1, ChronoUnit.MICROS), 0.499f)).toDF("i", "n").select($"i" * $"n"),
Row(Duration.of(0, ChronoUnit.MICROS)))
checkAnswer(
Seq((Duration.of(-1, ChronoUnit.MICROS), 0.51d)).toDF("i", "n").select($"i" * $"n"),
Row(Duration.of(-1, ChronoUnit.MICROS)))
checkAnswer(
Seq((Duration.of(-10000000, ChronoUnit.MICROS), BigDecimal(0.0000001d)))
.toDF("i", "n").select($"i" * $"n"),
Row(Duration.of(-1, ChronoUnit.MICROS)))

val e = intercept[SparkException] {
Seq((Duration.ofDays(9999), Long.MinValue)).toDF("i", "n").select($"n" * $"i").collect()
}.getCause
assert(e.isInstanceOf[ArithmeticException])
assert(e.getMessage.contains("overflow"))
}
}

0 comments on commit a68d7ca

Please sign in to comment.