Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-36632][SQL] DivideYMInterval and DivideDTInterval should throw the same exception when divide by zero. #33889

Closed
wants to merge 15 commits into from
Expand Up @@ -612,6 +612,13 @@ case class DivideYMInterval(
override def inputTypes: Seq[AbstractDataType] = Seq(YearMonthIntervalType, NumericType)
override def dataType: DataType = YearMonthIntervalType()

@transient
private lazy val checkFunc: (Any) => Unit = right.dataType match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ... val divideByZeroCheck: Any => Unit = ...

case _: DecimalType => (num) =>
if (num.asInstanceOf[Decimal].isZero) throw QueryExecutionErrors.divideByZeroError()
case _ => (num) => if (num == 0) throw QueryExecutionErrors.divideByZeroError()
}

@transient
private lazy val evalFunc: (Int, Any) => Any = right.dataType match {
case LongType => (months: Int, num) =>
Expand All @@ -629,6 +636,7 @@ case class DivideYMInterval(

override def nullSafeEval(interval: Any, num: Any): Any = {
checkDivideOverflow(interval.asInstanceOf[Int], Int.MinValue, right, num)
checkFunc(num)
evalFunc(interval.asInstanceOf[Int], num)
}

Expand All @@ -641,6 +649,11 @@ case class DivideYMInterval(
val javaType = CodeGenerator.javaType(dataType)
val months = left.genCode(ctx)
val num = right.genCode(ctx)
val checkDivideByZero =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add

private def divideByZeroCheckCodegen(value: String): String = right.dataType match {
  case _: DecimalType => "if ($value.isZero()) throw ..."
  case _ => "if ($value == 0) throw ..."
}

to avoid duplicate code here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and we can move these util functions to IntervalDivide

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

s"""
|if (${num.value} == 0)
| throw QueryExecutionErrors.divideByZeroError();
|""".stripMargin
val checkIntegralDivideOverflow =
s"""
|if (${months.value} == ${Int.MinValue} && ${num.value} == -1)
Expand All @@ -650,17 +663,26 @@ case class DivideYMInterval(
// Similarly to non-codegen code. The result of `divide(Int, Long, ...)` must fit to `Int`.
// Casting to `Int` is safe here.
s"""
|$checkDivideByZero
|$checkIntegralDivideOverflow
|${ev.value} = ($javaType)$math.divide($m, $n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
case _: DecimalType =>
defineCodeGen(ctx, ev, (m, n) =>
s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" +
".setScale(0, java.math.RoundingMode.HALF_UP).intValueExact()")
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|if ($n.isZero())
| throw QueryExecutionErrors.divideByZeroError();
|${ev.value} = ((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()
| .setScale(0, java.math.RoundingMode.HALF_UP).intValueExact();
""".stripMargin)
case _: FractionalType =>
val math = classOf[DoubleMath].getName
defineCodeGen(ctx, ev, (m, n) =>
s"$math.roundToInt($m / (double)$n, java.math.RoundingMode.HALF_UP)")
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|if ($n == 0)
| throw QueryExecutionErrors.divideByZeroError();
|${ev.value} = $math.roundToInt($m / (double)$n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
}

override def toString: String = s"($left / $right)"
Expand All @@ -683,6 +705,13 @@ case class DivideDTInterval(
override def inputTypes: Seq[AbstractDataType] = Seq(DayTimeIntervalType, NumericType)
override def dataType: DataType = DayTimeIntervalType()

@transient
private lazy val checkFunc: (Any) => Unit = right.dataType match {
case _: DecimalType => (num) =>
if (num.asInstanceOf[Decimal].isZero) throw QueryExecutionErrors.divideByZeroError()
case _ => (num) => if (num == 0) throw QueryExecutionErrors.divideByZeroError()
}

@transient
private lazy val evalFunc: (Long, Any) => Any = right.dataType match {
case _: IntegralType => (micros: Long, num) =>
Expand All @@ -696,6 +725,7 @@ case class DivideDTInterval(

override def nullSafeEval(interval: Any, num: Any): Any = {
checkDivideOverflow(interval.asInstanceOf[Long], Long.MinValue, right, num)
checkFunc(num)
evalFunc(interval.asInstanceOf[Long], num)
}

Expand All @@ -704,24 +734,38 @@ case class DivideDTInterval(
val math = classOf[LongMath].getName
val micros = left.genCode(ctx)
val num = right.genCode(ctx)
val checkDivideByZero =
s"""
|if (${num.value} == 0)
| throw QueryExecutionErrors.divideByZeroError();
|""".stripMargin
val checkIntegralDivideOverflow =
s"""
|if (${micros.value} == ${Long.MinValue}L && ${num.value} == -1L)
| throw QueryExecutionErrors.overflowInIntegralDivideError();
|""".stripMargin
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|$checkDivideByZero
|$checkIntegralDivideOverflow
|${ev.value} = $math.divide($m, $n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
case _: DecimalType =>
defineCodeGen(ctx, ev, (m, n) =>
s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" +
".setScale(0, java.math.RoundingMode.HALF_UP).longValueExact()")
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|if ($n.isZero())
| throw QueryExecutionErrors.divideByZeroError();
|${ev.value} = ((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()
| .setScale(0, java.math.RoundingMode.HALF_UP).longValueExact();
""".stripMargin)
case _: FractionalType =>
val math = classOf[DoubleMath].getName
defineCodeGen(ctx, ev, (m, n) =>
s"$math.roundToLong($m / (double)$n, java.math.RoundingMode.HALF_UP)")
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|if ($n == 0)
| throw QueryExecutionErrors.divideByZeroError();
|${ev.value} = $math.roundToLong($m / (double)$n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
}

override def toString: String = s"($left / $right)"
Expand Down
Expand Up @@ -412,8 +412,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

Seq(
(Period.ofMonths(1), 0) -> "/ by zero",
(Period.ofMonths(Int.MinValue), 0d) -> "input is infinite or NaN",
(Period.ofMonths(1), 0) -> "divide by zero",
(Period.ofMonths(Int.MinValue), 0d) -> "divide by zero",
(Period.ofMonths(-100), Float.NaN) -> "input is infinite or NaN"
).foreach { case ((period, num), expectedErrMsg) =>
checkExceptionInExpression[ArithmeticException](
Expand Down Expand Up @@ -447,8 +447,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

Seq(
(Duration.ofDays(1), 0) -> "/ by zero",
(Duration.ofMillis(Int.MinValue), 0d) -> "input is infinite or NaN",
(Duration.ofDays(1), 0) -> "divide by zero",
(Duration.ofMillis(Int.MinValue), 0d) -> "divide by zero",
(Duration.ofSeconds(-100), Float.NaN) -> "input is infinite or NaN"
).foreach { case ((period, num), expectedErrMsg) =>
checkExceptionInExpression[ArithmeticException](
Expand Down
Expand Up @@ -209,8 +209,8 @@ select interval '2 seconds' / 0
-- !query schema
struct<>
-- !query output
java.lang.ArithmeticException
/ by zero
org.apache.spark.SparkArithmeticException
divide by zero


-- !query
Expand Down Expand Up @@ -242,8 +242,8 @@ select interval '2' year / 0
-- !query schema
struct<>
-- !query output
java.lang.ArithmeticException
/ by zero
org.apache.spark.SparkArithmeticException
divide by zero


-- !query
Expand Down
Expand Up @@ -203,8 +203,8 @@ select interval '2 seconds' / 0
-- !query schema
struct<>
-- !query output
java.lang.ArithmeticException
/ by zero
org.apache.spark.SparkArithmeticException
divide by zero


-- !query
Expand Down Expand Up @@ -236,8 +236,8 @@ select interval '2' year / 0
-- !query schema
struct<>
-- !query output
java.lang.ArithmeticException
/ by zero
org.apache.spark.SparkArithmeticException
divide by zero


-- !query
Expand Down
Expand Up @@ -2737,7 +2737,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
Seq((Period.ofYears(9999), 0)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e.isInstanceOf[ArithmeticException])
assert(e.getMessage.contains("/ by zero"))
assert(e.getMessage.contains("divide by zero"))

val e2 = intercept[SparkException] {
Seq((Period.ofYears(9999), 0d)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e2.isInstanceOf[ArithmeticException])
assert(e2.getMessage.contains("divide by zero"))

val e3 = intercept[SparkException] {
Seq((Period.ofYears(9999), BigDecimal(0))).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e3.isInstanceOf[ArithmeticException])
assert(e3.getMessage.contains("divide by zero"))
}

test("SPARK-34875: divide day-time interval by numeric") {
Expand Down Expand Up @@ -2772,7 +2784,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
Seq((Duration.ofDays(9999), 0)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e.isInstanceOf[ArithmeticException])
assert(e.getMessage.contains("/ by zero"))
assert(e.getMessage.contains("divide by zero"))

val e2 = intercept[SparkException] {
Seq((Duration.ofDays(9999), 0d)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e2.isInstanceOf[ArithmeticException])
assert(e2.getMessage.contains("divide by zero"))

val e3 = intercept[SparkException] {
Seq((Duration.ofDays(9999), BigDecimal(0))).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e3.isInstanceOf[ArithmeticException])
assert(e3.getMessage.contains("divide by zero"))
}

test("SPARK-34896: return day-time interval from dates subtraction") {
Expand Down