diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 34e8f3f408599..3e5779308a565 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1987,15 +1987,15 @@ case class Right(str: Expression, len: Expression) extends RuntimeReplaceable override lazy val replacement: Expression = If( IsNull(str), - Literal(null, StringType), + Literal(null, str.dataType), If( LessThanOrEqual(len, Literal(0)), - Literal(UTF8String.EMPTY_UTF8, StringType), + Literal(UTF8String.EMPTY_UTF8, str.dataType), new Substring(str, UnaryMinus(len)) ) ) - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, IntegerType) override def left: Expression = str override def right: Expression = len override protected def withNewChildrenInternal( @@ -2026,7 +2026,7 @@ case class Left(str: Expression, len: Expression) extends RuntimeReplaceable override lazy val replacement: Expression = Substring(str, Literal(1), len) override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(StringType, BinaryType), IntegerType) + Seq(TypeCollection(StringTypeAnyCollation, BinaryType), IntegerType) } override def left: Expression = str diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 0dbd4c0ba713f..053ae2261de8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -212,6 +212,49 @@ class CollationStringExpressionsSuite }) } + test("Support Left/Right/Substr with collation") { + case class SubstringTestCase(query: String, collation: String, result: Row) + val checks = Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").flatMap( + c => Seq( + SubstringTestCase("select substr('example' collate " + c + ", 1, 100)", c, Row("example")), + SubstringTestCase("select substr('example' collate " + c + ", 2, 2)", c, Row("xa")), + SubstringTestCase("select substr('example' collate " + c + ", 0, 0)", c, Row("")), + SubstringTestCase("select substr('example' collate " + c + ", -3, 2)", c, Row("pl")), + SubstringTestCase("select substr(' a世a ' collate " + c + ", 2, 3)", c, Row("a世a")), // scalastyle:ignore + SubstringTestCase("select left(' a世a ' collate " + c + ", 3)", c, Row(" a世")), // scalastyle:ignore + SubstringTestCase("select right(' a世a ' collate " + c + ", 3)", c, Row("世a ")), // scalastyle:ignore + SubstringTestCase("select left('AaAaAaAa000000' collate " + c + ", 3)", c, Row("AaA")), + SubstringTestCase("select right('AaAaAaAa000000' collate " + c + ", 3)", c, Row("000")), + SubstringTestCase("select substr('' collate " + c + ", 1, 1)", c, Row("")), + SubstringTestCase("select left('' collate " + c + ", 1)", c, Row("")), + SubstringTestCase("select right('' collate " + c + ", 1)", c, Row("")), + // improper values + SubstringTestCase("select left(null collate " + c + ", 1)", c, Row(null)), + SubstringTestCase("select right(null collate " + c + ", 1)", c, Row(null)), + SubstringTestCase("select substr(null collate " + c + ", 1)", c, Row(null)), + SubstringTestCase("select substr(null collate " + c + ", 1, 1)", c, Row(null)), + SubstringTestCase("select left(null collate " + c + ", null)", c, Row(null)), + SubstringTestCase("select right(null collate " + c + ", null)", c, Row(null)), + SubstringTestCase("select substr(null collate " + c + ", null)", c, Row(null)), + SubstringTestCase("select substr(null collate " + c + ", null, null)", c, Row(null)), + SubstringTestCase("select left('AaAaAaAa000000' collate " + c + ", null)", c, Row(null)), + SubstringTestCase("select right('AaAaAaAa000000' collate " + c + ", null)", c, Row(null)), + SubstringTestCase("select substr('AaAaAaAa000000' collate " + c + ", null)", c, Row(null)), + SubstringTestCase("select substr('AaAaAaAa0' collate " + c + ", null, null)", c, Row(null)), + SubstringTestCase("select right('' collate " + c + ", null)", c, Row(null)), + SubstringTestCase("select substr('' collate " + c + ", null)", c, Row(null)), + SubstringTestCase("select substr('' collate " + c + ", null, null)", c, Row(null)), + SubstringTestCase("select left('' collate " + c + ", null)", c, Row(null)) + ) + ) + + checks.foreach { check => + // Result & data type + checkAnswer(sql(check.query), check.result) + assert(sql(check.query).schema.fields.head.dataType.sameType(StringType(check.collation))) + } + } + // TODO: Add more tests for other string expressions }