Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47541][SQL] Collated strings in complex types supporting operations reverse, array_join, concat, map #45693

Closed
wants to merge 9 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -1351,7 +1351,8 @@ case class Reverse(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {

// Input types are utilized by type coercion in ImplicitTypeCasts.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType))
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(StringTypeAnyCollation, ArrayType))

override def dataType: DataType = child.dataType

Expand All @@ -1365,7 +1366,7 @@ case class Reverse(child: Expression)
val arrayData = input.asInstanceOf[ArrayData]
new GenericArrayData(arrayData.toObjectArray(elementType).reverse)
}
case StringType => _.asInstanceOf[UTF8String].reverse()
case _: StringType => _.asInstanceOf[UTF8String].reverse()
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down Expand Up @@ -2002,9 +2003,9 @@ case class ArrayJoin(
this(array, delimiter, Some(nullReplacement))

override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) {
Seq(ArrayType(StringType), StringType, StringType)
Seq(ArrayType, StringTypeAnyCollation, StringTypeAnyCollation)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the input is array of int? Before it will fail with type check, but now it seems different.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need to override checkInputDataTypes to check the array element.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan There's a special case in FunctionArgumentConversion which implicitly casts array parameter to array of strings #21620.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I'm not mistaken this behavior was allowed before and it still works as expected.

} else {
Seq(ArrayType(StringType), StringType)
Seq(ArrayType, StringTypeAnyCollation)
}

override def children: Seq[Expression] = if (nullReplacement.isDefined) {
Expand Down Expand Up @@ -2149,7 +2150,7 @@ case class ArrayJoin(
}
}

override def dataType: DataType = StringType
override def dataType: DataType = array.dataType.asInstanceOf[ArrayType].elementType

override def prettyName: String = "array_join"
}
Expand Down Expand Up @@ -2724,7 +2725,8 @@ case class TryElementAt(left: Expression, right: Expression, replacement: Expres
case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression
with QueryErrorsBase {

private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType)
private def allowedTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, BinaryType, ArrayType)

final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT)

