Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,41 @@ public static boolean execICU(final UTF8String l, final UTF8String r,
* Collation-aware regexp expressions.
*/

public static class StringSplit {
public static UTF8String[] exec(final UTF8String string, final UTF8String regex,
final int limit, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return execBinary(string, regex, limit);
} else {
assert(collation.supportsLowercaseEquality);
return execLowercase(string, regex, limit);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no execICU for this operation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's not supported yet, we should add assert(collation.supportsLowercaseEquality)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think final lazy val collationId: Int = str.dataType.asInstanceOf[StringType].collationId for StringTypeBinaryLcase won't allow that to happen. But an additional assert in the else branch couldn't hurt too

}
}
public static String genCode(final String string, final String regex, final String limit,
final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.StringSplit.exec";
if (collation.supportsBinaryEquality) {
return String.format(expr + "Binary(%s, %s, %s)", string, regex, limit);
} else {
return String.format(expr + "Lowercase(%s, %s, %s)", string, regex, limit);
}
}
public static UTF8String[] execBinary(final UTF8String string, final UTF8String regex,
final int limit) {
return string.split(regex, limit);
}
public static UTF8String[] execLowercase(final UTF8String string, final UTF8String regex,
final int limit) {
if (string.numBytes() != 0 && regex.numBytes() == 0) {
return string.split(regex, limit);
} else {
return string.split(CollationAwareUTF8String.getLowercaseRegex(regex), limit);
}
}
}

// TODO: Add more collation-aware regexp expressions.

/**
Expand All @@ -169,6 +204,13 @@ private static boolean matchAt(final UTF8String target, final UTF8String pattern
pos, pos + pattern.numChars()), pattern, collationId).last() == 0;
}

// ui flags toggle unicode case-insensitive matching
private static final UTF8String lowercaseRegexPrefix = UTF8String.fromString("(?ui)");

private static UTF8String getLowercaseRegex(UTF8String regex) {
return UTF8String.concat(lowercaseRegexPrefix, regex);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.unsafe.types;

import java.util.Arrays;

import org.apache.spark.SparkException;
import org.apache.spark.sql.catalyst.util.CollationFactory;
import org.apache.spark.sql.catalyst.util.CollationSupport;
Expand Down Expand Up @@ -255,6 +257,82 @@ public void testEndsWith() throws SparkException {
* Collation-aware regexp expressions.
*/

