Skip to content

Commit

Permalink
[SPARK-48440][SQL] Fix StringTranslate behaviour for non-UTF8_BINARY …
Browse files Browse the repository at this point in the history
…collations

### What changes were proposed in this pull request?
String searching in UTF8_LCASE now works on character-level, rather than on byte-level. For example: `translate("İ", "i")` now returns `"İ"`, because there exists no **single character** in `"İ"` such that lowercased version of that character equals to `"i"`. Note, however, that there _is_ a byte subsequence of `"İ"` such that lowercased version of that UTF-8 byte sequence equals to `"i"` (so the new behaviour is different than the old behaviour).

Also, translation for ICU collations works by repeatedly translating the longest possible substring that matches a key in the dictionary (under the specified collation), starting from the left side of the input string, until the entire string is translated.

### Why are the changes needed?
Fix functions that give unusable results due to one-to-many case mapping when performing string search under UTF8_BINARY_LCASE (see example above).

### Does this PR introduce _any_ user-facing change?
Yes, behaviour of `translate` expression is changed for edge cases with one-to-many case mapping.

### How was this patch tested?
New unit tests in `CollationStringExpressionsSuite`.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#46761 from uros-db/alter-translate.

Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
uros-db authored and cloud-fan committed Jul 12, 2024
1 parent 0fa5787 commit 8e4bbdf
Show file tree
Hide file tree
Showing 5 changed files with 402 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import com.ibm.icu.lang.UCharacter;
import com.ibm.icu.text.BreakIterator;
import com.ibm.icu.text.Collator;
import com.ibm.icu.text.RuleBasedCollator;
import com.ibm.icu.text.StringSearch;
import com.ibm.icu.util.ULocale;

Expand All @@ -26,8 +28,12 @@

import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
import static org.apache.spark.unsafe.Platform.copyMemory;
import static org.apache.spark.unsafe.types.UTF8String.CodePointIteratorType;

import java.text.CharacterIterator;
import java.text.StringCharacterIterator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

