Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8241][SQL] string function: concat_ws. #7504

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -153,6 +153,7 @@ object FunctionRegistry {
expression[Ascii]("ascii"),
expression[Base64]("base64"),
expression[Concat]("concat"),
expression[ConcatWs]("concat_ws"),
expression[Encode]("encode"),
expression[Decode]("decode"),
expression[FormatNumber]("format_number"),
Expand Down Expand Up @@ -211,7 +212,10 @@ object FunctionRegistry {
val builder = (expressions: Seq[Expression]) => {
if (varargCtor.isDefined) {
// If there is an apply method that accepts Seq[Expression], use that one.
varargCtor.get.newInstance(expressions).asInstanceOf[Expression]
Try(varargCtor.get.newInstance(expressions).asInstanceOf[Expression]) match {
case Success(e) => e
case Failure(e) => throw new AnalysisException(e.getMessage)
}
} else {
// Otherwise, find an ctor method that matches the number of arguments, and use that.
val params = Seq.fill(expressions.size)(classOf[Expression])
Expand All @@ -221,7 +225,10 @@ object FunctionRegistry {
case Failure(e) =>
throw new AnalysisException(s"Invalid number of arguments for function $name")
}
f.newInstance(expressions : _*).asInstanceOf[Expression]
Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match {
case Success(e) => e
case Failure(e) => throw new AnalysisException(e.getMessage)
}
}
}
(name, builder)
Expand Down
Expand Up @@ -34,19 +34,14 @@ import org.apache.spark.unsafe.types.UTF8String

/**
* An expression that concatenates multiple input strings into a single string.
* Input expressions that are evaluated to nulls are skipped.
*
* For example, `concat("a", null, "b")` is evaluated to `"ab"`.
*
* Note that this is different from Hive since Hive outputs null if any input is null.
* We never output null.
* If any input is null, concat returns null.
*/
case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType)
override def dataType: DataType = StringType

override def nullable: Boolean = false
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)

override def eval(input: InternalRow): Any = {
Expand All @@ -56,15 +51,76 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas

override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val evals = children.map(_.gen(ctx))
val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.primitive}" }.mkString(", ")
val inputs = evals.map { eval =>
s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
}.mkString(", ")
evals.map(_.code).mkString("\n") + s"""
boolean ${ev.isNull} = false;
UTF8String ${ev.primitive} = UTF8String.concat($inputs);
if (${ev.primitive} == null) {
${ev.isNull} = true;
}
"""
}
}


/**
* An expression that concatenates multiple input strings or array of strings into a single string,
* using a given separator (the first child).
*
* Returns null if the separator is null. Otherwise, concat_ws skips all null values.
*/
case class ConcatWs(children: Seq[Expression])
extends Expression with ImplicitCastInputTypes with CodegenFallback {

require(children.nonEmpty, s"$prettyName requires at least one argument.")

override def prettyName: String = "concat_ws"

/** The 1st child (separator) is str, and rest are either str or array of str. */
override def inputTypes: Seq[AbstractDataType] = {
val arrayOrStr = TypeCollection(ArrayType(StringType), StringType)
StringType +: Seq.fill(children.size - 1)(arrayOrStr)
}

override def dataType: DataType = StringType

override def nullable: Boolean = children.head.nullable
override def foldable: Boolean = children.forall(_.foldable)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

existing: we could use this as the default one for Expression.


override def eval(input: InternalRow): Any = {
val flatInputs = children.flatMap { child =>
child.eval(input) match {
case s: UTF8String => Iterator(s)
case arr: Seq[_] => arr.asInstanceOf[Seq[UTF8String]]
case null => Iterator(null.asInstanceOf[UTF8String])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: Can we just ignore the null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How? It won't match s. Also this thing doesn't compile if I do a wildcard match on s, e.g.

case s: _ => ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean we can return an empty Iterator

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that won't work for the 1st value (separator), unless we special case handle that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That make sense, thanks!

}
}
UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*)
}

