Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47566][SQL] Support SubstringIndex function to work with collated strings #45725

Closed
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
d2e75fe
Add find method with collation supported
miland-db Mar 25, 2024
2097a66
Add SubstringIndex support for collated strings
miland-db Mar 26, 2024
b3bd34a
Improve unit tests and fix bugs
miland-db Mar 26, 2024
15c5491
Fix bug with the rfind on collated strings
miland-db Mar 26, 2024
5925763
Merge branch 'master' into miland-db/substringIndex-stringLocate
miland-db Mar 26, 2024
34ee8af
Resolve merge problems with master
miland-db Mar 26, 2024
2f8f13d
improve scala style
miland-db Mar 26, 2024
5538b07
Add tests to UTF8StringWithCollationSuite
miland-db Mar 27, 2024
ee6c67d
Improve tests and fix bug with collatedFind
miland-db Apr 1, 2024
1bfb027
Fix Java linter error
miland-db Apr 1, 2024
d2616f6
Remove repeated code for getting collationId
miland-db Apr 1, 2024
69fa05c
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 2, 2024
6d380ce
Improve method naming for collation aware methods, and remove lowerca…
miland-db Apr 2, 2024
f28eac9
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 3, 2024
5c9865a
Improve scala/java style
miland-db Apr 3, 2024
4837159
Update getStringSearch naming
miland-db Apr 3, 2024
027ab04
Add doc comment
miland-db Apr 3, 2024
208042b
Remove unrelated change: blank line in UTF8StringWithCollationSuite.java
miland-db Apr 3, 2024
16f5b15
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 4, 2024
39827bf
Merge latest master and add SubstringIndex to CollationTypeCasts tran…
miland-db Apr 4, 2024
cc453b1
Add empty lines between imports
miland-db Apr 4, 2024
8831848
Handle all collationIds in getStringSearch
miland-db Apr 4, 2024
dccd63d
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 12, 2024
2765e6b
Add code for SubstringIndex to CollationSupport
miland-db Apr 12, 2024
9ff4b50
Add SubstringIndex functionality and fix errors
miland-db Apr 12, 2024
ae2a572
Fix java line length
miland-db Apr 12, 2024
e7db0e9
Remove unused import
miland-db Apr 12, 2024
71ed00c
Add SubstringIndex to CollationTypeCasts
miland-db Apr 15, 2024
eefa77b
Refactor tests
miland-db Apr 15, 2024
bf6bb2a
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 16, 2024
c9e8788
Rename methods in CollationSupport to be as in UTF8String
miland-db Apr 16, 2024
02c927a
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 17, 2024
1d0531b
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 17, 2024
f5760fc
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 18, 2024
59953e9
Added tests (1 failing)
miland-db Apr 18, 2024
5e81c3c
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 23, 2024
9dbc55d
Added new tests with variable length character.
miland-db Apr 23, 2024
6102e49
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 24, 2024
ba1be71
Add tests and sync with master
miland-db Apr 24, 2024
2e6db7a
Fix Case-variable character length bug
miland-db Apr 24, 2024
c9f98d1
Add new tests
miland-db Apr 24, 2024
ae1bcf6
Fix java linter
miland-db Apr 24, 2024
ea76298
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 25, 2024
626e36c
Merge branch 'master' into substringIndex-stringLocate
miland-db Apr 29, 2024
437de71
Add more test cases with case variable character length
miland-db Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,30 @@ public int indexOf(UTF8String v, int start) {
return -1;
}

private int charPosToByte(int charPos) {
if(charPos < 0) {
return -1;
}

int i = 0;
int c = 0;
while (i < numBytes && c < charPos) {
i += numBytesForFirstByte(getByte(i));
c += 1;
}
return i;
}

private int bytePosToChar(int bytePos) {
int i = 0;
int c = 0;
while (i < numBytes && i < bytePos) {
i += numBytesForFirstByte(getByte(i));
c += 1;
}
return c;
}

