Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48282][SQL] Alter string search logic for UTF8_BINARY_LCASE collation (StringReplace, FindInSet) #46682

Closed
wants to merge 15 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -243,44 +243,28 @@ public static UTF8String lowercaseReplace(final UTF8String src, final UTF8String
if (src.numBytes() == 0 || search.numBytes() == 0) {
return src;
}
UTF8String lowercaseString = src.toLowerCase();

UTF8String lowercaseSearch = search.toLowerCase();

int start = 0;
int end = lowercaseString.indexOf(lowercaseSearch, 0);
int end = lowercaseFind(src, lowercaseSearch, start);
if (end == -1) {
// Search string was not found, so string is unchanged.
return src;
}

// Initialize byte positions
int c = 0;
int byteStart = 0; // position in byte
int byteEnd = 0; // position in byte
while (byteEnd < src.numBytes() && c < end) {
byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
c += 1;
}

// At least one match was found. Estimate space needed for result.
// The 16x multiplier here is chosen to match commons-lang3's implementation.
int increase = Math.max(0, replace.numBytes() - search.numBytes()) * 16;
final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase);
while (end != -1) {
buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart);
buf.append(src.substring(start, end));
buf.append(replace);
// Update character positions
start = end + lowercaseSearch.numChars();
end = lowercaseString.indexOf(lowercaseSearch, start);
// Update byte positions
byteStart = byteEnd + search.numBytes();
while (byteEnd < src.numBytes() && c < end) {
byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
c += 1;
}
start = end + lowercaseMatchLengthFrom(src, lowercaseSearch, end);
end = lowercaseFind(src, lowercaseSearch, start);
}
buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart,
src.numBytes() - byteStart);
buf.append(src.substring(start, src.numChars()));
return buf.build();
}

Expand All @@ -303,32 +287,28 @@ public static String toTitleCase(final String target, final int collationId) {
}

