Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48441][SQL] Fix StringTrim behaviour for non-UTF8_BINARY collations #46762

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.text.CharacterIterator;
import java.text.StringCharacterIterator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;

Expand Down Expand Up @@ -841,117 +842,268 @@ public static UTF8String translate(final UTF8String input,
return UTF8String.fromString(sb.toString());
}

/**
* Trims the `srcString` string from both ends of the string using the specified `trimString`
* characters, with respect to the UTF8_LCASE collation. String trimming is performed by
* first trimming the left side of the string, and then trimming the right side of the string.
* The method returns the trimmed string. If the `trimString` is null, the method returns null.
*
* @param srcString the input string to be trimmed from both ends of the string
* @param trimString the trim string characters to trim
* @return the trimmed string (for UTF8_LCASE collation)
*/
public static UTF8String lowercaseTrim(
final UTF8String srcString,
final UTF8String trimString) {
// Matching UTF8String behavior for null `trimString`.
if (trimString == null) {
return null;
}
return lowercaseTrimRight(lowercaseTrimLeft(srcString, trimString), trimString);
}

UTF8String leftTrimmed = lowercaseTrimLeft(srcString, trimString);
return lowercaseTrimRight(leftTrimmed, trimString);
/**
* Trims the `srcString` string from both ends of the string using the specified `trimString`
* characters, with respect to all ICU collations in Spark. String trimming is performed by
* first trimming the left side of the string, and then trimming the right side of the string.
* The method returns the trimmed string. If the `trimString` is null, the method returns null.
*
* @param srcString the input string to be trimmed from both ends of the string
* @param trimString the trim string characters to trim
uros-db marked this conversation as resolved.
Show resolved Hide resolved
* @param collationId the collation ID to use for string trimming
* @return the trimmed string (for ICU collations)
*/
public static UTF8String trim(
final UTF8String srcString,
final UTF8String trimString,
final int collationId) {
return trimRight(trimLeft(srcString, trimString, collationId), trimString, collationId);
}

