Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-37149][SQL] Improve error messages for arithmetic overflow exceptions under ANSI mode #34427

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,22 @@
"sqlState" : "22005"
},
"CANNOT_CHANGE_DECIMAL_PRECISION" : {
"message" : [ "%s cannot be represented as Decimal(%s, %s)." ],
"message" : [ "%s cannot be represented as Decimal(%s, %s). You can set %s to false to bypass this error." ],
"sqlState" : "22005"
},
"CANNOT_PARSE_DECIMAL" : {
"message" : [ "Cannot parse decimal" ],
"sqlState" : "42000"
},
"CAST_CAUSES_OVERFLOW" : {
"message" : [ "Casting %s to %s causes overflow" ],
"message" : [ "Casting %s to %s causes overflow. You can use 'try_cast' or set %s to false to bypass this error." ],
"sqlState" : "22005"
},
"CONCURRENT_QUERY" : {
"message" : [ "Another instance of this query was just started by a concurrent session." ]
},
"DIVIDE_BY_ZERO" : {
"message" : [ "divide by zero" ],
"message" : [ "divide by zero. You can use 'try_divide' or set %s to false (except for ANSI interval type) to bypass this error." ],
"sqlState" : "22012"
},
"DUPLICATE_KEY" : {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ class SparkThrowableSuite extends SparkFunSuite {
}

// Does not fail with too many args (expects 0 args)
assert(getMessage("DIVIDE_BY_ZERO", Array("foo", "bar")) == "divide by zero")
assert(getMessage("DIVIDE_BY_ZERO", Array("foo", "bar")) ==
"divide by zero. You can use 'try_divide' or set foo to false " +
"(except for ANSI interval type) to bypass this error.")
}

test("Error message is formatted") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, TreePattern,
UNARY_POSITIVE}
import org.apache.spark.sql.catalyst.util.{IntervalUtils, TypeUtils}
import org.apache.spark.sql.catalyst.util.{IntervalUtils, MathUtils, TypeUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -69,9 +69,9 @@ case class UnaryMinus(
""".stripMargin
})
case IntegerType | LongType if failOnError =>
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
nullSafeCodeGen(ctx, ev, eval => {
val mathClass = classOf[Math].getName
s"${ev.value} = $mathClass.negateExact($eval);"
s"${ev.value} = $mathUtils.negateExact($eval);"
})
case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
val originValue = ctx.freshName("origin")
Expand All @@ -87,17 +87,17 @@ case class UnaryMinus(
defineCodeGen(ctx, ev, c => s"$iu.$method($c)")
case _: AnsiIntervalType =>
nullSafeCodeGen(ctx, ev, eval => {
val mathClass = classOf[Math].getName
s"${ev.value} = $mathClass.negateExact($eval);"
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
s"${ev.value} = $mathUtils.negateExact($eval);"
})
}

protected override def nullSafeEval(input: Any): Any = dataType match {
case CalendarIntervalType if failOnError =>
IntervalUtils.negateExact(input.asInstanceOf[CalendarInterval])
case CalendarIntervalType => IntervalUtils.negate(input.asInstanceOf[CalendarInterval])
case _: DayTimeIntervalType => Math.negateExact(input.asInstanceOf[Long])
case _: YearMonthIntervalType => Math.negateExact(input.asInstanceOf[Int])
case _: DayTimeIntervalType => MathUtils.negateExact(input.asInstanceOf[Long])
case _: YearMonthIntervalType => MathUtils.negateExact(input.asInstanceOf[Int])
case _ => numeric.negate(input)
}

Expand Down Expand Up @@ -191,10 +191,12 @@ case class Abs(child: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled
|""".stripMargin)

case IntegerType | LongType if failOnError =>
defineCodeGen(ctx, ev, c => s"$c < 0 ? java.lang.Math.negateExact($c) : $c")
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
defineCodeGen(ctx, ev, c => s"$c < 0 ? $mathUtils.negateExact($c) : $c")

case _: AnsiIntervalType =>
defineCodeGen(ctx, ev, c => s"$c < 0 ? java.lang.Math.negateExact($c) : $c")
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
defineCodeGen(ctx, ev, c => s"$c < 0 ? $mathUtils.negateExact($c) : $c")

case dt: NumericType =>
defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))")
Expand Down Expand Up @@ -241,8 +243,8 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
assert(exactMathMethod.isDefined,
s"The expression '$nodeName' must override the exactMathMethod() method " +
"if it is supposed to operate over interval types.")
val mathClass = classOf[Math].getName
defineCodeGen(ctx, ev, (eval1, eval2) => s"$mathClass.${exactMathMethod.get}($eval1, $eval2)")
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
defineCodeGen(ctx, ev, (eval1, eval2) => s"$mathUtils.${exactMathMethod.get}($eval1, $eval2)")
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
Expand All @@ -267,8 +269,8 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
case IntegerType | LongType =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val operation = if (failOnError && exactMathMethod.isDefined) {
val mathClass = classOf[Math].getName
s"$mathClass.${exactMathMethod.get}($eval1, $eval2)"
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
s"$mathUtils.${exactMathMethod.get}($eval1, $eval2)"
} else {
s"$eval1 $symbol $eval2"
}
Expand Down Expand Up @@ -326,9 +328,9 @@ case class Add(
IntervalUtils.add(
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
case _: DayTimeIntervalType =>
Math.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long])
MathUtils.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long])
case _: YearMonthIntervalType =>
Math.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
case _ => numeric.plus(input1, input2)
}

