Skip to content

Commit

Permalink
[SPARK-47359][SQL] Support TRANSLATE function to work with collated s…
Browse files Browse the repository at this point in the history
…trings

### What changes were proposed in this pull request?
Extend built-in string functions to support non-binary, non-lowercase collation for: `translate`

### Why are the changes needed?
Update collation support for built-in string functions in Spark.

### Does this PR introduce _any_ user-facing change?
Yes, users should now be able to use COLLATE within arguments for built-in string function TRANSLATE in Spark SQL queries, using non-binary collations such as UNICODE_CI.

### How was this patch tested?
Unit tests for queries using StringTranslate (CollationStringExpressionsSuite.scala).

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#45820 from miland-db/miland-db/string-translate.

Authored-by: Milan Dankovic <milan.dankovic@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
miland-db authored and JacobZheng0927 committed May 11, 2024
1 parent f00846b commit 54167d7
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import org.apache.spark.unsafe.types.UTF8String;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
Expand Down Expand Up @@ -483,6 +485,56 @@ public static UTF8String execICU(final UTF8String string, final UTF8String delim
}
}

public static class StringTranslate {
public static UTF8String exec(final UTF8String source, Map<String, String> dict,
final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return execBinary(source, dict);
} else if (collation.supportsLowercaseEquality) {
return execLowercase(source, dict);
} else {
return execICU(source, dict, collationId);
}
}
public static String genCode(final String source, final String dict, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.EndsWith.exec";
if (collation.supportsBinaryEquality) {
return String.format(expr + "Binary(%s, %s)", source, dict);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s, %s)", source, dict);
} else {
return String.format(expr + "ICU(%s, %s, %d)", source, dict, collationId);
}
}
public static UTF8String execBinary(final UTF8String source, Map<String, String> dict) {
return source.translate(dict);
}
public static UTF8String execLowercase(final UTF8String source, Map<String, String> dict) {
String srcStr = source.toString();
StringBuilder sb = new StringBuilder();
int charCount = 0;
for (int k = 0; k < srcStr.length(); k += charCount) {
int codePoint = srcStr.codePointAt(k);
charCount = Character.charCount(codePoint);
String subStr = srcStr.substring(k, k + charCount);
String translated = dict.get(subStr.toLowerCase());
if (null == translated) {
sb.append(subStr);
} else if (!"\0".equals(translated)) {
sb.append(translated);
}
}
return UTF8String.fromString(sb.toString());
}
public static UTF8String execICU(final UTF8String source, Map<String, String> dict,
final int collationId) {
return source.translate(CollationAwareUTF8String.getCollationAwareDict(
source, dict, collationId));
}
}

// TODO: Add more collation-aware string expressions.

