Skip to content

Commit

Permalink
[SPARK-28412][SQL] ANSI SQL: OVERLAY function support byte array
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This is a ANSI SQL and feature id is `T312`

```
<binary overlay function> ::=
OVERLAY <left paren> <binary value expression> PLACING <binary value expression>
FROM <start position> [ FOR <string length> ] <right paren>
```

This PR related to #24918 and support treat byte array.

ref: https://www.postgresql.org/docs/11/functions-binarystring.html

## How was this patch tested?

new UT.
There are some show of the PR on my production environment.
```
spark-sql> select overlay(encode('Spark SQL', 'utf-8') PLACING encode('_', 'utf-8') FROM 6);
Spark_SQL
Time taken: 0.285 s
spark-sql> select overlay(encode('Spark SQL', 'utf-8') PLACING encode('CORE', 'utf-8') FROM 7);
Spark CORE
Time taken: 0.202 s
spark-sql> select overlay(encode('Spark SQL', 'utf-8') PLACING encode('ANSI ', 'utf-8') FROM 7 FOR 0);
Spark ANSI SQL
Time taken: 0.165 s
spark-sql> select overlay(encode('Spark SQL', 'utf-8') PLACING encode('tructured', 'utf-8') FROM 2 FOR 4);
Structured SQL
Time taken: 0.141 s
```

Closes #25172 from beliefer/ansi-overlay-byte-array.

Lead-authored-by: gengjiaan <gengjiaan@360.cn>
Co-authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
  • Loading branch information
2 people authored and maropu committed Sep 9, 2019
1 parent bdc1598 commit aafce7e
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 24 deletions.
Expand Up @@ -472,6 +472,19 @@ object Overlay {
builder.append(input.substringSQL(pos + length, Int.MaxValue))
builder.build()
}

def calculate(input: Array[Byte], replace: Array[Byte], pos: Int, len: Int): Array[Byte] = {
// If you specify length, it must be a positive whole number or zero.
// Otherwise it will be ignored.
// The default value for length is the length of replace.
val length = if (len >= 0) {
len
} else {
replace.length
}
ByteArray.concat(ByteArray.subStringSQL(input, 1, pos - 1),
replace, ByteArray.subStringSQL(input, pos + length, Int.MaxValue))
}
}

