Skip to content

Commit

Permalink
[SPARK-38929][SQL][3.3] Improve error messages for cast failures in ANSI
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Improve the error messages for cast failures in ANSI.
As mentioned in https://issues.apache.org/jira/browse/SPARK-38929, this PR targets two cast-to types: numeric types and date types.
* For numeric(`int`, `smallint`, `double`, `float`, `decimal` ..) types, it embeds the cast-to types in the error message. For example,
  ```
  Invalid input value for type INT: '1.0'. To return NULL instead, use 'try_cast'. If necessary set %s to false to bypass this error.
  ```
  It uses the `toSQLType` and `toSQLValue` to wrap the corresponding types and literals.
* For date types, it does similarly as above. For example,
  ```
  Invalid input value for type TIMESTAMP: 'a'. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error.
  ```

### Why are the changes needed?
To improve the error message in general.

### Does this PR introduce _any_ user-facing change?
It changes the error messages.

### How was this patch tested?
The related unit tests are updated.

Authored-by: Xinyi Yu <xinyi.yudatabricks.com>
Signed-off-by: Max Gekk <max.gekkgmail.com>
(cherry picked from commit f76b3e7)

Closes apache#36275 from anchovYu/ansi-error-improve-3.3.

Authored-by: Xinyi Yu <xinyi.yu@databricks.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
anchovYu authored and MaxGekk committed Apr 20, 2022
1 parent 83a365e commit 9d0650a
Show file tree
Hide file tree
Showing 26 changed files with 297 additions and 150 deletions.
8 changes: 4 additions & 4 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,6 @@
"message" : [ "The fraction of sec must be zero. Valid range is [0, 60]. If necessary set %s to false to bypass this error. " ],
"sqlState" : "22023"
},
"INVALID_INPUT_SYNTAX_FOR_NUMERIC_TYPE" : {
"message" : [ "invalid input syntax for type numeric: %s. To return NULL instead, use 'try_cast'. If necessary set %s to false to bypass this error.%s" ],
"sqlState" : "42000"
},
"INVALID_JSON_SCHEMA_MAPTYPE" : {
"message" : [ "Input schema %s can only contain StringType as a key type for a MapType." ]
},
Expand All @@ -120,6 +116,10 @@
"message" : [ "Invalid SQL syntax: %s" ],
"sqlState" : "42000"
},
"INVALID_SYNTAX_FOR_CAST" : {
"message" : [ "Invalid input syntax for type %s: %s. To return NULL instead, use 'try_cast'. If necessary set %s to false to bypass this error.%s" ],
"sqlState" : "42000"
},
"MAP_KEY_DOES_NOT_EXIST" : {
"message" : [ "Key %s does not exist. If necessary set %s to false to bypass this error.%s" ]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
})
case StringType if ansiEnabled =>
buildCast[UTF8String](_,
s => changePrecision(Decimal.fromStringANSI(s, origin.context), target))
s => changePrecision(Decimal.fromStringANSI(s, target, origin.context), target))
case BooleanType =>
buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
case DateType =>
Expand Down Expand Up @@ -845,7 +845,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
case _: NumberFormatException =>
val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false)
if(ansiEnabled && d == null) {
throw QueryExecutionErrors.invalidInputSyntaxForNumericError(s, origin.context)
throw QueryExecutionErrors.invalidInputSyntaxForNumericError(
DoubleType, s, origin.context)
} else {
d
}
Expand All @@ -870,7 +871,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
case _: NumberFormatException =>
val f = Cast.processFloatingPointSpecialLiterals(floatStr, true)
if (ansiEnabled && f == null) {
throw QueryExecutionErrors.invalidInputSyntaxForNumericError(s, origin.context)
throw QueryExecutionErrors.invalidInputSyntaxForNumericError(
FloatType, s, origin.context)
} else {
f
}
Expand Down Expand Up @@ -1376,9 +1378,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
"""
case StringType if ansiEnabled =>
val errorContext = ctx.addReferenceObj("errCtx", origin.context)
val toType = ctx.addReferenceObj("toType", target)
(c, evPrim, evNull) =>
code"""
Decimal $tmp = Decimal.fromStringANSI($c, $errorContext);
Decimal $tmp = Decimal.fromStringANSI($c, $toType, $errorContext);
${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx)}
"""
case BooleanType =>
Expand Down Expand Up @@ -1899,7 +1902,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
(c, evPrim, evNull) =>
val handleNull = if (ansiEnabled) {
val errorContext = ctx.addReferenceObj("errCtx", origin.context)
s"throw QueryExecutionErrors.invalidInputSyntaxForNumericError($c, $errorContext);"
s"throw QueryExecutionErrors.invalidInputSyntaxForNumericError(" +
s"org.apache.spark.sql.types.FloatType$$.MODULE$$,$c, $errorContext);"
} else {
s"$evNull = true;"
}
Expand Down Expand Up @@ -1936,7 +1940,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
(c, evPrim, evNull) =>
val handleNull = if (ansiEnabled) {
val errorContext = ctx.addReferenceObj("errCtx", origin.context)
s"throw QueryExecutionErrors.invalidInputSyntaxForNumericError($c, $errorContext);"
s"throw QueryExecutionErrors.invalidInputSyntaxForNumericError(" +
s"org.apache.spark.sql.types.DoubleType$$.MODULE$$, $c, $errorContext);"
} else {
s"$evNull = true;"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, LongType, ShortType}
import org.apache.spark.unsafe.types.UTF8String

/**
Expand All @@ -26,23 +27,23 @@ import org.apache.spark.unsafe.types.UTF8String
object UTF8StringUtils {

def toLongExact(s: UTF8String, errorContext: String): Long =
withException(s.toLongExact, errorContext)
withException(s.toLongExact, errorContext, LongType, s)

def toIntExact(s: UTF8String, errorContext: String): Int =
withException(s.toIntExact, errorContext)
withException(s.toIntExact, errorContext, IntegerType, s)

def toShortExact(s: UTF8String, errorContext: String): Short =
withException(s.toShortExact, errorContext)
withException(s.toShortExact, errorContext, ShortType, s)

def toByteExact(s: UTF8String, errorContext: String): Byte =
withException(s.toByteExact, errorContext)
withException(s.toByteExact, errorContext, ByteType, s)

private def withException[A](f: => A, errorContext: String): A = {
private def withException[A](f: => A, errorContext: String, to: DataType, s: UTF8String): A = {
try {
f
} catch {
case e: NumberFormatException =>
throw QueryExecutionErrors.invalidInputSyntaxForNumericError(e, errorContext)
throw QueryExecutionErrors.invalidInputSyntaxForNumericError(to, s, errorContext)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,12 @@ object QueryExecutionErrors extends QueryErrorsBase {
}

def invalidInputSyntaxForNumericError(
to: DataType,
s: UTF8String,
errorContext: String): NumberFormatException = {
new SparkNumberFormatException(errorClass = "INVALID_INPUT_SYNTAX_FOR_NUMERIC_TYPE",
messageParameters = Array(toSQLValue(s, StringType), SQLConf.ANSI_ENABLED.key, errorContext))
new SparkNumberFormatException(errorClass = "INVALID_SYNTAX_FOR_CAST",
messageParameters = Array(toSQLType(to), toSQLValue(s, StringType),
SQLConf.ANSI_ENABLED.key, errorContext))
}

def cannotCastFromNullTypeError(to: DataType): Throwable = {
Expand Down Expand Up @@ -1048,8 +1050,14 @@ object QueryExecutionErrors extends QueryErrorsBase {
}

def cannotCastToDateTimeError(value: Any, to: DataType, errorContext: String): Throwable = {
new DateTimeException(s"Cannot cast $value to $to. To return NULL instead, use 'try_cast'. " +
s"If necessary set ${SQLConf.ANSI_ENABLED.key} to false to bypass this error." + errorContext)
val valueString = if (value.isInstanceOf[UTF8String]) {
toSQLValue(value, StringType)
} else {
toSQLValue(value)
}
new DateTimeException(s"Invalid input syntax for type ${toSQLType(to)}: $valueString. " +
s"To return NULL instead, use 'try_cast'. If necessary set ${SQLConf.ANSI_ENABLED.key} " +
s"to false to bypass this error." + errorContext)
}

def registeringStreamingQueryListenerError(e: Exception): Throwable = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,10 @@ object Decimal {
}
}

def fromStringANSI(str: UTF8String, errorContext: String = ""): Decimal = {
def fromStringANSI(
str: UTF8String,
to: DecimalType = DecimalType.USER_DEFAULT,
errorContext: String = ""): Decimal = {
try {
val bigDecimal = stringToJavaBigDecimal(str)
// We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow.
Expand All @@ -626,7 +629,7 @@ object Decimal {
}
} catch {
case _: NumberFormatException =>
throw QueryExecutionErrors.invalidInputSyntaxForNumericError(str, errorContext)
throw QueryExecutionErrors.invalidInputSyntaxForNumericError(to, str, errorContext)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
import org.apache.spark.sql.errors.QueryExecutionErrors.toSQLValue
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -174,42 +175,43 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
test("cast from invalid string to numeric should throw NumberFormatException") {
// cast to IntegerType
Seq(IntegerType, ShortType, ByteType, LongType).foreach { dataType =>
checkExceptionInExpression[NumberFormatException](
cast("string", dataType), "invalid input syntax for type numeric: 'string'")
checkExceptionInExpression[NumberFormatException](
cast("123-string", dataType), "invalid input syntax for type numeric: '123-string'")
checkExceptionInExpression[NumberFormatException](
cast("2020-07-19", dataType), "invalid input syntax for type numeric: '2020-07-19'")
checkExceptionInExpression[NumberFormatException](
cast("1.23", dataType), "invalid input syntax for type numeric: '1.23'")
checkExceptionInExpression[NumberFormatException](cast("string", dataType),
s"Invalid input syntax for type ${dataType.sql}: 'string'")
checkExceptionInExpression[NumberFormatException](cast("123-string", dataType),
s"Invalid input syntax for type ${dataType.sql}: '123-string'")
checkExceptionInExpression[NumberFormatException](cast("2020-07-19", dataType),
s"Invalid input syntax for type ${dataType.sql}: '2020-07-19'")
checkExceptionInExpression[NumberFormatException](cast("1.23", dataType),
s"Invalid input syntax for type ${dataType.sql}: '1.23'")
}

Seq(DoubleType, FloatType, DecimalType.USER_DEFAULT).foreach { dataType =>
checkExceptionInExpression[NumberFormatException](
cast("string", dataType), "invalid input syntax for type numeric: 'string'")
checkExceptionInExpression[NumberFormatException](
cast("123.000.00", dataType), "invalid input syntax for type numeric: '123.000.00'")
checkExceptionInExpression[NumberFormatException](
cast("abc.com", dataType), "invalid input syntax for type numeric: 'abc.com'")
checkExceptionInExpression[NumberFormatException](cast("string", dataType),
s"Invalid input syntax for type ${dataType.sql}: 'string'")
checkExceptionInExpression[NumberFormatException](cast("123.000.00", dataType),
s"Invalid input syntax for type ${dataType.sql}: '123.000.00'")
checkExceptionInExpression[NumberFormatException](cast("abc.com", dataType),
s"Invalid input syntax for type ${dataType.sql}: 'abc.com'")
}
}

protected def checkCastToNumericError(l: Literal, to: DataType, tryCastResult: Any): Unit = {
protected def checkCastToNumericError(l: Literal, to: DataType,
expectedDataTypeInErrorMsg: DataType, tryCastResult: Any): Unit = {
checkExceptionInExpression[NumberFormatException](
cast(l, to), "invalid input syntax for type numeric: 'true'")
cast(l, to), s"Invalid input syntax for type ${expectedDataTypeInErrorMsg.sql}: 'true'")
}

test("cast from invalid string array to numeric array should throw NumberFormatException") {
val array = Literal.create(Seq("123", "true", "f", null),
ArrayType(StringType, containsNull = true))

checkCastToNumericError(array, ArrayType(ByteType, containsNull = true),
checkCastToNumericError(array, ArrayType(ByteType, containsNull = true), ByteType,
Seq(123.toByte, null, null, null))
checkCastToNumericError(array, ArrayType(ShortType, containsNull = true),
checkCastToNumericError(array, ArrayType(ShortType, containsNull = true), ShortType,
Seq(123.toShort, null, null, null))
checkCastToNumericError(array, ArrayType(IntegerType, containsNull = true),
checkCastToNumericError(array, ArrayType(IntegerType, containsNull = true), IntegerType,
Seq(123, null, null, null))
checkCastToNumericError(array, ArrayType(LongType, containsNull = true),
checkCastToNumericError(array, ArrayType(LongType, containsNull = true), LongType,
Seq(123L, null, null, null))
}

Expand Down Expand Up @@ -243,7 +245,7 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {

checkExceptionInExpression[NumberFormatException](
cast("abcd", DecimalType(38, 1)),
"invalid input syntax for type numeric")
s"Invalid input syntax for type ${DecimalType(38, 1).sql}: 'abcd'")
}

protected def checkCastToBooleanError(l: Literal, to: DataType, tryCastResult: Any): Unit = {
Expand All @@ -258,7 +260,7 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {

protected def checkCastToTimestampError(l: Literal, to: DataType): Unit = {
checkExceptionInExpression[DateTimeException](
cast(l, to), s"Cannot cast $l to $to")
cast(l, to), s"Invalid input syntax for type TIMESTAMP: ${toSQLValue(l)}")
}

test("cast from timestamp II") {
Expand Down Expand Up @@ -369,7 +371,7 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
assert(ret.resolved == !isTryCast)
if (!isTryCast) {
checkExceptionInExpression[NumberFormatException](
ret, "invalid input syntax for type numeric")
ret, s"Invalid input syntax for type ${IntegerType.sql}")
}
}

Expand All @@ -387,7 +389,7 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
assert(ret.resolved == !isTryCast)
if (!isTryCast) {
checkExceptionInExpression[NumberFormatException](
ret, "invalid input syntax for type numeric")
ret, s"Invalid input syntax for type ${IntegerType.sql}")
}
}
}
Expand Down Expand Up @@ -512,7 +514,7 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
assert(ret.resolved === !isTryCast)
if (!isTryCast) {
checkExceptionInExpression[NumberFormatException](
ret, "invalid input syntax for type numeric")
ret, s"Invalid input syntax for type ${IntegerType.sql}")
}
}

Expand All @@ -521,7 +523,7 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
def checkCastWithParseError(str: String): Unit = {
checkExceptionInExpression[DateTimeException](
cast(Literal(str), TimestampType, Option(zid.getId)),
s"Cannot cast $str to TimestampType.")
s"Invalid input syntax for type TIMESTAMP: '$str'")
}

checkCastWithParseError("123")
Expand All @@ -542,7 +544,7 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
def checkCastWithParseError(str: String): Unit = {
checkExceptionInExpression[DateTimeException](
cast(Literal(str), DateType, Option(zid.getId)),
s"Cannot cast $str to DateType.")
s"Invalid input syntax for type DATE: '$str'")
}

checkCastWithParseError("2015-13-18")
Expand Down Expand Up @@ -570,7 +572,7 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
"2021-06-17 00:00:00ABC").foreach { invalidInput =>
checkExceptionInExpression[DateTimeException](
cast(invalidInput, TimestampNTZType),
s"Cannot cast $invalidInput to TimestampNTZType")
s"Invalid input syntax for type TIMESTAMP_NTZ: '$invalidInput'")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class TryCastSuite extends AnsiCastSuiteBase {
checkEvaluation(cast(l, to), tryCastResult, InternalRow(l.value))
}

override def checkCastToNumericError(l: Literal, to: DataType, tryCastResult: Any): Unit = {
override def checkCastToNumericError(l: Literal, to: DataType,
expectedDataTypeInErrorMsg: DataType, tryCastResult: Any): Unit = {
checkEvaluation(cast(l, to), tryCastResult, InternalRow(l.value))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,6 @@ class DateFormatterSuite extends DatetimeFormatterSuite {
val errMsg = intercept[DateTimeException] {
formatter.parse("x123")
}.getMessage
assert(errMsg.contains("Cannot cast x123 to DateType"))
assert(errMsg.contains("Invalid input syntax for type DATE: 'x123'"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ class TimestampFormatterSuite extends DatetimeFormatterSuite {
val errMsg = intercept[DateTimeException] {
formatter.parse("x123")
}.getMessage
assert(errMsg.contains("Cannot cast x123 to TimestampType"))
assert(errMsg.contains("Invalid input syntax for type TIMESTAMP: 'x123'"))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper

assert(Decimal.fromString(UTF8String.fromString("str")) === null)
val e = intercept[NumberFormatException](Decimal.fromStringANSI(UTF8String.fromString("str")))
assert(e.getMessage.contains("invalid input syntax for type numeric"))
assert(e.getMessage.contains("Invalid input syntax for type " +
s"${DecimalType.USER_DEFAULT.sql}: 'str'"))
}

test("SPARK-35841: Casting string to decimal type doesn't work " +
Expand Down
10 changes: 8 additions & 2 deletions sql/core/src/test/resources/sql-tests/inputs/cast.sql
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ SELECT CAST('1.23' AS long);
SELECT CAST('-4.56' AS int);
SELECT CAST('-4.56' AS long);

-- cast string which are not numbers to integral should return null
-- cast string which are not numbers to numeric types
SELECT CAST('abc' AS int);
SELECT CAST('abc' AS long);
SELECT CAST('abc' AS float);
SELECT CAST('abc' AS double);

-- cast string representing a very large number to integral should return null
SELECT CAST('1234567890123' AS int);
Expand All @@ -15,14 +17,18 @@ SELECT CAST('12345678901234567890123' AS long);
-- cast empty string to integral should return null
SELECT CAST('' AS int);
SELECT CAST('' AS long);
SELECT CAST('' AS float);
SELECT CAST('' AS double);

-- cast null to integral should return null
SELECT CAST(NULL AS int);
SELECT CAST(NULL AS long);

-- cast invalid decimal string to integral should return null
-- cast invalid decimal string to numeric types
SELECT CAST('123.a' AS int);
SELECT CAST('123.a' AS long);
SELECT CAST('123.a' AS float);
SELECT CAST('123.a' AS double);

-- '-2147483648' is the smallest int value
SELECT CAST('-2147483648' AS int);
Expand Down
Loading

0 comments on commit 9d0650a

Please sign in to comment.