diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index c799c69b9a2ce..4f317083fa708 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -598,6 +598,17 @@ trait IntervalDivide { } } } + + def divideByZeroCheck(dataType: DataType, num: Any): Unit = dataType match { + case _: DecimalType => + if (num.asInstanceOf[Decimal].isZero) throw QueryExecutionErrors.divideByZeroError() + case _ => if (num == 0) throw QueryExecutionErrors.divideByZeroError() + } + + def divideByZeroCheckCodegen(dataType: DataType, value: String): String = dataType match { + case _: DecimalType => s"if ($value.isZero()) throw QueryExecutionErrors.divideByZeroError();" + case _ => s"if ($value == 0) throw QueryExecutionErrors.divideByZeroError();" + } } // Divide an year-month interval by a numeric @@ -629,6 +640,7 @@ case class DivideYMInterval( override def nullSafeEval(interval: Any, num: Any): Any = { checkDivideOverflow(interval.asInstanceOf[Int], Int.MinValue, right, num) + divideByZeroCheck(right.dataType, num) evalFunc(interval.asInstanceOf[Int], num) } @@ -650,17 +662,24 @@ case class DivideYMInterval( // Similarly to non-codegen code. The result of `divide(Int, Long, ...)` must fit to `Int`. // Casting to `Int` is safe here. s""" + |${divideByZeroCheckCodegen(right.dataType, n)} |$checkIntegralDivideOverflow |${ev.value} = ($javaType)$math.divide($m, $n, java.math.RoundingMode.HALF_UP); """.stripMargin) case _: DecimalType => - defineCodeGen(ctx, ev, (m, n) => - s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" + - ".setScale(0, java.math.RoundingMode.HALF_UP).intValueExact()") + nullSafeCodeGen(ctx, ev, (m, n) => + s""" + |${divideByZeroCheckCodegen(right.dataType, n)} + |${ev.value} = ((new Decimal()).set($m).$$div($n)).toJavaBigDecimal() + | .setScale(0, java.math.RoundingMode.HALF_UP).intValueExact(); + """.stripMargin) case _: FractionalType => val math = classOf[DoubleMath].getName - defineCodeGen(ctx, ev, (m, n) => - s"$math.roundToInt($m / (double)$n, java.math.RoundingMode.HALF_UP)") + nullSafeCodeGen(ctx, ev, (m, n) => + s""" + |${divideByZeroCheckCodegen(right.dataType, n)} + |${ev.value} = $math.roundToInt($m / (double)$n, java.math.RoundingMode.HALF_UP); + """.stripMargin) } override def toString: String = s"($left / $right)" @@ -696,6 +715,7 @@ case class DivideDTInterval( override def nullSafeEval(interval: Any, num: Any): Any = { checkDivideOverflow(interval.asInstanceOf[Long], Long.MinValue, right, num) + divideByZeroCheck(right.dataType, num) evalFunc(interval.asInstanceOf[Long], num) } @@ -711,17 +731,24 @@ case class DivideDTInterval( |""".stripMargin nullSafeCodeGen(ctx, ev, (m, n) => s""" + |${divideByZeroCheckCodegen(right.dataType, n)} |$checkIntegralDivideOverflow |${ev.value} = $math.divide($m, $n, java.math.RoundingMode.HALF_UP); """.stripMargin) case _: DecimalType => - defineCodeGen(ctx, ev, (m, n) => - s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" + - ".setScale(0, java.math.RoundingMode.HALF_UP).longValueExact()") + nullSafeCodeGen(ctx, ev, (m, n) => + s""" + |${divideByZeroCheckCodegen(right.dataType, n)} + |${ev.value} = ((new Decimal()).set($m).$$div($n)).toJavaBigDecimal() + | .setScale(0, java.math.RoundingMode.HALF_UP).longValueExact(); + """.stripMargin) case _: FractionalType => val math = classOf[DoubleMath].getName - defineCodeGen(ctx, ev, (m, n) => - s"$math.roundToLong($m / (double)$n, java.math.RoundingMode.HALF_UP)") + nullSafeCodeGen(ctx, ev, (m, n) => + s""" + |${divideByZeroCheckCodegen(right.dataType, n)} + |${ev.value} = $math.roundToLong($m / (double)$n, java.math.RoundingMode.HALF_UP); + """.stripMargin) } override def toString: String = s"($left / $right)" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index 12509ef981a5e..05f9d0f669631 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -412,8 +412,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } Seq( - (Period.ofMonths(1), 0) -> "/ by zero", - (Period.ofMonths(Int.MinValue), 0d) -> "input is infinite or NaN", + (Period.ofMonths(1), 0) -> "divide by zero", + (Period.ofMonths(Int.MinValue), 0d) -> "divide by zero", (Period.ofMonths(-100), Float.NaN) -> "input is infinite or NaN" ).foreach { case ((period, num), expectedErrMsg) => checkExceptionInExpression[ArithmeticException]( @@ -447,8 +447,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } Seq( - (Duration.ofDays(1), 0) -> "/ by zero", - (Duration.ofMillis(Int.MinValue), 0d) -> "input is infinite or NaN", + (Duration.ofDays(1), 0) -> "divide by zero", + (Duration.ofMillis(Int.MinValue), 0d) -> "divide by zero", (Duration.ofSeconds(-100), Float.NaN) -> "input is infinite or NaN" ).foreach { case ((period, num), expectedErrMsg) => checkExceptionInExpression[ArithmeticException]( diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index c347d31d79b43..6a5fa69b09c14 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -209,8 +209,8 @@ select interval '2 seconds' / 0 -- !query schema struct<> -- !query output -java.lang.ArithmeticException -/ by zero +org.apache.spark.SparkArithmeticException +divide by zero -- !query @@ -242,8 +242,8 @@ select interval '2' year / 0 -- !query schema struct<> -- !query output -java.lang.ArithmeticException -/ by zero +org.apache.spark.SparkArithmeticException +divide by zero -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index a8fa101a78e8e..70079da10a46f 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -203,8 +203,8 @@ select interval '2 seconds' / 0 -- !query schema struct<> -- !query output -java.lang.ArithmeticException -/ by zero +org.apache.spark.SparkArithmeticException +divide by zero -- !query @@ -236,8 +236,8 @@ select interval '2' year / 0 -- !query schema struct<> -- !query output -java.lang.ArithmeticException -/ by zero +org.apache.spark.SparkArithmeticException +divide by zero -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index b0cd61341011d..e7ca431726c0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -2737,7 +2737,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { Seq((Period.ofYears(9999), 0)).toDF("i", "n").select($"i" / $"n").collect() }.getCause assert(e.isInstanceOf[ArithmeticException]) - assert(e.getMessage.contains("/ by zero")) + assert(e.getMessage.contains("divide by zero")) + + val e2 = intercept[SparkException] { + Seq((Period.ofYears(9999), 0d)).toDF("i", "n").select($"i" / $"n").collect() + }.getCause + assert(e2.isInstanceOf[ArithmeticException]) + assert(e2.getMessage.contains("divide by zero")) + + val e3 = intercept[SparkException] { + Seq((Period.ofYears(9999), BigDecimal(0))).toDF("i", "n").select($"i" / $"n").collect() + }.getCause + assert(e3.isInstanceOf[ArithmeticException]) + assert(e3.getMessage.contains("divide by zero")) } test("SPARK-34875: divide day-time interval by numeric") { @@ -2772,7 +2784,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { Seq((Duration.ofDays(9999), 0)).toDF("i", "n").select($"i" / $"n").collect() }.getCause assert(e.isInstanceOf[ArithmeticException]) - assert(e.getMessage.contains("/ by zero")) + assert(e.getMessage.contains("divide by zero")) + + val e2 = intercept[SparkException] { + Seq((Duration.ofDays(9999), 0d)).toDF("i", "n").select($"i" / $"n").collect() + }.getCause + assert(e2.isInstanceOf[ArithmeticException]) + assert(e2.getMessage.contains("divide by zero")) + + val e3 = intercept[SparkException] { + Seq((Duration.ofDays(9999), BigDecimal(0))).toDF("i", "n").select($"i" / $"n").collect() + }.getCause + assert(e3.isInstanceOf[ArithmeticException]) + assert(e3.getMessage.contains("divide by zero")) } test("SPARK-34896: return day-time interval from dates subtraction") {