Skip to content

Commit

Permalink
[SPARK-37149][SQL] Improve error messages for arithmetic overflow exc…
Browse files Browse the repository at this point in the history
…eptions under ANSI mode

### What changes were proposed in this pull request?
This PR improves error messages for arithmetic overflow exceptions thrown under ANSI mode. It augments the error messages by suggesting workarounds to users:
- Turn off ANSI mode (except for ANSI interval type)
- Use `try_` function if applicable.

### Why are the changes needed?

To make error messages more actionable.

### Does this PR introduce _any_ user-facing change?

Yes. Certain error messages will be different.

### How was this patch tested?

Existing tests.

Closes #34427 from allisonwang-db/spark-37149-overflow-err-msg.

Authored-by: allisonwang-db <allison.wang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
allisonwang-db authored and cloud-fan committed Oct 29, 2021
1 parent f258d30 commit e582799
Show file tree
Hide file tree
Showing 19 changed files with 177 additions and 104 deletions.
6 changes: 3 additions & 3 deletions core/src/main/resources/error/error-classes.json
Expand Up @@ -8,22 +8,22 @@
"sqlState" : "22005"
},
"CANNOT_CHANGE_DECIMAL_PRECISION" : {
"message" : [ "%s cannot be represented as Decimal(%s, %s)." ],
"message" : [ "%s cannot be represented as Decimal(%s, %s). You can set %s to false to bypass this error." ],
"sqlState" : "22005"
},
"CANNOT_PARSE_DECIMAL" : {
"message" : [ "Cannot parse decimal" ],
"sqlState" : "42000"
},
"CAST_CAUSES_OVERFLOW" : {
"message" : [ "Casting %s to %s causes overflow" ],
"message" : [ "Casting %s to %s causes overflow. You can use 'try_cast' or set %s to false to bypass this error." ],
"sqlState" : "22005"
},
"CONCURRENT_QUERY" : {
"message" : [ "Another instance of this query was just started by a concurrent session." ]
},
"DIVIDE_BY_ZERO" : {
"message" : [ "divide by zero" ],
"message" : [ "divide by zero. You can use 'try_divide' or set %s to false (except for ANSI interval type) to bypass this error." ],
"sqlState" : "22012"
},
"DUPLICATE_KEY" : {
Expand Down
Expand Up @@ -124,7 +124,9 @@ class SparkThrowableSuite extends SparkFunSuite {
}

// Does not fail with too many args (expects 0 args)
assert(getMessage("DIVIDE_BY_ZERO", Array("foo", "bar")) == "divide by zero")
assert(getMessage("DIVIDE_BY_ZERO", Array("foo", "bar")) ==
"divide by zero. You can use 'try_divide' or set foo to false " +
"(except for ANSI interval type) to bypass this error.")
}

