From 8c0a7ba82c98c7f7e686c4ee81d2aad49cc7a6e0 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Wed, 15 May 2024 14:24:46 +0800 Subject: [PATCH] [SPARK-48160][SQL] Add collation support for XPATH expressions ### What changes were proposed in this pull request? Introduce collation awareness for XPath expressions: xpath_boolean, xpath_short, xpath_int, xpath_long, xpath_float, xpath_double, xpath_string, xpath. ### Why are the changes needed? Add collation support for Xpath expressions in Spark. ### Does this PR introduce _any_ user-facing change? Yes, users should now be able to use collated strings within arguments for XPath functions: xpath_boolean, xpath_short, xpath_int, xpath_long, xpath_float, xpath_double, xpath_string, xpath. ### How was this patch tested? E2e sql tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46508 from uros-db/xpath-expressions. Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Signed-off-by: Wenchen Fan --- .../sql/catalyst/expressions/xml/xpath.scala | 11 +++-- .../sql/CollationSQLExpressionsSuite.scala | 44 +++++++++++++++++++ 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index c3a285178c110..f65061e8d0ea9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.types.StringTypeAnyCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -39,7 +41,8 @@ abstract class XPathExtract /** XPath expressions are always nullable, e.g. if the xml string is empty. */ override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation) override def checkInputDataTypes(): TypeCheckResult = { if (!path.foldable) { @@ -47,7 +50,7 @@ abstract class XPathExtract errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("path"), - "inputType" -> toSQLType(StringType), + "inputType" -> toSQLType(StringTypeAnyCollation), "inputExpr" -> toSQLExpr(path) ) ) @@ -221,7 +224,7 @@ case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract { // scalastyle:on line.size.limit case class XPathString(xml: Expression, path: Expression) extends XPathExtract { override def prettyName: String = "xpath_string" - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullSafeEval(xml: Any, path: Any): Any = { val ret = xpathUtil.evalString(xml.asInstanceOf[UTF8String].toString, pathString) @@ -245,7 +248,7 @@ case class XPathString(xml: Expression, path: Expression) extends XPathExtract { // scalastyle:on line.size.limit case class XPathList(xml: Expression, path: Expression) extends XPathExtract { override def prettyName: String = "xpath" - override def dataType: DataType = ArrayType(StringType, containsNull = false) + override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType, containsNull = false) override def nullSafeEval(xml: Any, path: Any): Any = { val nodeList = xpathUtil.evalNodeList(xml.asInstanceOf[UTF8String].toString, pathString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 48c3853bb5cf9..37dcdf9bd7216 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -548,6 +548,50 @@ class CollationSQLExpressionsSuite }) } + test("Support XPath expressions with collation") { + case class XPathTestCase( + xml: String, + xpath: String, + functionName: String, + collationName: String, + result: Any, + resultType: DataType + ) + + val testCases = Seq( + XPathTestCase("1", "a/b", + "xpath_boolean", "UTF8_BINARY", true, BooleanType), + XPathTestCase("12", "sum(A/B)", + "xpath_short", "UTF8_BINARY", 3, ShortType), + XPathTestCase("34", "sum(a/b)", + "xpath_int", "UTF8_BINARY_LCASE", 7, IntegerType), + XPathTestCase("56", "sum(A/B)", + "xpath_long", "UTF8_BINARY_LCASE", 11, LongType), + XPathTestCase("78", "sum(a/b)", + "xpath_float", "UNICODE", 15.0, FloatType), + XPathTestCase("90", "sum(A/B)", + "xpath_double", "UNICODE", 9.0, DoubleType), + XPathTestCase("bcc", "a/c", + "xpath_string", "UNICODE_CI", "cc", StringType("UNICODE_CI")), + XPathTestCase("b1b2b3c1c2", "a/b/text()", + "xpath", "UNICODE_CI", Array("b1", "b2", "b3"), ArrayType(StringType("UNICODE_CI"))) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |select ${t.functionName}('${t.xml}', '${t.xpath}') + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + assert(testQuery.schema.fields.head.dataType.sameType(t.resultType)) + } + }) + } + test("Support StringSpace expression with collation") { case class StringSpaceTestCase( input: Int,