/**
* Trims the `srcString` string from the left side using the specified `trimString` characters,
* with respect to the UTF8_LCASE collation. For UTF8_LCASE, the method first creates a hash
* set of lowercased code points in `trimString`, and then iterates over the `srcString` from
* the left side, until reaching a character whose lowercased code point is not in the hash set.
* Finally, the method returns the substring from that position to the end of `srcString`.
* If `trimString` is null, null is returned. If `trimString` is empty, `srcString` is returned.
*
* @param srcString the input string to be trimmed from the left end of the string
* @param trimString the trim string characters to trim
* @return the trimmed string (for UTF8_LCASE collation)
*/
public static UTF8String lowercaseTrimLeft(
uros-db marked this conversation as resolved.
Show resolved Hide resolved
final UTF8String srcString,
final UTF8String trimString) {
// Matching UTF8String behavior for null `trimString`.
// Matching the default UTF8String behavior for null `trimString`.
if (trimString == null) {
return null;
}

// The searching byte position in the srcString.
int searchIdx = 0;
// The byte position of a first non-matching character in the srcString.
int trimByteIdx = 0;
// Number of bytes in srcString.
int numBytes = srcString.numBytes();
// Convert trimString to lowercase, so it can be searched properly.
UTF8String lowercaseTrimString = trimString.toLowerCase();

while (searchIdx < numBytes) {
UTF8String searchChar = srcString.copyUTF8String(
searchIdx,
searchIdx + UTF8String.numBytesForFirstByte(srcString.getByte(searchIdx)) - 1);
int searchCharBytes = searchChar.numBytes();

// Try to find the matching for the searchChar in the trimString.
if (lowercaseTrimString.find(searchChar.toLowerCase(), 0) >= 0) {
trimByteIdx += searchCharBytes;
searchIdx += searchCharBytes;
} else {
// No matching, exit the search.
// Create a hash set of lowercased code points for all characters of `trimString`.
HashSet<Integer> trimChars = new HashSet<>();
Iterator<Integer> trimIter = trimString.codePointIterator();
while (trimIter.hasNext()) trimChars.add(getLowercaseCodePoint(trimIter.next()));

// Iterate over `srcString` from the left to find the first character that is not in the set.
int searchIndex = 0, codePoint;
Iterator<Integer> srcIter = srcString.codePointIterator();
while (srcIter.hasNext()) {
codePoint = getLowercaseCodePoint(srcIter.next());
// Special handling for Turkish dotted uppercase letter I.
if (codePoint == CODE_POINT_LOWERCASE_I && srcIter.hasNext() &&
trimChars.contains(CODE_POINT_COMBINED_LOWERCASE_I_DOT)) {
int nextCodePoint = getLowercaseCodePoint(srcIter.next());
if ((trimChars.contains(codePoint) && trimChars.contains(nextCodePoint))
|| nextCodePoint == CODE_POINT_COMBINING_DOT) {
searchIndex += 2;
}
else {
if (trimChars.contains(codePoint)) ++searchIndex;
break;
}
} else if (trimChars.contains(codePoint)) {
++searchIndex;
}
else {
break;
}
}

if (searchIdx == 0) {
// Nothing trimmed - return original string (not converted to lowercase).
return srcString;
// Return the substring from that position to the end of the string.
return searchIndex == 0 ? srcString : srcString.substring(searchIndex, srcString.numChars());
}

/**
* Trims the `srcString` string from the left side using the specified `trimString` characters,
* with respect to ICU collations. For these collations, the method iterates over `srcString`
* from left to right, and repeatedly skips the longest possible substring that matches any
* character in `trimString`, until reaching a character that is not found in `trimString`.
* Finally, the method returns the substring from that position to the end of `srcString`.
* If `trimString` is null, null is returned. If `trimString` is empty, `srcString` is returned.
*
* @param srcString the input string to be trimmed from the left end of the string
* @param trimString the trim string characters to trim
uros-db marked this conversation as resolved.
Show resolved Hide resolved
* @param collationId the collation ID to use for string trimming
* @return the trimmed string (for ICU collations)
*/
public static UTF8String trimLeft(
uros-db marked this conversation as resolved.
Show resolved Hide resolved
final UTF8String srcString,
final UTF8String trimString,
final int collationId) {
// Short-circuit for base cases.
if (trimString == null) return null;
if (srcString.numBytes() == 0) return srcString;

// Create an array of Strings for all characters of `trimString`.
Map<Integer, String> trimChars = new HashMap<>();
Iterator<Integer> trimIter = trimString.codePointIterator(
CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID);
while (trimIter.hasNext()) {
int codePoint = trimIter.next();
trimChars.putIfAbsent(codePoint, String.valueOf((char) codePoint));
}
if (trimByteIdx >= numBytes) {
// Everything trimmed.
return UTF8String.EMPTY_UTF8;

// Iterate over srcString from the left and find the first character that is not in trimChars.
String src = srcString.toValidString();
CharacterIterator target = new StringCharacterIterator(src);
Collator collator = CollationFactory.fetchCollation(collationId).collator;
int charIndex = 0, longestMatchLen;
while (charIndex < src.length()) {
longestMatchLen = 0;
for (String trim : trimChars.values()) {
StringSearch stringSearch = new StringSearch(trim, target, (RuleBasedCollator) collator);
stringSearch.setIndex(charIndex);
int matchIndex = stringSearch.next();
if (matchIndex == charIndex) {
int matchLen = stringSearch.getMatchLength();
if (matchLen > longestMatchLen) {
longestMatchLen = matchLen;
}
}
}
if (longestMatchLen == 0) break;
else charIndex += longestMatchLen;
}
return srcString.copyUTF8String(trimByteIdx, numBytes - 1);

// Return the substring from the calculated position until the end of the string.
return UTF8String.fromString(src.substring(charIndex));
}

/**
* Trims the `srcString` string from the right side using the specified `trimString` characters,
* with respect to the UTF8_LCASE collation. For UTF8_LCASE, the method first creates a hash
* set of lowercased code points in `trimString`, and then iterates over the `srcString` from
* the right side, until reaching a character whose lowercased code point is not in the hash set.
* Finally, the method returns the substring from the start of `srcString` until that position.
* If `trimString` is null, null is returned. If `trimString` is empty, `srcString` is returned.
*
* @param srcString the input string to be trimmed from the right end of the string
* @param trimString the trim string characters to trim
* @return the trimmed string (for UTF8_LCASE collation)
*/
public static UTF8String lowercaseTrimRight(
final UTF8String srcString,
final UTF8String trimString) {
// Matching UTF8String behavior for null `trimString`.
// Matching the default UTF8String behavior for null `trimString`.
if (trimString == null) {
return null;
}

// Number of bytes iterated from the srcString.
int byteIdx = 0;
// Number of characters iterated from the srcString.
int numChars = 0;
// Number of bytes in srcString.
int numBytes = srcString.numBytes();
// Array of character length for the srcString.
int[] stringCharLen = new int[numBytes];
// Array of the first byte position for each character in the srcString.
int[] stringCharPos = new int[numBytes];
// Convert trimString to lowercase, so it can be searched properly.
UTF8String lowercaseTrimString = trimString.toLowerCase();

// Build the position and length array.
while (byteIdx < numBytes) {
stringCharPos[numChars] = byteIdx;
stringCharLen[numChars] = UTF8String.numBytesForFirstByte(srcString.getByte(byteIdx));
byteIdx += stringCharLen[numChars];
numChars++;
}

// Index trimEnd points to the first no matching byte position from the right side of
// the source string.
int trimByteIdx = numBytes - 1;

while (numChars > 0) {
UTF8String searchChar = srcString.copyUTF8String(
stringCharPos[numChars - 1],
stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1);

if(lowercaseTrimString.find(searchChar.toLowerCase(), 0) >= 0) {
trimByteIdx -= stringCharLen[numChars - 1];
numChars--;
} else {
// Create a hash set of lowercased code points for all characters of `trimString`.
HashSet<Integer> trimChars = new HashSet<>();
Iterator<Integer> trimIter = trimString.codePointIterator();
while (trimIter.hasNext()) trimChars.add(getLowercaseCodePoint(trimIter.next()));

// Iterate over `srcString` from the right to find the first character that is not in the set.
int searchIndex = srcString.numChars(), codePoint;
Iterator<Integer> srcIter = srcString.reverseCodePointIterator();
while (srcIter.hasNext()) {
codePoint = getLowercaseCodePoint(srcIter.next());
// Special handling for Turkish dotted uppercase letter I.
if (codePoint == CODE_POINT_COMBINING_DOT && srcIter.hasNext() &&
trimChars.contains(CODE_POINT_COMBINED_LOWERCASE_I_DOT)) {
int nextCodePoint = getLowercaseCodePoint(srcIter.next());
if ((trimChars.contains(codePoint) && trimChars.contains(nextCodePoint))
|| nextCodePoint == CODE_POINT_LOWERCASE_I) {
searchIndex -= 2;
}
else {
if (trimChars.contains(codePoint)) --searchIndex;
break;
}
} else if (trimChars.contains(codePoint)) {
--searchIndex;
}
else {
break;
}
}

if (trimByteIdx == numBytes - 1) {
// Nothing trimmed.
return srcString;
// Return the substring from the start of the string to the calculated position.
return searchIndex == srcString.numChars() ? srcString : srcString.substring(0, searchIndex);
}

/**
* Trims the `srcString` string from the right side using the specified `trimString` characters,
* with respect to ICU collations. For these collations, the method iterates over `srcString`
* from right to left, and repeatedly skips the longest possible substring that matches any
* character in `trimString`, until reaching a character that is not found in `trimString`.
* Finally, the method returns the substring from the start of `srcString` until that position.
* If `trimString` is null, null is returned. If `trimString` is empty, `srcString` is returned.
*
* @param srcString the input string to be trimmed from the right end of the string
* @param trimString the trim string characters to trim
uros-db marked this conversation as resolved.
Show resolved Hide resolved
* @param collationId the collation ID to use for string trimming
* @return the trimmed string (for ICU collations)
*/
public static UTF8String trimRight(
final UTF8String srcString,
final UTF8String trimString,
final int collationId) {
// Short-circuit for base cases.
if (trimString == null) return null;
if (srcString.numBytes() == 0) return srcString;

// Create an array of Strings for all characters of `trimString`.
Map<Integer, String> trimChars = new HashMap<>();
Iterator<Integer> trimIter = trimString.codePointIterator(
CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID);
while (trimIter.hasNext()) {
int codePoint = trimIter.next();
trimChars.putIfAbsent(codePoint, String.valueOf((char) codePoint));
}
if (trimByteIdx < 0) {
// Everything trimmed.
return UTF8String.EMPTY_UTF8;

// Iterate over srcString from the left and find the first character that is not in trimChars.
String src = srcString.toValidString();
CharacterIterator target = new StringCharacterIterator(src);
Collator collator = CollationFactory.fetchCollation(collationId).collator;
int charIndex = src.length(), longestMatchLen;
while (charIndex >= 0) {
longestMatchLen = 0;
for (String trim : trimChars.values()) {
StringSearch stringSearch = new StringSearch(trim, target, (RuleBasedCollator) collator);
// Note: stringSearch.previous() is NOT consistent with stringSearch.next()!
// Example: StringSearch("İ", "i\\u0307İi\\u0307İi\\u0307İ", "UNICODE_CI")
// stringSearch.next() gives: [0, 2, 3, 5, 6, 8].
// stringSearch.previous() gives: [8, 6, 3, 0].
// Since 1 character can map to at most 3 characters in Unicode, we can begin the search
// from character position: `charIndex` - 3, and use `next()` to find the longest match.
stringSearch.setIndex(Math.max(charIndex - 3, 0));
uros-db marked this conversation as resolved.
Show resolved Hide resolved
int matchIndex = stringSearch.next();
int matchLen = stringSearch.getMatchLength();
uros-db marked this conversation as resolved.
Show resolved Hide resolved
while (matchIndex != StringSearch.DONE && matchIndex < charIndex - matchLen) {
matchIndex = stringSearch.next();
matchLen = stringSearch.getMatchLength();
}
if (matchIndex == charIndex - matchLen) {
if (matchLen > longestMatchLen) {
longestMatchLen = matchLen;
}
}
}
if (longestMatchLen == 0) break;
else charIndex -= longestMatchLen;
}
return srcString.copyUTF8String(0, trimByteIdx);

// Return the substring from the start of the string until that position.
return UTF8String.fromString(src.substring(0, charIndex));
}

// TODO: Add more collation-aware UTF8String operations here.
Expand Down
Loading