/**
* Find the `str` from left to right.
*/
Expand All @@ -849,6 +873,33 @@ private int find(UTF8String str, int start) {
return -1;
}

/**
* Find the `str` from left to right considering different collations.
*/
private int find(UTF8String str, int start, int collationId) {
if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
return this.find(str, start);
}
if(collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) {
return charPosToByte(this.toLowerCase().indexOf(str.toLowerCase(), bytePosToChar(start)));
}
return collatedFind(str, start, collationId);
}

/**
* Find the `str` from left to right considering non-binary collations.
*/
private int collatedFind(UTF8String str, int start, int collationId) {
assert (str.numBytes > 0);

StringSearch stringSearch = CollationFactory.getStringSearch(this, str, collationId);
// Set search start position (start from character at start position)
stringSearch.setIndex(bytePosToChar(start));

// Return either the byte position or -1 if not found
return charPosToByte(stringSearch.next());
}

/**
* Find the `str` from right to left.
*/
Expand All @@ -863,6 +914,70 @@ private int rfind(UTF8String str, int start) {
return -1;
}

/**
* Find the `str` from right to left considering different collations.
*/
private int rfind(UTF8String str, int start, int collationId) {
if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
return this.rfind(str, start);
}
if(collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) {
return rfindLowercase(str, start);
}
return collatedRFind(str, start, collationId);
}

/**
* Find the `str` from left to right considering binary lowercase collation.
*/
private int rfindLowercase(UTF8String str, int start) {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
if(numBytes == 0 || str.numBytes == 0) {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
return -1;
}

UTF8String lowercaseThis = this.toLowerCase();
UTF8String lowercaseStr = str.toLowerCase();

int prevStart = -1;
int matchStart = lowercaseThis.indexOf(lowercaseStr, 0);
while(charPosToByte(matchStart) <= start) {
if(matchStart != -1) {
// Found a match, update the start position
prevStart = matchStart;
matchStart = lowercaseThis.indexOf(lowercaseStr, matchStart + 1);
} else {
return charPosToByte(prevStart);
}
}

return charPosToByte(prevStart);
}

/**
* Find the `str` from left to right considering non-binary collations.
*/
private int collatedRFind(UTF8String str, int start, int collationId) {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
if(numBytes == 0 || str.numBytes == 0) {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
return -1;
}

StringSearch stringSearch = CollationFactory.getStringSearch(this, str, collationId);

int prevStart = -1;
int matchStart = stringSearch.next();
while(charPosToByte(matchStart) <= start) {
if(matchStart != StringSearch.DONE) {
// Found a match, update the start position
prevStart = matchStart;
matchStart = stringSearch.next();
} else {
return charPosToByte(prevStart);
}
}

return charPosToByte(prevStart);
}