test("Error message is formatted") {
Expand Down
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, TreePattern,
UNARY_POSITIVE}
import org.apache.spark.sql.catalyst.util.{IntervalUtils, TypeUtils}
import org.apache.spark.sql.catalyst.util.{IntervalUtils, MathUtils, TypeUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -69,9 +69,9 @@ case class UnaryMinus(
""".stripMargin
})
case IntegerType | LongType if failOnError =>
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
nullSafeCodeGen(ctx, ev, eval => {
val mathClass = classOf[Math].getName
s"${ev.value} = $mathClass.negateExact($eval);"
s"${ev.value} = $mathUtils.negateExact($eval);"
})
case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
val originValue = ctx.freshName("origin")
Expand All @@ -87,17 +87,17 @@ case class UnaryMinus(
defineCodeGen(ctx, ev, c => s"$iu.$method($c)")
case _: AnsiIntervalType =>
nullSafeCodeGen(ctx, ev, eval => {
val mathClass = classOf[Math].getName
s"${ev.value} = $mathClass.negateExact($eval);"
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
s"${ev.value} = $mathUtils.negateExact($eval);"
})
}

protected override def nullSafeEval(input: Any): Any = dataType match {
case CalendarIntervalType if failOnError =>
IntervalUtils.negateExact(input.asInstanceOf[CalendarInterval])
case CalendarIntervalType => IntervalUtils.negate(input.asInstanceOf[CalendarInterval])
case _: DayTimeIntervalType => Math.negateExact(input.asInstanceOf[Long])
case _: YearMonthIntervalType => Math.negateExact(input.asInstanceOf[Int])
case _: DayTimeIntervalType => MathUtils.negateExact(input.asInstanceOf[Long])
case _: YearMonthIntervalType => MathUtils.negateExact(input.asInstanceOf[Int])
case _ => numeric.negate(input)
}

Expand Down Expand Up @@ -191,10 +191,12 @@ case class Abs(child: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled
|""".stripMargin)

case IntegerType | LongType if failOnError =>
defineCodeGen(ctx, ev, c => s"$c < 0 ? java.lang.Math.negateExact($c) : $c")
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
defineCodeGen(ctx, ev, c => s"$c < 0 ? $mathUtils.negateExact($c) : $c")

case _: AnsiIntervalType =>
defineCodeGen(ctx, ev, c => s"$c < 0 ? java.lang.Math.negateExact($c) : $c")
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
defineCodeGen(ctx, ev, c => s"$c < 0 ? $mathUtils.negateExact($c) : $c")

case dt: NumericType =>
defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))")
Expand Down Expand Up @@ -241,8 +243,8 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
assert(exactMathMethod.isDefined,
s"The expression '$nodeName' must override the exactMathMethod() method " +
"if it is supposed to operate over interval types.")
val mathClass = classOf[Math].getName
defineCodeGen(ctx, ev, (eval1, eval2) => s"$mathClass.${exactMathMethod.get}($eval1, $eval2)")
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
defineCodeGen(ctx, ev, (eval1, eval2) => s"$mathUtils.${exactMathMethod.get}($eval1, $eval2)")
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
Expand All @@ -267,8 +269,8 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
case IntegerType | LongType =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val operation = if (failOnError && exactMathMethod.isDefined) {
val mathClass = classOf[Math].getName
s"$mathClass.${exactMathMethod.get}($eval1, $eval2)"
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
s"$mathUtils.${exactMathMethod.get}($eval1, $eval2)"
} else {
s"$eval1 $symbol $eval2"
}
Expand Down Expand Up @@ -326,9 +328,9 @@ case class Add(
IntervalUtils.add(
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
case _: DayTimeIntervalType =>
Math.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long])
MathUtils.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long])
case _: YearMonthIntervalType =>
Math.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
case _ => numeric.plus(input1, input2)
}

Expand Down Expand Up @@ -372,9 +374,9 @@ case class Subtract(
IntervalUtils.subtract(
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
case _: DayTimeIntervalType =>
Math.subtractExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long])
MathUtils.subtractExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long])
case _: YearMonthIntervalType =>
Math.subtractExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
MathUtils.subtractExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
case _ => numeric.minus(input1, input2)
}

