Skip to content

Commit

Permalink
[SPARK-48162][SQL] Add collation support for MISC expressions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Introduce collation awareness for misc expressions: raise_error, uuid, version, typeof, aes_encrypt, aes_decrypt.

### Why are the changes needed?
Add collation support for misc expressions in Spark.

### Does this PR introduce _any_ user-facing change?
Yes, users should now be able to use collated strings within arguments for misc functions: raise_error, uuid, version, typeof, aes_encrypt, aes_decrypt.

### How was this patch tested?
E2e sql tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #46461 from uros-db/misc-expressions.

Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
uros-db authored and cloud-fan committed May 15, 2024
1 parent 7ec37e4 commit 7233540
Show file tree
Hide file tree
Showing 15 changed files with 157 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), GCM, DEFAULT, cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true) AS aes_decrypt(g, g, GCM, DEFAULT, )#0]
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), GCM, DEFAULT, cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, true, true, true) AS aes_decrypt(g, g, GCM, DEFAULT, )#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, DEFAULT, cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true) AS aes_decrypt(g, g, g, DEFAULT, )#0]
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, DEFAULT, cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, true, true, true) AS aes_decrypt(g, g, g, DEFAULT, )#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true) AS aes_decrypt(g, g, g, g, )#0]
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, true, true, true) AS aes_decrypt(g, g, g, g, )#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast(g#0 as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true) AS aes_decrypt(g, g, g, g, g)#0]
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast(g#0 as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, true, true, true) AS aes_decrypt(g, g, g, g, g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), GCM, DEFAULT, cast( as binary), cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, GCM, DEFAULT, , )#0]
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), GCM, DEFAULT, cast( as binary), cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, GCM, DEFAULT, , )#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, DEFAULT, cast( as binary), cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, DEFAULT, , )#0]
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, DEFAULT, cast( as binary), cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, DEFAULT, , )#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast( as binary), cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, g, , )#0]
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast( as binary), cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, g, , )#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, 0x434445, cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, g, X'434445', )#0]
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, 0x434445, cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, g, X'434445', )#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, 0x434445, cast(g#0 as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, g, X'434445', g)#0]
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, 0x434445, cast(g#0 as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, g, X'434445', g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [tryeval(staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), GCM, DEFAULT, cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true)) AS try_aes_decrypt(g, g, GCM, DEFAULT, )#0]
Project [tryeval(staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), GCM, DEFAULT, cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, true, true, true)) AS try_aes_decrypt(g, g, GCM, DEFAULT, )#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [tryeval(staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, DEFAULT, cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true)) AS try_aes_decrypt(g, g, g, DEFAULT, )#0]
Project [tryeval(staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, DEFAULT, cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, true, true, true)) AS try_aes_decrypt(g, g, g, DEFAULT, )#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [tryeval(staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true)) AS try_aes_decrypt(g, g, g, g, )#0]
Project [tryeval(staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, true, true, true)) AS try_aes_decrypt(g, g, g, g, )#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [tryeval(staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast(g#0 as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true)) AS try_aes_decrypt(g, g, g, g, g)#0]
Project [tryeval(staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast(g#0 as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, true, true, true)) AS try_aes_decrypt(g, g, g, g, g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, RandomUUIDGenerator}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -84,7 +85,7 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType:
override def foldable: Boolean = false
override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] =
Seq(StringType, MapType(StringType, StringType))
Seq(StringTypeAnyCollation, MapType(StringType, StringType))

override def left: Expression = errorClass
override def right: Expression = errorParms
Expand Down Expand Up @@ -251,7 +252,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non

override def nullable: Boolean = false

override def dataType: DataType = StringType
override def dataType: DataType = SQLConf.get.defaultStringType

override def stateful: Boolean = true

Expand Down Expand Up @@ -292,7 +293,7 @@ case class SparkVersion() extends LeafExpression with RuntimeReplaceable {

override lazy val replacement: Expression = StaticInvoke(
classOf[ExpressionImplUtils],
StringType,
SQLConf.get.defaultStringType,
"getSparkVersion",
returnNullable = false)
}
Expand All @@ -311,7 +312,7 @@ case class SparkVersion() extends LeafExpression with RuntimeReplaceable {
case class TypeOf(child: Expression) extends UnaryExpression {
override def nullable: Boolean = false
override def foldable: Boolean = true
override def dataType: DataType = StringType
override def dataType: DataType = SQLConf.get.defaultStringType
override def eval(input: InternalRow): Any = UTF8String.fromString(child.dataType.catalogString)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down Expand Up @@ -412,7 +413,8 @@ case class AesEncrypt(
override def prettyName: String = "aes_encrypt"

override def inputTypes: Seq[AbstractDataType] =
Seq(BinaryType, BinaryType, StringType, StringType, BinaryType, BinaryType)
Seq(BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation,
BinaryType, BinaryType)

override def children: Seq[Expression] = Seq(input, key, mode, padding, iv, aad)

Expand Down Expand Up @@ -485,7 +487,7 @@ case class AesDecrypt(
this(input, key, Literal("GCM"))

override def inputTypes: Seq[AbstractDataType] = {
Seq(BinaryType, BinaryType, StringType, StringType, BinaryType)
Seq(BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType)
}

override def prettyName: String = "aes_decrypt"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,142 @@ class CollationSQLExpressionsSuite
})
}

test("Support RaiseError misc expression with collation") {
// Supported collations
case class RaiseErrorTestCase(errorMessage: String, collationName: String)
val testCases = Seq(
RaiseErrorTestCase("custom error message 1", "UTF8_BINARY"),
RaiseErrorTestCase("custom error message 2", "UTF8_BINARY_LCASE"),
RaiseErrorTestCase("custom error message 3", "UNICODE"),
RaiseErrorTestCase("custom error message 4", "UNICODE_CI")
)
testCases.foreach(t => {
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
val query = s"SELECT raise_error('${t.errorMessage}')"
// Result & data type
val userException = intercept[SparkRuntimeException] {
sql(query).collect()
}
assert(userException.getErrorClass === "USER_RAISED_EXCEPTION")
assert(userException.getMessage.contains(t.errorMessage))
}
})
}

test("Support Uuid misc expression with collation") {
// Supported collations
Seq("UTF8_BINARY_LCASE", "UNICODE", "UNICODE_CI").foreach(collationName =>
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) {
val query = s"SELECT uuid()"
// Result & data type
val testQuery = sql(query)
val queryResult = testQuery.collect().head.getString(0)
val uuidFormat = "^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
assert(queryResult.matches(uuidFormat))
val dataType = StringType(collationName)
assert(testQuery.schema.fields.head.dataType.sameType(dataType))
}
)
}

test("Support SparkVersion misc expression with collation") {
// Supported collations
Seq("UTF8_BINARY", "UTF8_BINARY_LCASE", "UNICODE", "UNICODE_CI").foreach(collationName =>
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) {
val query = s"SELECT version()"
// Result & data type
val testQuery = sql(query)
val queryResult = testQuery.collect().head.getString(0)
val versionFormat = "^[0-9]\\.[0-9]\\.[0-9] [0-9a-f]{40}$"
assert(queryResult.matches(versionFormat))
val dataType = StringType(collationName)
assert(testQuery.schema.fields.head.dataType.sameType(dataType))
}
)
}

