Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47359][SQL] Support TRANSLATE function to work with collated strings #45820

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
6eea5ec
Add support for StringTranslate on collated strings
miland-db Apr 2, 2024
d4da7f7
Merge branch 'master' into string-translate
miland-db Apr 2, 2024
0b41113
Add tests for StringTranslate
miland-db Apr 2, 2024
de2102b
Improve code organization
miland-db Apr 2, 2024
8926c61
Satisfy Java linter
miland-db Apr 2, 2024
f06a7a4
Improve scala/java style
miland-db Apr 3, 2024
db76746
Merge branch 'master' into string-translate
miland-db Apr 3, 2024
653612e
Update getStringSearch naming and visibility
miland-db Apr 3, 2024
dfa51f9
Add doc comment
miland-db Apr 3, 2024
509e460
Fix indentation
miland-db Apr 3, 2024
bb17a4e
Fix indentation in test file
miland-db Apr 3, 2024
591ebdd
Update doc comment
miland-db Apr 4, 2024
f68136e
Merge branch 'master' into string-translate
miland-db Apr 4, 2024
1320ac1
Add StringTranslate to CollationTypeCasts transform method
miland-db Apr 4, 2024
7a53bda
Add empty lines between imports
miland-db Apr 4, 2024
49f2cf3
Handle all collationIds in getStringSearch
miland-db Apr 4, 2024
023d5f4
Merge branch 'master' into string-translate
miland-db Apr 12, 2024
e5efef3
Add StringTranslate functionality
miland-db Apr 12, 2024
6361a1e
Improve scalastyle import
miland-db Apr 13, 2024
431bea6
Improve java style
miland-db Apr 15, 2024
62bff79
Refactor TRANSLATE test
miland-db Apr 15, 2024
a28ca21
Improve java style
miland-db Apr 15, 2024
666fba9
Merge branch 'master' into string-translate
miland-db Apr 16, 2024
c58c6eb
Sync with the latest master and improve formatting
miland-db Apr 16, 2024
5cc1440
Merge branch 'master' into string-translate
miland-db Apr 17, 2024
bdbc32f
Merge branch 'master' into string-translate
miland-db Apr 17, 2024
2b7fdbd
Merge branch 'master' into string-translate
miland-db Apr 26, 2024
7dc19fd
Sync with the latest master
miland-db Apr 26, 2024
b49cf86
Add more test cases
miland-db Apr 30, 2024
36b8746
Merge branch 'master' into string-translate
miland-db Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,26 @@ public static StringSearch getStringSearch(
final UTF8String left,
final UTF8String right,
final int collationId) {

if(collationId == UTF8_BINARY_LCASE_COLLATION_ID) {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
return getStringSearch(left, right);
}

String pattern = right.toString();
CharacterIterator target = new StringCharacterIterator(left.toString());
Collator collator = CollationFactory.fetchCollation(collationId).collator;
return new StringSearch(pattern, target, (RuleBasedCollator) collator);
}

private static StringSearch getStringSearch(
final UTF8String left,
final UTF8String right) {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
String pattern = right.toLowerCase().toString();
String target = left.toLowerCase().toString();

return new StringSearch(pattern, target);
}