override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
if (children.forall(_.dataType == StringType)) {
// All children are strings. In that case we can construct a fixed size array.
val evals = children.map(_.gen(ctx))

val inputs = evals.map { eval =>
s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
}.mkString(", ")

evals.map(_.code).mkString("\n") + s"""
UTF8String ${ev.primitive} = UTF8String.concatWs($inputs);
boolean ${ev.isNull} = ${ev.primitive} == null;
"""
} else {
// Contains a mix of strings and array<string>s. Fall back to interpreted mode for now.
super.genCode(ctx, ev)
}
}
}


trait StringRegexExpression extends ImplicitCastInputTypes {
self: BinaryExpression =>

Expand Down
Expand Up @@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType {

override private[sql] def defaultConcreteType: DataType = this

override private[sql] def acceptsType(other: DataType): Boolean = this == other
override private[sql] def acceptsType(other: DataType): Boolean = sameType(other)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @cloud-fan

This is changed to test sameType in order to support ArrayType(StringType)

}


Expand Down
Expand Up @@ -37,7 +37,6 @@ class HiveTypeCoercionSuite extends PlanTest {
shouldCast(NullType, IntegerType, IntegerType)
shouldCast(NullType, DecimalType, DecimalType.Unlimited)

// TODO: write the entire implicit cast table out for test cases.
shouldCast(ByteType, IntegerType, IntegerType)
shouldCast(IntegerType, IntegerType, IntegerType)
shouldCast(IntegerType, LongType, LongType)
Expand Down Expand Up @@ -86,6 +85,16 @@ class HiveTypeCoercionSuite extends PlanTest {
DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe =>
shouldCast(tpe, NumericType, tpe)
}

shouldCast(
ArrayType(StringType, false),
TypeCollection(ArrayType(StringType), StringType),
ArrayType(StringType, false))

shouldCast(
ArrayType(StringType, true),
TypeCollection(ArrayType(StringType), StringType),
ArrayType(StringType, true))
}

test("ineligible implicit type cast") {
Expand Down
Expand Up @@ -26,7 +26,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("concat") {
def testConcat(inputs: String*): Unit = {
val expected = inputs.filter(_ != null).mkString
val expected = if (inputs.contains(null)) null else inputs.mkString
checkEvaluation(Concat(inputs.map(Literal.create(_, StringType))), expected, EmptyRow)
}

Expand All @@ -46,6 +46,35 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// scalastyle:on
}

test("concat_ws") {
def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = {
val inputExprs = inputs.map {
case s: Seq[_] => Literal.create(s, ArrayType(StringType))
case null => Literal.create(null, StringType)
case s: String => Literal.create(s, StringType)
}
val sepExpr = Literal.create(sep, StringType)
checkEvaluation(ConcatWs(sepExpr +: inputExprs), expected, EmptyRow)
}

// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
testConcatWs(null, null)
testConcatWs(null, null, "a", "b")
testConcatWs("", "")
testConcatWs("ab", "哈哈", "ab")
testConcatWs("a哈哈b", "哈哈", "a", "b")
testConcatWs("a哈哈b", "哈哈", "a", null, "b")
testConcatWs("a哈哈b哈哈c", "哈哈", null, "a", null, "b", "c")

testConcatWs("ab", "哈哈", Seq("ab"))
testConcatWs("a哈哈b", "哈哈", Seq("a", "b"))
testConcatWs("a哈哈b哈哈c哈哈d", "哈哈", Seq("a", null, "b"), null, "c", Seq(null, "d"))
testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq.empty[String])
testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq[String](null))
// scalastyle:on
}