/**
Expand Down Expand Up @@ -808,6 +860,39 @@ private static UTF8String lowercaseSubStringIndex(final UTF8String string,
}
}

private static Map<String, String> getCollationAwareDict(UTF8String string,
Map<String, String> dict, int collationId) {
String srcStr = string.toString();

Map<String, String> collationAwareDict = new HashMap<>();
for (String key : dict.keySet()) {
StringSearch stringSearch =
CollationFactory.getStringSearch(string, UTF8String.fromString(key), collationId);

int pos = 0;
while ((pos = stringSearch.next()) != StringSearch.DONE) {
int codePoint = srcStr.codePointAt(pos);
int charCount = Character.charCount(codePoint);
String newKey = srcStr.substring(pos, pos + charCount);

boolean exists = false;
for (String existingKey : collationAwareDict.keySet()) {
if (stringSearch.getCollator().compare(existingKey, newKey) == 0) {
collationAwareDict.put(newKey, collationAwareDict.get(existingKey));
exists = true;
break;
}
}

if (!exists) {
collationAwareDict.put(newKey, dict.get(key));
}
}
}

return collationAwareDict;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ object CollationTypeCasts extends TypeCoercionRule {

case otherExpr @ (
_: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least |
_: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask | _: StringReplace) =>
_: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask | _: StringReplace |
_: StringTranslate) =>
val newChildren = collateToSingleType(otherExpr.children)
otherExpr.withNewChildren(newChildren)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationSupport, GenericArrayData, TypeUtils}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, CollationSupport, GenericArrayData, TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation}
Expand Down Expand Up @@ -859,9 +859,14 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len:

object StringTranslate {

def buildDict(matchingString: UTF8String, replaceString: UTF8String)
def buildDict(matchingString: UTF8String, replaceString: UTF8String, collationId: Int)
: JMap[String, String] = {
val matching = matchingString.toString()
val matching = if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) {
matchingString.toString().toLowerCase()
} else {
matchingString.toString()
}

val replace = replaceString.toString()
val dict = new HashMap[String, String]()
var i = 0
Expand Down Expand Up @@ -912,13 +917,16 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
@transient private var lastReplace: UTF8String = _
@transient private var dict: JMap[String, String] = _

final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId

override def nullSafeEval(srcEval: Any, matchingEval: Any, replaceEval: Any): Any = {
if (matchingEval != lastMatching || replaceEval != lastReplace) {
lastMatching = matchingEval.asInstanceOf[UTF8String].clone()
lastReplace = replaceEval.asInstanceOf[UTF8String].clone()
dict = StringTranslate.buildDict(lastMatching, lastReplace)
dict = StringTranslate.buildDict(lastMatching, lastReplace, collationId)
}
srcEval.asInstanceOf[UTF8String].translate(dict)

CollationSupport.StringTranslate.exec(srcEval.asInstanceOf[UTF8String], dict, collationId)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand All @@ -939,15 +947,17 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
$termLastMatching = $matching.clone();
$termLastReplace = $replace.clone();
$termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate
.buildDict($termLastMatching, $termLastReplace);
.buildDict($termLastMatching, $termLastReplace, $collationId);
}
${ev.value} = $src.translate($termDict);
${ev.value} = CollationSupport.StringTranslate.
exec($src, $termDict, $collationId);
"""
})
}

override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType)
override def dataType: DataType = srcExpr.dataType
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation)
override def first: Expression = srcExpr
override def second: Expression = matchingExpr
override def third: Expression = replaceExpr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,80 @@ class CollationStringExpressionsSuite
}
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}
test("TRANSLATE check result on explicitly collated string") {
// Supported collations
case class TranslateTestCase[R](input: String, matchExpression: String,
replaceExpression: String, collation: String, result: R)
val testCases = Seq(
TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_BINARY_LCASE", "41a2s3a4e"),
TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_BINARY_LCASE", "41a2s3a4e"),
TranslateTestCase("TRanslate", "rnlt", "XxXx", "UTF8_BINARY_LCASE", "xXaxsXaxe"),
TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UTF8_BINARY_LCASE", "xxaxsXaxex"),
TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UTF8_BINARY_LCASE", "xXaxsXaxeX"),
// scalastyle:off
TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UTF8_BINARY_LCASE", "test大千世AB大千世A"),
TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UTF8_BINARY_LCASE", "大千世界abca大千世界"),
TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UTF8_BINARY_LCASE", "oeso大千世界大千世界"),
TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UTF8_BINARY_LCASE", "大千世界大千世界OesO"),
TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UTF8_BINARY_LCASE", "世世世界世世世界tesT"),
// scalastyle:on
TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e"),
TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe"),
TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE", "TxaxsXaxeX"),
TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE", "TXaxsXaxex"),
// scalastyle:off
TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE", "test大千世AX大千世A"),
TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE", "Oeso大千世界大千世界"),
TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE", "大千世界大千世界oesO"),
// scalastyle:on
TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE_CI", "41a2s3a4e"),
TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE_CI", "xXaxsXaxe"),
TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE_CI", "xxaxsXaxex"),
TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE_CI", "xXaxsXaxeX"),
// scalastyle:off
TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE_CI", "test大千世AB大千世A"),
TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UNICODE_CI", "大千世界abca大千世界"),
TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE_CI", "oeso大千世界大千世界"),
TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE_CI", "大千世界大千世界OesO"),
TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UNICODE_CI", "世世世界世世世界tesT"),
// scalastyle:on
TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY_LCASE", "14234e"),
TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE_CI", "14234e"),
TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE", "Tr4234e"),
TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY", "Tr4234e"),
TranslateTestCase("Translate", "Rnlt", "123495834634", "UTF8_BINARY_LCASE", "41a2s3a4e"),
TranslateTestCase("Translate", "Rnlt", "123495834634", "UNICODE", "Tra2s3a4e"),
TranslateTestCase("Translate", "Rnlt", "123495834634", "UNICODE_CI", "41a2s3a4e"),
TranslateTestCase("Translate", "Rnlt", "123495834634", "UTF8_BINARY", "Tra2s3a4e"),
TranslateTestCase("abcdef", "abcde", "123", "UTF8_BINARY", "123f"),
TranslateTestCase("abcdef", "abcde", "123", "UTF8_BINARY_LCASE", "123f"),
TranslateTestCase("abcdef", "abcde", "123", "UNICODE", "123f"),
TranslateTestCase("abcdef", "abcde", "123", "UNICODE_CI", "123f")
)

testCases.foreach(t => {
val query = s"SELECT translate(collate('${t.input}', '${t.collation}')," +
s"collate('${t.matchExpression}', '${t.collation}')," +
s"collate('${t.replaceExpression}', '${t.collation}'))"
// Result & data type
checkAnswer(sql(query), Row(t.result))
assert(sql(query).schema.fields.head.dataType.sameType(
StringType(CollationFactory.collationNameToId(t.collation))))
// Implicit casting
checkAnswer(sql(s"SELECT translate(collate('${t.input}', '${t.collation}')," +
s"'${t.matchExpression}', '${t.replaceExpression}')"), Row(t.result))
checkAnswer(sql(s"SELECT translate('${t.input}', collate('${t.matchExpression}'," +
s"'${t.collation}'), '${t.replaceExpression}')"), Row(t.result))
checkAnswer(sql(s"SELECT translate('${t.input}', '${t.matchExpression}'," +
s"collate('${t.replaceExpression}', '${t.collation}'))"), Row(t.result))
})
// Collation mismatch
val collationMismatch = intercept[AnalysisException] {
sql(s"SELECT translate(collate('Translate', 'UTF8_BINARY_LCASE')," +
s"collate('Rnlt', 'UNICODE'), '1234')")
}
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}

test("Support Replace string expression with collation") {
case class ReplaceTestCase[R](source: String, search: String, replace: String,
Expand Down

0 comments on commit 54167d7

Please sign in to comment.