Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47566][SQL] Support SubstringIndex function to work with collated strings #45725

Closed
Closed
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
d2e75fe
Add find method with collation supported
miland-db Mar 25, 2024
2097a66
Add SubstringIndex support for collated strings
miland-db Mar 26, 2024
b3bd34a
Improve unit tests and fix bugs
miland-db Mar 26, 2024
15c5491
Fix bug with the rfind on collated strings
miland-db Mar 26, 2024
5925763
Merge branch 'master' into miland-db/substringIndex-stringLocate
miland-db Mar 26, 2024
34ee8af
Resolve merge problems with master
miland-db Mar 26, 2024
2f8f13d
improve scala style
miland-db Mar 26, 2024
5538b07
Add tests to UTF8StringWithCollationSuite
miland-db Mar 27, 2024
ee6c67d
Improve tests and fix bug with collatedFind
miland-db Apr 1, 2024
1bfb027
Fix Java linter error
miland-db Apr 1, 2024
d2616f6
Remove repeated code for getting collationId
miland-db Apr 1, 2024
69fa05c
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 2, 2024
6d380ce
Improve method naming for collation aware methods, and remove lowerca…
miland-db Apr 2, 2024
f28eac9
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 3, 2024
5c9865a
Improve scala/java style
miland-db Apr 3, 2024
4837159
Update getStringSearch naming
miland-db Apr 3, 2024
027ab04
Add doc comment
miland-db Apr 3, 2024
208042b
Remove unrelated change: blank line in UTF8StringWithCollationSuite.java
miland-db Apr 3, 2024
16f5b15
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 4, 2024
39827bf
Merge latest master and add SubstringIndex to CollationTypeCasts tran…
miland-db Apr 4, 2024
cc453b1
Add empty lines between imports
miland-db Apr 4, 2024
8831848
Handle all collationIds in getStringSearch
miland-db Apr 4, 2024
dccd63d
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 12, 2024
2765e6b
Add code for SubstringIndex to CollationSupport
miland-db Apr 12, 2024
9ff4b50
Add SubstringIndex functionality and fix errors
miland-db Apr 12, 2024
ae2a572
Fix java line length
miland-db Apr 12, 2024
e7db0e9
Remove unused import
miland-db Apr 12, 2024
71ed00c
Add SubstringIndex to CollationTypeCasts
miland-db Apr 15, 2024
eefa77b
Refactor tests
miland-db Apr 15, 2024
bf6bb2a
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 16, 2024
c9e8788
Rename methods in CollationSupport to be as in UTF8String
miland-db Apr 16, 2024
02c927a
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 17, 2024
1d0531b
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 17, 2024
f5760fc
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 18, 2024
59953e9
Added tests (1 failing)
miland-db Apr 18, 2024
5e81c3c
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 23, 2024
9dbc55d
Added new tests with variable length character.
miland-db Apr 23, 2024
6102e49
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 24, 2024
ba1be71
Add tests and sync with master
miland-db Apr 24, 2024
2e6db7a
Fix Case-variable character length bug
miland-db Apr 24, 2024
c9f98d1
Add new tests
miland-db Apr 24, 2024
ae1bcf6
Fix java linter
miland-db Apr 24, 2024
ea76298
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 25, 2024
626e36c
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 29, 2024
437de71
Add more test cases with case variable character length
miland-db Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

import org.apache.spark.unsafe.types.UTF8String;

import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
import static org.apache.spark.unsafe.Platform.copyMemory;