test("Support TypeOf misc expression with collation") {
// Supported collations
case class TypeOfTestCase(input: String, collationName: String, result: String)
val testCases = Seq(
TypeOfTestCase("1", "UTF8_BINARY", "int"),
TypeOfTestCase("\"A\"", "UTF8_BINARY_LCASE", "string collate UTF8_BINARY_LCASE"),
TypeOfTestCase("array(1)", "UNICODE", "array<int>"),
TypeOfTestCase("null", "UNICODE_CI", "void")
)
testCases.foreach(t => {
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
val query = s"SELECT typeof(${t.input})"
// Result & data type
val testQuery = sql(query)
checkAnswer(testQuery, Row(t.result))
val dataType = StringType(t.collationName)
assert(testQuery.schema.fields.head.dataType.sameType(dataType))
}
})
}

test("Support AesEncrypt misc expression with collation") {
// Supported collations
case class AesEncryptTestCase(
input: String,
collationName: String,
params: String,
result: String
)
val testCases = Seq(
AesEncryptTestCase("Spark", "UTF8_BINARY", "'1234567890abcdef', 'ECB'",
"8DE7DB79A23F3E8ED530994DDEA98913"),
AesEncryptTestCase("Spark", "UTF8_BINARY_LCASE", "'1234567890abcdef', 'ECB', 'DEFAULT', ''",
"8DE7DB79A23F3E8ED530994DDEA98913"),
AesEncryptTestCase("Spark", "UNICODE", "'1234567890abcdef', 'GCM', 'DEFAULT', " +
"unhex('000000000000000000000000')",
"00000000000000000000000046596B2DE09C729FE48A0F81A00A4E7101DABEB61D"),
AesEncryptTestCase("Spark", "UNICODE_CI", "'1234567890abcdef', 'CBC', 'DEFAULT', " +
"unhex('00000000000000000000000000000000')",
"000000000000000000000000000000008DE7DB79A23F3E8ED530994DDEA98913")
)
testCases.foreach(t => {
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
val query = s"SELECT hex(aes_encrypt('${t.input}', ${t.params}))"
// Result & data type
val testQuery = sql(query)
checkAnswer(testQuery, Row(t.result))
val dataType = StringType(t.collationName)
assert(testQuery.schema.fields.head.dataType.sameType(dataType))
}
})
}

test("Support AesDecrypt misc expression with collation") {
// Supported collations
case class AesDecryptTestCase(
input: String,
collationName: String,
params: String,
result: String
)
val testCases = Seq(
AesDecryptTestCase("8DE7DB79A23F3E8ED530994DDEA98913",
"UTF8_BINARY", "'1234567890abcdef', 'ECB'", "Spark"),
AesDecryptTestCase("8DE7DB79A23F3E8ED530994DDEA98913",
"UTF8_BINARY_LCASE", "'1234567890abcdef', 'ECB', 'DEFAULT', ''", "Spark"),
AesDecryptTestCase("00000000000000000000000046596B2DE09C729FE48A0F81A00A4E7101DABEB61D",
"UNICODE", "'1234567890abcdef', 'GCM', 'DEFAULT'", "Spark"),
AesDecryptTestCase("000000000000000000000000000000008DE7DB79A23F3E8ED530994DDEA98913",
"UNICODE_CI", "'1234567890abcdef', 'CBC', 'DEFAULT'", "Spark")
)
testCases.foreach(t => {
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
val query = s"SELECT aes_decrypt(unhex('${t.input}'), ${t.params})"
// Result & data type
val testQuery = sql(query)
checkAnswer(testQuery, sql(s"SELECT to_binary('${t.result}', 'utf-8')"))
assert(testQuery.schema.fields.head.dataType.sameType(BinaryType))
}
})
}

test("Support Mask expression with collation") {
// Supported collations
case class MaskTestCase[R](i: String, u: String, l: String, d: String, o: String, c: String,
Expand Down

0 comments on commit 7233540

Please sign in to comment.