/**
* Returns the substring from string str before count occurrences of the delimiter delim.
* If count is positive, everything the left of the final delimiter (counting from left) is
Expand Down Expand Up @@ -913,6 +1028,57 @@ public UTF8String subStringIndex(UTF8String delim, int count) {
}
}

public UTF8String subStringIndex(UTF8String delim, int count, int collationId) {
if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
return subStringIndex(delim, count);
}
return collatedSubStringIndex(delim, count, collationId);
}

private UTF8String collatedSubStringIndex(UTF8String delim, int count, int collationId) {
if (delim.numBytes == 0 || count == 0) {
return EMPTY_UTF8;
}
if (count > 0) {
int idx = -1;
while (count > 0) {
idx = find(delim, idx + 1, collationId);
if (idx >= 0) {
count --;
} else {
// can not find enough delim
return this;
}
}
if (idx == 0) {
return EMPTY_UTF8;
}
byte[] bytes = new byte[idx];
copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx);
return fromBytes(bytes);

} else {
int idx = numBytes - delim.numBytes + 1;
count = -count;
while (count > 0) {
idx = rfind(delim, idx - 1, collationId);
if (idx >= 0) {
count --;
} else {
// can not find enough delim
return this;
}
}
if (idx + delim.numBytes == numBytes) {
return EMPTY_UTF8;
}
int size = numBytes - delim.numBytes - idx;
byte[] bytes = new byte[size];
copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size);
return fromBytes(bytes);
}
}

/**
* Returns str, right-padded with pad to a length of len
* For example:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1420,21 +1420,31 @@ case class StringInstr(str: Expression, substr: Expression)
case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {

override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
override def dataType: DataType = strExpr.dataType
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType)
override def first: Expression = strExpr
override def second: Expression = delimExpr
override def third: Expression = countExpr
override def prettyName: String = "substring_index"

override def nullSafeEval(str: Any, delim: Any, count: Any): Any = {
val collationId = first.dataType.asInstanceOf[StringType].collationId

str.asInstanceOf[UTF8String].subStringIndex(
delim.asInstanceOf[UTF8String],
count.asInstanceOf[Int])
count.asInstanceOf[Int], collationId)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)")
val collationId = first.dataType.asInstanceOf[StringType].collationId

if(CollationFactory.fetchCollation(collationId).supportsBinaryOrdering) {
defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)")
} else {
defineCodeGen(ctx, ev, (str, delim, count) =>
s"$str.subStringIndex($delim, $count, $collationId)")
}
}

override protected def withNewChildrenInternal(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.immutable.Seq

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.expressions.{Collation, ExpressionEvalHelper, Literal, StringRepeat}
import org.apache.spark.sql.catalyst.expressions.{Collation, ExpressionEvalHelper, Literal, StringRepeat, SubstringIndex}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StringType
Expand Down Expand Up @@ -73,6 +73,84 @@ class CollationStringExpressionsSuite extends QueryTest
})
}

test("SUBSTRING_INDEX check result on explicitly collated strings") {
def testSubstringIndex(str: String, delim: String, cnt: Integer,
collationId: Integer, expected: String): Unit = {
val string = Literal.create(str, StringType(collationId))
val delimiter = Literal.create(delim, StringType(collationId))
val count = Literal(cnt)

checkEvaluation(SubstringIndex(string, delimiter, count), expected)
}

testSubstringIndex("wwwgapachegorg", "g", -3, 0, "apachegorg")
testSubstringIndex("www||apache||org", "||", 2, 0, "www||apache")
// UTF8_BINARY_LCASE
miland-db marked this conversation as resolved.
Show resolved Hide resolved
testSubstringIndex("AaAaAaAaAa", "aa", 2, 1, "A")
testSubstringIndex("www.apache.org", ".", 3, 1, "www.apache.org")
testSubstringIndex("wwwXapachexorg", "x", 2, 1, "wwwXapache")
testSubstringIndex("wwwxapacheXorg", "X", 1, 1, "www")
testSubstringIndex("www.apache.org", ".", 0, 1, "")
testSubstringIndex("www.apache.ORG", ".", -3, 1, "www.apache.ORG")
testSubstringIndex("wwwGapacheGorg", "g", 1, 1, "www")
testSubstringIndex("wwwGapacheGorg", "g", 3, 1, "wwwGapacheGor")
testSubstringIndex("gwwwGapacheGorg", "g", 3, 1, "gwwwGapache")
testSubstringIndex("wwwGapacheGorg", "g", -3, 1, "apacheGorg")
testSubstringIndex("wwwmapacheMorg", "M", -2, 1, "apacheMorg")
testSubstringIndex("www.apache.org", ".", -1, 1, "org")
testSubstringIndex("", ".", -2, 1, "")
// scalastyle:off
testSubstringIndex("test大千世界X大千世界", "x", -1, 1, "大千世界")
testSubstringIndex("test大千世界X大千世界", "X", 1, 1, "test大千世界")
testSubstringIndex("test大千世界大千世界", "千", 2, 1, "test大千世界大")
// scalastyle:on
testSubstringIndex("www||APACHE||org", "||", 2, 1, "www||APACHE")
testSubstringIndex("www||APACHE||org", "||", -1, 1, "org")
// UNICODE
testSubstringIndex("AaAaAaAaAa", "Aa", 2, 2, "Aa")
testSubstringIndex("wwwYapacheyorg", "y", 3, 2, "wwwYapacheyorg")
testSubstringIndex("www.apache.org", ".", 2, 2, "www.apache")
miland-db marked this conversation as resolved.
Show resolved Hide resolved
testSubstringIndex("wwwYapacheYorg", "Y", 1, 2, "www")
testSubstringIndex("wwwYapacheYorg", "y", 1, 2, "wwwYapacheYorg")
testSubstringIndex("wwwGapacheGorg", "g", 1, 2, "wwwGapacheGor")
testSubstringIndex("GwwwGapacheGorG", "G", 3, 2, "GwwwGapache")
testSubstringIndex("wwwGapacheGorG", "G", -3, 2, "apacheGorG")
testSubstringIndex("www.apache.org", ".", 0, 2, "")
testSubstringIndex("www.apache.org", ".", -3, 2, "www.apache.org")
testSubstringIndex("www.apache.org", ".", -2, 2, "apache.org")
testSubstringIndex("www.apache.org", ".", -1, 2, "org")
testSubstringIndex("", ".", -2, 2, "")
// scalastyle:off
testSubstringIndex("test大千世界X大千世界", "X", -1, 2, "大千世界")
testSubstringIndex("test大千世界X大千世界", "X", 1, 2, "test大千世界")
testSubstringIndex("大x千世界大千世x界", "x", 1, 2, "大")
testSubstringIndex("大x千世界大千世x界", "x", -1, 2, "界")
testSubstringIndex("大x千世界大千世x界", "x", -2, 2, "千世界大千世x界")
testSubstringIndex("大千世界大千世界", "千", 2, 2, "大千世界大")
// scalastyle:on
testSubstringIndex("www||apache||org", "||", 2, 2, "www||apache")
// UNICODE_CI
testSubstringIndex("AaAaAaAaAa", "aa", 2, 3, "A")
testSubstringIndex("www.apache.org", ".", 3, 3, "www.apache.org")
testSubstringIndex("wwwXapachexorg", "x", 2, 3, "wwwXapache")
testSubstringIndex("wwwxapacheXorg", "X", 1, 3, "www")
testSubstringIndex("www.apache.org", ".", 0, 3, "")
testSubstringIndex("wwwGapacheGorg", "g", 1, 3, "www")
testSubstringIndex("wwwGapacheGorg", "g", 3, 3, "wwwGapacheGor")
testSubstringIndex("gwwwGapacheGorg", "g", 3, 3, "gwwwGapache")
testSubstringIndex("wwwGapacheGorg", "g", -3, 3, "apacheGorg")
testSubstringIndex("www.apache.ORG", ".", -3, 3, "www.apache.ORG")
testSubstringIndex("wwwmapacheMorg", "M", -2, 3, "apacheMorg")
testSubstringIndex("www.apache.org", ".", -1, 3, "org")
testSubstringIndex("", ".", -2, 3, "")
// scalastyle:off
testSubstringIndex("test大千世界X大千世界", "X", -1, 3, "大千世界")
testSubstringIndex("test大千世界X大千世界", "X", 1, 3, "test大千世界")
testSubstringIndex("test大千世界大千世界", "千", 2, 3, "test大千世界大")
// scalastyle:on
testSubstringIndex("www||APACHE||org", "||", 2, 3, "www||APACHE")
}

test("REPEAT check output type on explicitly collated string") {
def testRepeat(expected: String, collationId: Int, input: String, n: Int): Unit = {
val s = Literal.create(input, StringType(collationId))
Expand Down