From aafce7ebffe1acd8f6022f208beaa9ec6c9f7592 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Tue, 10 Sep 2019 08:16:18 +0900 Subject: [PATCH] [SPARK-28412][SQL] ANSI SQL: OVERLAY function support byte array ## What changes were proposed in this pull request? This is a ANSI SQL and feature id is `T312` ``` ::= OVERLAY PLACING FROM [ FOR ] ``` This PR related to https://github.com/apache/spark/pull/24918 and support treat byte array. ref: https://www.postgresql.org/docs/11/functions-binarystring.html ## How was this patch tested? new UT. There are some show of the PR on my production environment. ``` spark-sql> select overlay(encode('Spark SQL', 'utf-8') PLACING encode('_', 'utf-8') FROM 6); Spark_SQL Time taken: 0.285 s spark-sql> select overlay(encode('Spark SQL', 'utf-8') PLACING encode('CORE', 'utf-8') FROM 7); Spark CORE Time taken: 0.202 s spark-sql> select overlay(encode('Spark SQL', 'utf-8') PLACING encode('ANSI ', 'utf-8') FROM 7 FOR 0); Spark ANSI SQL Time taken: 0.165 s spark-sql> select overlay(encode('Spark SQL', 'utf-8') PLACING encode('tructured', 'utf-8') FROM 2 FOR 4); Structured SQL Time taken: 0.141 s ``` Closes #25172 from beliefer/ansi-overlay-byte-array. Lead-authored-by: gengjiaan Co-authored-by: Jiaan Geng Signed-off-by: Takeshi Yamamuro --- .../expressions/stringExpressions.scala | 60 +++++++++++++--- .../expressions/StringExpressionsSuite.scala | 72 ++++++++++++++++++- .../org/apache/spark/sql/functions.scala | 16 ++--- .../spark/sql/StringFunctionsSuite.scala | 33 +++++++-- 4 files changed, 157 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index d7a5fb27a3d56..e4847e9cec3f0 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -472,6 +472,19 @@ object Overlay { builder.append(input.substringSQL(pos + length, Int.MaxValue)) builder.build() } + + def calculate(input: Array[Byte], replace: Array[Byte], pos: Int, len: Int): Array[Byte] = { + // If you specify length, it must be a positive whole number or zero. + // Otherwise it will be ignored. + // The default value for length is the length of replace. + val length = if (len >= 0) { + len + } else { + replace.length + } + ByteArray.concat(ByteArray.subStringSQL(input, 1, pos - 1), + replace, ByteArray.subStringSQL(input, pos + length, Int.MaxValue)) + } } // scalastyle:off line.size.limit @@ -487,6 +500,14 @@ object Overlay { Spark ANSI SQL > SELECT _FUNC_('Spark SQL' PLACING 'tructured' FROM 2 FOR 4); Structured SQL + > SELECT _FUNC_(encode('Spark SQL', 'utf-8') PLACING encode('_', 'utf-8') FROM 6); + Spark_SQL + > SELECT _FUNC_(encode('Spark SQL', 'utf-8') PLACING encode('CORE', 'utf-8') FROM 7); + Spark CORE + > SELECT _FUNC_(encode('Spark SQL', 'utf-8') PLACING encode('ANSI ', 'utf-8') FROM 7 FOR 0); + Spark ANSI SQL + > SELECT _FUNC_(encode('Spark SQL', 'utf-8') PLACING encode('tructured', 'utf-8') FROM 2 FOR 4); + Structured SQL """) // scalastyle:on line.size.limit case class Overlay(input: Expression, replace: Expression, pos: Expression, len: Expression) @@ -496,19 +517,42 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len: this(str, replace, pos, Literal.create(-1, IntegerType)) } - override def dataType: DataType = StringType + override def dataType: DataType = input.dataType - override def inputTypes: Seq[AbstractDataType] = - Seq(StringType, StringType, IntegerType, IntegerType) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType), + TypeCollection(StringType, BinaryType), IntegerType, IntegerType) override def children: Seq[Expression] = input :: replace :: pos :: len :: Nil + override def checkInputDataTypes(): TypeCheckResult = { + val inputTypeCheck = super.checkInputDataTypes() + if (inputTypeCheck.isSuccess) { + TypeUtils.checkForSameTypeInputExpr( + input.dataType :: replace.dataType :: Nil, s"function $prettyName") + } else { + inputTypeCheck + } + } + + private lazy val replaceFunc = input.dataType match { + case StringType => + (inputEval: Any, replaceEval: Any, posEval: Int, lenEval: Int) => { + Overlay.calculate( + inputEval.asInstanceOf[UTF8String], + replaceEval.asInstanceOf[UTF8String], + posEval, lenEval) + } + case BinaryType => + (inputEval: Any, replaceEval: Any, posEval: Int, lenEval: Int) => { + Overlay.calculate( + inputEval.asInstanceOf[Array[Byte]], + replaceEval.asInstanceOf[Array[Byte]], + posEval, lenEval) + } + } + override def nullSafeEval(inputEval: Any, replaceEval: Any, posEval: Any, lenEval: Any): Any = { - val inputStr = inputEval.asInstanceOf[UTF8String] - val replaceStr = replaceEval.asInstanceOf[UTF8String] - val position = posEval.asInstanceOf[Int] - val length = lenEval.asInstanceOf[Int] - Overlay.calculate(inputStr, replaceStr, position, length) + replaceFunc(inputEval, replaceEval, posEval.asInstanceOf[Int], lenEval.asInstanceOf[Int]) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 1b5acf4b0abcc..4308f98d6969a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ @@ -428,7 +429,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:on } - test("overlay") { + test("overlay for string") { checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("_"), Literal.create(6, IntegerType)), "Spark_SQL") checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("CORE"), @@ -450,6 +451,75 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(new Overlay(Literal("Spark的SQL"), Literal("_"), Literal.create(6, IntegerType)), "Spark_SQL") // scalastyle:on + // position greater than the length of input string + checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("_"), + Literal.create(10, IntegerType)), "Spark SQL_") + checkEvaluation(Overlay(Literal("Spark SQL"), Literal("_"), + Literal.create(10, IntegerType), Literal.create(4, IntegerType)), "Spark SQL_") + // position is zero + checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("__"), + Literal.create(0, IntegerType)), "__park SQL") + checkEvaluation(Overlay(Literal("Spark SQL"), Literal("__"), + Literal.create(0, IntegerType), Literal.create(4, IntegerType)), "__rk SQL") + // position is negative + checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("__"), + Literal.create(-10, IntegerType)), "__park SQL") + checkEvaluation(Overlay(Literal("Spark SQL"), Literal("__"), + Literal.create(-10, IntegerType), Literal.create(4, IntegerType)), "__rk SQL") + } + + test("overlay for byte array") { + val input = Literal(Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9)) + checkEvaluation(new Overlay(input, Literal(Array[Byte](-1)), + Literal.create(6, IntegerType)), Array[Byte](1, 2, 3, 4, 5, -1, 7, 8, 9)) + checkEvaluation(new Overlay(input, Literal(Array[Byte](-1, -1, -1, -1)), + Literal.create(7, IntegerType)), Array[Byte](1, 2, 3, 4, 5, 6, -1, -1, -1, -1)) + checkEvaluation(Overlay(input, Literal(Array[Byte](-1, -1)), Literal.create(7, IntegerType), + Literal.create(0, IntegerType)), Array[Byte](1, 2, 3, 4, 5, 6, -1, -1, 7, 8, 9)) + checkEvaluation(Overlay(input, Literal(Array[Byte](-1, -1, -1, -1, -1)), + Literal.create(2, IntegerType), Literal.create(4, IntegerType)), + Array[Byte](1, -1, -1, -1, -1, -1, 6, 7, 8, 9)) + + val nullInput = Literal.create(null, BinaryType) + checkEvaluation(new Overlay(nullInput, Literal(Array[Byte](-1)), + Literal.create(6, IntegerType)), null) + checkEvaluation(new Overlay(nullInput, Literal(Array[Byte](-1, -1, -1, -1)), + Literal.create(7, IntegerType)), null) + checkEvaluation(Overlay(nullInput, Literal(Array[Byte](-1, -1)), + Literal.create(7, IntegerType), Literal.create(0, IntegerType)), null) + checkEvaluation(Overlay(nullInput, Literal(Array[Byte](-1, -1, -1, -1, -1)), + Literal.create(2, IntegerType), Literal.create(4, IntegerType)), null) + // position greater than the length of input byte array + checkEvaluation(new Overlay(input, Literal(Array[Byte](-1)), + Literal.create(10, IntegerType)), Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9, -1)) + checkEvaluation(Overlay(input, Literal(Array[Byte](-1)), Literal.create(10, IntegerType), + Literal.create(4, IntegerType)), Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9, -1)) + // position is zero + checkEvaluation(new Overlay(input, Literal(Array[Byte](-1, -1)), + Literal.create(0, IntegerType)), Array[Byte](-1, -1, 2, 3, 4, 5, 6, 7, 8, 9)) + checkEvaluation(Overlay(input, Literal(Array[Byte](-1, -1)), Literal.create(0, IntegerType), + Literal.create(4, IntegerType)), Array[Byte](-1, -1, 4, 5, 6, 7, 8, 9)) + // position is negative + checkEvaluation(new Overlay(input, Literal(Array[Byte](-1, -1)), + Literal.create(-10, IntegerType)), Array[Byte](-1, -1, 2, 3, 4, 5, 6, 7, 8, 9)) + checkEvaluation(Overlay(input, Literal(Array[Byte](-1, -1)), Literal.create(-10, IntegerType), + Literal.create(4, IntegerType)), Array[Byte](-1, -1, 4, 5, 6, 7, 8, 9)) + } + + test("Check Overlay.checkInputDataTypes results") { + assert(new Overlay(Literal("Spark SQL"), Literal("_"), + Literal.create(6, IntegerType)).checkInputDataTypes().isSuccess) + assert(Overlay(Literal("Spark SQL"), Literal("ANSI "), Literal.create(7, IntegerType), + Literal.create(0, IntegerType)).checkInputDataTypes().isSuccess) + assert(new Overlay(Literal.create("Spark SQL".getBytes), Literal.create("_".getBytes), + Literal.create(6, IntegerType)).checkInputDataTypes().isSuccess) + assert(Overlay(Literal.create("Spark SQL".getBytes), Literal.create("ANSI ".getBytes), + Literal.create(7, IntegerType), Literal.create(0, IntegerType)) + .checkInputDataTypes().isSuccess) + assert(new Overlay(Literal.create(1), Literal.create(2), Literal.create(0, IntegerType)) + .checkInputDataTypes().isFailure) + assert(Overlay(Literal("Spark SQL"), Literal.create(2), Literal.create(7, IntegerType), + Literal.create(0, IntegerType)).checkInputDataTypes().isFailure) } test("translate") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6b8127bab1cb4..395f1b4667b1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2521,25 +2521,25 @@ object functions { } /** - * Overlay the specified portion of `src` with `replaceString`, - * starting from byte position `pos` of `inputString` and proceeding for `len` bytes. + * Overlay the specified portion of `src` with `replace`, + * starting from byte position `pos` of `src` and proceeding for `len` bytes. * * @group string_funcs * @since 3.0.0 */ - def overlay(src: Column, replaceString: String, pos: Int, len: Int): Column = withExpr { - Overlay(src.expr, lit(replaceString).expr, lit(pos).expr, lit(len).expr) + def overlay(src: Column, replace: Column, pos: Column, len: Column): Column = withExpr { + Overlay(src.expr, replace.expr, pos.expr, len.expr) } /** - * Overlay the specified portion of `src` with `replaceString`, - * starting from byte position `pos` of `inputString`. + * Overlay the specified portion of `src` with `replace`, + * starting from byte position `pos` of `src`. * * @group string_funcs * @since 3.0.0 */ - def overlay(src: Column, replaceString: String, pos: Int): Column = withExpr { - new Overlay(src.expr, lit(replaceString).expr, lit(pos).expr) + def overlay(src: Column, replace: Column, pos: Column): Column = withExpr { + new Overlay(src.expr, replace.expr, pos.expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 88b3e5ec61f8a..5049df3219959 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -129,18 +129,37 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { Row("AQIDBA==", bytes)) } - test("overlay function") { + test("string overlay function") { // scalastyle:off // non ascii characters are not allowed in the code, so we disable the scalastyle here. - val df = Seq(("Spark SQL", "Spark的SQL")).toDF("a", "b") - checkAnswer(df.select(overlay($"a", "_", 6)), Row("Spark_SQL")) - checkAnswer(df.select(overlay($"a", "CORE", 7)), Row("Spark CORE")) - checkAnswer(df.select(overlay($"a", "ANSI ", 7, 0)), Row("Spark ANSI SQL")) - checkAnswer(df.select(overlay($"a", "tructured", 2, 4)), Row("Structured SQL")) - checkAnswer(df.select(overlay($"b", "_", 6)), Row("Spark_SQL")) + val df = Seq(("Spark SQL", "Spark的SQL", "_", "CORE", "ANSI ", "tructured", 6, 7, 0, 2, 4)). + toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k") + checkAnswer(df.select(overlay($"a", $"c", $"g")), Row("Spark_SQL")) + checkAnswer(df.select(overlay($"a", $"d", $"h")), Row("Spark CORE")) + checkAnswer(df.select(overlay($"a", $"e", $"h", $"i")), Row("Spark ANSI SQL")) + checkAnswer(df.select(overlay($"a", $"f", $"j", $"k")), Row("Structured SQL")) + checkAnswer(df.select(overlay($"b", $"c", $"g")), Row("Spark_SQL")) // scalastyle:on } + test("binary overlay function") { + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + val df = Seq(( + Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9), + Array[Byte](-1), + Array[Byte](-1, -1, -1, -1), + Array[Byte](-1, -1), + Array[Byte](-1, -1, -1, -1, -1), + 6, 7, 0, 2, 4)).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") + checkAnswer(df.select(overlay($"a", $"b", $"f")), Row(Array[Byte](1, 2, 3, 4, 5, -1, 7, 8, 9))) + checkAnswer(df.select(overlay($"a", $"c", $"g")), + Row(Array[Byte](1, 2, 3, 4, 5, 6, -1, -1, -1, -1))) + checkAnswer(df.select(overlay($"a", $"d", $"g", $"h")), + Row(Array[Byte](1, 2, 3, 4, 5, 6, -1, -1, 7, 8, 9))) + checkAnswer(df.select(overlay($"a", $"e", $"i", $"j")), + Row(Array[Byte](1, -1, -1, -1, -1, -1, 6, 7, 8, 9))) + } + test("string / binary substring function") { // scalastyle:off // non ascii characters are not allowed in the code, so we disable the scalastyle here.