diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 4e71c8c103889..1fa957e9ccfe0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -64,11 +64,13 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { custom error message """, since = "3.1.0") -case class RaiseError(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class RaiseError(child: Expression, dataType: DataType) + extends UnaryExpression with ImplicitCastInputTypes { + + def this(child: Expression) = this(child, NullType) override def foldable: Boolean = false override def nullable: Boolean = true - override def dataType: DataType = NullType override def inputTypes: Seq[AbstractDataType] = Seq(StringType) override def prettyName: String = "raise_error" @@ -98,6 +100,10 @@ case class RaiseError(child: Expression) extends UnaryExpression with ImplicitCa } } +object RaiseError { + def apply(child: Expression): RaiseError = new RaiseError(child) +} + /** * A function that throws an exception if 'condition' is not true. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala index eaafe35fc00eb..5fc070a121079 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String object CharVarcharUtils extends Logging { @@ -202,12 +203,9 @@ object CharVarcharUtils extends Logging { }.getOrElse(expr) } - private def raiseError(expr: Expression, typeName: String, length: Int): Expression = { - val errorMsg = Concat(Seq( - Literal("input string of length "), - Cast(Length(expr), StringType), - Literal(s" exceeds $typeName type length limitation: $length"))) - Cast(RaiseError(errorMsg), StringType) + private def raiseError(typeName: String, length: Int): Expression = { + val errMsg = UTF8String.fromString(s"Exceeds $typeName type length limitation: $length") + RaiseError(Literal(errMsg, StringType), StringType) } private def stringLengthCheck(expr: Expression, dt: DataType): Expression = dt match { @@ -217,7 +215,7 @@ object CharVarcharUtils extends Logging { // spaces, as we will pad char type columns/fields at read time. If( GreaterThan(Length(trimmed), Literal(length)), - raiseError(expr, "char", length), + raiseError("char", length), trimmed) case VarcharType(length) => @@ -230,7 +228,7 @@ object CharVarcharUtils extends Logging { expr, If( GreaterThan(Length(trimmed), Literal(length)), - raiseError(expr, "varchar", length), + raiseError("varchar", length), StringRPad(trimmed, Literal(length)))) case StructType(fields) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 7546e888075d9..fbf3f2ab0fa0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -189,8 +189,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql("INSERT INTO t VALUES (null)") checkAnswer(spark.table("t"), Row(null)) val e = intercept[SparkException](sql("INSERT INTO t VALUES ('123456')")) - assert(e.getCause.getMessage.contains( - s"input string of length 6 exceeds $typeName type length limitation: 5")) + assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5")) } } @@ -202,8 +201,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql("INSERT INTO t VALUES (1, null)") checkAnswer(spark.table("t"), Row(1, null)) val e = intercept[SparkException](sql("INSERT INTO t VALUES (1, '123456')")) - assert(e.getCause.getMessage.contains( - s"input string of length 6 exceeds $typeName type length limitation: 5")) + assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5")) } } } @@ -214,8 +212,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql("INSERT INTO t SELECT struct(null)") checkAnswer(spark.table("t"), Row(Row(null))) val e = intercept[SparkException](sql("INSERT INTO t SELECT struct('123456')")) - assert(e.getCause.getMessage.contains( - s"input string of length 6 exceeds $typeName type length limitation: 5")) + assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5")) } } @@ -225,8 +222,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql("INSERT INTO t VALUES (array(null))") checkAnswer(spark.table("t"), Row(Seq(null))) val e = intercept[SparkException](sql("INSERT INTO t VALUES (array('a', '123456'))")) - assert(e.getCause.getMessage.contains( - s"input string of length 6 exceeds $typeName type length limitation: 5")) + assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5")) } } @@ -234,8 +230,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { testTableWrite { typeName => sql(s"CREATE TABLE t(c MAP<$typeName(5), STRING>) USING $format") val e = intercept[SparkException](sql("INSERT INTO t VALUES (map('123456', 'a'))")) - assert(e.getCause.getMessage.contains( - s"input string of length 6 exceeds $typeName type length limitation: 5")) + assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5")) } } @@ -245,8 +240,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql("INSERT INTO t VALUES (map('a', null))") checkAnswer(spark.table("t"), Row(Map("a" -> null))) val e = intercept[SparkException](sql("INSERT INTO t VALUES (map('a', '123456'))")) - assert(e.getCause.getMessage.contains( - s"input string of length 6 exceeds $typeName type length limitation: 5")) + assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5")) } } @@ -254,11 +248,9 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { testTableWrite { typeName => sql(s"CREATE TABLE t(c MAP<$typeName(5), $typeName(5)>) USING $format") val e1 = intercept[SparkException](sql("INSERT INTO t VALUES (map('123456', 'a'))")) - assert(e1.getCause.getMessage.contains( - s"input string of length 6 exceeds $typeName type length limitation: 5")) + assert(e1.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5")) val e2 = intercept[SparkException](sql("INSERT INTO t VALUES (map('a', '123456'))")) - assert(e2.getCause.getMessage.contains( - s"input string of length 6 exceeds $typeName type length limitation: 5")) + assert(e2.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5")) } } @@ -268,8 +260,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql("INSERT INTO t SELECT struct(array(null))") checkAnswer(spark.table("t"), Row(Row(Seq(null)))) val e = intercept[SparkException](sql("INSERT INTO t SELECT struct(array('123456'))")) - assert(e.getCause.getMessage.contains( - s"input string of length 6 exceeds $typeName type length limitation: 5")) + assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5")) } } @@ -279,8 +270,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql("INSERT INTO t VALUES (array(struct(null)))") checkAnswer(spark.table("t"), Row(Seq(Row(null)))) val e = intercept[SparkException](sql("INSERT INTO t VALUES (array(struct('123456')))")) - assert(e.getCause.getMessage.contains( - s"input string of length 6 exceeds $typeName type length limitation: 5")) + assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5")) } } @@ -290,8 +280,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql("INSERT INTO t VALUES (array(array(null)))") checkAnswer(spark.table("t"), Row(Seq(Seq(null)))) val e = intercept[SparkException](sql("INSERT INTO t VALUES (array(array('123456')))")) - assert(e.getCause.getMessage.contains( - s"input string of length 6 exceeds $typeName type length limitation: 5")) + assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5")) } } @@ -312,11 +301,9 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql("INSERT INTO t VALUES (1234, 1234)") checkAnswer(spark.table("t"), Row("1234 ", "1234")) val e1 = intercept[SparkException](sql("INSERT INTO t VALUES (123456, 1)")) - assert(e1.getCause.getMessage.contains( - "input string of length 6 exceeds char type length limitation: 5")) + assert(e1.getCause.getMessage.contains("Exceeds char type length limitation: 5")) val e2 = intercept[SparkException](sql("INSERT INTO t VALUES (1, 123456)")) - assert(e2.getCause.getMessage.contains( - "input string of length 6 exceeds varchar type length limitation: 5")) + assert(e2.getCause.getMessage.contains("Exceeds varchar type length limitation: 5")) } } @@ -626,8 +613,7 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa sql("SELECT '123456' as col").write.format(format).save(dir.toString) sql(s"CREATE TABLE t (col $typ(2)) using $format LOCATION '$dir'") val e = intercept[SparkException] { sql("select * from t").collect() } - assert(e.getCause.getMessage.contains( - s"input string of length 6 exceeds $typ type length limitation: 2")) + assert(e.getCause.getMessage.contains(s"Exceeds $typ type length limitation: 2")) } } } @@ -654,8 +640,7 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa sql(s"CREATE TABLE t (col $typ(2)) using $format") sql(s"ALTER TABLE t SET LOCATION '$dir'") val e = intercept[SparkException] { spark.table("t").collect() } - assert(e.getCause.getMessage.contains( - s"input string of length 6 exceeds $typ type length limitation: 2")) + assert(e.getCause.getMessage.contains(s"Exceeds $typ type length limitation: 2")) } } }