Skip to content

Commit

Permalink
[SPARK-48283][SQL] Modify string comparison for UTF8_BINARY_LCASE
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
String comparison and hashing in UTF8_BINARY_LCASE is now context-unaware, and uses ICU root locale rules to convert string to lowercase at code point level, taking into consideration special cases for one-to-many case mapping. For example: comparing "ΘΑΛΑΣΣΙΝΟΣ" and "θαλασσινοσ" under UTF8_BINARY_LCASE now returns true, because Greek final sigma is special-cased in the new comparison implementation.

### Why are the changes needed?
1. UTF8_BINARY_LCASE should use ICU root locale rules (instead of JVM)
2. comparing strings under UTF8_BINARY_LCASE should be context-insensitive

### Does this PR introduce _any_ user-facing change?
Yes, comparing strings under UTF8_BINARY_LCASE will now give different results in two kinds of special cases (Turkish dotted letter "i" and Greek final letter "sigma").

### How was this patch tested?
Unit tests in `CollationSupportSuite`.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#46700 from uros-db/lcase-casing.

Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
uros-db authored and cloud-fan committed Jun 6, 2024
1 parent b5a4b32 commit 84fa052
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,54 @@ private static int lowercaseRFind(
return MATCH_NOT_FOUND;
}

/**
* Lowercase UTF8String comparison used for UTF8_BINARY_LCASE collation. While the default
* UTF8String comparison is equivalent to a.toLowerCase().binaryCompare(b.toLowerCase()), this
* method uses code points to compare the strings in a case-insensitive manner using ICU rules,
* as well as handling special rules for one-to-many case mappings (see: lowerCaseCodePoints).
*
* @param left The first UTF8String to compare.
* @param right The second UTF8String to compare.
* @return An integer representing the comparison result.
*/
public static int compareLowerCase(final UTF8String left, final UTF8String right) {
// Only if both strings are ASCII, we can use faster comparison (no string allocations).
if (left.isFullAscii() && right.isFullAscii()) {
return compareLowerCaseAscii(left, right);
}
return compareLowerCaseSlow(left, right);
}

/**
* Fast version of the `compareLowerCase` method, used when both arguments are ASCII strings.
*
* @param left The first ASCII UTF8String to compare.
* @param right The second ASCII UTF8String to compare.
* @return An integer representing the comparison result.
*/
private static int compareLowerCaseAscii(final UTF8String left, final UTF8String right) {
int leftBytes = left.numBytes(), rightBytes = right.numBytes();
for (int curr = 0; curr < leftBytes && curr < rightBytes; curr++) {
int lowerLeftByte = Character.toLowerCase(left.getByte(curr));
int lowerRightByte = Character.toLowerCase(right.getByte(curr));
if (lowerLeftByte != lowerRightByte) {
return lowerLeftByte - lowerRightByte;
}
}
return leftBytes - rightBytes;
}

/**
* Slow version of the `compareLowerCase` method, used when both arguments are non-ASCII strings.
*
* @param left The first non-ASCII UTF8String to compare.
* @param right The second non-ASCII UTF8String to compare.
* @return An integer representing the comparison result.
*/
private static int compareLowerCaseSlow(final UTF8String left, final UTF8String right) {
return lowerCaseCodePoints(left.toString()).compareTo(lowerCaseCodePoints(right.toString()));
}

public static UTF8String replace(final UTF8String src, final UTF8String search,
final UTF8String replace, final int collationId) {
// This collation aware implementation is based on existing implementation on UTF8String
Expand Down Expand Up @@ -296,6 +344,48 @@ public static String toLowerCase(final String target, final int collationId) {
return UCharacter.toLowerCase(locale, target);
}

/**
* Converts a single code point to lowercase using ICU rules, with special handling for
* one-to-many case mappings (i.e. characters that map to multiple characters in lowercase) and
* context-insensitive case mappings (i.e. characters that map to different characters based on
* string context - e.g. the position in the string relative to other characters).
*
* @param codePoint The code point to convert to lowercase.
* @param sb The StringBuilder to append the lowercase character to.
*/
private static void lowercaseCodePoint(final int codePoint, final StringBuilder sb) {
if (codePoint == 0x0130) {
// Latin capital letter I with dot above is mapped to 2 lowercase characters.
sb.appendCodePoint(0x0069);
sb.appendCodePoint(0x0307);
}
else if (codePoint == 0x03C2) {
// Greek final and non-final capital letter sigma should be mapped the same.
sb.appendCodePoint(0x03C3);
}
else {
// All other characters should follow context-unaware ICU single-code point case mapping.
sb.appendCodePoint(UCharacter.toLowerCase(codePoint));
}
}

/**
* Converts an entire string to lowercase using ICU rules, code point by code point, with
* special handling for one-to-many case mappings (i.e. characters that map to multiple
* characters in lowercase). Also, this method omits information about context-sensitive case
* mappings using special handling in the `lowercaseCodePoint` method.
*
* @param target The target string to convert to lowercase.
* @return The string converted to lowercase in a context-unaware manner.
*/
public static String lowerCaseCodePoints(final String target) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < target.length(); ++i) {
lowercaseCodePoint(target.codePointAt(i), sb);
}
return sb.toString();
}

