Skip to content

Commit

Permalink
[SPARK-8241][SQL] string function: concat_ws.
Browse files Browse the repository at this point in the history
I also changed the semantics of concat w.r.t. null back to the same behavior as Hive.
That is to say, concat now returns null if any input is null.

Author: Reynold Xin <rxin@databricks.com>

Closes #7504 from rxin/concat_ws and squashes the following commits:

83fd950 [Reynold Xin] Fixed type casting.
3ae85f7 [Reynold Xin] Write null better.
cdc7be6 [Reynold Xin] Added code generation for pure string mode.
a61c4e4 [Reynold Xin] Updated comments.
2d51406 [Reynold Xin] [SPARK-8241][SQL] string function: concat_ws.
  • Loading branch information
rxin committed Jul 19, 2015
1 parent 7a81245 commit 163e3f1
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ object FunctionRegistry {
expression[Ascii]("ascii"),
expression[Base64]("base64"),
expression[Concat]("concat"),
expression[ConcatWs]("concat_ws"),
expression[Encode]("encode"),
expression[Decode]("decode"),
expression[FormatNumber]("format_number"),
Expand Down Expand Up @@ -211,7 +212,10 @@ object FunctionRegistry {
val builder = (expressions: Seq[Expression]) => {
if (varargCtor.isDefined) {
// If there is an apply method that accepts Seq[Expression], use that one.
varargCtor.get.newInstance(expressions).asInstanceOf[Expression]
Try(varargCtor.get.newInstance(expressions).asInstanceOf[Expression]) match {
case Success(e) => e
case Failure(e) => throw new AnalysisException(e.getMessage)
}
} else {
// Otherwise, find an ctor method that matches the number of arguments, and use that.
val params = Seq.fill(expressions.size)(classOf[Expression])
Expand All @@ -221,7 +225,10 @@ object FunctionRegistry {
case Failure(e) =>
throw new AnalysisException(s"Invalid number of arguments for function $name")
}
f.newInstance(expressions : _*).asInstanceOf[Expression]
Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match {
case Success(e) => e
case Failure(e) => throw new AnalysisException(e.getMessage)
}
}
}
(name, builder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,14 @@ import org.apache.spark.unsafe.types.UTF8String

/**
* An expression that concatenates multiple input strings into a single string.
* Input expressions that are evaluated to nulls are skipped.
*
* For example, `concat("a", null, "b")` is evaluated to `"ab"`.
*
* Note that this is different from Hive since Hive outputs null if any input is null.
* We never output null.
* If any input is null, concat returns null.
*/
case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType)
override def dataType: DataType = StringType

override def nullable: Boolean = false
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)

override def eval(input: InternalRow): Any = {
Expand All @@ -56,15 +51,76 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas

override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val evals = children.map(_.gen(ctx))
val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.primitive}" }.mkString(", ")
val inputs = evals.map { eval =>
s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
}.mkString(", ")
evals.map(_.code).mkString("\n") + s"""
boolean ${ev.isNull} = false;
UTF8String ${ev.primitive} = UTF8String.concat($inputs);
if (${ev.primitive} == null) {
${ev.isNull} = true;
}
"""
}
}


/**
* An expression that concatenates multiple input strings or array of strings into a single string,
* using a given separator (the first child).
*
* Returns null if the separator is null. Otherwise, concat_ws skips all null values.
*/
case class ConcatWs(children: Seq[Expression])
extends Expression with ImplicitCastInputTypes with CodegenFallback {

require(children.nonEmpty, s"$prettyName requires at least one argument.")

override def prettyName: String = "concat_ws"

/** The 1st child (separator) is str, and rest are either str or array of str. */
override def inputTypes: Seq[AbstractDataType] = {
val arrayOrStr = TypeCollection(ArrayType(StringType), StringType)
StringType +: Seq.fill(children.size - 1)(arrayOrStr)
}

override def dataType: DataType = StringType

override def nullable: Boolean = children.head.nullable
override def foldable: Boolean = children.forall(_.foldable)

override def eval(input: InternalRow): Any = {
val flatInputs = children.flatMap { child =>
child.eval(input) match {
case s: UTF8String => Iterator(s)
case arr: Seq[_] => arr.asInstanceOf[Seq[UTF8String]]
case null => Iterator(null.asInstanceOf[UTF8String])
}
}
UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*)
}

