Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48441][SQL] Fix StringTrim behaviour for non-UTF8_BINARY collations #46762

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import static org.apache.spark.unsafe.Platform.copyMemory;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;

/**
Expand Down Expand Up @@ -657,57 +659,64 @@ public static Map<String, String> getCollationAwareDict(UTF8String string,
public static UTF8String lowercaseTrim(
final UTF8String srcString,
final UTF8String trimString) {
return lowercaseTrimRight(lowercaseTrimLeft(srcString, trimString), trimString);
}

public static UTF8String trim(
final UTF8String srcString,
final UTF8String trimString,
final int collationId) {
return trimRight(trimLeft(srcString, trimString, collationId), trimString, collationId);
}

public static UTF8String lowercaseTrimLeft(
uros-db marked this conversation as resolved.
Show resolved Hide resolved
final UTF8String srcString,
final UTF8String trimString) {
// Matching UTF8String behavior for null `trimString`.
if (trimString == null) {
return null;
}

UTF8String leftTrimmed = lowercaseTrimLeft(srcString, trimString);
return lowercaseTrimRight(leftTrimmed, trimString);
HashSet<Integer> trimChars = new HashSet<>();
Iterator<Integer> trimIter = trimString.codePointIterator();
while (trimIter.hasNext()) trimChars.add(UCharacter.toLowerCase(trimIter.next()));
uros-db marked this conversation as resolved.
Show resolved Hide resolved

int searchIndex = 0;
Iterator<Integer> srcIter = srcString.codePointIterator();
while (srcIter.hasNext()) {
if (!trimChars.contains(UCharacter.toLowerCase(srcIter.next()))) break;
++searchIndex;
}

return srcString.substring(searchIndex, srcString.numChars());
uros-db marked this conversation as resolved.
Show resolved Hide resolved
}

public static UTF8String lowercaseTrimLeft(
public static UTF8String trimLeft(
uros-db marked this conversation as resolved.
Show resolved Hide resolved
final UTF8String srcString,
final UTF8String trimString) {
final UTF8String trimString,
final int collationId) {
// Matching UTF8String behavior for null `trimString`.
if (trimString == null) {
return null;
}

// The searching byte position in the srcString.
int searchIdx = 0;
// The byte position of a first non-matching character in the srcString.
int trimByteIdx = 0;
// Number of bytes in srcString.
int numBytes = srcString.numBytes();
// Convert trimString to lowercase, so it can be searched properly.
UTF8String lowercaseTrimString = trimString.toLowerCase();

while (searchIdx < numBytes) {
UTF8String searchChar = srcString.copyUTF8String(
searchIdx,
searchIdx + UTF8String.numBytesForFirstByte(srcString.getByte(searchIdx)) - 1);
int searchCharBytes = searchChar.numBytes();

// Try to find the matching for the searchChar in the trimString.
if (lowercaseTrimString.find(searchChar.toLowerCase(), 0) >= 0) {
trimByteIdx += searchCharBytes;
searchIdx += searchCharBytes;
} else {
// No matching, exit the search.
break;
}
// Create a set of collation keys for all characters of the trim string, for fast lookup.
String trim = trimString.toString();
HashSet<String> trimChars = new HashSet<>();
for (int i = 0; i < trim.length(); i++) {
trimChars.add(CollationFactory.getCollationKey(String.valueOf(trim.charAt(i)), collationId));
}

if (searchIdx == 0) {
// Nothing trimmed - return original string (not converted to lowercase).
return srcString;
}
if (trimByteIdx >= numBytes) {
// Everything trimmed.
return UTF8String.EMPTY_UTF8;
// Iterate over srcString from the left and find the first character that is not in trimChars.
String input = srcString.toString();
int i = 0;
while (i < input.length()) {
String key = CollationFactory.getCollationKey(String.valueOf(input.charAt(i)), collationId);
if (!trimChars.contains(key)) break;
++i;
uros-db marked this conversation as resolved.
Show resolved Hide resolved
}
return srcString.copyUTF8String(trimByteIdx, numBytes - 1);
// Return the substring from that position to the end of the string.
return UTF8String.fromString(input.substring(i, srcString.numChars()));
}

public static UTF8String lowercaseTrimRight(
Expand All @@ -718,53 +727,48 @@ public static UTF8String lowercaseTrimRight(
return null;
}

// Number of bytes iterated from the srcString.
int byteIdx = 0;
// Number of characters iterated from the srcString.
int numChars = 0;
// Number of bytes in srcString.
int numBytes = srcString.numBytes();
// Array of character length for the srcString.
int[] stringCharLen = new int[numBytes];
// Array of the first byte position for each character in the srcString.
int[] stringCharPos = new int[numBytes];
// Convert trimString to lowercase, so it can be searched properly.
UTF8String lowercaseTrimString = trimString.toLowerCase();

// Build the position and length array.
while (byteIdx < numBytes) {
stringCharPos[numChars] = byteIdx;
stringCharLen[numChars] = UTF8String.numBytesForFirstByte(srcString.getByte(byteIdx));
byteIdx += stringCharLen[numChars];
numChars++;
}

// Index trimEnd points to the first no matching byte position from the right side of
// the source string.
int trimByteIdx = numBytes - 1;

while (numChars > 0) {
UTF8String searchChar = srcString.copyUTF8String(
stringCharPos[numChars - 1],
stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1);

if(lowercaseTrimString.find(searchChar.toLowerCase(), 0) >= 0) {
trimByteIdx -= stringCharLen[numChars - 1];
numChars--;
} else {
HashSet<Integer> trimChars = new HashSet<>();
Iterator<Integer> trimIter = trimString.codePointIterator();
while (trimIter.hasNext()) trimChars.add(UCharacter.toLowerCase(trimIter.next()));

int searchIndex = srcString.numChars();
Iterator<Integer> srcIter = srcString.reverseCodePointIterator();
while (srcIter.hasNext()) {
if (!trimChars.contains(UCharacter.toLowerCase(srcIter.next()))) {
break;
}
--searchIndex;
}
uros-db marked this conversation as resolved.
Show resolved Hide resolved

return srcString.substring(0, searchIndex);
}

public static UTF8String trimRight(
final UTF8String srcString,
final UTF8String trimString,
final int collationId) {
// Matching UTF8String behavior for null `trimString`.
if (trimString == null) {
return null;
}

if (trimByteIdx == numBytes - 1) {
// Nothing trimmed.
return srcString;
// Create a set of collation keys for all characters of the trim string, for fast lookup.
String trim = trimString.toString();
HashSet<String> trimChars = new HashSet<>();
for (int i = 0; i < trim.length(); i++) {
trimChars.add(CollationFactory.getCollationKey(String.valueOf(trim.charAt(i)), collationId));
uros-db marked this conversation as resolved.
Show resolved Hide resolved
}
if (trimByteIdx < 0) {
// Everything trimmed.
return UTF8String.EMPTY_UTF8;

// Iterate over srcString from the right and find the first character that is not in trimChars.
String input = srcString.toString();
int i = input.length() - 1;
while (i >= 0) {
String key = CollationFactory.getCollationKey(String.valueOf(input.charAt(i)), collationId);
if (!trimChars.contains(key)) break;
--i;
}
return srcString.copyUTF8String(0, trimByteIdx);
// Return the substring from the start of the string until that position.
return UTF8String.fromString(input.substring(0, i + 1));
}

// TODO: Add more collation-aware UTF8String operations here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,18 @@ public static String[] getICULocaleNames() {
return Collation.CollationSpecICU.ICULocaleNames;
}

public static String getCollationKey(String input, int collationId) {
Collation collation = fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return input;
} else if (collation.supportsLowercaseEquality) {
return input.toLowerCase();
} else {
CollationKey collationKey = collation.collator.getCollationKey(input);
return Arrays.toString(collationKey.toByteArray());
}
}
uros-db marked this conversation as resolved.
Show resolved Hide resolved

public static UTF8String getCollationKey(UTF8String input, int collationId) {
Collation collation = fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
Expand Down
Loading