diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index cf3b5c86dcf69..056b202bc3984 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -183,6 +183,54 @@ private static int lowercaseRFind( return MATCH_NOT_FOUND; } + /** + * Lowercase UTF8String comparison used for UTF8_BINARY_LCASE collation. While the default + * UTF8String comparison is equivalent to a.toLowerCase().binaryCompare(b.toLowerCase()), this + * method uses code points to compare the strings in a case-insensitive manner using ICU rules, + * as well as handling special rules for one-to-many case mappings (see: lowerCaseCodePoints). + * + * @param left The first UTF8String to compare. + * @param right The second UTF8String to compare. + * @return An integer representing the comparison result. + */ + public static int compareLowerCase(final UTF8String left, final UTF8String right) { + // Only if both strings are ASCII, we can use faster comparison (no string allocations). + if (left.isFullAscii() && right.isFullAscii()) { + return compareLowerCaseAscii(left, right); + } + return compareLowerCaseSlow(left, right); + } + + /** + * Fast version of the `compareLowerCase` method, used when both arguments are ASCII strings. + * + * @param left The first ASCII UTF8String to compare. + * @param right The second ASCII UTF8String to compare. + * @return An integer representing the comparison result. + */ + private static int compareLowerCaseAscii(final UTF8String left, final UTF8String right) { + int leftBytes = left.numBytes(), rightBytes = right.numBytes(); + for (int curr = 0; curr < leftBytes && curr < rightBytes; curr++) { + int lowerLeftByte = Character.toLowerCase(left.getByte(curr)); + int lowerRightByte = Character.toLowerCase(right.getByte(curr)); + if (lowerLeftByte != lowerRightByte) { + return lowerLeftByte - lowerRightByte; + } + } + return leftBytes - rightBytes; + } + + /** + * Slow version of the `compareLowerCase` method, used when both arguments are non-ASCII strings. + * + * @param left The first non-ASCII UTF8String to compare. + * @param right The second non-ASCII UTF8String to compare. + * @return An integer representing the comparison result. + */ + private static int compareLowerCaseSlow(final UTF8String left, final UTF8String right) { + return lowerCaseCodePoints(left.toString()).compareTo(lowerCaseCodePoints(right.toString())); + } + public static UTF8String replace(final UTF8String src, final UTF8String search, final UTF8String replace, final int collationId) { // This collation aware implementation is based on existing implementation on UTF8String @@ -296,6 +344,48 @@ public static String toLowerCase(final String target, final int collationId) { return UCharacter.toLowerCase(locale, target); } + /** + * Converts a single code point to lowercase using ICU rules, with special handling for + * one-to-many case mappings (i.e. characters that map to multiple characters in lowercase) and + * context-insensitive case mappings (i.e. characters that map to different characters based on + * string context - e.g. the position in the string relative to other characters). + * + * @param codePoint The code point to convert to lowercase. + * @param sb The StringBuilder to append the lowercase character to. + */ + private static void lowercaseCodePoint(final int codePoint, final StringBuilder sb) { + if (codePoint == 0x0130) { + // Latin capital letter I with dot above is mapped to 2 lowercase characters. + sb.appendCodePoint(0x0069); + sb.appendCodePoint(0x0307); + } + else if (codePoint == 0x03C2) { + // Greek final and non-final capital letter sigma should be mapped the same. + sb.appendCodePoint(0x03C3); + } + else { + // All other characters should follow context-unaware ICU single-code point case mapping. + sb.appendCodePoint(UCharacter.toLowerCase(codePoint)); + } + } + + /** + * Converts an entire string to lowercase using ICU rules, code point by code point, with + * special handling for one-to-many case mappings (i.e. characters that map to multiple + * characters in lowercase). Also, this method omits information about context-sensitive case + * mappings using special handling in the `lowercaseCodePoint` method. + * + * @param target The target string to convert to lowercase. + * @return The string converted to lowercase in a context-unaware manner. + */ + public static String lowerCaseCodePoints(final String target) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < target.length(); ++i) { + lowercaseCodePoint(target.codePointAt(i), sb); + } + return sb.toString(); + } + public static String toTitleCase(final String target, final int collationId) { ULocale locale = CollationFactory.fetchCollation(collationId) .collator.getLocale(ULocale.ACTUAL_LOCALE); diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 272bf5ab3e9c2..3c9240678467e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -412,9 +412,9 @@ protected Collation buildCollation() { "UTF8_BINARY_LCASE", PROVIDER_SPARK, null, - UTF8String::compareLowerCase, + CollationAwareUTF8String::compareLowerCase, "1.0", - s -> (long) s.toLowerCase().hashCode(), + s -> (long) CollationAwareUTF8String.lowerCaseCodePoints(s.toString()).hashCode(), /* supportsBinaryEquality = */ false, /* supportsBinaryOrdering = */ false, /* supportsLowercaseEquality = */ true); diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e28dfa910b59e..c0fa2719e4fe6 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -388,34 +388,6 @@ private UTF8String toUpperCaseSlow() { return fromString(toString().toUpperCase()); } - /** - * Optimized lowercase comparison for UTF8_BINARY_LCASE collation - * a.compareLowerCase(b) is equivalent to a.toLowerCase().binaryCompare(b.toLowerCase()) - */ - public int compareLowerCase(UTF8String other) { - int curr; - for (curr = 0; curr < numBytes && curr < other.numBytes; curr++) { - byte left, right; - if ((left = getByte(curr)) < 0 || (right = other.getByte(curr)) < 0) { - return compareLowerCaseSuffixSlow(other, curr); - } - int lowerLeft = Character.toLowerCase(left); - int lowerRight = Character.toLowerCase(right); - if (lowerLeft != lowerRight) { - return lowerLeft - lowerRight; - } - } - return numBytes - other.numBytes; - } - - private int compareLowerCaseSuffixSlow(UTF8String other, int pref) { - UTF8String suffixLeft = UTF8String.fromAddress(base, offset + pref, - numBytes - pref); - UTF8String suffixRight = UTF8String.fromAddress(other.base, other.offset + pref, - other.numBytes - pref); - return suffixLeft.toLowerCaseSlow().binaryCompare(suffixRight.toLowerCaseSlow()); - } - /** * Returns the lower case of this string */ @@ -427,7 +399,7 @@ public UTF8String toLowerCase() { return isFullAscii() ? toLowerCaseAscii() : toLowerCaseSlow(); } - private boolean isFullAscii() { + public boolean isFullAscii() { for (var i = 0; i < numBytes; i++) { if (getByte(i) < 0) { return false; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index b47f95ad7c299..fefa5b52a0c2f 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.unsafe.types; import org.apache.spark.SparkException; +import org.apache.spark.sql.catalyst.util.CollationAwareUTF8String; import org.apache.spark.sql.catalyst.util.CollationFactory; import org.apache.spark.sql.catalyst.util.CollationSupport; import org.junit.jupiter.api.Test; @@ -26,6 +27,156 @@ // checkstyle.off: AvoidEscapedUnicodeCharacters public class CollationSupportSuite { + /** + * A list containing some of the supported collations in Spark. Use this list to iterate over + * all the important collation groups (binary, lowercase, icu) for complete unit test coverage. + * Note: this list may come in handy when the Spark function result is the same regardless of + * the specified collations (as often seen in some pass-through Spark expressions). + */ + private final String[] testSupportedCollations = + {"UTF8_BINARY", "UTF8_BINARY_LCASE", "UNICODE", "UNICODE_CI"}; + + /** + * Collation-aware UTF8String comparison. + */ + + private void assertStringCompare(String s1, String s2, String collationName, int expected) + throws SparkException { + UTF8String l = UTF8String.fromString(s1); + UTF8String r = UTF8String.fromString(s2); + int compare = CollationFactory.fetchCollation(collationName).comparator.compare(l, r); + assertEquals(Integer.signum(expected), Integer.signum(compare)); + } + + @Test + public void testCompare() throws SparkException { + for (String collationName: testSupportedCollations) { + // Edge cases + assertStringCompare("", "", collationName, 0); + assertStringCompare("a", "", collationName, 1); + assertStringCompare("", "a", collationName, -1); + // Basic tests + assertStringCompare("a", "a", collationName, 0); + assertStringCompare("a", "b", collationName, -1); + assertStringCompare("b", "a", collationName, 1); + assertStringCompare("A", "A", collationName, 0); + assertStringCompare("A", "B", collationName, -1); + assertStringCompare("B", "A", collationName, 1); + assertStringCompare("aa", "a", collationName, 1); + assertStringCompare("b", "bb", collationName, -1); + assertStringCompare("abc", "a", collationName, 1); + assertStringCompare("abc", "b", collationName, -1); + assertStringCompare("abc", "ab", collationName, 1); + assertStringCompare("abc", "abc", collationName, 0); + // ASCII strings + assertStringCompare("aaaa", "aaa", collationName, 1); + assertStringCompare("hello", "world", collationName, -1); + assertStringCompare("Spark", "Spark", collationName, 0); + // Non-ASCII strings + assertStringCompare("ü", "ü", collationName, 0); + assertStringCompare("ü", "", collationName, 1); + assertStringCompare("", "ü", collationName, -1); + assertStringCompare("äü", "äü", collationName, 0); + assertStringCompare("äxx", "äx", collationName, 1); + assertStringCompare("a", "ä", collationName, -1); + } + // Non-ASCII strings + assertStringCompare("äü", "bü", "UTF8_BINARY", 1); + assertStringCompare("bxx", "bü", "UTF8_BINARY", -1); + assertStringCompare("äü", "bü", "UTF8_BINARY_LCASE", 1); + assertStringCompare("bxx", "bü", "UTF8_BINARY_LCASE", -1); + assertStringCompare("äü", "bü", "UNICODE", -1); + assertStringCompare("bxx", "bü", "UNICODE", 1); + assertStringCompare("äü", "bü", "UNICODE_CI", -1); + assertStringCompare("bxx", "bü", "UNICODE_CI", 1); + // Case variation + assertStringCompare("AbCd", "aBcD", "UTF8_BINARY", -1); + assertStringCompare("ABCD", "abcd", "UTF8_BINARY_LCASE", 0); + assertStringCompare("AbcD", "aBCd", "UNICODE", 1); + assertStringCompare("abcd", "ABCD", "UNICODE_CI", 0); + // Accent variation + assertStringCompare("aBćD", "ABĆD", "UTF8_BINARY", 1); + assertStringCompare("AbCδ", "ABCΔ", "UTF8_BINARY_LCASE", 0); + assertStringCompare("äBCd", "ÄBCD", "UNICODE", -1); + assertStringCompare("Ab́cD", "AB́CD", "UNICODE_CI", 0); + // Case-variable character length + assertStringCompare("i\u0307", "İ", "UTF8_BINARY", -1); + assertStringCompare("İ", "i\u0307", "UTF8_BINARY", 1); + assertStringCompare("i\u0307", "İ", "UTF8_BINARY_LCASE", 0); + assertStringCompare("İ", "i\u0307", "UTF8_BINARY_LCASE", 0); + assertStringCompare("i\u0307", "İ", "UNICODE", -1); + assertStringCompare("İ", "i\u0307", "UNICODE", 1); + assertStringCompare("i\u0307", "İ", "UNICODE_CI", 0); + assertStringCompare("İ", "i\u0307", "UNICODE_CI", 0); + assertStringCompare("i\u0307İ", "i\u0307İ", "UTF8_BINARY_LCASE", 0); + assertStringCompare("i\u0307İ", "İi\u0307", "UTF8_BINARY_LCASE", 0); + assertStringCompare("İi\u0307", "i\u0307İ", "UTF8_BINARY_LCASE", 0); + assertStringCompare("İi\u0307", "İi\u0307", "UTF8_BINARY_LCASE", 0); + assertStringCompare("i\u0307İ", "i\u0307İ", "UNICODE_CI", 0); + assertStringCompare("i\u0307İ", "İi\u0307", "UNICODE_CI", 0); + assertStringCompare("İi\u0307", "i\u0307İ", "UNICODE_CI", 0); + assertStringCompare("İi\u0307", "İi\u0307", "UNICODE_CI", 0); + // Conditional case mapping + assertStringCompare("ς", "σ", "UTF8_BINARY", -1); + assertStringCompare("ς", "Σ", "UTF8_BINARY", 1); + assertStringCompare("σ", "Σ", "UTF8_BINARY", 1); + assertStringCompare("ς", "σ", "UTF8_BINARY_LCASE", 0); + assertStringCompare("ς", "Σ", "UTF8_BINARY_LCASE", 0); + assertStringCompare("σ", "Σ", "UTF8_BINARY_LCASE", 0); + assertStringCompare("ς", "σ", "UNICODE", 1); + assertStringCompare("ς", "Σ", "UNICODE", 1); + assertStringCompare("σ", "Σ", "UNICODE", -1); + assertStringCompare("ς", "σ", "UNICODE_CI", 0); + assertStringCompare("ς", "Σ", "UNICODE_CI", 0); + assertStringCompare("σ", "Σ", "UNICODE_CI", 0); + } + + private void assertLowerCaseCodePoints(UTF8String target, UTF8String expected, + Boolean useCodePoints) { + if (useCodePoints) { + assertEquals(expected.toString(), + CollationAwareUTF8String.lowerCaseCodePoints(target.toString())); + } else { + assertEquals(expected, target.toLowerCase()); + } + } + + @Test + public void testLowerCaseCodePoints() { + // Edge cases + assertLowerCaseCodePoints(UTF8String.fromString(""), UTF8String.fromString(""), false); + assertLowerCaseCodePoints(UTF8String.fromString(""), UTF8String.fromString(""), true); + // Basic tests + assertLowerCaseCodePoints(UTF8String.fromString("abcd"), UTF8String.fromString("abcd"), false); + assertLowerCaseCodePoints(UTF8String.fromString("AbCd"), UTF8String.fromString("abcd"), false); + assertLowerCaseCodePoints(UTF8String.fromString("abcd"), UTF8String.fromString("abcd"), true); + assertLowerCaseCodePoints(UTF8String.fromString("aBcD"), UTF8String.fromString("abcd"), true); + // Accent variation + assertLowerCaseCodePoints(UTF8String.fromString("AbĆd"), UTF8String.fromString("abćd"), false); + assertLowerCaseCodePoints(UTF8String.fromString("aBcΔ"), UTF8String.fromString("abcδ"), true); + // Case-variable character length + assertLowerCaseCodePoints( + UTF8String.fromString("İoDiNe"), UTF8String.fromString("i̇odine"), false); + assertLowerCaseCodePoints( + UTF8String.fromString("Abi̇o12"), UTF8String.fromString("abi̇o12"), false); + assertLowerCaseCodePoints( + UTF8String.fromString("İodInE"), UTF8String.fromString("i̇odine"), true); + assertLowerCaseCodePoints( + UTF8String.fromString("aBi̇o12"), UTF8String.fromString("abi̇o12"), true); + // Conditional case mapping + assertLowerCaseCodePoints( + UTF8String.fromString("ΘΑΛΑΣΣΙΝΟΣ"), UTF8String.fromString("θαλασσινος"), false); + assertLowerCaseCodePoints( + UTF8String.fromString("ΘΑΛΑΣΣΙΝΟΣ"), UTF8String.fromString("θαλασσινοσ"), true); + // Surrogate pairs are treated as invalid UTF8 sequences + assertLowerCaseCodePoints(UTF8String.fromBytes(new byte[] + {(byte) 0xED, (byte) 0xA0, (byte) 0x80, (byte) 0xED, (byte) 0xB0, (byte) 0x80}), + UTF8String.fromString("\ufffd\ufffd"), false); + assertLowerCaseCodePoints(UTF8String.fromBytes(new byte[] + {(byte) 0xED, (byte) 0xA0, (byte) 0x80, (byte) 0xED, (byte) 0xB0, (byte) 0x80}), + UTF8String.fromString("\ufffd\ufffd"), true); + } + /** * Collation-aware string expressions. */ diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 0188297fd05a2..d3fe361fce37b 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -107,29 +107,6 @@ public void binaryCompareTo() { assertTrue(fromString("你好123").binaryCompare(fromString("你好122")) > 0); } - @Test - public void lowercaseComparison() { - // SPARK-47693: Test optimized lowercase comparison of UTF8String instances - // ASCII - assertEquals(fromString("aaa").compareLowerCase(fromString("AAA")), 0); - assertTrue(fromString("aaa").compareLowerCase(fromString("AAAA")) < 0); - assertTrue(fromString("AAA").compareLowerCase(fromString("aaaa")) < 0); - assertTrue(fromString("a").compareLowerCase(fromString("B")) < 0); - assertTrue(fromString("b").compareLowerCase(fromString("A")) > 0); - assertEquals(fromString("aAa").compareLowerCase(fromString("AaA")), 0); - assertTrue(fromString("abcd").compareLowerCase(fromString("abC")) > 0); - assertTrue(fromString("ABC").compareLowerCase(fromString("abcd")) < 0); - assertEquals(fromString("abcd").compareLowerCase(fromString("abcd")), 0); - // non-ASCII - assertEquals(fromString("ü").compareLowerCase(fromString("Ü")), 0); - assertEquals(fromString("Äü").compareLowerCase(fromString("äÜ")), 0); - assertTrue(fromString("a").compareLowerCase(fromString("ä")) < 0); - assertTrue(fromString("a").compareLowerCase(fromString("Ä")) < 0); - assertTrue(fromString("A").compareLowerCase(fromString("ä")) < 0); - assertTrue(fromString("bä").compareLowerCase(fromString("aü")) > 0); - assertTrue(fromString("bxxxxxxxxxx").compareLowerCase(fromString("bü")) < 0); - } - protected static void testUpperandLower(String upper, String lower) { UTF8String us = fromString(upper); UTF8String ls = fromString(lower);