Expand Down Expand Up @@ -2774,7 +2776,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
ByteArray.concat(inputs: _*)
}
case StringType =>
case _: StringType =>
input => {
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
UTF8String.concat(inputs: _*)
Expand Down Expand Up @@ -2845,7 +2847,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
val (concat, initCode) = dataType match {
case BinaryType =>
(s"${classOf[ByteArray].getName}.concat", s"byte[][] $args = new byte[${evals.length}][];")
case StringType =>
case _: StringType =>
("UTF8String.concat", s"UTF8String[] $args = new UTF8String[${evals.length}];")
case ArrayType(elementType, containsNull) =>
val concat = genCodeForArrays(ctx, elementType, containsNull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,22 @@ import org.apache.spark.unsafe.array.ByteArrayMethods
class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Serializable {
assert(!keyType.existsRecursively(_.isInstanceOf[MapType]), "key of map cannot be/contain map")

private lazy val keyToIndex = keyType match {
// Binary type data is `byte[]`, which can't use `==` to check equality.
case _: AtomicType | _: CalendarIntervalType | _: NullType
if !keyType.isInstanceOf[BinaryType] => new java.util.HashMap[Any, Int]()
case _ =>
// for complex types, use interpreted ordering to be able to compare unsafe data with safe
// data, e.g. UnsafeRow vs GenericInternalRow.
new java.util.TreeMap[Any, Int](TypeUtils.getInterpretedOrdering(keyType))
private lazy val keyToIndex = {
def hashMap = new java.util.HashMap[Any, Int]()
def treeMap = new java.util.TreeMap[Any, Int](TypeUtils.getInterpretedOrdering(keyType))

keyType match {
// StringType binary equality support implies hashing support
case s: StringType if s.supportsBinaryEquality => hashMap
case _: StringType => treeMap
// Binary type data is `byte[]`, which can't use `==` to check equality.
case _: BinaryType => treeMap
case _: AtomicType | _: CalendarIntervalType | _: NullType => hashMap
case _ =>
// for complex types, use interpreted ordering to be able to compare unsafe data with safe
// data, e.g. UnsafeRow vs GenericInternalRow.
treeMap
}
}

// TODO: specialize it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(0),
"requiredType" -> "(\"STRING\" or \"BINARY\" or \"ARRAY\")",
"requiredType" -> "(\"STRING_ANY_COLLATION\" or \"BINARY\" or \"ARRAY\")",
"inputSql" -> "\"1\"",
"inputType" -> "\"INT\""
)
Expand Down
52 changes: 52 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,58 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
}
}

test("Support operations on complex types containing collated strings") {
checkAnswer(sql("select reverse('abc' collate utf8_binary_lcase)"), Seq(Row("cba")))
checkAnswer(sql(
"""
|select reverse(array('a' collate utf8_binary_lcase,
|'b' collate utf8_binary_lcase))
|""".stripMargin), Seq(Row(Seq("b", "a"))))
checkAnswer(sql(
"""
|select array_join(array('a' collate utf8_binary_lcase,
|'b' collate utf8_binary_lcase), ', ' collate utf8_binary_lcase)
|""".stripMargin), Seq(Row("a, b")))
checkAnswer(sql(
"""
|select array_join(array('a' collate utf8_binary_lcase,
|'b' collate utf8_binary_lcase, null), ', ' collate utf8_binary_lcase,
|'c' collate utf8_binary_lcase)
|""".stripMargin), Seq(Row("a, b, c")))
checkAnswer(sql(
"""
|select concat('a' collate utf8_binary_lcase, 'b' collate utf8_binary_lcase)
|""".stripMargin), Seq(Row("ab")))
checkAnswer(sql(
"""
|select concat(array('a' collate utf8_binary_lcase, 'b' collate utf8_binary_lcase))
|""".stripMargin), Seq(Row(Seq("a", "b"))))
checkAnswer(sql(
"""
|select map('a' collate utf8_binary_lcase, 1, 'b' collate utf8_binary_lcase, 2)
|['A' collate utf8_binary_lcase]
|""".stripMargin), Seq(Row(1)))
val ctx = "map('aaa' collate utf8_binary_lcase, 1, 'AAA' collate utf8_binary_lcase, 2)['AaA']"
val query = s"select $ctx"
checkError(
exception = intercept[AnalysisException](sql(query)),
errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
parameters = Map(
"sqlExpr" -> "\"map(collate(aaa), 1, collate(AAA), 2)[AaA]\"",
"paramIndex" -> "second",
"inputSql" -> "\"AaA\"",
"inputType" -> toSQLType(StringType),
"requiredType" -> toSQLType(StringType(
CollationFactory.collationNameToId("UTF8_BINARY_LCASE")))
),
context = ExpectedContext(
fragment = ctx,
start = query.length - ctx.length,
stop = query.length - 1
)
)
}

test("window aggregates should respect collation") {
val t1 = "T_NON_BINARY"
val t2 = "T_BINARY"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1713,7 +1713,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
"paramIndex" -> "second",
"inputSql" -> "\"1\"",
"inputType" -> "\"INT\"",
"requiredType" -> "\"STRING\""
"requiredType" -> "\"STRING_ANY_COLLATION\""
),
queryContext = Array(ExpectedContext("", "", 0, 15, "array_join(x, 1)"))
)
Expand All @@ -1727,7 +1727,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
"paramIndex" -> "third",
"inputSql" -> "\"1\"",
"inputType" -> "\"INT\"",
"requiredType" -> "\"STRING\""
"requiredType" -> "\"STRING_ANY_COLLATION\""
),
queryContext = Array(ExpectedContext("", "", 0, 21, "array_join(x, ', ', 1)"))
)
Expand Down Expand Up @@ -1987,7 +1987,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
"paramIndex" -> "first",
"inputSql" -> "\"struct(1, a)\"",
"inputType" -> "\"STRUCT<col1: INT NOT NULL, col2: STRING NOT NULL>\"",
"requiredType" -> "(\"STRING\" or \"ARRAY\")"
"requiredType" -> "(\"STRING_ANY_COLLATION\" or \"ARRAY\")"
),
queryContext = Array(ExpectedContext("", "", 7, 29, "reverse(struct(1, 'a'))"))
)
Expand All @@ -2002,7 +2002,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
"paramIndex" -> "first",
"inputSql" -> "\"map(1, a)\"",
"inputType" -> "\"MAP<INT, STRING>\"",
"requiredType" -> "(\"STRING\" or \"ARRAY\")"
"requiredType" -> "(\"STRING_ANY_COLLATION\" or \"ARRAY\")"
),
queryContext = Array(ExpectedContext("", "", 7, 26, "reverse(map(1, 'a'))"))
)
Expand Down Expand Up @@ -2552,7 +2552,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
parameters = Map(
"sqlExpr" -> "\"concat(map(1, 2), map(3, 4))\"",
"paramIndex" -> "first",
"requiredType" -> "(\"STRING\" or \"BINARY\" or \"ARRAY\")",
"requiredType" -> "(\"STRING_ANY_COLLATION\" or \"BINARY\" or \"ARRAY\")",
"inputSql" -> "\"map(1, 2)\"",
"inputType" -> "\"MAP<INT, INT>\""
),
Expand Down