public static String toTitleCase(final String target, final int collationId) {
ULocale locale = CollationFactory.fetchCollation(collationId)
.collator.getLocale(ULocale.ACTUAL_LOCALE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,9 @@ protected Collation buildCollation() {
"UTF8_BINARY_LCASE",
PROVIDER_SPARK,
null,
UTF8String::compareLowerCase,
CollationAwareUTF8String::compareLowerCase,
"1.0",
s -> (long) s.toLowerCase().hashCode(),
s -> (long) CollationAwareUTF8String.lowerCaseCodePoints(s.toString()).hashCode(),
/* supportsBinaryEquality = */ false,
/* supportsBinaryOrdering = */ false,
/* supportsLowercaseEquality = */ true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,34 +388,6 @@ private UTF8String toUpperCaseSlow() {
return fromString(toString().toUpperCase());
}

/**
* Optimized lowercase comparison for UTF8_BINARY_LCASE collation
* a.compareLowerCase(b) is equivalent to a.toLowerCase().binaryCompare(b.toLowerCase())
*/
public int compareLowerCase(UTF8String other) {
int curr;
for (curr = 0; curr < numBytes && curr < other.numBytes; curr++) {
byte left, right;
if ((left = getByte(curr)) < 0 || (right = other.getByte(curr)) < 0) {
return compareLowerCaseSuffixSlow(other, curr);
}
int lowerLeft = Character.toLowerCase(left);
int lowerRight = Character.toLowerCase(right);
if (lowerLeft != lowerRight) {
return lowerLeft - lowerRight;
}
}
return numBytes - other.numBytes;
}

private int compareLowerCaseSuffixSlow(UTF8String other, int pref) {
UTF8String suffixLeft = UTF8String.fromAddress(base, offset + pref,
numBytes - pref);
UTF8String suffixRight = UTF8String.fromAddress(other.base, other.offset + pref,
other.numBytes - pref);
return suffixLeft.toLowerCaseSlow().binaryCompare(suffixRight.toLowerCaseSlow());
}

/**
* Returns the lower case of this string
*/
Expand All @@ -427,7 +399,7 @@ public UTF8String toLowerCase() {
return isFullAscii() ? toLowerCaseAscii() : toLowerCaseSlow();
}

private boolean isFullAscii() {
public boolean isFullAscii() {
for (var i = 0; i < numBytes; i++) {
if (getByte(i) < 0) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.unsafe.types;

import org.apache.spark.SparkException;
import org.apache.spark.sql.catalyst.util.CollationAwareUTF8String;
import org.apache.spark.sql.catalyst.util.CollationFactory;
import org.apache.spark.sql.catalyst.util.CollationSupport;
import org.junit.jupiter.api.Test;
Expand All @@ -26,6 +27,156 @@
// checkstyle.off: AvoidEscapedUnicodeCharacters
public class CollationSupportSuite {

/**
* A list containing some of the supported collations in Spark. Use this list to iterate over
* all the important collation groups (binary, lowercase, icu) for complete unit test coverage.
* Note: this list may come in handy when the Spark function result is the same regardless of
* the specified collations (as often seen in some pass-through Spark expressions).
*/
private final String[] testSupportedCollations =
{"UTF8_BINARY", "UTF8_BINARY_LCASE", "UNICODE", "UNICODE_CI"};

/**
* Collation-aware UTF8String comparison.
*/

private void assertStringCompare(String s1, String s2, String collationName, int expected)
throws SparkException {
UTF8String l = UTF8String.fromString(s1);
UTF8String r = UTF8String.fromString(s2);
int compare = CollationFactory.fetchCollation(collationName).comparator.compare(l, r);
assertEquals(Integer.signum(expected), Integer.signum(compare));
}

@Test
public void testCompare() throws SparkException {
for (String collationName: testSupportedCollations) {
// Edge cases
assertStringCompare("", "", collationName, 0);
assertStringCompare("a", "", collationName, 1);
assertStringCompare("", "a", collationName, -1);
// Basic tests
assertStringCompare("a", "a", collationName, 0);
assertStringCompare("a", "b", collationName, -1);
assertStringCompare("b", "a", collationName, 1);
assertStringCompare("A", "A", collationName, 0);
assertStringCompare("A", "B", collationName, -1);
assertStringCompare("B", "A", collationName, 1);
assertStringCompare("aa", "a", collationName, 1);
assertStringCompare("b", "bb", collationName, -1);
assertStringCompare("abc", "a", collationName, 1);
assertStringCompare("abc", "b", collationName, -1);
assertStringCompare("abc", "ab", collationName, 1);
assertStringCompare("abc", "abc", collationName, 0);
// ASCII strings
assertStringCompare("aaaa", "aaa", collationName, 1);
assertStringCompare("hello", "world", collationName, -1);
assertStringCompare("Spark", "Spark", collationName, 0);
// Non-ASCII strings
assertStringCompare("ü", "ü", collationName, 0);
assertStringCompare("ü", "", collationName, 1);
assertStringCompare("", "ü", collationName, -1);
assertStringCompare("äü", "äü", collationName, 0);
assertStringCompare("äxx", "äx", collationName, 1);
assertStringCompare("a", "ä", collationName, -1);
}
// Non-ASCII strings
assertStringCompare("äü", "bü", "UTF8_BINARY", 1);
assertStringCompare("bxx", "bü", "UTF8_BINARY", -1);
assertStringCompare("äü", "bü", "UTF8_BINARY_LCASE", 1);
assertStringCompare("bxx", "bü", "UTF8_BINARY_LCASE", -1);
assertStringCompare("äü", "bü", "UNICODE", -1);
assertStringCompare("bxx", "bü", "UNICODE", 1);
assertStringCompare("äü", "bü", "UNICODE_CI", -1);
assertStringCompare("bxx", "bü", "UNICODE_CI", 1);
// Case variation
assertStringCompare("AbCd", "aBcD", "UTF8_BINARY", -1);
assertStringCompare("ABCD", "abcd", "UTF8_BINARY_LCASE", 0);
assertStringCompare("AbcD", "aBCd", "UNICODE", 1);
assertStringCompare("abcd", "ABCD", "UNICODE_CI", 0);
// Accent variation
assertStringCompare("aBćD", "ABĆD", "UTF8_BINARY", 1);
assertStringCompare("AbCδ", "ABCΔ", "UTF8_BINARY_LCASE", 0);
assertStringCompare("äBCd", "ÄBCD", "UNICODE", -1);
assertStringCompare("Ab́cD", "AB́CD", "UNICODE_CI", 0);
// Case-variable character length
assertStringCompare("i\u0307", "İ", "UTF8_BINARY", -1);
assertStringCompare("İ", "i\u0307", "UTF8_BINARY", 1);
assertStringCompare("i\u0307", "İ", "UTF8_BINARY_LCASE", 0);
assertStringCompare("İ", "i\u0307", "UTF8_BINARY_LCASE", 0);
assertStringCompare("i\u0307", "İ", "UNICODE", -1);
assertStringCompare("İ", "i\u0307", "UNICODE", 1);
assertStringCompare("i\u0307", "İ", "UNICODE_CI", 0);
assertStringCompare("İ", "i\u0307", "UNICODE_CI", 0);
assertStringCompare("i\u0307İ", "i\u0307İ", "UTF8_BINARY_LCASE", 0);
assertStringCompare("i\u0307İ", "İi\u0307", "UTF8_BINARY_LCASE", 0);
assertStringCompare("İi\u0307", "i\u0307İ", "UTF8_BINARY_LCASE", 0);
assertStringCompare("İi\u0307", "İi\u0307", "UTF8_BINARY_LCASE", 0);
assertStringCompare("i\u0307İ", "i\u0307İ", "UNICODE_CI", 0);
assertStringCompare("i\u0307İ", "İi\u0307", "UNICODE_CI", 0);
assertStringCompare("İi\u0307", "i\u0307İ", "UNICODE_CI", 0);
assertStringCompare("İi\u0307", "İi\u0307", "UNICODE_CI", 0);
// Conditional case mapping
assertStringCompare("ς", "σ", "UTF8_BINARY", -1);
assertStringCompare("ς", "Σ", "UTF8_BINARY", 1);
assertStringCompare("σ", "Σ", "UTF8_BINARY", 1);
assertStringCompare("ς", "σ", "UTF8_BINARY_LCASE", 0);
assertStringCompare("ς", "Σ", "UTF8_BINARY_LCASE", 0);
assertStringCompare("σ", "Σ", "UTF8_BINARY_LCASE", 0);
assertStringCompare("ς", "σ", "UNICODE", 1);
assertStringCompare("ς", "Σ", "UNICODE", 1);
assertStringCompare("σ", "Σ", "UNICODE", -1);
assertStringCompare("ς", "σ", "UNICODE_CI", 0);
assertStringCompare("ς", "Σ", "UNICODE_CI", 0);
assertStringCompare("σ", "Σ", "UNICODE_CI", 0);
}

private void assertLowerCaseCodePoints(UTF8String target, UTF8String expected,
Boolean useCodePoints) {
if (useCodePoints) {
assertEquals(expected.toString(),
CollationAwareUTF8String.lowerCaseCodePoints(target.toString()));
} else {
assertEquals(expected, target.toLowerCase());
}
}

@Test
public void testLowerCaseCodePoints() {
// Edge cases
assertLowerCaseCodePoints(UTF8String.fromString(""), UTF8String.fromString(""), false);
assertLowerCaseCodePoints(UTF8String.fromString(""), UTF8String.fromString(""), true);
// Basic tests
assertLowerCaseCodePoints(UTF8String.fromString("abcd"), UTF8String.fromString("abcd"), false);
assertLowerCaseCodePoints(UTF8String.fromString("AbCd"), UTF8String.fromString("abcd"), false);
assertLowerCaseCodePoints(UTF8String.fromString("abcd"), UTF8String.fromString("abcd"), true);
assertLowerCaseCodePoints(UTF8String.fromString("aBcD"), UTF8String.fromString("abcd"), true);
// Accent variation
assertLowerCaseCodePoints(UTF8String.fromString("AbĆd"), UTF8String.fromString("abćd"), false);
assertLowerCaseCodePoints(UTF8String.fromString("aBcΔ"), UTF8String.fromString("abcδ"), true);
// Case-variable character length
assertLowerCaseCodePoints(
UTF8String.fromString("İoDiNe"), UTF8String.fromString("i̇odine"), false);
assertLowerCaseCodePoints(
UTF8String.fromString("Abi̇o12"), UTF8String.fromString("abi̇o12"), false);
assertLowerCaseCodePoints(
UTF8String.fromString("İodInE"), UTF8String.fromString("i̇odine"), true);
assertLowerCaseCodePoints(
UTF8String.fromString("aBi̇o12"), UTF8String.fromString("abi̇o12"), true);
// Conditional case mapping
assertLowerCaseCodePoints(
UTF8String.fromString("ΘΑΛΑΣΣΙΝΟΣ"), UTF8String.fromString("θαλασσινος"), false);
assertLowerCaseCodePoints(
UTF8String.fromString("ΘΑΛΑΣΣΙΝΟΣ"), UTF8String.fromString("θαλασσινοσ"), true);
// Surrogate pairs are treated as invalid UTF8 sequences
assertLowerCaseCodePoints(UTF8String.fromBytes(new byte[]
{(byte) 0xED, (byte) 0xA0, (byte) 0x80, (byte) 0xED, (byte) 0xB0, (byte) 0x80}),
UTF8String.fromString("\ufffd\ufffd"), false);
assertLowerCaseCodePoints(UTF8String.fromBytes(new byte[]
{(byte) 0xED, (byte) 0xA0, (byte) 0x80, (byte) 0xED, (byte) 0xB0, (byte) 0x80}),
UTF8String.fromString("\ufffd\ufffd"), true);
}

/**
* Collation-aware string expressions.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,29 +107,6 @@ public void binaryCompareTo() {
assertTrue(fromString("你好123").binaryCompare(fromString("你好122")) > 0);
}

@Test
public void lowercaseComparison() {
// SPARK-47693: Test optimized lowercase comparison of UTF8String instances
// ASCII
assertEquals(fromString("aaa").compareLowerCase(fromString("AAA")), 0);
assertTrue(fromString("aaa").compareLowerCase(fromString("AAAA")) < 0);
assertTrue(fromString("AAA").compareLowerCase(fromString("aaaa")) < 0);
assertTrue(fromString("a").compareLowerCase(fromString("B")) < 0);
assertTrue(fromString("b").compareLowerCase(fromString("A")) > 0);
assertEquals(fromString("aAa").compareLowerCase(fromString("AaA")), 0);
assertTrue(fromString("abcd").compareLowerCase(fromString("abC")) > 0);
assertTrue(fromString("ABC").compareLowerCase(fromString("abcd")) < 0);
assertEquals(fromString("abcd").compareLowerCase(fromString("abcd")), 0);
// non-ASCII
assertEquals(fromString("ü").compareLowerCase(fromString("Ü")), 0);
assertEquals(fromString("Äü").compareLowerCase(fromString("äÜ")), 0);
assertTrue(fromString("a").compareLowerCase(fromString("ä")) < 0);
assertTrue(fromString("a").compareLowerCase(fromString("Ä")) < 0);
assertTrue(fromString("A").compareLowerCase(fromString("ä")) < 0);
assertTrue(fromString("bä").compareLowerCase(fromString("aü")) > 0);
assertTrue(fromString("bxxxxxxxxxx").compareLowerCase(fromString("bü")) < 0);
}

protected static void testUpperandLower(String upper, String lower) {
UTF8String us = fromString(upper);
UTF8String ls = fromString(lower);
Expand Down

0 comments on commit 84fa052

Please sign in to comment.