@Test
public void testStringSplit() throws SparkException {
// binary equality
assertStringSplit("ABC", "[B]", "UTF8_BINARY", new String[]{"A", "C"});
assertStringSplit("ABC", "[b]", "UTF8_BINARY", new String[]{"ABC"});
assertStringSplit("aaaa", "", "UTF8_BINARY", new String[]{"a", "a", "a", "a"});
assertStringSplit("aaaa", "[a-z]", "UTF8_BINARY", new String[]{"", "", "", "", ""});
assertStringSplit("aaaa", "[0-9]", "UTF8_BINARY", new String[]{"aaaa"});
assertStringSplit("a1b2", "[a-z0-9]", "UTF8_BINARY", new String[]{"", "", "", "", ""});
assertStringSplit("ABC", "[B]", "UNICODE", new String[]{"A", "C"});
assertStringSplit("ABC", "[b]", "UNICODE", new String[]{"ABC"});
assertStringSplit("aaaa", "", "UNICODE", new String[]{"a", "a", "a", "a"});
assertStringSplit("aaaa", "[a-z]", "UNICODE", new String[]{"", "", "", "", ""});
assertStringSplit("aaaa", "[0-9]", "UNICODE", new String[]{"aaaa"});
assertStringSplit("a1b2", "[a-z0-9]", "UNICODE", new String[]{"", "", "", "", ""});
// non-binary equality (lowercase)
assertStringSplit("ABC", "[B]", "UTF8_BINARY_LCASE", new String[]{"A", "C"});
assertStringSplit("ABC", "[b]", "UTF8_BINARY_LCASE", new String[]{"A", "C"});
assertStringSplit("aaaa", "", "UTF8_BINARY_LCASE", new String[]{"a", "a", "a", "a"});
assertStringSplit("aaaa", "[a-z]", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""});
assertStringSplit("aaaa", "[0-9]", "UTF8_BINARY_LCASE", new String[]{"aaaa"});
assertStringSplit("a1b2", "[a-z0-9]", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""});
assertStringSplit("AAA", "[a]", "UTF8_BINARY_LCASE", new String[]{"", "", "", ""});
assertStringSplit("AAA", "[b]", "UTF8_BINARY_LCASE", new String[]{"AAA"});
assertStringSplit("aAbB", "[ab]", "UTF8_BINARY_LCASE",new String[]{"", "", "", "", ""});
assertStringSplit("", "", "UTF8_BINARY_LCASE", new String[]{""});
assertStringSplit("", "[a]", "UTF8_BINARY_LCASE", new String[]{""});
assertStringSplit("xAxBxaxbx", "[AB]", "UTF8_BINARY_LCASE",
new String[]{"x", "x", "x", "x", "x"});
assertStringSplit("ABC", "", "UTF8_BINARY_LCASE", new String[]{"A", "B", "C"});
// special characters
assertStringSplit("ä", "", "UTF8_BINARY", new String[]{"ä"});
assertStringSplit("ääää", "", "UTF8_BINARY", new String[]{"ä", "ä", "ä", "ä"});
assertStringSplit("äbćδ", "", "UTF8_BINARY", new String[]{"ä", "b", "ć", "δ"});
assertStringSplit("äbćδ", "[äbćδ]", "UTF8_BINARY", new String[]{"", "", "", "", ""});
assertStringSplit("ä", "", "UTF8_BINARY_LCASE", new String[]{"ä"});
assertStringSplit("ääää", "", "UTF8_BINARY_LCASE", new String[]{"ä", "ä", "ä", "ä"});
assertStringSplit("äbćδ", "", "UTF8_BINARY_LCASE", new String[]{"ä", "b", "ć", "δ"});
assertStringSplit("äbćδ", "[äbćδ]", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""});
assertStringSplit("äbćδ", "[ÄBĆΔ]", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""});
assertStringSplit("äbćδ", "[äBćΔ]", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""});
assertStringSplit("ääää", "Ä", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""});
assertStringSplit("AäBÄCä", "Ä", "UTF8_BINARY_LCASE", new String[]{"A", "B", "C", ""});
assertStringSplit("AäBÄCäD", "Ä", "UTF8_BINARY_LCASE", new String[]{"A", "B", "C", "D"});
assertStringSplit("ä", "", "UNICODE", new String[]{"ä"});
assertStringSplit("ääää", "", "UNICODE", new String[]{"ä", "ä", "ä", "ä"});
assertStringSplit("äbćδ", "", "UNICODE", new String[]{"ä", "b", "ć", "δ"});
assertStringSplit("äbćδ", "[äbćδ]", "UNICODE", new String[]{"", "", "", "", ""});
// set limit
assertStringSplit("ABC", "[B]", 0, "UTF8_BINARY", new String[]{"A", "C"});
assertStringSplit("ABC", "[B]", 1, "UTF8_BINARY", new String[]{"ABC"});
assertStringSplit("ABC", "[B]", 2, "UTF8_BINARY", new String[]{"A", "C"});
assertStringSplit("ABC", "[B]", 3, "UTF8_BINARY", new String[]{"A", "C"});
assertStringSplit("ABC", "[b]", 0, "UTF8_BINARY_LCASE", new String[]{"A", "C"});
assertStringSplit("ABC", "[b]", 1, "UTF8_BINARY_LCASE", new String[]{"ABC"});
assertStringSplit("ABC", "[b]", 2, "UTF8_BINARY_LCASE", new String[]{"A", "C"});
assertStringSplit("ABC", "[b]", 3, "UTF8_BINARY_LCASE", new String[]{"A", "C"});
assertStringSplit("ABC", "[B]", 0, "UNICODE", new String[]{"A", "C"});
assertStringSplit("ABC", "[B]", 1, "UNICODE", new String[]{"ABC"});
assertStringSplit("ABC", "[B]", 2, "UNICODE", new String[]{"A", "C"});
assertStringSplit("ABC", "[B]", 3, "UNICODE", new String[]{"A", "C"});
}

private void assertStringSplit(String string, String regex, int limit, String collationName,
String[] value) throws SparkException {
UTF8String[] result = CollationSupport.StringSplit.exec(UTF8String.fromString(string),
UTF8String.fromString(regex), limit, CollationFactory.collationNameToId(collationName));
String[] actual = Arrays.stream(result).map(UTF8String::toString).toArray(String[]::new);
assertArrayEquals(value, actual);
}