Expand Down Expand Up @@ -441,8 +443,8 @@ trait DivModLike extends BinaryArithmetic {
null
} else {
if (isZero(input2)) {
// when we reach here, failOnError must bet true.
throw QueryExecutionErrors.divideByZeroError
// when we reach here, failOnError must be true.
throw QueryExecutionErrors.divideByZeroError()
}
if (checkDivideOverflow && input1 == Long.MinValue && input2 == -1) {
throw QueryExecutionErrors.overflowInIntegralDivideError()
Expand Down
Expand Up @@ -616,8 +616,8 @@ object IntervalUtils {
monthsWithFraction: Double,
daysWithFraction: Double,
microsWithFraction: Double): CalendarInterval = {
val truncatedMonths = Math.toIntExact(monthsWithFraction.toLong)
val truncatedDays = Math.toIntExact(daysWithFraction.toLong)
val truncatedMonths = MathUtils.toIntExact(monthsWithFraction.toLong)
val truncatedDays = MathUtils.toIntExact(daysWithFraction.toLong)
val micros = microsWithFraction + MICROS_PER_DAY * (daysWithFraction - truncatedDays)
new CalendarInterval(truncatedMonths, truncatedDays, micros.round)
}
Expand All @@ -644,9 +644,9 @@ object IntervalUtils {
* @throws ArithmeticException if the result overflows any field value
*/
def negateExact(interval: CalendarInterval): CalendarInterval = {
val months = Math.negateExact(interval.months)
val days = Math.negateExact(interval.days)
val microseconds = Math.negateExact(interval.microseconds)
val months = MathUtils.negateExact(interval.months)
val days = MathUtils.negateExact(interval.days)
val microseconds = MathUtils.negateExact(interval.microseconds)
new CalendarInterval(months, days, microseconds)
}

Expand All @@ -663,9 +663,9 @@ object IntervalUtils {
* @throws ArithmeticException if the result overflows any field value
*/
def addExact(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
val months = Math.addExact(left.months, right.months)
val days = Math.addExact(left.days, right.days)
val microseconds = Math.addExact(left.microseconds, right.microseconds)
val months = MathUtils.addExact(left.months, right.months)
val days = MathUtils.addExact(left.days, right.days)
val microseconds = MathUtils.addExact(left.microseconds, right.microseconds)
new CalendarInterval(months, days, microseconds)
}

Expand All @@ -685,9 +685,9 @@ object IntervalUtils {
* @throws ArithmeticException if the result overflows any field value
*/
def subtractExact(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
val months = Math.subtractExact(left.months, right.months)
val days = Math.subtractExact(left.days, right.days)
val microseconds = Math.subtractExact(left.microseconds, right.microseconds)
val months = MathUtils.subtractExact(left.months, right.months)
val days = MathUtils.subtractExact(left.days, right.days)
val microseconds = MathUtils.subtractExact(left.microseconds, right.microseconds)
new CalendarInterval(months, days, microseconds)
}

Expand Down
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.errors.QueryExecutionErrors

/**
* Helper functions for arithmetic operations with overflow.
*/
object MathUtils {

def addExact(a: Int, b: Int): Int = withOverflow(Math.addExact(a, b))

def addExact(a: Long, b: Long): Long = withOverflow(Math.addExact(a, b))

def subtractExact(a: Int, b: Int): Int = withOverflow(Math.subtractExact(a, b))

def subtractExact(a: Long, b: Long): Long = withOverflow(Math.subtractExact(a, b))

def multiplyExact(a: Int, b: Int): Int = withOverflow(Math.multiplyExact(a, b))

def multiplyExact(a: Long, b: Long): Long = withOverflow(Math.multiplyExact(a, b))

def negateExact(a: Int): Int = withOverflow(Math.negateExact(a))

def negateExact(a: Long): Long = withOverflow(Math.negateExact(a))

def toIntExact(a: Long): Int = withOverflow(Math.toIntExact(a))

def floorDiv(a: Int, b: Int): Int = withOverflow(Math.floorDiv(a, b), Some("try_divide"))

def floorDiv(a: Long, b: Long): Long = withOverflow(Math.floorDiv(a, b), Some("try_divide"))

def floorMod(a: Int, b: Int): Int = withOverflow(Math.floorMod(a, b))

def floorMod(a: Long, b: Long): Long = withOverflow(Math.floorMod(a, b))

private def withOverflow[A](f: => A, hint: Option[String] = None): A = {
try {
f
} catch {
case e: ArithmeticException =>
throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage, hint)
}
}
}
Expand Up @@ -93,15 +93,15 @@ object QueryExecutionErrors {
}

def castingCauseOverflowError(t: Any, targetType: String): ArithmeticException = {
new SparkArithmeticException (errorClass = "CAST_CAUSES_OVERFLOW",
messageParameters = Array(t.toString, targetType))
new SparkArithmeticException(errorClass = "CAST_CAUSES_OVERFLOW",
messageParameters = Array(t.toString, targetType, SQLConf.ANSI_ENABLED.key))
}

def cannotChangeDecimalPrecisionError(
value: Decimal, decimalPrecision: Int, decimalScale: Int): ArithmeticException = {
new SparkArithmeticException(errorClass = "CANNOT_CHANGE_DECIMAL_PRECISION",
messageParameters = Array(value.toDebugString,
decimalPrecision.toString, decimalScale.toString))
decimalPrecision.toString, decimalScale.toString, SQLConf.ANSI_ENABLED.key))
}

def invalidInputSyntaxForNumericError(s: UTF8String): NumberFormatException = {
Expand Down Expand Up @@ -152,7 +152,8 @@ object QueryExecutionErrors {
}

def divideByZeroError(): ArithmeticException = {
new SparkArithmeticException(errorClass = "DIVIDE_BY_ZERO", messageParameters = Array.empty)
new SparkArithmeticException(
errorClass = "DIVIDE_BY_ZERO", messageParameters = Array(SQLConf.ANSI_ENABLED.key))
}

def invalidArrayIndexError(index: Int, numElements: Int): ArrayIndexOutOfBoundsException = {
Expand All @@ -179,11 +180,11 @@ object QueryExecutionErrors {
}

def overflowInSumOfDecimalError(): ArithmeticException = {
new ArithmeticException("Overflow in sum of decimals.")
arithmeticOverflowError("Overflow in sum of decimals")
}

def overflowInIntegralDivideError(): ArithmeticException = {
new ArithmeticException("Overflow in integral divide.")
arithmeticOverflowError("Overflow in integral divide", Some("try_divide"))
}

def mapSizeExceedArraySizeWhenZipMapError(size: Int): RuntimeException = {
Expand Down Expand Up @@ -392,13 +393,20 @@ object QueryExecutionErrors {
new IllegalStateException("table stats must be specified.")
}

def arithmeticOverflowError(
message: String, hint: Option[String] = None): ArithmeticException = {
new ArithmeticException(s"$message. You can ${hint.map(x => s"use '$x' or ").getOrElse("")}" +
s"set ${SQLConf.ANSI_ENABLED.key} to false (except for ANSI interval type) " +
"to bypass this error.")
}

def unaryMinusCauseOverflowError(originValue: AnyVal): ArithmeticException = {
new ArithmeticException(s"- $originValue caused overflow.")
arithmeticOverflowError(s"- $originValue caused overflow")
}

def binaryArithmeticCauseOverflowError(
eval1: Short, symbol: String, eval2: Short): ArithmeticException = {
new ArithmeticException(s"$eval1 $symbol $eval2 caused overflow.")
arithmeticOverflowError(s"$eval1 $symbol $eval2 caused overflow")
}

def failedSplitSubExpressionMsg(length: Int): String = {
Expand Down
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.types
import scala.math.Numeric._
import scala.math.Ordering

import org.apache.spark.sql.catalyst.util.SQLOrderingUtil
import org.apache.spark.sql.catalyst.util.{MathUtils, SQLOrderingUtil}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.Decimal.DecimalIsConflicted

Expand Down Expand Up @@ -93,23 +93,23 @@ private[sql] object ShortExactNumeric extends ShortIsIntegral with Ordering.Shor


private[sql] object IntegerExactNumeric extends IntIsIntegral with Ordering.IntOrdering {
override def plus(x: Int, y: Int): Int = Math.addExact(x, y)
override def plus(x: Int, y: Int): Int = MathUtils.addExact(x, y)

override def minus(x: Int, y: Int): Int = Math.subtractExact(x, y)
override def minus(x: Int, y: Int): Int = MathUtils.subtractExact(x, y)

override def times(x: Int, y: Int): Int = Math.multiplyExact(x, y)
override def times(x: Int, y: Int): Int = MathUtils.multiplyExact(x, y)

override def negate(x: Int): Int = Math.negateExact(x)
override def negate(x: Int): Int = MathUtils.negateExact(x)
}

private[sql] object LongExactNumeric extends LongIsIntegral with Ordering.LongOrdering {
override def plus(x: Long, y: Long): Long = Math.addExact(x, y)
override def plus(x: Long, y: Long): Long = MathUtils.addExact(x, y)

override def minus(x: Long, y: Long): Long = Math.subtractExact(x, y)
override def minus(x: Long, y: Long): Long = MathUtils.subtractExact(x, y)

override def times(x: Long, y: Long): Long = Math.multiplyExact(x, y)
override def times(x: Long, y: Long): Long = MathUtils.multiplyExact(x, y)

override def negate(x: Long): Long = Math.negateExact(x)
override def negate(x: Long): Long = MathUtils.negateExact(x)

override def toInt(x: Long): Int =
if (x == x.toInt) {
Expand Down

0 comments on commit e582799

Please sign in to comment.