test("StringComparison") {
val row = create_row("abc", null)
val c1 = 'a.string.at(0)
Expand Down
24 changes: 24 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Expand Up @@ -1732,6 +1732,30 @@ object functions {
concat((columnName +: columnNames).map(Column.apply): _*)
}

/**
* Concatenates input strings together into a single string, using the given separator.
*
* @group string_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
def concat_ws(sep: String, exprs: Column*): Column = {
ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr))
}

/**
* Concatenates input strings together into a single string, using the given separator.
*
* This is the variant of concat_ws that takes in the column names.
*
* @group string_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
def concat_ws(sep: String, columnName: String, columnNames: String*): Column = {
concat_ws(sep, (columnName +: columnNames).map(Column.apply) : _*)
}

/**
* Computes the length of a given string / binary value.
*
Expand Down
Expand Up @@ -30,14 +30,25 @@ class StringFunctionsSuite extends QueryTest {
val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")

checkAnswer(
df.select(concat($"a", $"b", $"c")),
Row("ab"))
df.select(concat($"a", $"b"), concat($"a", $"b", $"c")),
Row("ab", null))

checkAnswer(
df.selectExpr("concat(a, b, c)"),
Row("ab"))
df.selectExpr("concat(a, b)", "concat(a, b, c)"),
Row("ab", null))
}

test("string concat_ws") {
val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")

checkAnswer(
df.select(concat_ws("||", $"a", $"b", $"c")),
Row("a||b"))

checkAnswer(
df.selectExpr("concat_ws('||', a, b, c)"),
Row("a||b"))
}

test("string Levenshtein distance") {
val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
Expand Down
Expand Up @@ -263,9 +263,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"timestamp_2",
"timestamp_udf",

// Hive outputs NULL if any concat input has null. We never output null for concat.
"udf_concat",

// Unlike Hive, we do support log base in (0, 1.0], therefore disable this
"udf7"
)
Expand Down Expand Up @@ -856,6 +853,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_case",
"udf_ceil",
"udf_ceiling",
"udf_concat",
"udf_concat_insert1",
"udf_concat_insert2",
"udf_concat_ws",
Expand Down
58 changes: 52 additions & 6 deletions unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
Expand Up @@ -397,33 +397,79 @@ public UTF8String lpad(int len, UTF8String pad) {
}

/**
* Concatenates input strings together into a single string. A null input is skipped.
* For example, concat("a", null, "c") would yield "ac".
* Concatenates input strings together into a single string. Returns null if any input is null.
*/
public static UTF8String concat(UTF8String... inputs) {
if (inputs == null) {
return fromBytes(new byte[0]);
}

// Compute the total length of the result.
int totalLength = 0;
for (int i = 0; i < inputs.length; i++) {
if (inputs[i] != null) {
totalLength += inputs[i].numBytes;
} else {
return null;
}
}

// Allocate a new byte array, and copy the inputs one by one into it.
final byte[] result = new byte[totalLength];
int offset = 0;
for (int i = 0; i < inputs.length; i++) {
int len = inputs[i].numBytes;
PlatformDependent.copyMemory(
inputs[i].base, inputs[i].offset,
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
len);
offset += len;
}
return fromBytes(result);
}

/**
* Concatenates input strings together into a single string using the separator.
* A null input is skipped. For example, concat(",", "a", null, "c") would yield "a,c".
*/
public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) {
if (separator == null) {
return null;
}

int numInputBytes = 0; // total number of bytes from the inputs
int numInputs = 0; // number of non-null inputs
for (int i = 0; i < inputs.length; i++) {
if (inputs[i] != null) {
numInputBytes += inputs[i].numBytes;
numInputs++;
}
}

if (numInputs == 0) {
// Return an empty string if there is no input, or all the inputs are null.
return fromBytes(new byte[0]);
}

// Allocate a new byte array, and copy the inputs one by one into it.
// The size of the new array is the size of all inputs, plus the separators.
final byte[] result = new byte[numInputBytes + (numInputs - 1) * separator.numBytes];
int offset = 0;

for (int i = 0, j = 0; i < inputs.length; i++) {
if (inputs[i] != null) {
int len = inputs[i].numBytes;
PlatformDependent.copyMemory(
inputs[i].base, inputs[i].offset,
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
len);
offset += len;

j++;
// Add separator if this is not the last input.
if (j < numInputs) {
PlatformDependent.copyMemory(
separator.base, separator.offset,
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
separator.numBytes);
offset += separator.numBytes;
}
}
}
return fromBytes(result);
Expand Down