Skip to content

Commit

Permalink
[SPARK-47566][SQL] Support SubstringIndex function to work with colla…
Browse files Browse the repository at this point in the history
…ted strings

### What changes were proposed in this pull request?
Extend built-in string functions to support non-binary, non-lowercase collation for: substring_index.

### Why are the changes needed?
Update collation support for built-in string functions in Spark.

### Does this PR introduce _any_ user-facing change?
Yes, users should now be able to use COLLATE within arguments for built-in string function SUBSTRING_INDEX in Spark SQL queries, using non-binary collations such as UNICODE_CI.

### How was this patch tested?
Unit tests for queries using SubstringIndex (`CollationStringExpressionsSuite.scala`).

### Was this patch authored or co-authored using generative AI tooling?
No

### To consider:
There is no check for collation match between string and delimiter, it will be introduced with Implicit Casting.

We can remove the original `public UTF8String subStringIndex(UTF8String delim, int count)` method, and get the existing behavior using `subStringIndex(delim, count, 0)`.

Closes #45725 from miland-db/miland-db/substringIndex-stringLocate.

Authored-by: Milan Dankovic <milan.dankovic@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
miland-db authored and cloud-fan committed Apr 30, 2024
1 parent 9e8c4aa commit 12a5074
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
import java.util.List;
import java.util.regex.Pattern;

import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
import static org.apache.spark.unsafe.Platform.copyMemory;