override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
if (children.forall(_.dataType == StringType)) {
// All children are strings. In that case we can construct a fixed size array.
val evals = children.map(_.gen(ctx))

val inputs = evals.map { eval =>
s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
}.mkString(", ")

evals.map(_.code).mkString("\n") + s"""
UTF8String ${ev.primitive} = UTF8String.concatWs($inputs);
boolean ${ev.isNull} = ${ev.primitive} == null;
"""
} else {
// Contains a mix of strings and array<string>s. Fall back to interpreted mode for now.
super.genCode(ctx, ev)
}
}
}


trait StringRegexExpression extends ImplicitCastInputTypes {
self: BinaryExpression =>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType {

override private[sql] def defaultConcreteType: DataType = this

override private[sql] def acceptsType(other: DataType): Boolean = this == other
override private[sql] def acceptsType(other: DataType): Boolean = sameType(other)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class HiveTypeCoercionSuite extends PlanTest {
shouldCast(NullType, IntegerType, IntegerType)
shouldCast(NullType, DecimalType, DecimalType.Unlimited)

// TODO: write the entire implicit cast table out for test cases.
shouldCast(ByteType, IntegerType, IntegerType)
shouldCast(IntegerType, IntegerType, IntegerType)
shouldCast(IntegerType, LongType, LongType)
Expand Down Expand Up @@ -86,6 +85,16 @@ class HiveTypeCoercionSuite extends PlanTest {
DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe =>
shouldCast(tpe, NumericType, tpe)
}

shouldCast(
ArrayType(StringType, false),
TypeCollection(ArrayType(StringType), StringType),
ArrayType(StringType, false))

shouldCast(
ArrayType(StringType, true),
TypeCollection(ArrayType(StringType), StringType),
ArrayType(StringType, true))
}

test("ineligible implicit type cast") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("concat") {
def testConcat(inputs: String*): Unit = {
val expected = inputs.filter(_ != null).mkString
val expected = if (inputs.contains(null)) null else inputs.mkString
checkEvaluation(Concat(inputs.map(Literal.create(_, StringType))), expected, EmptyRow)
}

Expand All @@ -46,6 +46,35 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// scalastyle:on
}

test("concat_ws") {
def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = {
val inputExprs = inputs.map {
case s: Seq[_] => Literal.create(s, ArrayType(StringType))
case null => Literal.create(null, StringType)
case s: String => Literal.create(s, StringType)
}
val sepExpr = Literal.create(sep, StringType)
checkEvaluation(ConcatWs(sepExpr +: inputExprs), expected, EmptyRow)
}

// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
testConcatWs(null, null)
testConcatWs(null, null, "a", "b")
testConcatWs("", "")
testConcatWs("ab", "哈哈", "ab")
testConcatWs("a哈哈b", "哈哈", "a", "b")
testConcatWs("a哈哈b", "哈哈", "a", null, "b")
testConcatWs("a哈哈b哈哈c", "哈哈", null, "a", null, "b", "c")

testConcatWs("ab", "哈哈", Seq("ab"))
testConcatWs("a哈哈b", "哈哈", Seq("a", "b"))
testConcatWs("a哈哈b哈哈c哈哈d", "哈哈", Seq("a", null, "b"), null, "c", Seq(null, "d"))
testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq.empty[String])
testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq[String](null))
// scalastyle:on
}