/**
Expand Down Expand Up @@ -424,27 +430,58 @@ private static UTF8String toLowerCaseSlow(final UTF8String target, final int col
* @param codePoint The code point to convert to lowercase.
* @param sb The StringBuilder to append the lowercase character to.
*/
private static void lowercaseCodePoint(final int codePoint, final StringBuilder sb) {
if (codePoint == 0x0130) {
private static void appendLowercaseCodePoint(final int codePoint, final StringBuilder sb) {
int lowercaseCodePoint = getLowercaseCodePoint(codePoint);
if (lowercaseCodePoint == CODE_POINT_COMBINED_LOWERCASE_I_DOT) {
// Latin capital letter I with dot above is mapped to 2 lowercase characters.
sb.appendCodePoint(0x0069);
sb.appendCodePoint(0x0307);
} else {
// All other characters should follow context-unaware ICU single-code point case mapping.
sb.appendCodePoint(lowercaseCodePoint);
}
}

/**
* `CODE_POINT_COMBINED_LOWERCASE_I_DOT` is an internal representation of the combined lowercase
* code point for ASCII lowercase letter i with an additional combining dot character (U+0307).
* This integer value is not a valid code point itself, but rather an artificial code point
* marker used to represent the two lowercase characters that are the result of converting the
* uppercase Turkish dotted letter I with a combining dot character (U+0130) to lowercase.
*/
private static final int CODE_POINT_LOWERCASE_I = 0x69;
private static final int CODE_POINT_COMBINING_DOT = 0x307;
private static final int CODE_POINT_COMBINED_LOWERCASE_I_DOT =
CODE_POINT_LOWERCASE_I << 16 | CODE_POINT_COMBINING_DOT;

/**
* Returns the lowercase version of the provided code point, with special handling for
* one-to-many case mappings (i.e. characters that map to multiple characters in lowercase) and
* context-insensitive case mappings (i.e. characters that map to different characters based on
* the position in the string relative to other characters in lowercase).
*/
private static int getLowercaseCodePoint(final int codePoint) {
if (codePoint == 0x0130) {
// Latin capital letter I with dot above is mapped to 2 lowercase characters.
return CODE_POINT_COMBINED_LOWERCASE_I_DOT;
}
else if (codePoint == 0x03C2) {
// Greek final and non-final capital letter sigma should be mapped the same.
sb.appendCodePoint(0x03C3);
// Greek final and non-final letter sigma should be mapped the same. This is achieved by
// mapping Greek small final sigma (U+03C2) to Greek small non-final sigma (U+03C3). Capital
// letter sigma (U+03A3) is mapped to small non-final sigma (U+03C3) in the `else` branch.
return 0x03C3;
}
else {
// All other characters should follow context-unaware ICU single-code point case mapping.
sb.appendCodePoint(UCharacter.toLowerCase(codePoint));
return UCharacter.toLowerCase(codePoint);
}
}

/**
* Converts an entire string to lowercase using ICU rules, code point by code point, with
* special handling for one-to-many case mappings (i.e. characters that map to multiple
* characters in lowercase). Also, this method omits information about context-sensitive case
* mappings using special handling in the `lowercaseCodePoint` method.
* mappings using special handling in the `appendLowercaseCodePoint` method.
*
* @param target The target string to convert to lowercase.
* @return The string converted to lowercase in a context-unaware manner.
Expand All @@ -455,10 +492,11 @@ public static UTF8String lowerCaseCodePoints(final UTF8String target) {
}

private static UTF8String lowerCaseCodePointsSlow(final UTF8String target) {
String targetString = target.toValidString();
Iterator<Integer> targetIter = target.codePointIterator(
CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < targetString.length(); ++i) {
lowercaseCodePoint(targetString.codePointAt(i), sb);
while (targetIter.hasNext()) {
appendLowercaseCodePoint(targetIter.next(), sb);
}
return UTF8String.fromString(sb.toString());
}
Expand Down Expand Up @@ -655,38 +693,152 @@ public static UTF8String lowercaseSubStringIndex(final UTF8String string,
}
}

public static Map<String, String> getCollationAwareDict(UTF8String string,
Map<String, String> dict, int collationId) {
// TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid`
String srcStr = string.toString();
/**
* Converts the original translation dictionary (`dict`) to a dictionary with lowercased keys.
* This method is used to create a dictionary that can be used for the UTF8_LCASE collation.
* Note that `StringTranslate.buildDict` will ensure that all strings are validated properly.
*
* The method returns a map with lowercased code points as keys, while the values remain
* unchanged. Note that `dict` is constructed on a character by character basis, and the
* original keys are stored as strings. Keys in the resulting lowercase dictionary are stored
* as integers, which correspond only to single characters from the original `dict`. Also,
* there is special handling for the Turkish dotted uppercase letter I (U+0130).
*/
private static Map<Integer, String> getLowercaseDict(final Map<String, String> dict) {
// Replace all the keys in the dict with lowercased code points.
Map<Integer, String> lowercaseDict = new HashMap<>();
for (Map.Entry<String, String> entry : dict.entrySet()) {
int codePoint = entry.getKey().codePointAt(0);
lowercaseDict.putIfAbsent(getLowercaseCodePoint(codePoint), entry.getValue());
}
return lowercaseDict;
}

/**
* Translates the `input` string using the translation map `dict`, for UTF8_LCASE collation.
* String translation is performed by iterating over the input string, from left to right, and
* repeatedly translating the longest possible substring that matches a key in the dictionary.
* For UTF8_LCASE, the method uses the lowercased substring to perform the lookup in the
* lowercased version of the translation map.
*
* @param input the string to be translated
* @param dict the lowercase translation dictionary
* @return the translated string
*/
public static UTF8String lowercaseTranslate(final UTF8String input,
final Map<String, String> dict) {
// Iterator for the input string.
Iterator<Integer> inputIter = input.codePointIterator(
CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID);
// Lowercased translation dictionary.
Map<Integer, String> lowercaseDict = getLowercaseDict(dict);
// StringBuilder to store the translated string.
StringBuilder sb = new StringBuilder();

Map<String, String> collationAwareDict = new HashMap<>();
for (String key : dict.keySet()) {
StringSearch stringSearch =
CollationFactory.getStringSearch(string, UTF8String.fromString(key), collationId);
// We use buffered code point iteration to handle one-to-many case mappings. We need to handle
// at most two code points at a time (for `CODE_POINT_COMBINED_LOWERCASE_I_DOT`), a buffer of
// size 1 enables us to match two codepoints in the input string with a single codepoint in
// the lowercase translation dictionary.
int codePointBuffer = -1, codePoint;
while (inputIter.hasNext()) {
if (codePointBuffer != -1) {
codePoint = codePointBuffer;
codePointBuffer = -1;
} else {
codePoint = inputIter.next();
}
// Special handling for letter i (U+0069) followed by a combining dot (U+0307). By ensuring
// that `CODE_POINT_LOWERCASE_I` is buffered, we guarantee finding a max-length match.
if (lowercaseDict.containsKey(CODE_POINT_COMBINED_LOWERCASE_I_DOT) &&
codePoint == CODE_POINT_LOWERCASE_I && inputIter.hasNext()) {
int nextCodePoint = inputIter.next();
if (nextCodePoint == CODE_POINT_COMBINING_DOT) {
codePoint = CODE_POINT_COMBINED_LOWERCASE_I_DOT;
} else {
codePointBuffer = nextCodePoint;
}
}
// Translate the code point using the lowercased dictionary.
String translated = lowercaseDict.get(getLowercaseCodePoint(codePoint));
if (translated == null) {
// Append the original code point if no translation is found.
sb.appendCodePoint(codePoint);
} else if (!"\0".equals(translated)) {
// Append the translated code point if the translation is not the null character.
sb.append(translated);
}
// Skip the code point if it maps to the null character.
}
// Append the last code point if it was buffered.
if (codePointBuffer != -1) sb.appendCodePoint(codePointBuffer);

int pos = 0;
while ((pos = stringSearch.next()) != StringSearch.DONE) {
int codePoint = srcStr.codePointAt(pos);
int charCount = Character.charCount(codePoint);
String newKey = srcStr.substring(pos, pos + charCount);
// Return the translated string.
return UTF8String.fromString(sb.toString());
}

boolean exists = false;
for (String existingKey : collationAwareDict.keySet()) {
if (stringSearch.getCollator().compare(existingKey, newKey) == 0) {
collationAwareDict.put(newKey, collationAwareDict.get(existingKey));
exists = true;
break;
/**
* Translates the `input` string using the translation map `dict`, for all ICU collations.
* String translation is performed by iterating over the input string, from left to right, and
* repeatedly translating the longest possible substring that matches a key in the dictionary.
* For ICU collations, the method uses the ICU `StringSearch` class to perform the lookup in
* the translation map, while respecting the rules of the specified ICU collation.
*
* @param input the string to be translated
* @param dict the collation aware translation dictionary
* @param collationId the collation ID to use for string translation
* @return the translated string
*/
public static UTF8String translate(final UTF8String input,
final Map<String, String> dict, final int collationId) {
// Replace invalid UTF-8 sequences with the Unicode replacement character U+FFFD.
String inputString = input.toValidString();
// Create a character iterator for the validated input string. This will be used for searching
// inside the string using ICU `StringSearch` class. We only need to do it once before the
// main loop of the translate algorithm.
CharacterIterator target = new StringCharacterIterator(inputString);
Collator collator = CollationFactory.fetchCollation(collationId).collator;
StringBuilder sb = new StringBuilder();
// Index for the current character in the (validated) input string. This is the character we
// want to determine if we need to replace or not.
int charIndex = 0;
while (charIndex < inputString.length()) {
// We search the replacement dictionary to find a match. If there are more than one matches
// (which is possible for collated strings), we want to choose the match of largest length.
int longestMatchLen = 0;
String longestMatch = "";
for (String key : dict.keySet()) {
StringSearch stringSearch = new StringSearch(key, target, (RuleBasedCollator) collator);
// Point `stringSearch` to start at the current character.
stringSearch.setIndex(charIndex);
int matchIndex = stringSearch.next();
if (matchIndex == charIndex) {
// We have found a match (that is the current position matches with one of the characters
// in the dictionary). However, there might be other matches of larger length, so we need
// to continue searching against the characters in the dictionary and keep track of the
// match of largest length.
int matchLen = stringSearch.getMatchLength();
if (matchLen > longestMatchLen) {
longestMatchLen = matchLen;
longestMatch = key;
}
}

if (!exists) {
collationAwareDict.put(newKey, dict.get(key));
}
if (longestMatchLen == 0) {
// No match was found, so output the current character.
sb.append(inputString.charAt(charIndex));
// Move on to the next character in the input string.
++charIndex;
} else {
// We have found at least one match. Append the match of longest match length to the output.
if (!"\0".equals(dict.get(longestMatch))) {
sb.append(dict.get(longestMatch));
}
// Skip as many characters as the longest match.
charIndex += longestMatchLen;
}
}

return collationAwareDict;
// Return the translated string.
return UTF8String.fromString(sb.toString());
}

public static UTF8String lowercaseTrim(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean
return useICU ? execBinaryICU(v) : execBinary(v);
} else if (collation.supportsLowercaseEquality) {
return execLowercase(v);
} else {
} else {
return execICU(v, collationId);
}
}
Expand All @@ -224,7 +224,7 @@ public static String genCode(final String v, final int collationId, boolean useI
return String.format(expr + "%s(%s)", funcName, v);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s)", v);
} else {
} else {
return String.format(expr + "ICU(%s, %d)", v, collationId);
}
}
Expand Down Expand Up @@ -261,7 +261,7 @@ public static String genCode(final String v, final int collationId, boolean useI
return String.format(expr + "%s(%s)", funcName, v);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s)", v);
} else {
} else {
return String.format(expr + "ICU(%s, %d)", v, collationId);
}
}
Expand Down Expand Up @@ -522,26 +522,11 @@ public static UTF8String execBinary(final UTF8String source, Map<String, String>
return source.translate(dict);
}
public static UTF8String execLowercase(final UTF8String source, Map<String, String> dict) {
String srcStr = source.toString();
StringBuilder sb = new StringBuilder();
int charCount = 0;
for (int k = 0; k < srcStr.length(); k += charCount) {
int codePoint = srcStr.codePointAt(k);
charCount = Character.charCount(codePoint);
String subStr = srcStr.substring(k, k + charCount);
String translated = dict.get(subStr.toLowerCase());
if (null == translated) {
sb.append(subStr);
} else if (!"\0".equals(translated)) {
sb.append(translated);
}
}
return UTF8String.fromString(sb.toString());
return CollationAwareUTF8String.lowercaseTranslate(source, dict);
}
public static UTF8String execICU(final UTF8String source, Map<String, String> dict,
final int collationId) {
return source.translate(CollationAwareUTF8String.getCollationAwareDict(
source, dict, collationId));
return CollationAwareUTF8String.translate(source, dict, collationId);
}
}

Expand Down
Loading

0 comments on commit 8e4bbdf

Please sign in to comment.