Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-36632][SQL] DivideYMInterval and DivideDTInterval should throw the same exception when divide by zero. #33889

Closed
wants to merge 15 commits into from
Expand Up @@ -598,6 +598,17 @@ trait IntervalDivide {
}
}
}

def divideByZeroCheck(dataType: DataType, num: Any): Unit = dataType match {
case _: DecimalType =>
if (num.asInstanceOf[Decimal].isZero) throw QueryExecutionErrors.divideByZeroError()
case _ => if (num == 0) throw QueryExecutionErrors.divideByZeroError()
}

def divideByZeroCheckCodegen(dataType: DataType, value: String): String = dataType match {
case _: DecimalType => s"if ($value.isZero()) throw QueryExecutionErrors.divideByZeroError();"
case _ => s"if ($value == 0) throw QueryExecutionErrors.divideByZeroError();"
}
}

// Divide an year-month interval by a numeric
Expand Down Expand Up @@ -629,6 +640,7 @@ case class DivideYMInterval(

override def nullSafeEval(interval: Any, num: Any): Any = {
checkDivideOverflow(interval.asInstanceOf[Int], Int.MinValue, right, num)
divideByZeroCheck(right.dataType, num)
evalFunc(interval.asInstanceOf[Int], num)
}

Expand All @@ -650,17 +662,24 @@ case class DivideYMInterval(
// Similarly to non-codegen code. The result of `divide(Int, Long, ...)` must fit to `Int`.
// Casting to `Int` is safe here.
s"""
|${divideByZeroCheckCodegen(right.dataType, n)}
|$checkIntegralDivideOverflow
|${ev.value} = ($javaType)$math.divide($m, $n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
case _: DecimalType =>
defineCodeGen(ctx, ev, (m, n) =>
s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" +
".setScale(0, java.math.RoundingMode.HALF_UP).intValueExact()")
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|${divideByZeroCheckCodegen(right.dataType, n)}
|${ev.value} = ((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()
| .setScale(0, java.math.RoundingMode.HALF_UP).intValueExact();
""".stripMargin)
case _: FractionalType =>
val math = classOf[DoubleMath].getName
defineCodeGen(ctx, ev, (m, n) =>
s"$math.roundToInt($m / (double)$n, java.math.RoundingMode.HALF_UP)")
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|${divideByZeroCheckCodegen(right.dataType, n)}
|${ev.value} = $math.roundToInt($m / (double)$n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
}

override def toString: String = s"($left / $right)"
Expand Down Expand Up @@ -696,6 +715,7 @@ case class DivideDTInterval(

override def nullSafeEval(interval: Any, num: Any): Any = {
checkDivideOverflow(interval.asInstanceOf[Long], Long.MinValue, right, num)
divideByZeroCheck(right.dataType, num)
evalFunc(interval.asInstanceOf[Long], num)
}

Expand All @@ -711,17 +731,24 @@ case class DivideDTInterval(
|""".stripMargin
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|${divideByZeroCheckCodegen(right.dataType, n)}
|$checkIntegralDivideOverflow
|${ev.value} = $math.divide($m, $n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
case _: DecimalType =>
defineCodeGen(ctx, ev, (m, n) =>
s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" +
".setScale(0, java.math.RoundingMode.HALF_UP).longValueExact()")
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|${divideByZeroCheckCodegen(right.dataType, n)}
|${ev.value} = ((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()
| .setScale(0, java.math.RoundingMode.HALF_UP).longValueExact();
""".stripMargin)
case _: FractionalType =>
val math = classOf[DoubleMath].getName
defineCodeGen(ctx, ev, (m, n) =>
s"$math.roundToLong($m / (double)$n, java.math.RoundingMode.HALF_UP)")
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|${divideByZeroCheckCodegen(right.dataType, n)}
|${ev.value} = $math.roundToLong($m / (double)$n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
}

override def toString: String = s"($left / $right)"
Expand Down
Expand Up @@ -412,8 +412,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

Seq(
(Period.ofMonths(1), 0) -> "/ by zero",
(Period.ofMonths(Int.MinValue), 0d) -> "input is infinite or NaN",
(Period.ofMonths(1), 0) -> "divide by zero",
(Period.ofMonths(Int.MinValue), 0d) -> "divide by zero",
(Period.ofMonths(-100), Float.NaN) -> "input is infinite or NaN"
).foreach { case ((period, num), expectedErrMsg) =>
checkExceptionInExpression[ArithmeticException](
Expand Down Expand Up @@ -447,8 +447,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

Seq(
(Duration.ofDays(1), 0) -> "/ by zero",
(Duration.ofMillis(Int.MinValue), 0d) -> "input is infinite or NaN",
(Duration.ofDays(1), 0) -> "divide by zero",
(Duration.ofMillis(Int.MinValue), 0d) -> "divide by zero",
(Duration.ofSeconds(-100), Float.NaN) -> "input is infinite or NaN"
).foreach { case ((period, num), expectedErrMsg) =>
checkExceptionInExpression[ArithmeticException](
Expand Down
Expand Up @@ -209,8 +209,8 @@ select interval '2 seconds' / 0
-- !query schema
struct<>
-- !query output
java.lang.ArithmeticException
/ by zero
org.apache.spark.SparkArithmeticException
divide by zero


-- !query
Expand Down Expand Up @@ -242,8 +242,8 @@ select interval '2' year / 0
-- !query schema
struct<>
-- !query output
java.lang.ArithmeticException
/ by zero
org.apache.spark.SparkArithmeticException
divide by zero


-- !query
Expand Down
Expand Up @@ -203,8 +203,8 @@ select interval '2 seconds' / 0
-- !query schema
struct<>
-- !query output
java.lang.ArithmeticException
/ by zero
org.apache.spark.SparkArithmeticException
divide by zero


-- !query
Expand Down Expand Up @@ -236,8 +236,8 @@ select interval '2' year / 0
-- !query schema
struct<>
-- !query output
java.lang.ArithmeticException
/ by zero
org.apache.spark.SparkArithmeticException
divide by zero


-- !query
Expand Down
Expand Up @@ -2737,7 +2737,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
Seq((Period.ofYears(9999), 0)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e.isInstanceOf[ArithmeticException])
assert(e.getMessage.contains("/ by zero"))
assert(e.getMessage.contains("divide by zero"))

val e2 = intercept[SparkException] {
Seq((Period.ofYears(9999), 0d)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e2.isInstanceOf[ArithmeticException])
assert(e2.getMessage.contains("divide by zero"))

val e3 = intercept[SparkException] {
Seq((Period.ofYears(9999), BigDecimal(0))).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e3.isInstanceOf[ArithmeticException])
assert(e3.getMessage.contains("divide by zero"))
}

test("SPARK-34875: divide day-time interval by numeric") {
Expand Down Expand Up @@ -2772,7 +2784,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
Seq((Duration.ofDays(9999), 0)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e.isInstanceOf[ArithmeticException])
assert(e.getMessage.contains("/ by zero"))
assert(e.getMessage.contains("divide by zero"))

val e2 = intercept[SparkException] {
Seq((Duration.ofDays(9999), 0d)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e2.isInstanceOf[ArithmeticException])
assert(e2.getMessage.contains("divide by zero"))

val e3 = intercept[SparkException] {
Seq((Duration.ofDays(9999), BigDecimal(0))).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e3.isInstanceOf[ArithmeticException])
assert(e3.getMessage.contains("divide by zero"))
}

test("SPARK-34896: return day-time interval from dates subtraction") {
Expand Down