// scalastyle:off line.size.limit
Expand All @@ -487,6 +500,14 @@ object Overlay {
Spark ANSI SQL
> SELECT _FUNC_('Spark SQL' PLACING 'tructured' FROM 2 FOR 4);
Structured SQL
> SELECT _FUNC_(encode('Spark SQL', 'utf-8') PLACING encode('_', 'utf-8') FROM 6);
Spark_SQL
> SELECT _FUNC_(encode('Spark SQL', 'utf-8') PLACING encode('CORE', 'utf-8') FROM 7);
Spark CORE
> SELECT _FUNC_(encode('Spark SQL', 'utf-8') PLACING encode('ANSI ', 'utf-8') FROM 7 FOR 0);
Spark ANSI SQL
> SELECT _FUNC_(encode('Spark SQL', 'utf-8') PLACING encode('tructured', 'utf-8') FROM 2 FOR 4);
Structured SQL
""")
// scalastyle:on line.size.limit
case class Overlay(input: Expression, replace: Expression, pos: Expression, len: Expression)
Expand All @@ -496,19 +517,42 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len:
this(str, replace, pos, Literal.create(-1, IntegerType))
}

override def dataType: DataType = StringType
override def dataType: DataType = input.dataType

override def inputTypes: Seq[AbstractDataType] =
Seq(StringType, StringType, IntegerType, IntegerType)
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType),
TypeCollection(StringType, BinaryType), IntegerType, IntegerType)

override def children: Seq[Expression] = input :: replace :: pos :: len :: Nil

override def checkInputDataTypes(): TypeCheckResult = {
val inputTypeCheck = super.checkInputDataTypes()
if (inputTypeCheck.isSuccess) {
TypeUtils.checkForSameTypeInputExpr(
input.dataType :: replace.dataType :: Nil, s"function $prettyName")
} else {
inputTypeCheck
}
}

private lazy val replaceFunc = input.dataType match {
case StringType =>
(inputEval: Any, replaceEval: Any, posEval: Int, lenEval: Int) => {
Overlay.calculate(
inputEval.asInstanceOf[UTF8String],
replaceEval.asInstanceOf[UTF8String],
posEval, lenEval)
}
case BinaryType =>
(inputEval: Any, replaceEval: Any, posEval: Int, lenEval: Int) => {
Overlay.calculate(
inputEval.asInstanceOf[Array[Byte]],
replaceEval.asInstanceOf[Array[Byte]],
posEval, lenEval)
}
}

override def nullSafeEval(inputEval: Any, replaceEval: Any, posEval: Any, lenEval: Any): Any = {
val inputStr = inputEval.asInstanceOf[UTF8String]
val replaceStr = replaceEval.asInstanceOf[UTF8String]
val position = posEval.asInstanceOf[Int]
val length = lenEval.asInstanceOf[Int]
Overlay.calculate(inputStr, replaceStr, position, length)
replaceFunc(inputEval, replaceEval, posEval.asInstanceOf[Int], lenEval.asInstanceOf[Int])
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -428,7 +429,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// scalastyle:on
}

test("overlay") {
test("overlay for string") {
checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("_"),
Literal.create(6, IntegerType)), "Spark_SQL")
checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("CORE"),
Expand All @@ -450,6 +451,75 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(new Overlay(Literal("Spark的SQL"), Literal("_"),
Literal.create(6, IntegerType)), "Spark_SQL")
// scalastyle:on
// position greater than the length of input string
checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("_"),
Literal.create(10, IntegerType)), "Spark SQL_")
checkEvaluation(Overlay(Literal("Spark SQL"), Literal("_"),
Literal.create(10, IntegerType), Literal.create(4, IntegerType)), "Spark SQL_")
// position is zero
checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("__"),
Literal.create(0, IntegerType)), "__park SQL")
checkEvaluation(Overlay(Literal("Spark SQL"), Literal("__"),
Literal.create(0, IntegerType), Literal.create(4, IntegerType)), "__rk SQL")
// position is negative
checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("__"),
Literal.create(-10, IntegerType)), "__park SQL")
checkEvaluation(Overlay(Literal("Spark SQL"), Literal("__"),
Literal.create(-10, IntegerType), Literal.create(4, IntegerType)), "__rk SQL")
}

test("overlay for byte array") {
val input = Literal(Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9))
checkEvaluation(new Overlay(input, Literal(Array[Byte](-1)),
Literal.create(6, IntegerType)), Array[Byte](1, 2, 3, 4, 5, -1, 7, 8, 9))
checkEvaluation(new Overlay(input, Literal(Array[Byte](-1, -1, -1, -1)),
Literal.create(7, IntegerType)), Array[Byte](1, 2, 3, 4, 5, 6, -1, -1, -1, -1))
checkEvaluation(Overlay(input, Literal(Array[Byte](-1, -1)), Literal.create(7, IntegerType),
Literal.create(0, IntegerType)), Array[Byte](1, 2, 3, 4, 5, 6, -1, -1, 7, 8, 9))
checkEvaluation(Overlay(input, Literal(Array[Byte](-1, -1, -1, -1, -1)),
Literal.create(2, IntegerType), Literal.create(4, IntegerType)),
Array[Byte](1, -1, -1, -1, -1, -1, 6, 7, 8, 9))

val nullInput = Literal.create(null, BinaryType)
checkEvaluation(new Overlay(nullInput, Literal(Array[Byte](-1)),
Literal.create(6, IntegerType)), null)
checkEvaluation(new Overlay(nullInput, Literal(Array[Byte](-1, -1, -1, -1)),
Literal.create(7, IntegerType)), null)
checkEvaluation(Overlay(nullInput, Literal(Array[Byte](-1, -1)),
Literal.create(7, IntegerType), Literal.create(0, IntegerType)), null)
checkEvaluation(Overlay(nullInput, Literal(Array[Byte](-1, -1, -1, -1, -1)),
Literal.create(2, IntegerType), Literal.create(4, IntegerType)), null)
// position greater than the length of input byte array
checkEvaluation(new Overlay(input, Literal(Array[Byte](-1)),
Literal.create(10, IntegerType)), Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9, -1))
checkEvaluation(Overlay(input, Literal(Array[Byte](-1)), Literal.create(10, IntegerType),
Literal.create(4, IntegerType)), Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9, -1))
// position is zero
checkEvaluation(new Overlay(input, Literal(Array[Byte](-1, -1)),
Literal.create(0, IntegerType)), Array[Byte](-1, -1, 2, 3, 4, 5, 6, 7, 8, 9))
checkEvaluation(Overlay(input, Literal(Array[Byte](-1, -1)), Literal.create(0, IntegerType),
Literal.create(4, IntegerType)), Array[Byte](-1, -1, 4, 5, 6, 7, 8, 9))
// position is negative
checkEvaluation(new Overlay(input, Literal(Array[Byte](-1, -1)),
Literal.create(-10, IntegerType)), Array[Byte](-1, -1, 2, 3, 4, 5, 6, 7, 8, 9))
checkEvaluation(Overlay(input, Literal(Array[Byte](-1, -1)), Literal.create(-10, IntegerType),
Literal.create(4, IntegerType)), Array[Byte](-1, -1, 4, 5, 6, 7, 8, 9))
}

test("Check Overlay.checkInputDataTypes results") {
assert(new Overlay(Literal("Spark SQL"), Literal("_"),
Literal.create(6, IntegerType)).checkInputDataTypes().isSuccess)
assert(Overlay(Literal("Spark SQL"), Literal("ANSI "), Literal.create(7, IntegerType),
Literal.create(0, IntegerType)).checkInputDataTypes().isSuccess)
assert(new Overlay(Literal.create("Spark SQL".getBytes), Literal.create("_".getBytes),
Literal.create(6, IntegerType)).checkInputDataTypes().isSuccess)
assert(Overlay(Literal.create("Spark SQL".getBytes), Literal.create("ANSI ".getBytes),
Literal.create(7, IntegerType), Literal.create(0, IntegerType))
.checkInputDataTypes().isSuccess)
assert(new Overlay(Literal.create(1), Literal.create(2), Literal.create(0, IntegerType))
.checkInputDataTypes().isFailure)
assert(Overlay(Literal("Spark SQL"), Literal.create(2), Literal.create(7, IntegerType),
Literal.create(0, IntegerType)).checkInputDataTypes().isFailure)
}

test("translate") {
Expand Down
16 changes: 8 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Expand Up @@ -2521,25 +2521,25 @@ object functions {
}

/**
* Overlay the specified portion of `src` with `replaceString`,
* starting from byte position `pos` of `inputString` and proceeding for `len` bytes.
* Overlay the specified portion of `src` with `replace`,
* starting from byte position `pos` of `src` and proceeding for `len` bytes.
*
* @group string_funcs
* @since 3.0.0
*/
def overlay(src: Column, replaceString: String, pos: Int, len: Int): Column = withExpr {
Overlay(src.expr, lit(replaceString).expr, lit(pos).expr, lit(len).expr)
def overlay(src: Column, replace: Column, pos: Column, len: Column): Column = withExpr {
Overlay(src.expr, replace.expr, pos.expr, len.expr)
}

/**
* Overlay the specified portion of `src` with `replaceString`,
* starting from byte position `pos` of `inputString`.
* Overlay the specified portion of `src` with `replace`,
* starting from byte position `pos` of `src`.
*
* @group string_funcs
* @since 3.0.0
*/
def overlay(src: Column, replaceString: String, pos: Int): Column = withExpr {
new Overlay(src.expr, lit(replaceString).expr, lit(pos).expr)
def overlay(src: Column, replace: Column, pos: Column): Column = withExpr {
new Overlay(src.expr, replace.expr, pos.expr)
}

/**
Expand Down
Expand Up @@ -129,18 +129,37 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession {
Row("AQIDBA==", bytes))
}

test("overlay function") {
test("string overlay function") {
// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
val df = Seq(("Spark SQL", "Spark的SQL")).toDF("a", "b")
checkAnswer(df.select(overlay($"a", "_", 6)), Row("Spark_SQL"))
checkAnswer(df.select(overlay($"a", "CORE", 7)), Row("Spark CORE"))
checkAnswer(df.select(overlay($"a", "ANSI ", 7, 0)), Row("Spark ANSI SQL"))
checkAnswer(df.select(overlay($"a", "tructured", 2, 4)), Row("Structured SQL"))
checkAnswer(df.select(overlay($"b", "_", 6)), Row("Spark_SQL"))
val df = Seq(("Spark SQL", "Spark的SQL", "_", "CORE", "ANSI ", "tructured", 6, 7, 0, 2, 4)).
toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k")
checkAnswer(df.select(overlay($"a", $"c", $"g")), Row("Spark_SQL"))
checkAnswer(df.select(overlay($"a", $"d", $"h")), Row("Spark CORE"))
checkAnswer(df.select(overlay($"a", $"e", $"h", $"i")), Row("Spark ANSI SQL"))
checkAnswer(df.select(overlay($"a", $"f", $"j", $"k")), Row("Structured SQL"))
checkAnswer(df.select(overlay($"b", $"c", $"g")), Row("Spark_SQL"))
// scalastyle:on
}

test("binary overlay function") {
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
val df = Seq((
Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9),
Array[Byte](-1),
Array[Byte](-1, -1, -1, -1),
Array[Byte](-1, -1),
Array[Byte](-1, -1, -1, -1, -1),
6, 7, 0, 2, 4)).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j")
checkAnswer(df.select(overlay($"a", $"b", $"f")), Row(Array[Byte](1, 2, 3, 4, 5, -1, 7, 8, 9)))
checkAnswer(df.select(overlay($"a", $"c", $"g")),
Row(Array[Byte](1, 2, 3, 4, 5, 6, -1, -1, -1, -1)))
checkAnswer(df.select(overlay($"a", $"d", $"g", $"h")),
Row(Array[Byte](1, 2, 3, 4, 5, 6, -1, -1, 7, 8, 9)))
checkAnswer(df.select(overlay($"a", $"e", $"i", $"j")),
Row(Array[Byte](1, -1, -1, -1, -1, -1, 6, 7, 8, 9)))
}

test("string / binary substring function") {
// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
Expand Down

0 comments on commit aafce7e

Please sign in to comment.