/**
* Static entry point for collation-aware expressions (StringExpressions, RegexpExpressions, and
* other expressions that require custom collation support), as well as private utility methods for
Expand Down Expand Up @@ -137,6 +140,45 @@ public static boolean execICU(final UTF8String l, final UTF8String r,
}
}

public static class SubstringIndex {
public static UTF8String exec(final UTF8String string, final UTF8String delimiter,
final int count, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return execBinary(string, delimiter, count);
} else if (collation.supportsLowercaseEquality) {
return execLowercase(string, delimiter, count);
} else {
return execICU(string, delimiter, count, collationId);
}
}
public static String genCode(final String string, final String delimiter,
final int count, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.SubstringIndex.exec";
if (collation.supportsBinaryEquality) {
return String.format(expr + "Binary(%s, %s, %d)", string, delimiter, count);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s, %s, %d)", string, delimiter, count);
} else {
return String.format(expr + "ICU(%s, %s, %d, %d)", string, delimiter, count, collationId);
}
}
public static UTF8String execBinary(final UTF8String string, final UTF8String delimiter,
final int count) {
return string.subStringIndex(delimiter, count);
}
public static UTF8String execLowercase(final UTF8String string, final UTF8String delimiter,
final int count) {
return CollationAwareUTF8String.lowercaseSubStringIndex(string, delimiter, count);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class CollationAwareUTF8String is getting bigger. Shall we move it to an individual file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe in the next PR. We will consider this option

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, we should do this in #45820

}
public static UTF8String execICU(final UTF8String string, final UTF8String delimiter,
final int count, final int collationId) {
return CollationAwareUTF8String.collationAwareSubStringIndex(string, delimiter, count,
collationId);
}
}

// TODO: Add more collation-aware string expressions.

/**
Expand Down Expand Up @@ -169,6 +211,139 @@ private static boolean matchAt(final UTF8String target, final UTF8String pattern
pos, pos + pattern.numChars()), pattern, collationId).last() == 0;
}

private static int collationAwareFind(UTF8String target, UTF8String pattern, int start,
int collationId) {
assert (pattern.numBytes() > 0);

StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId);
// Set search start position (start from character at start position)
stringSearch.setIndex(target.bytePosToChar(start));

// Return either the byte position or -1 if not found
return target.charPosToByte(stringSearch.next());
}

private static int collationAwareRFind(UTF8String target, UTF8String pattern, int start,
int collationId) {
assert (pattern.numBytes() > 0);

if (target.numBytes() == 0) {
return -1;
}

StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId);

int prevStart = -1;
int matchStart = stringSearch.next();
while (target.charPosToByte(matchStart) <= start) {
if (matchStart != StringSearch.DONE) {
// Found a match, update the start position
prevStart = matchStart;
matchStart = stringSearch.next();
} else {
return target.charPosToByte(prevStart);
}
}

return target.charPosToByte(prevStart);
}

private static UTF8String collationAwareSubStringIndex(final UTF8String string,
final UTF8String delimiter, int count, final int collationId) {
if (delimiter.numBytes() == 0 || count == 0) {
return UTF8String.EMPTY_UTF8;
}
if (count > 0) {
int idx = -1;
while (count > 0) {
idx = collationAwareFind(string, delimiter, idx + 1, collationId);
if (idx >= 0) {
count --;
} else {
// can not find enough delim
return string;
}
}
if (idx == 0) {
return UTF8String.EMPTY_UTF8;
}
byte[] bytes = new byte[idx];
copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, BYTE_ARRAY_OFFSET, idx);
return UTF8String.fromBytes(bytes);

} else {
int idx = string.numBytes() - delimiter.numBytes() + 1;
count = -count;
while (count > 0) {
idx = collationAwareRFind(string, delimiter, idx - 1, collationId);
if (idx >= 0) {
count --;
} else {
// can not find enough delim
return string;
}
}
if (idx + delimiter.numBytes() == string.numBytes()) {
return UTF8String.EMPTY_UTF8;
}
int size = string.numBytes() - delimiter.numBytes() - idx;
byte[] bytes = new byte[size];
copyMemory(string.getBaseObject(), string.getBaseOffset() + idx + delimiter.numBytes(),
bytes, BYTE_ARRAY_OFFSET, size);
return UTF8String.fromBytes(bytes);
}
}

private static UTF8String lowercaseSubStringIndex(final UTF8String string,
final UTF8String delimiter, int count) {
if (delimiter.numBytes() == 0 || count == 0) {
return UTF8String.EMPTY_UTF8;
}

UTF8String lowercaseString = string.toLowerCase();
UTF8String lowercaseDelimiter = delimiter.toLowerCase();

if (count > 0) {
int idx = -1;
while (count > 0) {
idx = lowercaseString.find(lowercaseDelimiter, idx + 1);
if (idx >= 0) {
count --;
} else {
// can not find enough delim
return string;
}
}
if (idx == 0) {
return UTF8String.EMPTY_UTF8;
}
byte[] bytes = new byte[idx];
copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, BYTE_ARRAY_OFFSET, idx);
return UTF8String.fromBytes(bytes);

} else {
int idx = string.numBytes() - delimiter.numBytes() + 1;
count = -count;
while (count > 0) {
idx = lowercaseString.rfind(lowercaseDelimiter, idx - 1);
if (idx >= 0) {
count --;
} else {
// can not find enough delim
return string;
}
}
if (idx + delimiter.numBytes() == string.numBytes()) {
return UTF8String.EMPTY_UTF8;
}
int size = string.numBytes() - delimiter.numBytes() - idx;
byte[] bytes = new byte[size];
copyMemory(string.getBaseObject(), string.getBaseOffset() + idx + delimiter.numBytes(),
bytes, BYTE_ARRAY_OFFSET, size);
return UTF8String.fromBytes(bytes);
}
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -801,10 +801,34 @@ public int indexOf(UTF8String v, int start) {
return -1;
}

public int charPosToByte(int charPos) {
if (charPos < 0) {
return -1;
}

int i = 0;
int c = 0;
while (i < numBytes && c < charPos) {
i += numBytesForFirstByte(getByte(i));
c += 1;
}
return i;
}

public int bytePosToChar(int bytePos) {
int i = 0;
int c = 0;
while (i < numBytes && i < bytePos) {
i += numBytesForFirstByte(getByte(i));
c += 1;
}
return c;
}

/**
* Find the `str` from left to right.
*/
private int find(UTF8String str, int start) {
public int find(UTF8String str, int start) {
assert (str.numBytes > 0);
while (start <= numBytes - str.numBytes) {
if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) {
Expand All @@ -818,7 +842,7 @@ private int find(UTF8String str, int start) {
/**
* Find the `str` from right to left.
*/
private int rfind(UTF8String str, int start) {
public int rfind(UTF8String str, int start) {
assert (str.numBytes > 0);
while (start >= 0) {
if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import javax.annotation.Nullable
import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType}
import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least}
import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least, SubstringIndex}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
Expand All @@ -47,7 +47,7 @@ object CollationTypeCasts extends TypeCoercionRule {

case otherExpr @ (
_: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least |
_: Coalesce | _: BinaryExpression | _: ConcatWs) =>
_: Coalesce | _: BinaryExpression | _: ConcatWs | _: SubstringIndex) =>
miland-db marked this conversation as resolved.
Show resolved Hide resolved
val newChildren = collateToSingleType(otherExpr.children)
otherExpr.withNewChildren(newChildren)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1392,21 +1392,24 @@ case class StringInstr(str: Expression, substr: Expression)
case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {

override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId

override def dataType: DataType = strExpr.dataType
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType)
override def first: Expression = strExpr
override def second: Expression = delimExpr
override def third: Expression = countExpr
override def prettyName: String = "substring_index"

override def nullSafeEval(str: Any, delim: Any, count: Any): Any = {
str.asInstanceOf[UTF8String].subStringIndex(
delim.asInstanceOf[UTF8String],
count.asInstanceOf[Int])
CollationSupport.SubstringIndex.exec(str.asInstanceOf[UTF8String],
delim.asInstanceOf[UTF8String], count.asInstanceOf[Int], collationId);
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)")
defineCodeGen(ctx, ev, (str, delim, count) =>
CollationSupport.SubstringIndex.genCode(str, delim, Integer.parseInt(count, 10), collationId))
}

override protected def withNewChildrenInternal(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@

package org.apache.spark.sql

import scala.collection.immutable.Seq

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper
import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal, SubstringIndex}
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{BooleanType, StringType}
Expand Down Expand Up @@ -96,6 +95,89 @@ class CollationStringExpressionsSuite
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}

test("SUBSTRING_INDEX check result on explicitly collated strings") {
def testSubstringIndex(str: String, delim: String, cnt: Integer,
collationId: Integer, expected: String): Unit = {
val string = Literal.create(str, StringType(collationId))
val delimiter = Literal.create(delim, StringType(collationId))
val count = Literal(cnt)

checkEvaluation(SubstringIndex(string, delimiter, count), expected)
}

var collationId = CollationFactory.collationNameToId("UTF8_BINARY")
testSubstringIndex("wwwgapachegorg", "g", -3, collationId, "apachegorg")
testSubstringIndex("www||apache||org", "||", 2, collationId, "www||apache")
testSubstringIndex("aaaaaaaaaa", "aa", 2, collationId, "a")

collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE")
testSubstringIndex("AaAaAaAaAa", "aa", 2, collationId, "A")
testSubstringIndex("www.apache.org", ".", 3, collationId, "www.apache.org")
testSubstringIndex("wwwXapacheXorg", "x", 2, collationId, "wwwXapache")
testSubstringIndex("wwwxapachexorg", "X", 1, collationId, "www")
testSubstringIndex("www.apache.org", ".", 0, collationId, "")
testSubstringIndex("www.apache.ORG", ".", -3, collationId, "www.apache.ORG")
testSubstringIndex("wwwGapacheGorg", "g", 1, collationId, "www")
testSubstringIndex("wwwGapacheGorg", "g", 3, collationId, "wwwGapacheGor")
testSubstringIndex("gwwwGapacheGorg", "g", 3, collationId, "gwwwGapache")
testSubstringIndex("wwwGapacheGorg", "g", -3, collationId, "apacheGorg")
testSubstringIndex("wwwmapacheMorg", "M", -2, collationId, "apacheMorg")
testSubstringIndex("www.apache.org", ".", -1, collationId, "org")
testSubstringIndex("www.apache.org.", ".", -1, collationId, "")
testSubstringIndex("", ".", -2, collationId, "")
// scalastyle:off
testSubstringIndex("test大千世界X大千世界", "x", -1, collationId, "大千世界")
testSubstringIndex("test大千世界X大千世界", "X", 1, collationId, "test大千世界")
testSubstringIndex("test大千世界大千世界", "千", 2, collationId, "test大千世界大")
// scalastyle:on
testSubstringIndex("www||APACHE||org", "||", 2, collationId, "www||APACHE")
testSubstringIndex("www||APACHE||org", "||", -1, collationId, "org")

collationId = CollationFactory.collationNameToId("UNICODE")
testSubstringIndex("AaAaAaAaAa", "Aa", 2, collationId, "Aa")
testSubstringIndex("wwwYapacheyorg", "y", 3, collationId, "wwwYapacheyorg")
testSubstringIndex("www.apache.org", ".", 2, collationId, "www.apache")
testSubstringIndex("wwwYapacheYorg", "Y", 1, collationId, "www")
testSubstringIndex("wwwYapacheYorg", "y", 1, collationId, "wwwYapacheYorg")
testSubstringIndex("wwwGapacheGorg", "g", 1, collationId, "wwwGapacheGor")
testSubstringIndex("GwwwGapacheGorG", "G", 3, collationId, "GwwwGapache")
testSubstringIndex("wwwGapacheGorG", "G", -3, collationId, "apacheGorG")
testSubstringIndex("www.apache.org", ".", 0, collationId, "")
testSubstringIndex("www.apache.org", ".", -3, collationId, "www.apache.org")
testSubstringIndex("www.apache.org", ".", -2, collationId, "apache.org")
testSubstringIndex("www.apache.org", ".", -1, collationId, "org")
testSubstringIndex("", ".", -2, collationId, "")
// scalastyle:off
testSubstringIndex("test大千世界X大千世界", "X", -1, collationId, "大千世界")
testSubstringIndex("test大千世界X大千世界", "X", 1, collationId, "test大千世界")
testSubstringIndex("大x千世界大千世x界", "x", 1, collationId, "大")
testSubstringIndex("大x千世界大千世x界", "x", -1, collationId, "界")
testSubstringIndex("大x千世界大千世x界", "x", -2, collationId, "千世界大千世x界")
testSubstringIndex("大千世界大千世界", "千", 2, collationId, "大千世界大")
// scalastyle:on
testSubstringIndex("www||apache||org", "||", 2, collationId, "www||apache")

collationId = CollationFactory.collationNameToId("UNICODE_CI")
testSubstringIndex("AaAaAaAaAa", "aa", 2, collationId, "A")
testSubstringIndex("www.apache.org", ".", 3, collationId, "www.apache.org")
testSubstringIndex("wwwXapacheXorg", "x", 2, collationId, "wwwXapache")
testSubstringIndex("wwwxapacheXorg", "X", 1, collationId, "www")
testSubstringIndex("www.apache.org", ".", 0, collationId, "")
testSubstringIndex("wwwGapacheGorg", "G", 3, collationId, "wwwGapacheGor")
testSubstringIndex("gwwwGapacheGorg", "g", 3, collationId, "gwwwGapache")
testSubstringIndex("gwwwGapacheGorg", "g", -3, collationId, "apacheGorg")
testSubstringIndex("www.apache.ORG", ".", -3, collationId, "www.apache.ORG")
testSubstringIndex("wwwmapacheMorg", "M", -2, collationId, "apacheMorg")
testSubstringIndex("www.apache.org", ".", -1, collationId, "org")
testSubstringIndex("", ".", -2, collationId, "")
// scalastyle:off
testSubstringIndex("test大千世界X大千世界", "X", -1, collationId, "大千世界")
testSubstringIndex("test大千世界X大千世界", "X", 1, collationId, "test大千世界")
testSubstringIndex("test大千世界大千世界", "千", 2, collationId, "test大千世界大")
// scalastyle:on
testSubstringIndex("www||APACHE||org", "||", 2, collationId, "www||APACHE")
}

test("Support StartsWith string expression with collation") {
// Supported collations
case class StartsWithTestCase[R](l: String, r: String, c: String, result: R)
Expand Down