private void assertStringSplit(String string, String regex, String collationName,
String[] value) throws SparkException {
assertStringSplit(string, regex, -1, collationName, value);
}

// TODO: Test more collation-aware regexp expressions.

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
import org.apache.spark.sql.catalyst.util.{CollationSupport, GenericArrayData, StringUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.types.{StringTypeAnyCollation, StringTypeBinaryLcase}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -543,25 +544,28 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
case class StringSplit(str: Expression, regex: Expression, limit: Expression)
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {

override def dataType: DataType = ArrayType(StringType, containsNull = false)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
override def dataType: DataType = ArrayType(str.dataType, containsNull = false)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeBinaryLcase, StringTypeAnyCollation, IntegerType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does it mean for "regex" to be of type StringTypeAnyCollation, as it doesn't seem to me that collation is respected/needed for this parameter to begin with? for example, consider: [,] with UNICODE_CI collation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, why would we allow any collation for the regex string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to support session-level default collation. If the user changes it and passes regex string literal, it will be interpreted as collated string. We don't want to throw exception in such cases.

override def first: Expression = str
override def second: Expression = regex
override def third: Expression = limit

final lazy val collationId: Int = str.dataType.asInstanceOf[StringType].collationId

def this(exp: Expression, regex: Expression) = this(exp, regex, Literal(-1))

override def nullSafeEval(string: Any, regex: Any, limit: Any): Any = {
val strings = string.asInstanceOf[UTF8String].split(
regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int])
val strings = CollationSupport.StringSplit.exec(string.asInstanceOf[UTF8String],
regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int], collationId)
new GenericArrayData(strings.asInstanceOf[Array[Any]])
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, (str, regex, limit) => {
defineCodeGen(ctx, ev, (str, regex, limit) => {
// Array in java is covariant, so we don't need to cast UTF8String[] to Object[].
s"""${ev.value} = new $arrayClass($str.split($regex,$limit));""".stripMargin
s"new $arrayClass(${CollationSupport.StringSplit.genCode(str, regex, limit, collationId)})"
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,30 +116,31 @@ class CollationRegexpExpressionsSuite

test("Support StringSplit string expression with collation") {
// Supported collations
case class StringSplitTestCase[R](l: String, r: String, c: String, result: R)
case class StringSplitTestCase[R](l: String, r: String, c: String, result: R, limit: Int = -1)
val testCases = Seq(
StringSplitTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C"))
StringSplitTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C")),
StringSplitTestCase("ABC", "[b]", "UTF8_BINARY", Seq("ABC")),
StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("A", "C")),
StringSplitTestCase("AAA", "[a]", "UTF8_BINARY_LCASE", Seq("", "", "", "")),
StringSplitTestCase("ABC", "[B]", "UNICODE", Seq("A", "C")),
StringSplitTestCase("ABC", "[b]", "UTF8_BINARY", Seq("ABC"), 1),
StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("ABC"), 1),
StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("A", "C"), 2)
)
testCases.foreach(t => {
val query = s"SELECT split(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))"
val query = s"SELECT split(collate('${t.l}', '${t.c}'), '${t.r}', ${t.limit})"
// Result & data type
checkAnswer(sql(query), Row(t.result))
assert(sql(query).schema.fields.head.dataType.sameType(ArrayType(StringType(t.c))))
// TODO: Implicit casting (not currently supported)
})
// Unsupported collations
case class StringSplitTestFail(l: String, r: String, c: String)
val failCases = Seq(
StringSplitTestFail("ABC", "[b]", "UTF8_BINARY_LCASE"),
StringSplitTestFail("ABC", "[B]", "UNICODE"),
StringSplitTestFail("ABC", "[b]", "UNICODE_CI")
)
val failCases = Seq(StringSplitTestFail("ABC", "[b]", "UNICODE_CI"))
failCases.foreach(t => {
val query = s"SELECT split(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))"
val query = s"SELECT split(collate('${t.l}', '${t.c}'), '${t.r}')"
val unsupportedCollation = intercept[AnalysisException] { sql(query) }
assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
})
// TODO: Collation mismatch (not currently supported)
}

test("Support RegExpReplace string expression with collation") {
Expand Down