Expand Down Expand Up @@ -372,9 +374,9 @@ case class Subtract(
IntervalUtils.subtract(
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
case _: DayTimeIntervalType =>
Math.subtractExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long])
MathUtils.subtractExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long])
case _: YearMonthIntervalType =>
Math.subtractExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
MathUtils.subtractExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
case _ => numeric.minus(input1, input2)
}

Expand Down Expand Up @@ -441,8 +443,8 @@ trait DivModLike extends BinaryArithmetic {
null
} else {
if (isZero(input2)) {
// when we reach here, failOnError must bet true.
throw QueryExecutionErrors.divideByZeroError
// when we reach here, failOnError must be true.
throw QueryExecutionErrors.divideByZeroError()
}
if (checkDivideOverflow && input1 == Long.MinValue && input2 == -1) {
throw QueryExecutionErrors.overflowInIntegralDivideError()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -616,8 +616,8 @@ object IntervalUtils {
monthsWithFraction: Double,
daysWithFraction: Double,
microsWithFraction: Double): CalendarInterval = {
val truncatedMonths = Math.toIntExact(monthsWithFraction.toLong)
val truncatedDays = Math.toIntExact(daysWithFraction.toLong)
val truncatedMonths = MathUtils.toIntExact(monthsWithFraction.toLong)
val truncatedDays = MathUtils.toIntExact(daysWithFraction.toLong)
val micros = microsWithFraction + MICROS_PER_DAY * (daysWithFraction - truncatedDays)
new CalendarInterval(truncatedMonths, truncatedDays, micros.round)
}
Expand All @@ -644,9 +644,9 @@ object IntervalUtils {
* @throws ArithmeticException if the result overflows any field value
*/
def negateExact(interval: CalendarInterval): CalendarInterval = {
val months = Math.negateExact(interval.months)
val days = Math.negateExact(interval.days)
val microseconds = Math.negateExact(interval.microseconds)
val months = MathUtils.negateExact(interval.months)
val days = MathUtils.negateExact(interval.days)
val microseconds = MathUtils.negateExact(interval.microseconds)
new CalendarInterval(months, days, microseconds)
}

Expand All @@ -663,9 +663,9 @@ object IntervalUtils {
* @throws ArithmeticException if the result overflows any field value
*/
def addExact(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
val months = Math.addExact(left.months, right.months)
val days = Math.addExact(left.days, right.days)
val microseconds = Math.addExact(left.microseconds, right.microseconds)
val months = MathUtils.addExact(left.months, right.months)
val days = MathUtils.addExact(left.days, right.days)
val microseconds = MathUtils.addExact(left.microseconds, right.microseconds)
new CalendarInterval(months, days, microseconds)
}

Expand All @@ -685,9 +685,9 @@ object IntervalUtils {
* @throws ArithmeticException if the result overflows any field value
*/
def subtractExact(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
val months = Math.subtractExact(left.months, right.months)
val days = Math.subtractExact(left.days, right.days)
val microseconds = Math.subtractExact(left.microseconds, right.microseconds)
val months = MathUtils.subtractExact(left.months, right.months)
val days = MathUtils.subtractExact(left.days, right.days)
val microseconds = MathUtils.subtractExact(left.microseconds, right.microseconds)
new CalendarInterval(months, days, microseconds)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.errors.QueryExecutionErrors

/**
* Helper functions for arithmetic operations with overflow.
*/
object MathUtils {

def addExact(a: Int, b: Int): Int = withOverflow(Math.addExact(a, b))

def addExact(a: Long, b: Long): Long = withOverflow(Math.addExact(a, b))

def subtractExact(a: Int, b: Int): Int = withOverflow(Math.subtractExact(a, b))

def subtractExact(a: Long, b: Long): Long = withOverflow(Math.subtractExact(a, b))

def multiplyExact(a: Int, b: Int): Int = withOverflow(Math.multiplyExact(a, b))

def multiplyExact(a: Long, b: Long): Long = withOverflow(Math.multiplyExact(a, b))

def negateExact(a: Int): Int = withOverflow(Math.negateExact(a))

def negateExact(a: Long): Long = withOverflow(Math.negateExact(a))

def toIntExact(a: Long): Int = withOverflow(Math.toIntExact(a))

def floorDiv(a: Int, b: Int): Int = withOverflow(Math.floorDiv(a, b), Some("try_divide"))

def floorDiv(a: Long, b: Long): Long = withOverflow(Math.floorDiv(a, b), Some("try_divide"))

def floorMod(a: Int, b: Int): Int = withOverflow(Math.floorMod(a, b))

def floorMod(a: Long, b: Long): Long = withOverflow(Math.floorMod(a, b))

private def withOverflow[A](f: => A, hint: Option[String] = None): A = {
try {
f
} catch {
case e: ArithmeticException =>
throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage, hint)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,15 @@ object QueryExecutionErrors {
}

def castingCauseOverflowError(t: Any, targetType: String): ArithmeticException = {
new SparkArithmeticException (errorClass = "CAST_CAUSES_OVERFLOW",
messageParameters = Array(t.toString, targetType))
new SparkArithmeticException(errorClass = "CAST_CAUSES_OVERFLOW",
messageParameters = Array(t.toString, targetType, SQLConf.ANSI_ENABLED.key))
}

def cannotChangeDecimalPrecisionError(
value: Decimal, decimalPrecision: Int, decimalScale: Int): ArithmeticException = {
new SparkArithmeticException(errorClass = "CANNOT_CHANGE_DECIMAL_PRECISION",
messageParameters = Array(value.toDebugString,
decimalPrecision.toString, decimalScale.toString))
decimalPrecision.toString, decimalScale.toString, SQLConf.ANSI_ENABLED.key))
}

def invalidInputSyntaxForNumericError(s: UTF8String): NumberFormatException = {
Expand Down Expand Up @@ -152,7 +152,8 @@ object QueryExecutionErrors {
}

def divideByZeroError(): ArithmeticException = {
new SparkArithmeticException(errorClass = "DIVIDE_BY_ZERO", messageParameters = Array.empty)
new SparkArithmeticException(
errorClass = "DIVIDE_BY_ZERO", messageParameters = Array(SQLConf.ANSI_ENABLED.key))
}

def invalidArrayIndexError(index: Int, numElements: Int): ArrayIndexOutOfBoundsException = {
Expand All @@ -179,11 +180,11 @@ object QueryExecutionErrors {
}

def overflowInSumOfDecimalError(): ArithmeticException = {
new ArithmeticException("Overflow in sum of decimals.")
arithmeticOverflowError("Overflow in sum of decimals")
}

def overflowInIntegralDivideError(): ArithmeticException = {
new ArithmeticException("Overflow in integral divide.")
arithmeticOverflowError("Overflow in integral divide", Some("try_divide"))
}

def mapSizeExceedArraySizeWhenZipMapError(size: Int): RuntimeException = {
Expand Down Expand Up @@ -392,13 +393,20 @@ object QueryExecutionErrors {
new IllegalStateException("table stats must be specified.")
}

def arithmeticOverflowError(
message: String, hint: Option[String] = None): ArithmeticException = {
new ArithmeticException(s"$message. You can ${hint.map(x => s"use '$x' or ").getOrElse("")}" +
s"set ${SQLConf.ANSI_ENABLED.key} to false (except for ANSI interval type) " +
"to bypass this error.")
}

def unaryMinusCauseOverflowError(originValue: AnyVal): ArithmeticException = {
new ArithmeticException(s"- $originValue caused overflow.")
arithmeticOverflowError(s"- $originValue caused overflow")
}

def binaryArithmeticCauseOverflowError(
eval1: Short, symbol: String, eval2: Short): ArithmeticException = {
new ArithmeticException(s"$eval1 $symbol $eval2 caused overflow.")
arithmeticOverflowError(s"$eval1 $symbol $eval2 caused overflow")
}

def failedSplitSubExpressionMsg(length: Int): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.types
import scala.math.Numeric._
import scala.math.Ordering

import org.apache.spark.sql.catalyst.util.SQLOrderingUtil
import org.apache.spark.sql.catalyst.util.{MathUtils, SQLOrderingUtil}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.Decimal.DecimalIsConflicted

Expand Down Expand Up @@ -93,23 +93,23 @@ private[sql] object ShortExactNumeric extends ShortIsIntegral with Ordering.Shor


private[sql] object IntegerExactNumeric extends IntIsIntegral with Ordering.IntOrdering {
override def plus(x: Int, y: Int): Int = Math.addExact(x, y)
override def plus(x: Int, y: Int): Int = MathUtils.addExact(x, y)

override def minus(x: Int, y: Int): Int = Math.subtractExact(x, y)
override def minus(x: Int, y: Int): Int = MathUtils.subtractExact(x, y)

override def times(x: Int, y: Int): Int = Math.multiplyExact(x, y)
override def times(x: Int, y: Int): Int = MathUtils.multiplyExact(x, y)

override def negate(x: Int): Int = Math.negateExact(x)
override def negate(x: Int): Int = MathUtils.negateExact(x)
}

private[sql] object LongExactNumeric extends LongIsIntegral with Ordering.LongOrdering {
override def plus(x: Long, y: Long): Long = Math.addExact(x, y)
override def plus(x: Long, y: Long): Long = MathUtils.addExact(x, y)

override def minus(x: Long, y: Long): Long = Math.subtractExact(x, y)
override def minus(x: Long, y: Long): Long = MathUtils.subtractExact(x, y)

override def times(x: Long, y: Long): Long = Math.multiplyExact(x, y)
override def times(x: Long, y: Long): Long = MathUtils.multiplyExact(x, y)

override def negate(x: Long): Long = Math.negateExact(x)
override def negate(x: Long): Long = MathUtils.negateExact(x)

override def toInt(x: Long): Int =
if (x == x.toInt) {
Expand Down
Loading