test("StringComparison") {
val row = create_row("abc", null)
val c1 = 'a.string.at(0)
Expand Down
24 changes: 24 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1732,6 +1732,30 @@ object functions {
concat((columnName +: columnNames).map(Column.apply): _*)
}

/**
* Concatenates input strings together into a single string, using the given separator.
*
* @group string_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
def concat_ws(sep: String, exprs: Column*): Column = {
ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr))
}

/**
* Concatenates input strings together into a single string, using the given separator.
*
* This is the variant of concat_ws that takes in the column names.
*
* @group string_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
def concat_ws(sep: String, columnName: String, columnNames: String*): Column = {
concat_ws(sep, (columnName +: columnNames).map(Column.apply) : _*)
}

/**
* Computes the length of a given string / binary value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,25 @@ class StringFunctionsSuite extends QueryTest {
val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")

checkAnswer(
df.select(concat($"a", $"b", $"c")),
Row("ab"))
df.select(concat($"a", $"b"), concat($"a", $"b", $"c")),
Row("ab", null))

checkAnswer(
df.selectExpr("concat(a, b, c)"),
Row("ab"))
df.selectExpr("concat(a, b)", "concat(a, b, c)"),
Row("ab", null))
}

test("string concat_ws") {
val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")

checkAnswer(
df.select(concat_ws("||", $"a", $"b", $"c")),
Row("a||b"))

checkAnswer(
df.selectExpr("concat_ws('||', a, b, c)"),
Row("a||b"))
}

test("string Levenshtein distance") {
val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"timestamp_2",
"timestamp_udf",

// Hive outputs NULL if any concat input has null. We never output null for concat.
"udf_concat",

// Unlike Hive, we do support log base in (0, 1.0], therefore disable this
"udf7"
)
Expand Down Expand Up @@ -856,6 +853,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_case",
"udf_ceil",
"udf_ceiling",
"udf_concat",
"udf_concat_insert1",
"udf_concat_insert2",
"udf_concat_ws",
Expand Down
58 changes: 52 additions & 6 deletions unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
Original file line number Diff line number Diff line change
Expand Up @@ -397,33 +397,79 @@ public UTF8String lpad(int len, UTF8String pad) {
}

/**
* Concatenates input strings together into a single string. A null input is skipped.
* For example, concat("a", null, "c") would yield "ac".
* Concatenates input strings together into a single string. Returns null if any input is null.
*/
public static UTF8String concat(UTF8String... inputs) {
if (inputs == null) {
return fromBytes(new byte[0]);
}

// Compute the total length of the result.
int totalLength = 0;
for (int i = 0; i < inputs.length; i++) {
if (inputs[i] != null) {
totalLength += inputs[i].numBytes;
} else {
return null;
}
}

// Allocate a new byte array, and copy the inputs one by one into it.
final byte[] result = new byte[totalLength];
int offset = 0;
for (int i = 0; i < inputs.length; i++) {
int len = inputs[i].numBytes;
PlatformDependent.copyMemory(
inputs[i].base, inputs[i].offset,
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
len);
offset += len;
}
return fromBytes(result);
}

/**
* Concatenates input strings together into a single string using the separator.
* A null input is skipped. For example, concat(",", "a", null, "c") would yield "a,c".
*/
public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) {
if (separator == null) {
return null;
}

int numInputBytes = 0; // total number of bytes from the inputs
int numInputs = 0; // number of non-null inputs
for (int i = 0; i < inputs.length; i++) {
if (inputs[i] != null) {
numInputBytes += inputs[i].numBytes;
numInputs++;
}
}

if (numInputs == 0) {
// Return an empty string if there is no input, or all the inputs are null.
return fromBytes(new byte[0]);
}

// Allocate a new byte array, and copy the inputs one by one into it.
// The size of the new array is the size of all inputs, plus the separators.
final byte[] result = new byte[numInputBytes + (numInputs - 1) * separator.numBytes];
int offset = 0;

for (int i = 0, j = 0; i < inputs.length; i++) {
if (inputs[i] != null) {
int len = inputs[i].numBytes;
PlatformDependent.copyMemory(
inputs[i].base, inputs[i].offset,
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
len);
offset += len;

j++;
// Add separator if this is not the last input.
if (j < numInputs) {
PlatformDependent.copyMemory(
separator.base, separator.offset,
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
separator.numBytes);
offset += separator.numBytes;
}
}
}
return fromBytes(result);
Expand Down
Loading

0 comments on commit 163e3f1

Please sign in to comment.