/**
* Static entry point for collation-aware expressions (StringExpressions, RegexpExpressions, and
* other expressions that require custom collation support), as well as private utility methods for
Expand Down Expand Up @@ -441,6 +444,45 @@ public static int execICU(final UTF8String string, final UTF8String substring, f
}
}

public static class SubstringIndex {
public static UTF8String exec(final UTF8String string, final UTF8String delimiter,
final int count, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return execBinary(string, delimiter, count);
} else if (collation.supportsLowercaseEquality) {
return execLowercase(string, delimiter, count);
} else {
return execICU(string, delimiter, count, collationId);
}
}
public static String genCode(final String string, final String delimiter,
final int count, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.SubstringIndex.exec";
if (collation.supportsBinaryEquality) {
return String.format(expr + "Binary(%s, %s, %d)", string, delimiter, count);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s, %s, %d)", string, delimiter, count);
} else {
return String.format(expr + "ICU(%s, %s, %d, %d)", string, delimiter, count, collationId);
}
}
public static UTF8String execBinary(final UTF8String string, final UTF8String delimiter,
final int count) {
return string.subStringIndex(delimiter, count);
}
public static UTF8String execLowercase(final UTF8String string, final UTF8String delimiter,
final int count) {
return CollationAwareUTF8String.lowercaseSubStringIndex(string, delimiter, count);
}
public static UTF8String execICU(final UTF8String string, final UTF8String delimiter,
final int count, final int collationId) {
return CollationAwareUTF8String.subStringIndex(string, delimiter, count,
collationId);
}
}

// TODO: Add more collation-aware string expressions.

/**
Expand Down Expand Up @@ -639,6 +681,133 @@ private static int indexOf(final UTF8String target, final UTF8String pattern,
return stringSearch.next();
}

private static int find(UTF8String target, UTF8String pattern, int start,
int collationId) {
assert (pattern.numBytes() > 0);

StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId);
// Set search start position (start from character at start position)
stringSearch.setIndex(target.bytePosToChar(start));

// Return either the byte position or -1 if not found
return target.charPosToByte(stringSearch.next());
}

private static UTF8String subStringIndex(final UTF8String string, final UTF8String delimiter,
int count, final int collationId) {
if (delimiter.numBytes() == 0 || count == 0 || string.numBytes() == 0) {
return UTF8String.EMPTY_UTF8;
}
if (count > 0) {
int idx = -1;
while (count > 0) {
idx = find(string, delimiter, idx + 1, collationId);
if (idx >= 0) {
count --;
} else {
// can not find enough delim
return string;
}
}
if (idx == 0) {
return UTF8String.EMPTY_UTF8;
}
byte[] bytes = new byte[idx];
copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, BYTE_ARRAY_OFFSET, idx);
return UTF8String.fromBytes(bytes);

} else {
count = -count;

StringSearch stringSearch = CollationFactory
.getStringSearch(string, delimiter, collationId);

int start = string.numChars() - 1;
int lastMatchLength = 0;
int prevStart = -1;
while (count > 0) {
stringSearch.reset();
prevStart = -1;
int matchStart = stringSearch.next();
lastMatchLength = stringSearch.getMatchLength();
while (matchStart <= start) {
if (matchStart != StringSearch.DONE) {
// Found a match, update the start position
prevStart = matchStart;
matchStart = stringSearch.next();
} else {
break;
}
}

if (prevStart == -1) {
// can not find enough delim
return string;
} else {
start = prevStart - 1;
count--;
}
}

int resultStart = prevStart + lastMatchLength;
if (resultStart == string.numChars()) {
return UTF8String.EMPTY_UTF8;
}

return string.substring(resultStart, string.numChars());
}
}

private static UTF8String lowercaseSubStringIndex(final UTF8String string,
final UTF8String delimiter, int count) {
if (delimiter.numBytes() == 0 || count == 0) {
return UTF8String.EMPTY_UTF8;
}

UTF8String lowercaseString = string.toLowerCase();
UTF8String lowercaseDelimiter = delimiter.toLowerCase();

if (count > 0) {
int idx = -1;
while (count > 0) {
idx = lowercaseString.find(lowercaseDelimiter, idx + 1);
if (idx >= 0) {
count --;
} else {
// can not find enough delim
return string;
}
}
if (idx == 0) {
return UTF8String.EMPTY_UTF8;
}
byte[] bytes = new byte[idx];
copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, BYTE_ARRAY_OFFSET, idx);
return UTF8String.fromBytes(bytes);

} else {
int idx = string.numBytes() - delimiter.numBytes() + 1;
count = -count;
while (count > 0) {
idx = lowercaseString.rfind(lowercaseDelimiter, idx - 1);
if (idx >= 0) {
count --;
} else {
// can not find enough delim
return string;
}
}
if (idx + delimiter.numBytes() == string.numBytes()) {
return UTF8String.EMPTY_UTF8;
}
int size = string.numBytes() - delimiter.numBytes() - idx;
byte[] bytes = new byte[size];
copyMemory(string.getBaseObject(), string.getBaseOffset() + idx + delimiter.numBytes(),
bytes, BYTE_ARRAY_OFFSET, size);
return UTF8String.fromBytes(bytes);
}
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -926,10 +926,34 @@ public int indexOf(UTF8String v, int start) {
return -1;
}

public int charPosToByte(int charPos) {
if (charPos < 0) {
return -1;
}

int i = 0;
int c = 0;
while (i < numBytes && c < charPos) {
i += numBytesForFirstByte(getByte(i));
c += 1;
}
return i;
}

public int bytePosToChar(int bytePos) {
int i = 0;
int c = 0;
while (i < numBytes && i < bytePos) {
i += numBytesForFirstByte(getByte(i));
c += 1;
}
return c;
}

/**
* Find the `str` from left to right.
*/
private int find(UTF8String str, int start) {
public int find(UTF8String str, int start) {
assert (str.numBytes > 0);
while (start <= numBytes - str.numBytes) {
if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) {
Expand All @@ -943,7 +967,7 @@ private int find(UTF8String str, int start) {
/**
* Find the `str` from right to left.
*/
private int rfind(UTF8String str, int start) {
public int rfind(UTF8String str, int start) {
assert (str.numBytes > 0);
while (start >= 0) {
if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,89 @@ public void testLocate() throws SparkException {
assertLocate("İo", "世界i̇o大千世界大千i̇o", 4, "UNICODE_CI", 12); // 12 instead of 11
}

private void assertSubstringIndex(String string, String delimiter, Integer count,
String collationName, String expected) throws SparkException {
UTF8String str = UTF8String.fromString(string);
UTF8String delim = UTF8String.fromString(delimiter);
int collationId = CollationFactory.collationNameToId(collationName);
assertEquals(expected,
CollationSupport.SubstringIndex.exec(str, delim, count, collationId).toString());
}

@Test
public void testSubstringIndex() throws SparkException {
assertSubstringIndex("wwwgapachegorg", "g", -3, "UTF8_BINARY", "apachegorg");
assertSubstringIndex("www||apache||org", "||", 2, "UTF8_BINARY", "www||apache");
assertSubstringIndex("aaaaaaaaaa", "aa", 2, "UTF8_BINARY", "a");
assertSubstringIndex("AaAaAaAaAa", "aa", 2, "UTF8_BINARY_LCASE", "A");
assertSubstringIndex("www.apache.org", ".", 3, "UTF8_BINARY_LCASE", "www.apache.org");
assertSubstringIndex("wwwXapacheXorg", "x", 2, "UTF8_BINARY_LCASE", "wwwXapache");
assertSubstringIndex("wwwxapachexorg", "X", 1, "UTF8_BINARY_LCASE", "www");
assertSubstringIndex("www.apache.org", ".", 0, "UTF8_BINARY_LCASE", "");
assertSubstringIndex("www.apache.ORG", ".", -3, "UTF8_BINARY_LCASE", "www.apache.ORG");
assertSubstringIndex("wwwGapacheGorg", "g", 1, "UTF8_BINARY_LCASE", "www");
assertSubstringIndex("wwwGapacheGorg", "g", 3, "UTF8_BINARY_LCASE", "wwwGapacheGor");
assertSubstringIndex("gwwwGapacheGorg", "g", 3, "UTF8_BINARY_LCASE", "gwwwGapache");
assertSubstringIndex("wwwGapacheGorg", "g", -3, "UTF8_BINARY_LCASE", "apacheGorg");
assertSubstringIndex("wwwmapacheMorg", "M", -2, "UTF8_BINARY_LCASE", "apacheMorg");
assertSubstringIndex("www.apache.org", ".", -1, "UTF8_BINARY_LCASE", "org");
assertSubstringIndex("www.apache.org.", ".", -1, "UTF8_BINARY_LCASE", "");
assertSubstringIndex("", ".", -2, "UTF8_BINARY_LCASE", "");
assertSubstringIndex("test大千世界X大千世界", "x", -1, "UTF8_BINARY_LCASE", "大千世界");
assertSubstringIndex("test大千世界X大千世界", "X", 1, "UTF8_BINARY_LCASE", "test大千世界");
assertSubstringIndex("test大千世界大千世界", "千", 2, "UTF8_BINARY_LCASE", "test大千世界大");
assertSubstringIndex("www||APACHE||org", "||", 2, "UTF8_BINARY_LCASE", "www||APACHE");
assertSubstringIndex("www||APACHE||org", "||", -1, "UTF8_BINARY_LCASE", "org");
assertSubstringIndex("AaAaAaAaAa", "Aa", 2, "UNICODE", "Aa");
assertSubstringIndex("wwwYapacheyorg", "y", 3, "UNICODE", "wwwYapacheyorg");
assertSubstringIndex("www.apache.org", ".", 2, "UNICODE", "www.apache");
assertSubstringIndex("wwwYapacheYorg", "Y", 1, "UNICODE", "www");
assertSubstringIndex("wwwYapacheYorg", "y", 1, "UNICODE", "wwwYapacheYorg");
assertSubstringIndex("wwwGapacheGorg", "g", 1, "UNICODE", "wwwGapacheGor");
assertSubstringIndex("GwwwGapacheGorG", "G", 3, "UNICODE", "GwwwGapache");
assertSubstringIndex("wwwGapacheGorG", "G", -3, "UNICODE", "apacheGorG");
assertSubstringIndex("www.apache.org", ".", 0, "UNICODE", "");
assertSubstringIndex("www.apache.org", ".", -3, "UNICODE", "www.apache.org");
assertSubstringIndex("www.apache.org", ".", -2, "UNICODE", "apache.org");
assertSubstringIndex("www.apache.org", ".", -1, "UNICODE", "org");
assertSubstringIndex("", ".", -2, "UNICODE", "");
assertSubstringIndex("test大千世界X大千世界", "X", -1, "UNICODE", "大千世界");
assertSubstringIndex("test大千世界X大千世界", "X", 1, "UNICODE", "test大千世界");
assertSubstringIndex("大x千世界大千世x界", "x", 1, "UNICODE", "大");
assertSubstringIndex("大x千世界大千世x界", "x", -1, "UNICODE", "界");
assertSubstringIndex("大x千世界大千世x界", "x", -2, "UNICODE", "千世界大千世x界");
assertSubstringIndex("大千世界大千世界", "千", 2, "UNICODE", "大千世界大");
assertSubstringIndex("www||apache||org", "||", 2, "UNICODE", "www||apache");
assertSubstringIndex("AaAaAaAaAa", "aa", 2, "UNICODE_CI", "A");
assertSubstringIndex("www.apache.org", ".", 3, "UNICODE_CI", "www.apache.org");
assertSubstringIndex("wwwXapacheXorg", "x", 2, "UNICODE_CI", "wwwXapache");
assertSubstringIndex("wwwxapacheXorg", "X", 1, "UNICODE_CI", "www");
assertSubstringIndex("www.apache.org", ".", 0, "UNICODE_CI", "");
assertSubstringIndex("wwwGapacheGorg", "G", 3, "UNICODE_CI", "wwwGapacheGor");
assertSubstringIndex("gwwwGapacheGorg", "g", 3, "UNICODE_CI", "gwwwGapache");
assertSubstringIndex("gwwwGapacheGorg", "g", -3, "UNICODE_CI", "apacheGorg");
assertSubstringIndex("www.apache.ORG", ".", -3, "UNICODE_CI", "www.apache.ORG");
assertSubstringIndex("wwwmapacheMorg", "M", -2, "UNICODE_CI", "apacheMorg");
assertSubstringIndex("www.apache.org", ".", -1, "UNICODE_CI", "org");
assertSubstringIndex("", ".", -2, "UNICODE_CI", "");
assertSubstringIndex("test大千世界X大千世界", "X", -1, "UNICODE_CI", "大千世界");
assertSubstringIndex("test大千世界X大千世界", "X", 1, "UNICODE_CI", "test大千世界");
assertSubstringIndex("test大千世界大千世界", "千", 2, "UNICODE_CI", "test大千世界大");
assertSubstringIndex("www||APACHE||org", "||", 2, "UNICODE_CI", "www||APACHE");
assertSubstringIndex("abİo12", "i̇o", 1, "UNICODE_CI", "ab");
assertSubstringIndex("abİo12", "i̇o", -1, "UNICODE_CI", "12");
assertSubstringIndex("abi̇o12", "İo", 1, "UNICODE_CI", "ab");
assertSubstringIndex("abi̇o12", "İo", -1, "UNICODE_CI", "12");
assertSubstringIndex("ai̇bi̇o12", "İo", 1, "UNICODE_CI", "ai̇b");
assertSubstringIndex("ai̇bi̇o12i̇o", "İo", 2, "UNICODE_CI", "ai̇bi̇o12");
assertSubstringIndex("ai̇bi̇o12i̇o", "İo", -1, "UNICODE_CI", "");
assertSubstringIndex("ai̇bi̇o12i̇o", "İo", -2, "UNICODE_CI", "12i̇o");
assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "İo", -4, "UNICODE_CI", "İo12İoi̇o");
assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "i̇o", -4, "UNICODE_CI", "İo12İoi̇o");
assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "İo", -4, "UNICODE_CI", "i̇o12i̇oİo");
assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i̇o", -4, "UNICODE_CI", "i̇o12i̇oİo");
}

// TODO: Test more collation-aware string expressions.

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ object CollationTypeCasts extends TypeCoercionRule {
stringLocate.withNewChildren(collateToSingleType(
Seq(stringLocate.first, stringLocate.second)) :+ stringLocate.third)

case substringIndex: SubstringIndex =>
substringIndex.withNewChildren(
collateToSingleType(
Seq(substringIndex.first, substringIndex.second)) :+ substringIndex.third)

case eltExpr: Elt =>
eltExpr.withNewChildren(eltExpr.children.head +: collateToSingleType(eltExpr.children.tail))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1406,21 +1406,24 @@ case class StringInstr(str: Expression, substr: Expression)
case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {

override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId

override def dataType: DataType = strExpr.dataType
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType)
override def first: Expression = strExpr
override def second: Expression = delimExpr
override def third: Expression = countExpr
override def prettyName: String = "substring_index"

override def nullSafeEval(str: Any, delim: Any, count: Any): Any = {
str.asInstanceOf[UTF8String].subStringIndex(
delim.asInstanceOf[UTF8String],
count.asInstanceOf[Int])
CollationSupport.SubstringIndex.exec(str.asInstanceOf[UTF8String],
delim.asInstanceOf[UTF8String], count.asInstanceOf[Int], collationId);
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)")
defineCodeGen(ctx, ev, (str, delim, count) =>
CollationSupport.SubstringIndex.genCode(str, delim, Integer.parseInt(count, 10), collationId))
}

override protected def withNewChildrenInternal(
Expand Down

0 comments on commit 12a5074

Please sign in to comment.