Skip to content

Commit

Permalink
[SPARK-48160][SQL] Add collation support for XPATH expressions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Introduce collation awareness for XPath expressions: xpath_boolean, xpath_short, xpath_int, xpath_long, xpath_float, xpath_double, xpath_string, xpath.

### Why are the changes needed?
Add collation support for Xpath expressions in Spark.

### Does this PR introduce _any_ user-facing change?
Yes, users should now be able to use collated strings within arguments for XPath functions: xpath_boolean, xpath_short, xpath_int, xpath_long, xpath_float, xpath_double, xpath_string, xpath.

### How was this patch tested?
E2e sql tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #46508 from uros-db/xpath-expressions.

Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
uros-db authored and cloud-fan committed May 15, 2024
1 parent 7233540 commit 8c0a7ba
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -39,15 +41,16 @@ abstract class XPathExtract
/** XPath expressions are always nullable, e.g. if the xml string is empty. */
override def nullable: Boolean = true

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation)

override def checkInputDataTypes(): TypeCheckResult = {
if (!path.foldable) {
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> toSQLId("path"),
"inputType" -> toSQLType(StringType),
"inputType" -> toSQLType(StringTypeAnyCollation),
"inputExpr" -> toSQLExpr(path)
)
)
Expand Down Expand Up @@ -221,7 +224,7 @@ case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract {
// scalastyle:on line.size.limit
case class XPathString(xml: Expression, path: Expression) extends XPathExtract {
override def prettyName: String = "xpath_string"
override def dataType: DataType = StringType
override def dataType: DataType = SQLConf.get.defaultStringType

override def nullSafeEval(xml: Any, path: Any): Any = {
val ret = xpathUtil.evalString(xml.asInstanceOf[UTF8String].toString, pathString)
Expand All @@ -245,7 +248,7 @@ case class XPathString(xml: Expression, path: Expression) extends XPathExtract {
// scalastyle:on line.size.limit
case class XPathList(xml: Expression, path: Expression) extends XPathExtract {
override def prettyName: String = "xpath"
override def dataType: DataType = ArrayType(StringType, containsNull = false)
override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType, containsNull = false)

override def nullSafeEval(xml: Any, path: Any): Any = {
val nodeList = xpathUtil.evalNodeList(xml.asInstanceOf[UTF8String].toString, pathString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,50 @@ class CollationSQLExpressionsSuite
})
}

test("Support XPath expressions with collation") {
case class XPathTestCase(
xml: String,
xpath: String,
functionName: String,
collationName: String,
result: Any,
resultType: DataType
)

val testCases = Seq(
XPathTestCase("<a><b>1</b></a>", "a/b",
"xpath_boolean", "UTF8_BINARY", true, BooleanType),
XPathTestCase("<A><B>1</B><B>2</B></A>", "sum(A/B)",
"xpath_short", "UTF8_BINARY", 3, ShortType),
XPathTestCase("<a><b>3</b><b>4</b></a>", "sum(a/b)",
"xpath_int", "UTF8_BINARY_LCASE", 7, IntegerType),
XPathTestCase("<A><B>5</B><B>6</B></A>", "sum(A/B)",
"xpath_long", "UTF8_BINARY_LCASE", 11, LongType),
XPathTestCase("<a><b>7</b><b>8</b></a>", "sum(a/b)",
"xpath_float", "UNICODE", 15.0, FloatType),
XPathTestCase("<A><B>9</B><B>0</B></A>", "sum(A/B)",
"xpath_double", "UNICODE", 9.0, DoubleType),
XPathTestCase("<a><b>b</b><c>cc</c></a>", "a/c",
"xpath_string", "UNICODE_CI", "cc", StringType("UNICODE_CI")),
XPathTestCase("<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>", "a/b/text()",
"xpath", "UNICODE_CI", Array("b1", "b2", "b3"), ArrayType(StringType("UNICODE_CI")))
)

// Supported collations
testCases.foreach(t => {
val query =
s"""
|select ${t.functionName}('${t.xml}', '${t.xpath}')
|""".stripMargin
// Result & data type
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
val testQuery = sql(query)
checkAnswer(testQuery, Row(t.result))
assert(testQuery.schema.fields.head.dataType.sameType(t.resultType))
}
})
}

test("Support StringSpace expression with collation") {
case class StringSpaceTestCase(
input: Int,
Expand Down

0 comments on commit 8c0a7ba

Please sign in to comment.