diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 3c9888940221a..bd2c3baf4fe85 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1991,15 +1991,15 @@ case class Right(str: Expression, len: Expression) extends RuntimeReplaceable override lazy val replacement: Expression = If( IsNull(str), - Literal(null, StringType), + Literal(null, str.dataType), If( LessThanOrEqual(len, Literal(0)), - Literal(UTF8String.EMPTY_UTF8, StringType), + Literal(UTF8String.EMPTY_UTF8, str.dataType), new Substring(str, UnaryMinus(len)) ) ) - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, IntegerType) override def left: Expression = str override def right: Expression = len override protected def withNewChildrenInternal( @@ -2030,7 +2030,7 @@ case class Left(str: Expression, len: Expression) extends RuntimeReplaceable override lazy val replacement: Expression = Substring(str, Literal(1), len) override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(StringType, BinaryType), IntegerType) + Seq(TypeCollection(StringTypeAnyCollation, BinaryType), IntegerType) } override def left: Expression = str diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index f0be9cc89a4d2..07be8d48e8697 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -425,6 +425,55 @@ class CollationStringExpressionsSuite }) } + test("Support Left/Right/Substr with collation") { + case class SubstringTestCase( + method: String, + str: String, + len: String, + pad: Option[String], + collation: String, + result: Row) { + val strString = if (str == "null") "null" else s"'$str'" + val query = + s"SELECT $method(collate($strString, '$collation')," + + s" $len${pad.map(p => s", '$p'").getOrElse("")})" + } + + val checks = Seq( + SubstringTestCase("substr", "example", "1", Some("100"), "utf8_binary_lcase", Row("example")), + SubstringTestCase("substr", "example", "2", Some("2"), "utf8_binary", Row("xa")), + SubstringTestCase("right", "", "1", None, "utf8_binary_lcase", Row("")), + SubstringTestCase("substr", "example", "0", Some("0"), "unicode", Row("")), + SubstringTestCase("substr", "example", "-3", Some("2"), "unicode_ci", Row("pl")), + SubstringTestCase("substr", " a世a ", "2", Some("3"), "utf8_binary_lcase", Row("a世a")), + SubstringTestCase("left", " a世a ", "3", None, "utf8_binary", Row(" a世")), + SubstringTestCase("right", " a世a ", "3", None, "unicode", Row("世a ")), + SubstringTestCase("left", "ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", "3", None, "unicode_ci", Row("ÀÃÂ")), + SubstringTestCase("right", "ÀÃÂĀĂȦÄäâãȻȻȻȻȻǢǼÆ", "3", None, "utf8_binary_lcase", Row("ǢǼÆ")), + SubstringTestCase("substr", "", "1", Some("1"), "utf8_binary_lcase", Row("")), + SubstringTestCase("substr", "", "1", Some("1"), "unicode", Row("")), + SubstringTestCase("left", "", "1", None, "utf8_binary", Row("")), + SubstringTestCase("left", "null", "1", None, "utf8_binary_lcase", Row(null)), + SubstringTestCase("right", "null", "1", None, "unicode", Row(null)), + SubstringTestCase("substr", "null", "1", None, "utf8_binary", Row(null)), + SubstringTestCase("substr", "null", "1", Some("1"), "unicode_ci", Row(null)), + SubstringTestCase("left", "null", "null", None, "utf8_binary_lcase", Row(null)), + SubstringTestCase("right", "null", "null", None, "unicode", Row(null)), + SubstringTestCase("substr", "null", "null", Some("null"), "utf8_binary", Row(null)), + SubstringTestCase("substr", "null", "null", None, "unicode_ci", Row(null)), + SubstringTestCase("left", "ÀÃÂȦÄäåäáâãȻȻȻǢǼÆ", "null", None, "utf8_binary_lcase", Row(null)), + SubstringTestCase("right", "ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", "null", None, "unicode", Row(null)), + SubstringTestCase("substr", "ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", "null", None, "utf8_binary", Row(null)), + SubstringTestCase("substr", "", "null", None, "unicode_ci", Row(null)) + ) + + checks.foreach { check => + // Result & data type + checkAnswer(sql(check.query), check.result) + assert(sql(check.query).schema.fields.head.dataType.sameType(StringType(check.collation))) + } + } + // TODO: Add more tests for other string expressions }