/**
* Returns the collation id for the given collation name.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Pattern;

Expand Down Expand Up @@ -1155,6 +1156,45 @@ public UTF8String translate(Map<String, String> dict) {
return fromString(sb.toString());
}

public UTF8String translate(Map<String, String> dict, int collationId) {
if(CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
return translate(dict);
}
return translate(getCollationAwareDict(dict, collationId));
}

private Map<String, String> getCollationAwareDict(Map<String, String> dict, int collationId) {
String srcStr = this.toString();
miland-db marked this conversation as resolved.
Show resolved Hide resolved

Map<String, String> collationAwareDict = new HashMap<>();
for(String key : dict.keySet()) {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
StringSearch stringSearch =
CollationFactory.getStringSearch(this, UTF8String.fromString(key), collationId);

int pos = 0;
while((pos = stringSearch.next()) != StringSearch.DONE) {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
int codePoint = srcStr.codePointAt(pos);
int charCount = Character.charCount(codePoint);
String newKey = srcStr.substring(pos, pos + charCount);

boolean exists = false;
for(String existingKey : collationAwareDict.keySet()) {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
if(stringSearch.getCollator().compare(existingKey, newKey) == 0) {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
collationAwareDict.put(newKey, collationAwareDict.get(existingKey));
exists = true;
break;
}
}

if(!exists) {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
collationAwareDict.put(newKey, dict.get(key));
}
}
}

return collationAwareDict;
}

/**
* Wrapper over `long` to allow result of parsing long from string to be accessed via reference.
* This is done solely for better performance and is not expected to be used by end users.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -934,13 +934,20 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
@transient private var lastReplace: UTF8String = _
@transient private var dict: JMap[String, String] = _

final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId

override def nullSafeEval(srcEval: Any, matchingEval: Any, replaceEval: Any): Any = {
if (matchingEval != lastMatching || replaceEval != lastReplace) {
lastMatching = matchingEval.asInstanceOf[UTF8String].clone()
lastReplace = replaceEval.asInstanceOf[UTF8String].clone()
dict = StringTranslate.buildDict(lastMatching, lastReplace)
}
srcEval.asInstanceOf[UTF8String].translate(dict)

if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
srcEval.asInstanceOf[UTF8String].translate(dict)
} else {
srcEval.asInstanceOf[UTF8String].translate(dict, collationId)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand All @@ -963,13 +970,18 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
$termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate
.buildDict($termLastMatching, $termLastReplace);
}
${ev.value} = $src.translate($termDict);
if (CollationFactory.fetchCollation(${collationId}).supportsBinaryEquality) {
${ev.value} = $src.translate($termDict);
} else {
${ev.value} = $src.translate($termDict, ${collationId});
}
"""
})
}

override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType)
override def dataType: DataType = srcExpr.dataType
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation)
override def first: Expression = srcExpr
override def second: Expression = matchingExpr
override def third: Expression = replaceExpr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import scala.collection.immutable.Seq

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.expressions.{Collation, ExpressionEvalHelper, Literal, StringRepeat}
import org.apache.spark.sql.catalyst.expressions.{Collation, ExpressionEvalHelper, Literal, StringRepeat, StringTranslate}
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StringType
Expand Down Expand Up @@ -73,6 +74,57 @@ class CollationStringExpressionsSuite extends QueryTest
})
}

test("TRANSLATE check result on explicitly collated string") {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
def testTranslate(input: String,
matchExpression: String,
replaceExpression: String,
collationId: Int,
expected: String): Unit = {
miland-db marked this conversation as resolved.
Show resolved Hide resolved
val srcExpr = Literal.create(input, StringType(collationId))
val matchExpr = Literal.create(matchExpression, StringType(collationId))
val replaceExpr = Literal.create(replaceExpression, StringType(collationId))

checkEvaluation(StringTranslate(srcExpr, matchExpr, replaceExpr), expected)
}

var collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE")
testTranslate("Translate", "Rnlt", "1234", collationId, "41a2s3a4e")
testTranslate("TRanslate", "rnlt", "XxXx", collationId, "xXaxsXaxe")
testTranslate("TRanslater", "Rrnlt", "xXxXx", collationId, "xxaxsXaxex")
testTranslate("TRanslater", "Rrnlt", "XxxXx", collationId, "xXaxsXaxeX")
// scalastyle:off
testTranslate("test大千世界X大千世界", "界x", "AB", collationId, "test大千世AB大千世A")
testTranslate("大千世界test大千世界", "TEST", "abcd", collationId, "大千世界abca大千世界")
testTranslate("Test大千世界大千世界", "tT", "oO", collationId, "oeso大千世界大千世界")
testTranslate("大千世界大千世界tesT", "Tt", "Oo", collationId, "大千世界大千世界OesO")
testTranslate("大千世界大千世界tesT", "大千", "世世", collationId, "世世世界世世世界tesT")
// scalastyle:on

collationId = CollationFactory.collationNameToId("UNICODE")
testTranslate("Translate", "Rnlt", "1234", collationId, "Tra2s3a4e")
testTranslate("TRanslate", "rnlt", "XxXx", collationId, "TRaxsXaxe")
testTranslate("TRanslater", "Rrnlt", "xXxXx", collationId, "TxaxsXaxeX")
testTranslate("TRanslater", "Rrnlt", "XxxXx", collationId, "TXaxsXaxex")
// scalastyle:off
testTranslate("test大千世界X大千世界", "界x", "AB", collationId, "test大千世AX大千世A")
testTranslate("Test大千世界大千世界", "tT", "oO", collationId, "Oeso大千世界大千世界")
testTranslate("大千世界大千世界tesT", "Tt", "Oo", collationId, "大千世界大千世界oesO")
// scalastyle:on

collationId = CollationFactory.collationNameToId("UNICODE_CI")
testTranslate("Translate", "Rnlt", "1234", collationId, "41a2s3a4e")
testTranslate("TRanslate", "rnlt", "XxXx", collationId, "xXaxsXaxe")
testTranslate("TRanslater", "Rrnlt", "xXxXx", collationId, "xxaxsXaxex")
testTranslate("TRanslater", "Rrnlt", "XxxXx", collationId, "xXaxsXaxeX")
// scalastyle:off
testTranslate("test大千世界X大千世界", "界x", "AB", collationId, "test大千世AB大千世A")
testTranslate("大千世界test大千世界", "TEST", "abcd", collationId, "大千世界abca大千世界")
testTranslate("Test大千世界大千世界", "tT", "oO", collationId, "oeso大千世界大千世界")
testTranslate("大千世界大千世界tesT", "Tt", "Oo", collationId, "大千世界大千世界OesO")
testTranslate("大千世界大千世界tesT", "大千", "世世", collationId, "世世世界世世世界tesT")
// scalastyle:on
}

test("REPEAT check output type on explicitly collated string") {
def testRepeat(expected: String, collationId: Int, input: String, n: Int): Unit = {
val s = Literal.create(input, StringType(collationId))
Expand Down