From 8e4bbdff80a1c069ccce71060751987e9e6c0b6b Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 12 Jul 2024 22:30:18 +0800 Subject: [PATCH] [SPARK-48440][SQL] Fix StringTranslate behaviour for non-UTF8_BINARY collations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? String searching in UTF8_LCASE now works on character-level, rather than on byte-level. For example: `translate("İ", "i")` now returns `"İ"`, because there exists no **single character** in `"İ"` such that lowercased version of that character equals to `"i"`. Note, however, that there _is_ a byte subsequence of `"İ"` such that lowercased version of that UTF-8 byte sequence equals to `"i"` (so the new behaviour is different than the old behaviour). Also, translation for ICU collations works by repeatedly translating the longest possible substring that matches a key in the dictionary (under the specified collation), starting from the left side of the input string, until the entire string is translated. ### Why are the changes needed? Fix functions that give unusable results due to one-to-many case mapping when performing string search under UTF8_BINARY_LCASE (see example above). ### Does this PR introduce _any_ user-facing change? Yes, behaviour of `translate` expression is changed for edge cases with one-to-many case mapping. ### How was this patch tested? New unit tests in `CollationStringExpressionsSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46761 from uros-db/alter-translate. Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Signed-off-by: Wenchen Fan --- .../util/CollationAwareUTF8String.java | 218 +++++++++++++++--- .../sql/catalyst/util/CollationSupport.java | 25 +- .../unsafe/types/CollationSupportSuite.java | 192 ++++++++++++++- .../expressions/stringExpressions.scala | 30 ++- .../sql/CollationStringExpressionsSuite.scala | 51 +--- 5 files changed, 402 insertions(+), 114 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index 23adc772b7f34..af152c87f88ce 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -18,6 +18,8 @@ import com.ibm.icu.lang.UCharacter; import com.ibm.icu.text.BreakIterator; +import com.ibm.icu.text.Collator; +import com.ibm.icu.text.RuleBasedCollator; import com.ibm.icu.text.StringSearch; import com.ibm.icu.util.ULocale; @@ -26,8 +28,12 @@ import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; import static org.apache.spark.unsafe.Platform.copyMemory; +import static org.apache.spark.unsafe.types.UTF8String.CodePointIteratorType; +import java.text.CharacterIterator; +import java.text.StringCharacterIterator; import java.util.HashMap; +import java.util.Iterator; import java.util.Map; /** @@ -424,19 +430,50 @@ private static UTF8String toLowerCaseSlow(final UTF8String target, final int col * @param codePoint The code point to convert to lowercase. * @param sb The StringBuilder to append the lowercase character to. */ - private static void lowercaseCodePoint(final int codePoint, final StringBuilder sb) { - if (codePoint == 0x0130) { + private static void appendLowercaseCodePoint(final int codePoint, final StringBuilder sb) { + int lowercaseCodePoint = getLowercaseCodePoint(codePoint); + if (lowercaseCodePoint == CODE_POINT_COMBINED_LOWERCASE_I_DOT) { // Latin capital letter I with dot above is mapped to 2 lowercase characters. sb.appendCodePoint(0x0069); sb.appendCodePoint(0x0307); + } else { + // All other characters should follow context-unaware ICU single-code point case mapping. + sb.appendCodePoint(lowercaseCodePoint); + } + } + + /** + * `CODE_POINT_COMBINED_LOWERCASE_I_DOT` is an internal representation of the combined lowercase + * code point for ASCII lowercase letter i with an additional combining dot character (U+0307). + * This integer value is not a valid code point itself, but rather an artificial code point + * marker used to represent the two lowercase characters that are the result of converting the + * uppercase Turkish dotted letter I with a combining dot character (U+0130) to lowercase. + */ + private static final int CODE_POINT_LOWERCASE_I = 0x69; + private static final int CODE_POINT_COMBINING_DOT = 0x307; + private static final int CODE_POINT_COMBINED_LOWERCASE_I_DOT = + CODE_POINT_LOWERCASE_I << 16 | CODE_POINT_COMBINING_DOT; + + /** + * Returns the lowercase version of the provided code point, with special handling for + * one-to-many case mappings (i.e. characters that map to multiple characters in lowercase) and + * context-insensitive case mappings (i.e. characters that map to different characters based on + * the position in the string relative to other characters in lowercase). + */ + private static int getLowercaseCodePoint(final int codePoint) { + if (codePoint == 0x0130) { + // Latin capital letter I with dot above is mapped to 2 lowercase characters. + return CODE_POINT_COMBINED_LOWERCASE_I_DOT; } else if (codePoint == 0x03C2) { - // Greek final and non-final capital letter sigma should be mapped the same. - sb.appendCodePoint(0x03C3); + // Greek final and non-final letter sigma should be mapped the same. This is achieved by + // mapping Greek small final sigma (U+03C2) to Greek small non-final sigma (U+03C3). Capital + // letter sigma (U+03A3) is mapped to small non-final sigma (U+03C3) in the `else` branch. + return 0x03C3; } else { // All other characters should follow context-unaware ICU single-code point case mapping. - sb.appendCodePoint(UCharacter.toLowerCase(codePoint)); + return UCharacter.toLowerCase(codePoint); } } @@ -444,7 +481,7 @@ else if (codePoint == 0x03C2) { * Converts an entire string to lowercase using ICU rules, code point by code point, with * special handling for one-to-many case mappings (i.e. characters that map to multiple * characters in lowercase). Also, this method omits information about context-sensitive case - * mappings using special handling in the `lowercaseCodePoint` method. + * mappings using special handling in the `appendLowercaseCodePoint` method. * * @param target The target string to convert to lowercase. * @return The string converted to lowercase in a context-unaware manner. @@ -455,10 +492,11 @@ public static UTF8String lowerCaseCodePoints(final UTF8String target) { } private static UTF8String lowerCaseCodePointsSlow(final UTF8String target) { - String targetString = target.toValidString(); + Iterator targetIter = target.codePointIterator( + CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID); StringBuilder sb = new StringBuilder(); - for (int i = 0; i < targetString.length(); ++i) { - lowercaseCodePoint(targetString.codePointAt(i), sb); + while (targetIter.hasNext()) { + appendLowercaseCodePoint(targetIter.next(), sb); } return UTF8String.fromString(sb.toString()); } @@ -655,38 +693,152 @@ public static UTF8String lowercaseSubStringIndex(final UTF8String string, } } - public static Map getCollationAwareDict(UTF8String string, - Map dict, int collationId) { - // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` - String srcStr = string.toString(); + /** + * Converts the original translation dictionary (`dict`) to a dictionary with lowercased keys. + * This method is used to create a dictionary that can be used for the UTF8_LCASE collation. + * Note that `StringTranslate.buildDict` will ensure that all strings are validated properly. + * + * The method returns a map with lowercased code points as keys, while the values remain + * unchanged. Note that `dict` is constructed on a character by character basis, and the + * original keys are stored as strings. Keys in the resulting lowercase dictionary are stored + * as integers, which correspond only to single characters from the original `dict`. Also, + * there is special handling for the Turkish dotted uppercase letter I (U+0130). + */ + private static Map getLowercaseDict(final Map dict) { + // Replace all the keys in the dict with lowercased code points. + Map lowercaseDict = new HashMap<>(); + for (Map.Entry entry : dict.entrySet()) { + int codePoint = entry.getKey().codePointAt(0); + lowercaseDict.putIfAbsent(getLowercaseCodePoint(codePoint), entry.getValue()); + } + return lowercaseDict; + } + + /** + * Translates the `input` string using the translation map `dict`, for UTF8_LCASE collation. + * String translation is performed by iterating over the input string, from left to right, and + * repeatedly translating the longest possible substring that matches a key in the dictionary. + * For UTF8_LCASE, the method uses the lowercased substring to perform the lookup in the + * lowercased version of the translation map. + * + * @param input the string to be translated + * @param dict the lowercase translation dictionary + * @return the translated string + */ + public static UTF8String lowercaseTranslate(final UTF8String input, + final Map dict) { + // Iterator for the input string. + Iterator inputIter = input.codePointIterator( + CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID); + // Lowercased translation dictionary. + Map lowercaseDict = getLowercaseDict(dict); + // StringBuilder to store the translated string. + StringBuilder sb = new StringBuilder(); - Map collationAwareDict = new HashMap<>(); - for (String key : dict.keySet()) { - StringSearch stringSearch = - CollationFactory.getStringSearch(string, UTF8String.fromString(key), collationId); + // We use buffered code point iteration to handle one-to-many case mappings. We need to handle + // at most two code points at a time (for `CODE_POINT_COMBINED_LOWERCASE_I_DOT`), a buffer of + // size 1 enables us to match two codepoints in the input string with a single codepoint in + // the lowercase translation dictionary. + int codePointBuffer = -1, codePoint; + while (inputIter.hasNext()) { + if (codePointBuffer != -1) { + codePoint = codePointBuffer; + codePointBuffer = -1; + } else { + codePoint = inputIter.next(); + } + // Special handling for letter i (U+0069) followed by a combining dot (U+0307). By ensuring + // that `CODE_POINT_LOWERCASE_I` is buffered, we guarantee finding a max-length match. + if (lowercaseDict.containsKey(CODE_POINT_COMBINED_LOWERCASE_I_DOT) && + codePoint == CODE_POINT_LOWERCASE_I && inputIter.hasNext()) { + int nextCodePoint = inputIter.next(); + if (nextCodePoint == CODE_POINT_COMBINING_DOT) { + codePoint = CODE_POINT_COMBINED_LOWERCASE_I_DOT; + } else { + codePointBuffer = nextCodePoint; + } + } + // Translate the code point using the lowercased dictionary. + String translated = lowercaseDict.get(getLowercaseCodePoint(codePoint)); + if (translated == null) { + // Append the original code point if no translation is found. + sb.appendCodePoint(codePoint); + } else if (!"\0".equals(translated)) { + // Append the translated code point if the translation is not the null character. + sb.append(translated); + } + // Skip the code point if it maps to the null character. + } + // Append the last code point if it was buffered. + if (codePointBuffer != -1) sb.appendCodePoint(codePointBuffer); - int pos = 0; - while ((pos = stringSearch.next()) != StringSearch.DONE) { - int codePoint = srcStr.codePointAt(pos); - int charCount = Character.charCount(codePoint); - String newKey = srcStr.substring(pos, pos + charCount); + // Return the translated string. + return UTF8String.fromString(sb.toString()); + } - boolean exists = false; - for (String existingKey : collationAwareDict.keySet()) { - if (stringSearch.getCollator().compare(existingKey, newKey) == 0) { - collationAwareDict.put(newKey, collationAwareDict.get(existingKey)); - exists = true; - break; + /** + * Translates the `input` string using the translation map `dict`, for all ICU collations. + * String translation is performed by iterating over the input string, from left to right, and + * repeatedly translating the longest possible substring that matches a key in the dictionary. + * For ICU collations, the method uses the ICU `StringSearch` class to perform the lookup in + * the translation map, while respecting the rules of the specified ICU collation. + * + * @param input the string to be translated + * @param dict the collation aware translation dictionary + * @param collationId the collation ID to use for string translation + * @return the translated string + */ + public static UTF8String translate(final UTF8String input, + final Map dict, final int collationId) { + // Replace invalid UTF-8 sequences with the Unicode replacement character U+FFFD. + String inputString = input.toValidString(); + // Create a character iterator for the validated input string. This will be used for searching + // inside the string using ICU `StringSearch` class. We only need to do it once before the + // main loop of the translate algorithm. + CharacterIterator target = new StringCharacterIterator(inputString); + Collator collator = CollationFactory.fetchCollation(collationId).collator; + StringBuilder sb = new StringBuilder(); + // Index for the current character in the (validated) input string. This is the character we + // want to determine if we need to replace or not. + int charIndex = 0; + while (charIndex < inputString.length()) { + // We search the replacement dictionary to find a match. If there are more than one matches + // (which is possible for collated strings), we want to choose the match of largest length. + int longestMatchLen = 0; + String longestMatch = ""; + for (String key : dict.keySet()) { + StringSearch stringSearch = new StringSearch(key, target, (RuleBasedCollator) collator); + // Point `stringSearch` to start at the current character. + stringSearch.setIndex(charIndex); + int matchIndex = stringSearch.next(); + if (matchIndex == charIndex) { + // We have found a match (that is the current position matches with one of the characters + // in the dictionary). However, there might be other matches of larger length, so we need + // to continue searching against the characters in the dictionary and keep track of the + // match of largest length. + int matchLen = stringSearch.getMatchLength(); + if (matchLen > longestMatchLen) { + longestMatchLen = matchLen; + longestMatch = key; } } - - if (!exists) { - collationAwareDict.put(newKey, dict.get(key)); + } + if (longestMatchLen == 0) { + // No match was found, so output the current character. + sb.append(inputString.charAt(charIndex)); + // Move on to the next character in the input string. + ++charIndex; + } else { + // We have found at least one match. Append the match of longest match length to the output. + if (!"\0".equals(dict.get(longestMatch))) { + sb.append(dict.get(longestMatch)); } + // Skip as many characters as the longest match. + charIndex += longestMatchLen; } } - - return collationAwareDict; + // Return the translated string. + return UTF8String.fromString(sb.toString()); } public static UTF8String lowercaseTrim( diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 450a3eea1a3a0..f9ccd22f3f5c6 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -212,7 +212,7 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean return useICU ? execBinaryICU(v) : execBinary(v); } else if (collation.supportsLowercaseEquality) { return execLowercase(v); - } else { + } else { return execICU(v, collationId); } } @@ -224,7 +224,7 @@ public static String genCode(final String v, final int collationId, boolean useI return String.format(expr + "%s(%s)", funcName, v); } else if (collation.supportsLowercaseEquality) { return String.format(expr + "Lowercase(%s)", v); - } else { + } else { return String.format(expr + "ICU(%s, %d)", v, collationId); } } @@ -261,7 +261,7 @@ public static String genCode(final String v, final int collationId, boolean useI return String.format(expr + "%s(%s)", funcName, v); } else if (collation.supportsLowercaseEquality) { return String.format(expr + "Lowercase(%s)", v); - } else { + } else { return String.format(expr + "ICU(%s, %d)", v, collationId); } } @@ -522,26 +522,11 @@ public static UTF8String execBinary(final UTF8String source, Map return source.translate(dict); } public static UTF8String execLowercase(final UTF8String source, Map dict) { - String srcStr = source.toString(); - StringBuilder sb = new StringBuilder(); - int charCount = 0; - for (int k = 0; k < srcStr.length(); k += charCount) { - int codePoint = srcStr.codePointAt(k); - charCount = Character.charCount(codePoint); - String subStr = srcStr.substring(k, k + charCount); - String translated = dict.get(subStr.toLowerCase()); - if (null == translated) { - sb.append(subStr); - } else if (!"\0".equals(translated)) { - sb.append(translated); - } - } - return UTF8String.fromString(sb.toString()); + return CollationAwareUTF8String.lowercaseTranslate(source, dict); } public static UTF8String execICU(final UTF8String source, Map dict, final int collationId) { - return source.translate(CollationAwareUTF8String.getCollationAwareDict( - source, dict, collationId)); + return CollationAwareUTF8String.translate(source, dict, collationId); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 9438484344d62..ce0cef3fef307 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -22,6 +22,9 @@ import org.apache.spark.sql.catalyst.util.CollationSupport; import org.junit.jupiter.api.Test; +import java.util.HashMap; +import java.util.Map; + import static org.junit.jupiter.api.Assertions.*; // checkstyle.off: AvoidEscapedUnicodeCharacters @@ -1378,19 +1381,186 @@ public void testStringTrim() throws SparkException { assertStringTrimRight("UTF8_LCASE", "Ëaaaẞ", "Ëẞ", "Ëaaa"); } - // TODO: Test more collation-aware string expressions. - - /** - * Collation-aware regexp expressions. - */ - - // TODO: Test more collation-aware regexp expressions. + private void assertStringTranslate( + String inputString, + String matchingString, + String replaceString, + String collationName, + String expectedResultString) throws SparkException { + int collationId = CollationFactory.collationNameToId(collationName); + Map dict = buildDict(matchingString, replaceString); + UTF8String source = UTF8String.fromString(inputString); + UTF8String result = CollationSupport.StringTranslate.exec(source, dict, collationId); + assertEquals(expectedResultString, result.toString()); + } - /** - * Other collation-aware expressions. - */ + @Test + public void testStringTranslate() throws SparkException { + // Basic tests - UTF8_BINARY. + assertStringTranslate("Translate", "Rnlt", "12", "UTF8_BINARY", "Tra2sae"); + assertStringTranslate("Translate", "Rn", "1234", "UTF8_BINARY", "Tra2slate"); + assertStringTranslate("Translate", "Rnlt", "1234", "UTF8_BINARY", "Tra2s3a4e"); + assertStringTranslate("TRanslate", "rnlt", "XxXx", "UTF8_BINARY", "TRaxsXaxe"); + assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UTF8_BINARY", "TxaxsXaxeX"); + assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UTF8_BINARY", "TXaxsXaxex"); + assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UTF8_BINARY", "test大千世AX大千世A"); + assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UTF8_BINARY", "大千世界test大千世界"); + assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UTF8_BINARY", "Oeso大千世界大千世界"); + assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UTF8_BINARY", "大千世界大千世界oesO"); + assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UTF8_BINARY", "世世世界世世世界tesT"); + assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY", "Tr4234e"); + assertStringTranslate("Translate", "Rnlt", "123495834634", "UTF8_BINARY", "Tra2s3a4e"); + assertStringTranslate("abcdef", "abcde", "123", "UTF8_BINARY", "123f"); + // Basic tests - UTF8_LCASE. + assertStringTranslate("Translate", "Rnlt", "12", "UTF8_LCASE", "1a2sae"); + assertStringTranslate("Translate", "Rn", "1234", "UTF8_LCASE", "T1a2slate"); + assertStringTranslate("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"); + assertStringTranslate("TRanslate", "rnlt", "XxXx", "UTF8_LCASE", "xXaxsXaxe"); + assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UTF8_LCASE", "xxaxsXaxex"); + assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UTF8_LCASE", "xXaxsXaxeX"); + assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UTF8_LCASE", "test大千世AB大千世A"); + assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UTF8_LCASE", "大千世界abca大千世界"); + assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UTF8_LCASE", "oeso大千世界大千世界"); + assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UTF8_LCASE", "大千世界大千世界OesO"); + assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UTF8_LCASE", "世世世界世世世界tesT"); + assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UTF8_LCASE", "14234e"); + assertStringTranslate("Translate", "Rnlt", "123495834634", "UTF8_LCASE", "41a2s3a4e"); + assertStringTranslate("abcdef", "abcde", "123", "UTF8_LCASE", "123f"); + // Basic tests - UNICODE. + assertStringTranslate("Translate", "Rnlt", "12", "UNICODE", "Tra2sae"); + assertStringTranslate("Translate", "Rn", "1234", "UNICODE", "Tra2slate"); + assertStringTranslate("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e"); + assertStringTranslate("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe"); + assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UNICODE", "TxaxsXaxeX"); + assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UNICODE", "TXaxsXaxex"); + assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UNICODE", "test大千世AX大千世A"); + assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UNICODE", "大千世界test大千世界"); + assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UNICODE", "Oeso大千世界大千世界"); + assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UNICODE", "大千世界大千世界oesO"); + assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UNICODE", "世世世界世世世界tesT"); + assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UNICODE", "Tr4234e"); + assertStringTranslate("Translate", "Rnlt", "123495834634", "UNICODE", "Tra2s3a4e"); + assertStringTranslate("abcdef", "abcde", "123", "UNICODE", "123f"); + // Basic tests - UNICODE_CI. + assertStringTranslate("Translate", "Rnlt", "12", "UNICODE_CI", "1a2sae"); + assertStringTranslate("Translate", "Rn", "1234", "UNICODE_CI", "T1a2slate"); + assertStringTranslate("Translate", "Rnlt", "1234", "UNICODE_CI", "41a2s3a4e"); + assertStringTranslate("TRanslate", "rnlt", "XxXx", "UNICODE_CI", "xXaxsXaxe"); + assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UNICODE_CI", "xxaxsXaxex"); + assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UNICODE_CI", "xXaxsXaxeX"); + assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UNICODE_CI", "test大千世AB大千世A"); + assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UNICODE_CI", "大千世界abca大千世界"); + assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UNICODE_CI", "oeso大千世界大千世界"); + assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UNICODE_CI", "大千世界大千世界OesO"); + assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UNICODE_CI", "世世世界世世世界tesT"); + assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UNICODE_CI", "14234e"); + assertStringTranslate("Translate", "Rnlt", "123495834634", "UNICODE_CI", "41a2s3a4e"); + assertStringTranslate("abcdef", "abcde", "123", "UNICODE_CI", "123f"); + + // One-to-many case mapping - UTF8_BINARY. + assertStringTranslate("İ", "i\u0307", "xy", "UTF8_BINARY", "İ"); + assertStringTranslate("i\u0307", "İ", "xy", "UTF8_BINARY", "i\u0307"); + assertStringTranslate("i\u030A", "İ", "x", "UTF8_BINARY", "i\u030A"); + assertStringTranslate("i\u030A", "İi", "xy", "UTF8_BINARY", "y\u030A"); + assertStringTranslate("İi\u0307", "İi\u0307", "123", "UTF8_BINARY", "123"); + assertStringTranslate("İi\u0307", "İyz", "123", "UTF8_BINARY", "1i\u0307"); + assertStringTranslate("İi\u0307", "xi\u0307", "123", "UTF8_BINARY", "İ23"); + assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UTF8_BINARY", "12bc3"); + assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UTF8_BINARY", "a2bcå"); + assertStringTranslate("a\u030AβφδI\u0307", "Iİaå", "1234", "UTF8_BINARY", "3\u030Aβφδ1\u0307"); + // One-to-many case mapping - UTF8_LCASE. + assertStringTranslate("İ", "i\u0307", "xy", "UTF8_LCASE", "İ"); + assertStringTranslate("i\u0307", "İ", "xy", "UTF8_LCASE", "x"); + assertStringTranslate("i\u030A", "İ", "x", "UTF8_LCASE", "i\u030A"); + assertStringTranslate("i\u030A", "İi", "xy", "UTF8_LCASE", "y\u030A"); + assertStringTranslate("İi\u0307", "İi\u0307", "123", "UTF8_LCASE", "11"); + assertStringTranslate("İi\u0307", "İyz", "123", "UTF8_LCASE", "11"); + assertStringTranslate("İi\u0307", "xi\u0307", "123", "UTF8_LCASE", "İ23"); + assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UTF8_LCASE", "12bc3"); + assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UTF8_LCASE", "12bc3"); + assertStringTranslate("A\u030Aβφδi\u0307", "Iİaå", "1234", "UTF8_LCASE", "3\u030Aβφδ2"); + // One-to-many case mapping - UNICODE. + assertStringTranslate("İ", "i\u0307", "xy", "UNICODE", "İ"); + assertStringTranslate("i\u0307", "İ", "xy", "UNICODE", "i\u0307"); + assertStringTranslate("i\u030A", "İ", "x", "UNICODE", "i\u030A"); + assertStringTranslate("i\u030A", "İi", "xy", "UNICODE", "i\u030A"); + assertStringTranslate("İi\u0307", "İi\u0307", "123", "UNICODE", "1i\u0307"); + assertStringTranslate("İi\u0307", "İyz", "123", "UNICODE", "1i\u0307"); + assertStringTranslate("İi\u0307", "xi\u0307", "123", "UNICODE", "İi\u0307"); + assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UNICODE", "3bc3"); + assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UNICODE", "a\u030Abcå"); + assertStringTranslate("a\u030AβφδI\u0307", "Iİaå", "1234", "UNICODE", "4βφδ2"); + // One-to-many case mapping - UNICODE_CI. + assertStringTranslate("İ", "i\u0307", "xy", "UNICODE_CI", "İ"); + assertStringTranslate("i\u0307", "İ", "xy", "UNICODE_CI", "x"); + assertStringTranslate("i\u030A", "İ", "x", "UNICODE_CI", "i\u030A"); + assertStringTranslate("i\u030A", "İi", "xy", "UNICODE_CI", "i\u030A"); + assertStringTranslate("İi\u0307", "İi\u0307", "123", "UNICODE_CI", "11"); + assertStringTranslate("İi\u0307", "İyz", "123", "UNICODE_CI", "11"); + assertStringTranslate("İi\u0307", "xi\u0307", "123", "UNICODE_CI", "İi\u0307"); + assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UNICODE_CI", "3bc3"); + assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UNICODE_CI", "3bc3"); + assertStringTranslate("A\u030Aβφδi\u0307", "Iİaå", "1234", "UNICODE_CI", "4βφδ2"); + + // Greek sigmas - UTF8_BINARY. + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UTF8_BINARY", "σΥσΤΗΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UTF8_BINARY", "ςΥςΤΗΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("συστηματικος", "Συη", "σιι", "UTF8_BINARY", "σιστιματικος"); + assertStringTranslate("συστηματικος", "συη", "σιι", "UTF8_BINARY", "σιστιματικος"); + assertStringTranslate("συστηματικος", "ςυη", "σιι", "UTF8_BINARY", "σιστιματικοσ"); + // Greek sigmas - UTF8_LCASE. + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UTF8_LCASE", "σισΤιΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UTF8_LCASE", "σισΤιΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UTF8_LCASE", "σισΤιΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UTF8_LCASE", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UTF8_LCASE", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UTF8_LCASE", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("συστηματικος", "Συη", "σιι", "UTF8_LCASE", "σιστιματικοσ"); + assertStringTranslate("συστηματικος", "συη", "σιι", "UTF8_LCASE", "σιστιματικοσ"); + assertStringTranslate("συστηματικος", "ςυη", "σιι", "UTF8_LCASE", "σιστιματικοσ"); + // Greek sigmas - UNICODE. + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UNICODE", "σΥσΤΗΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UNICODE", "ςΥςΤΗΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("συστηματικος", "Συη", "σιι", "UNICODE", "σιστιματικος"); + assertStringTranslate("συστηματικος", "συη", "σιι", "UNICODE", "σιστιματικος"); + assertStringTranslate("συστηματικος", "ςυη", "σιι", "UNICODE", "σιστιματικοσ"); + // Greek sigmas - UNICODE_CI. + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UNICODE_CI", "σισΤιΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UNICODE_CI", "σισΤιΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UNICODE_CI", "σισΤιΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UNICODE_CI", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UNICODE_CI", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UNICODE_CI", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("συστηματικος", "Συη", "σιι", "UNICODE_CI", "σιστιματικοσ"); + assertStringTranslate("συστηματικος", "συη", "σιι", "UNICODE_CI", "σιστιματικοσ"); + assertStringTranslate("συστηματικος", "ςυη", "σιι", "UNICODE_CI", "σιστιματικοσ"); + } - // TODO: Test other collation-aware expressions. + private Map buildDict(String matching, String replace) { + Map dict = new HashMap<>(); + int i = 0, j = 0; + while (i < matching.length()) { + String rep = "\u0000"; + if (j < replace.length()) { + int repCharCount = Character.charCount(replace.codePointAt(j)); + rep = replace.substring(j, j + repCharCount); + j += repCharCount; + } + int matchCharCount = Character.charCount(matching.codePointAt(i)); + String matchStr = matching.substring(i, i + matchCharCount); + dict.putIfAbsent(matchStr, rep); + i += matchCharCount; + } + return dict; + } } // checkstyle.on: AvoidEscapedUnicodeCharacters diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b188b9c2630fa..1302ca80e51a3 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1050,15 +1050,35 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len: object StringTranslate { - def buildDict(matchingString: UTF8String, replaceString: UTF8String, collationId: Int) + /** + * Build a translation dictionary from UTF8Strings. First, this method converts the input strings + * to valid Java Strings. However, we avoid any behavior changes for the UTF8_BINARY collation, + * but ensure that all other collations use `UTF8String.toValidString` to achieve this step. + */ + def buildDict(matchingString: UTF8String, replaceString: UTF8String, collationId: Integer) : JMap[String, String] = { - val matching = if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { - matchingString.toString().toLowerCase() + val isCollationAware = collationId == CollationFactory.UTF8_BINARY_COLLATION_ID + val matching: String = if (isCollationAware) { + matchingString.toString + } else { + matchingString.toValidString + } + val replace: String = if (isCollationAware) { + replaceString.toString } else { - matchingString.toString() + replaceString.toValidString } + buildDict(matching, replace) + } - val replace = replaceString.toString() + /** + * Build a translation dictionary from Strings. This method assumes that the input strings are + * already valid. The result dictionary maps each character in `matching` to the corresponding + * character in `replace`. If `replace` is shorter than `matching`, the extra characters in + * `matching` will be mapped to null terminator, which causes characters to get deleted during + * translation. If `replace` is longer than `matching`, the extra characters will be ignored. + */ + private def buildDict(matching: String, replace: String): JMap[String, String] = { val dict = new HashMap[String, String]() var i = 0 var j = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 78aee5b80e549..5f722b2f01fb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -252,55 +252,16 @@ class CollationStringExpressionsSuite } assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } - test("TRANSLATE check result on explicitly collated string") { + + test("Support StringTranslate string expression with collation") { // Supported collations case class TranslateTestCase[R](input: String, matchExpression: String, - replaceExpression: String, collation: String, result: R) + replaceExpression: String, collation: String, result: R) val testCases = Seq( + TranslateTestCase("Translate", "Rnlt", "12", "UTF8_BINARY", "Tra2sae"), TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"), - TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"), - TranslateTestCase("TRanslate", "rnlt", "XxXx", "UTF8_LCASE", "xXaxsXaxe"), - TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UTF8_LCASE", "xxaxsXaxex"), - TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UTF8_LCASE", "xXaxsXaxeX"), - // scalastyle:off - TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UTF8_LCASE", "test大千世AB大千世A"), - TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UTF8_LCASE", "大千世界abca大千世界"), - TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UTF8_LCASE", "oeso大千世界大千世界"), - TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UTF8_LCASE", "大千世界大千世界OesO"), - TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UTF8_LCASE", "世世世界世世世界tesT"), - // scalastyle:on - TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e"), - TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe"), - TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE", "TxaxsXaxeX"), - TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE", "TXaxsXaxex"), - // scalastyle:off - TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE", "test大千世AX大千世A"), - TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE", "Oeso大千世界大千世界"), - TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE", "大千世界大千世界oesO"), - // scalastyle:on - TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE_CI", "41a2s3a4e"), - TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE_CI", "xXaxsXaxe"), - TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE_CI", "xxaxsXaxex"), - TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE_CI", "xXaxsXaxeX"), - // scalastyle:off - TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE_CI", "test大千世AB大千世A"), - TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UNICODE_CI", "大千世界abca大千世界"), - TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE_CI", "oeso大千世界大千世界"), - TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE_CI", "大千世界大千世界OesO"), - TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UNICODE_CI", "世世世界世世世界tesT"), - // scalastyle:on - TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_LCASE", "14234e"), - TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE_CI", "14234e"), - TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE", "Tr4234e"), - TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY", "Tr4234e"), - TranslateTestCase("Translate", "Rnlt", "123495834634", "UTF8_LCASE", "41a2s3a4e"), - TranslateTestCase("Translate", "Rnlt", "123495834634", "UNICODE", "Tra2s3a4e"), - TranslateTestCase("Translate", "Rnlt", "123495834634", "UNICODE_CI", "41a2s3a4e"), - TranslateTestCase("Translate", "Rnlt", "123495834634", "UTF8_BINARY", "Tra2s3a4e"), - TranslateTestCase("abcdef", "abcde", "123", "UTF8_BINARY", "123f"), - TranslateTestCase("abcdef", "abcde", "123", "UTF8_LCASE", "123f"), - TranslateTestCase("abcdef", "abcde", "123", "UNICODE", "123f"), - TranslateTestCase("abcdef", "abcde", "123", "UNICODE_CI", "123f") + TranslateTestCase("Translate", "Rn", "\u0000\u0000", "UNICODE", "Traslate"), + TranslateTestCase("Translate", "Rn", "1234", "UNICODE_CI", "T1a2slate") ) testCases.foreach(t => {