public static int findInSet(final UTF8String match, final UTF8String set, int collationId) {
// If the "word" string contains a comma, FindInSet should return 0.
if (match.contains(UTF8String.fromString(","))) {
return 0;
}

String setString = set.toString();
StringSearch stringSearch = CollationFactory.getStringSearch(setString, match.toString(),
collationId);

int wordStart = 0;
while ((wordStart = stringSearch.next()) != StringSearch.DONE) {
boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ',';
boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length()
|| setString.charAt(wordStart + stringSearch.getMatchLength()) == ',';

if (isValidStart && isValidEnd) {
int pos = 0;
for (int i = 0; i < setString.length() && i < wordStart; i++) {
if (setString.charAt(i) == ',') {
pos++;
}
// Otherwise, search for commas in "set" and compare each substring with "word".
int byteIndex = 0, charIndex = 0, wordCount = 1, lastComma = -1;
while (byteIndex < set.numBytes()) {
byte nextByte = set.getByte(byteIndex);
if (nextByte == (byte) ',') {
if (set.substring(lastComma + 1, charIndex).semanticEquals(match, collationId)) {
return wordCount;
}

return pos + 1;
lastComma = charIndex;
++wordCount;
}
byteIndex += UTF8String.numBytesForFirstByte(nextByte);
++charIndex;
}

if (set.substring(lastComma + 1, set.numBytes()).semanticEquals(match, collationId)) {
return wordCount;
}
// If no match is found, return 0.
return 0;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,31 +297,24 @@ public static int exec(final UTF8String word, final UTF8String set, final int co
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return execBinary(word, set);
} else if (collation.supportsLowercaseEquality) {
return execLowercase(word, set);
} else {
return execICU(word, set, collationId);
return execCollationAware(word, set, collationId);
}
}
public static String genCode(final String word, final String set, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.FindInSet.exec";
if (collation.supportsBinaryEquality) {
return String.format(expr + "Binary(%s, %s)", word, set);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s, %s)", word, set);
} else {
return String.format(expr + "ICU(%s, %s, %d)", word, set, collationId);
return String.format(expr + "execCollationAware(%s, %s, %d)", word, set, collationId);
}
}
public static int execBinary(final UTF8String word, final UTF8String set) {
return set.findInSet(word);
}
public static int execLowercase(final UTF8String word, final UTF8String set) {
return set.toLowerCase().findInSet(word.toLowerCase());
}
public static int execICU(final UTF8String word, final UTF8String set,
final int collationId) {
public static int execCollationAware(final UTF8String word, final UTF8String set,
final int collationId) {
return CollationAwareUTF8String.findInSet(word, set, collationId);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -639,47 +639,93 @@ public void testStringInstr() throws SparkException {
assertStringInstr("abi̇o12", "İo", "UNICODE_CI", 3);
}

private void assertFindInSet(String word, String set, String collationName,
Integer expected) throws SparkException {
private void assertFindInSet(String word, UTF8String set, String collationName,
Integer expected) throws SparkException {
UTF8String w = UTF8String.fromString(word);
UTF8String s = UTF8String.fromString(set);
int collationId = CollationFactory.collationNameToId(collationName);
assertEquals(expected, CollationSupport.FindInSet.exec(w, s, collationId));
assertEquals(expected, CollationSupport.FindInSet.exec(w, set, collationId));
}

@Test
public void testFindInSet() throws SparkException {
assertFindInSet("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0);
assertFindInSet("abc", "abc,b,ab,c,def", "UTF8_BINARY", 1);
assertFindInSet("def", "abc,b,ab,c,def", "UTF8_BINARY", 5);
assertFindInSet("d,ef", "abc,b,ab,c,def", "UTF8_BINARY", 0);
assertFindInSet("", "abc,b,ab,c,def", "UTF8_BINARY", 0);
assertFindInSet("a", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0);
assertFindInSet("c", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 4);
assertFindInSet("AB", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 3);
assertFindInSet("AbC", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 1);
assertFindInSet("abcd", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0);
assertFindInSet("d,ef", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0);
assertFindInSet("XX", "xx", "UTF8_BINARY_LCASE", 1);
assertFindInSet("", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0);
assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UTF8_BINARY_LCASE", 4);
assertFindInSet("a", "abc,b,ab,c,def", "UNICODE", 0);
assertFindInSet("ab", "abc,b,ab,c,def", "UNICODE", 3);
assertFindInSet("Ab", "abc,b,ab,c,def", "UNICODE", 0);
assertFindInSet("d,ef", "abc,b,ab,c,def", "UNICODE", 0);
assertFindInSet("xx", "xx", "UNICODE", 1);
assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UNICODE", 0);
assertFindInSet("大", "test,大千,世,界X,大,千,世界", "UNICODE", 5);
assertFindInSet("a", "abc,b,ab,c,def", "UNICODE_CI", 0);
assertFindInSet("C", "abc,b,ab,c,def", "UNICODE_CI", 4);
assertFindInSet("DeF", "abc,b,ab,c,dEf", "UNICODE_CI", 5);
assertFindInSet("DEFG", "abc,b,ab,c,def", "UNICODE_CI", 0);
assertFindInSet("XX", "xx", "UNICODE_CI", 1);
assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 4);
assertFindInSet("界x", "test,大千,界Xx,世,界X,大,千,世界", "UNICODE_CI", 5);
assertFindInSet("大", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 5);
assertFindInSet("i̇o", "ab,İo,12", "UNICODE_CI", 2);
assertFindInSet("İo", "ab,i̇o,12", "UNICODE_CI", 2);
assertFindInSet("AB", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0);
assertFindInSet("abc", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 1);
assertFindInSet("def", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 5);
assertFindInSet("d,ef", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0);
assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0);
uros-db marked this conversation as resolved.
Show resolved Hide resolved
assertFindInSet("a", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY_LCASE", 0);
assertFindInSet("c", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY_LCASE", 4);
assertFindInSet("AB", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY_LCASE", 3);
assertFindInSet("AbC", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY_LCASE", 1);
assertFindInSet("abcd", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY_LCASE", 0);
assertFindInSet("d,ef", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY_LCASE", 0);
assertFindInSet("XX", UTF8String.fromString("xx"), "UTF8_BINARY_LCASE", 1);
assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY_LCASE", 0);
uros-db marked this conversation as resolved.
Show resolved Hide resolved
assertFindInSet("界x", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UTF8_BINARY_LCASE", 4);
assertFindInSet("a", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0);
assertFindInSet("ab", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 3);
assertFindInSet("Ab", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0);
assertFindInSet("d,ef", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0);
assertFindInSet("xx", UTF8String.fromString("xx"), "UNICODE", 1);
assertFindInSet("界x", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE", 0);
assertFindInSet("大", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE", 5);
assertFindInSet("a", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 0);
assertFindInSet("C", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 4);
assertFindInSet("DeF", UTF8String.fromString("abc,b,ab,c,dEf"), "UNICODE_CI", 5);
assertFindInSet("DEFG", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 0);
assertFindInSet("XX", UTF8String.fromString("xx"), "UNICODE_CI", 1);
assertFindInSet("界x", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE_CI", 4);
assertFindInSet("界x", UTF8String.fromString("test,大千,界Xx,世,界X,大,千,世界"), "UNICODE_CI", 5);
assertFindInSet("大", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE_CI", 5);
assertFindInSet("i̇", UTF8String.fromString("İ"), "UNICODE_CI", 1);
assertFindInSet("i", UTF8String.fromString("İ"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("i̇"), "UNICODE_CI", 1);
assertFindInSet("i", UTF8String.fromString("i̇"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("İ,"), "UNICODE_CI", 1);
assertFindInSet("i", UTF8String.fromString("İ,"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("i̇,"), "UNICODE_CI", 1);
assertFindInSet("i", UTF8String.fromString("i̇,"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,İ"), "UNICODE_CI", 2);
assertFindInSet("i", UTF8String.fromString("ab,İ"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,i̇"), "UNICODE_CI", 2);
assertFindInSet("i", UTF8String.fromString("ab,i̇"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,İ,12"), "UNICODE_CI", 2);
assertFindInSet("i", UTF8String.fromString("ab,İ,12"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,i̇,12"), "UNICODE_CI", 2);
assertFindInSet("i", UTF8String.fromString("ab,i̇,12"), "UNICODE_CI", 0);
assertFindInSet("i̇o", UTF8String.fromString("ab,İo,12"), "UNICODE_CI", 2);
assertFindInSet("İo", UTF8String.fromString("ab,i̇o,12"), "UNICODE_CI", 2);
assertFindInSet("i̇", UTF8String.fromString("İ"), "UTF8_BINARY_LCASE", 1);
assertFindInSet("i", UTF8String.fromString("İ"), "UTF8_BINARY_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("i̇"), "UTF8_BINARY_LCASE", 1);
assertFindInSet("i", UTF8String.fromString("i̇"), "UTF8_BINARY_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("İ,"), "UTF8_BINARY_LCASE", 1);
assertFindInSet("i", UTF8String.fromString("İ,"), "UTF8_BINARY_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("i̇,"), "UTF8_BINARY_LCASE", 1);
assertFindInSet("i", UTF8String.fromString("i̇,"), "UTF8_BINARY_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,İ"), "UTF8_BINARY_LCASE", 2);
assertFindInSet("i", UTF8String.fromString("ab,İ"), "UTF8_BINARY_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,i̇"), "UTF8_BINARY_LCASE", 2);
assertFindInSet("i", UTF8String.fromString("ab,i̇"), "UTF8_BINARY_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,İ,12"), "UTF8_BINARY_LCASE", 2);
assertFindInSet("i", UTF8String.fromString("ab,İ,12"), "UTF8_BINARY_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,i̇,12"), "UTF8_BINARY_LCASE", 2);
assertFindInSet("i", UTF8String.fromString("ab,i̇,12"), "UTF8_BINARY_LCASE", 0);
assertFindInSet("i̇o", UTF8String.fromString("ab,İo,12"), "UTF8_BINARY_LCASE", 2);
assertFindInSet("İo", UTF8String.fromString("ab,i̇o,12"), "UTF8_BINARY_LCASE", 2);
// Invalid UTF8 strings
assertFindInSet("C", UTF8String.fromBytes(
new byte[] { 0x41, (byte) 0xC2, 0x2C, 0x42, 0x2C, 0x43, 0x2C, 0x43, 0x2C, 0x56 }),
"UTF8_BINARY", 3);
assertFindInSet("c", UTF8String.fromBytes(
new byte[] { 0x41, (byte) 0xC2, 0x2C, 0x42, 0x2C, 0x43, 0x2C, 0x43, 0x2C, 0x56 }),
"UTF8_BINARY_LCASE", 2);
assertFindInSet("C", UTF8String.fromBytes(
new byte[] { 0x41, (byte) 0xC2, 0x2C, 0x42, 0x2C, 0x43, 0x2C, 0x43, 0x2C, 0x56 }),
"UNICODE", 3);
uros-db marked this conversation as resolved.
Show resolved Hide resolved
assertFindInSet("c", UTF8String.fromBytes(
new byte[] { 0x41, (byte) 0xC2, 0x2C, 0x42, 0x2C, 0x43, 0x2C, 0x43, 0x2C, 0x56 }),
"UNICODE_CI", 2);
}

private void assertReplace(String source, String search, String replace, String collationName,
Expand Down Expand Up @@ -716,8 +762,23 @@ public void testReplace() throws SparkException {
assertReplace("replace", "", "123", "UNICODE_CI", "replace");
assertReplace("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c");
assertReplace("a世Bcdabcd", "bC", "", "UNICODE_CI", "a世dad");
assertReplace("abi̇12", "i", "X", "UNICODE_CI", "abi̇12");
assertReplace("abi̇12", "\u0307", "X", "UNICODE_CI", "abi̇12");
assertReplace("abi̇12", "İ", "X", "UNICODE_CI", "abX12");
assertReplace("abİ12", "i", "X", "UNICODE_CI", "abİ12");
assertReplace("İi̇İi̇İi̇", "i̇", "x", "UNICODE_CI", "xxxxxx");
assertReplace("İi̇İi̇İi̇", "i", "x", "UNICODE_CI", "İi̇İi̇İi̇");
assertReplace("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx");
assertReplace("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy");
assertReplace("abi̇12", "i", "X", "UTF8_BINARY_LCASE", "abX\u030712"); // != UNICODE_CI
assertReplace("abi̇12", "\u0307", "X", "UTF8_BINARY_LCASE", "abiX12"); // != UNICODE_CI
assertReplace("abi̇12", "İ", "X", "UTF8_BINARY_LCASE", "abX12");
assertReplace("abİ12", "i", "X", "UTF8_BINARY_LCASE", "abİ12");
assertReplace("İi̇İi̇İi̇", "i̇", "x", "UTF8_BINARY_LCASE", "xxxxxx");
assertReplace("İi̇İi̇İi̇", "i", "x", "UTF8_BINARY_LCASE",
"İx\u0307İx\u0307İx\u0307"); // != UNICODE_CI
assertReplace("abİo12i̇o", "i̇o", "xx", "UTF8_BINARY_LCASE", "abxx12xx");
assertReplace("abi̇o12i̇o", "İo", "yy", "UTF8_BINARY_LCASE", "abyy12yy");
}

private void assertLocate(String substring, String string, Integer start, String collationName,
Expand Down