Skip to content

Commit

Permalink
[SPARK-8241][SQL] string function: concat_ws.
Browse files Browse the repository at this point in the history
I also changed the semantics of concat w.r.t. null back to the same behavior as Hive.
That is to say, concat now returns null if any input is null.
  • Loading branch information
rxin committed Jul 19, 2015
1 parent a803ac3 commit 2d51406
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ object FunctionRegistry {
expression[Ascii]("ascii"),
expression[Base64]("base64"),
expression[Concat]("concat"),
expression[ConcatWs]("concat_ws"),
expression[Encode]("encode"),
expression[Decode]("decode"),
expression[FormatNumber]("format_number"),
Expand Down Expand Up @@ -211,7 +212,10 @@ object FunctionRegistry {
val builder = (expressions: Seq[Expression]) => {
if (varargCtor.isDefined) {
// If there is an apply method that accepts Seq[Expression], use that one.
varargCtor.get.newInstance(expressions).asInstanceOf[Expression]
Try(varargCtor.get.newInstance(expressions).asInstanceOf[Expression]) match {
case Success(e) => e
case Failure(e) => throw new AnalysisException(e.getMessage)
}
} else {
// Otherwise, find an ctor method that matches the number of arguments, and use that.
val params = Seq.fill(expressions.size)(classOf[Expression])
Expand All @@ -221,7 +225,10 @@ object FunctionRegistry {
case Failure(e) =>
throw new AnalysisException(s"Invalid number of arguments for function $name")
}
f.newInstance(expressions : _*).asInstanceOf[Expression]
Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match {
case Success(e) => e
case Failure(e) => throw new AnalysisException(e.getMessage)
}
}
}
(name, builder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,14 @@ import org.apache.spark.unsafe.types.UTF8String

/**
* An expression that concatenates multiple input strings into a single string.
* Input expressions that are evaluated to nulls are skipped.
*
* For example, `concat("a", null, "b")` is evaluated to `"ab"`.
*
* Note that this is different from Hive since Hive outputs null if any input is null.
* We never output null.
* If any input is null, concat returns null.
*/
case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType)
override def dataType: DataType = StringType

override def nullable: Boolean = false
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)

override def eval(input: InternalRow): Any = {
Expand All @@ -56,15 +51,50 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas

override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val evals = children.map(_.gen(ctx))
val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.primitive}" }.mkString(", ")
val inputs = evals.map { eval =>
s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
}.mkString(", ")
evals.map(_.code).mkString("\n") + s"""
boolean ${ev.isNull} = false;
UTF8String ${ev.primitive} = UTF8String.concat($inputs);
if (${ev.primitive} == null) {
${ev.isNull} = true;
}
"""
}
}


case class ConcatWs(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes with CodegenFallback {
require(children.nonEmpty, s"$prettyName requires at least one argument.")

override def prettyName: String = "concat_ws"

override def inputTypes: Seq[AbstractDataType] = {
Seq.fill(children.size)(TypeCollection(
ArrayType(StringType, true),
ArrayType(StringType, false),
StringType))
}

override def dataType: DataType = StringType

override def nullable: Boolean = children.head.nullable
override def foldable: Boolean = children.forall(_.foldable)

override def eval(input: InternalRow): Any = {
val flatInputs = children.flatMap { child =>
child.eval(input) match {
case s: UTF8String => Iterator(s)
case arr: Seq[_] => arr.asInstanceOf[Seq[UTF8String]]
case null => Iterator(null.asInstanceOf[UTF8String])
}
}
UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*)
}
}


trait StringRegexExpression extends ImplicitCastInputTypes {
self: BinaryExpression =>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("concat") {
def testConcat(inputs: String*): Unit = {
val expected = inputs.filter(_ != null).mkString
val expected = if (inputs.contains(null)) null else inputs.mkString
checkEvaluation(Concat(inputs.map(Literal.create(_, StringType))), expected, EmptyRow)
}

Expand All @@ -46,6 +46,36 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// scalastyle:on
}

test("concat_ws") {
def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = {
val inputExprs = inputs.map {
case s: Seq[_] => Literal.create(s, ArrayType(StringType))
case null => Literal.create(null, StringType)
case s: String => Literal.create(s, StringType)
}
val sepExpr = Literal.create(sep, StringType)
checkEvaluation(ConcatWs(sepExpr +: inputExprs), expected, EmptyRow)
}

// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
testConcatWs(null, null)
testConcatWs(null, null, "a", "b")
testConcatWs("", "")
testConcatWs("ab", "哈哈", "ab")
testConcatWs("a哈哈b", "哈哈", "a", "b")
testConcatWs("a哈哈b", "哈哈", "a", null, "b")
testConcatWs("a哈哈b哈哈c", "哈哈", null, "a", null, "b", "c")

testConcatWs("ab", "哈哈", Seq("ab"))
testConcatWs("a哈哈b", "哈哈", Seq("a", "b"))
testConcatWs("a哈哈b哈哈c哈哈d", "哈哈", Seq("a", null, "b"), null, "c", Seq(null, "d"))
testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq.empty[String])
testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq[String](null))
// scalastyle:on

}

test("StringComparison") {
val row = create_row("abc", null)
val c1 = 'a.string.at(0)
Expand Down
24 changes: 24 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1732,6 +1732,30 @@ object functions {
concat((columnName +: columnNames).map(Column.apply): _*)
}

/**
* Concatenates input strings together into a single string, using the given separator.
*
* @group string_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
def concat_ws(sep: String, exprs: Column*): Column = {
ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr))
}

/**
* Concatenates input strings together into a single string, using the given separator.
*
* This is the variant of concat_ws that takes in the column names.
*
* @group string_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
def concat_ws(sep: String, columnName: String, columnNames: String*): Column = {
concat_ws(sep, (columnName +: columnNames).map(Column.apply) : _*)
}

/**
* Computes the length of a given string / binary value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,25 @@ class StringFunctionsSuite extends QueryTest {
val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")

checkAnswer(
df.select(concat($"a", $"b", $"c")),
Row("ab"))
df.select(concat($"a", $"b"), concat($"a", $"b", $"c")),
Row("ab", null))

checkAnswer(
df.selectExpr("concat(a, b, c)"),
Row("ab"))
df.selectExpr("concat(a, b)", "concat(a, b, c)"),
Row("ab", null))
}

test("string concat_ws") {
val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")

checkAnswer(
df.select(concat_ws("||", $"a", $"b", $"c")),
Row("a||b"))

checkAnswer(
df.selectExpr("concat_ws('||', a, b, c)"),
Row("a||b"))
}

test("string Levenshtein distance") {
val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"timestamp_2",
"timestamp_udf",

// Hive outputs NULL if any concat input has null. We never output null for concat.
"udf_concat",

// Unlike Hive, we do support log base in (0, 1.0], therefore disable this
"udf7"
)
Expand Down Expand Up @@ -856,6 +853,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_case",
"udf_ceil",
"udf_ceiling",
"udf_concat",
"udf_concat_insert1",
"udf_concat_insert2",
"udf_concat_ws",
Expand Down
58 changes: 52 additions & 6 deletions unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
Original file line number Diff line number Diff line change
Expand Up @@ -397,33 +397,79 @@ public UTF8String lpad(int len, UTF8String pad) {
}

/**
* Concatenates input strings together into a single string. A null input is skipped.
* For example, concat("a", null, "c") would yield "ac".
* Concatenates input strings together into a single string. Returns null if any input is null.
*/
public static UTF8String concat(UTF8String... inputs) {
if (inputs == null) {
return fromBytes(new byte[0]);
}

// Compute the total length of the result.
int totalLength = 0;
for (int i = 0; i < inputs.length; i++) {
if (inputs[i] != null) {
totalLength += inputs[i].numBytes;
} else {
return null;
}
}

// Allocate a new byte array, and copy the inputs one by one into it.
final byte[] result = new byte[totalLength];
int offset = 0;
for (int i = 0; i < inputs.length; i++) {
int len = inputs[i].numBytes;
PlatformDependent.copyMemory(
inputs[i].base, inputs[i].offset,
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
len);
offset += len;
}
return fromBytes(result);
}

/**
* Concatenates input strings together into a single string using the separator.
* A null input is skipped. For example, concat(",", "a", null, "c") would yield "a,c".
*/
public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) {
if (separator == null) {
return null;
}

int numInputBytes = 0; // total number of bytes from the inputs
int numInputs = 0; // number of non-null inputs
for (int i = 0; i < inputs.length; i++) {
if (inputs[i] != null) {
numInputBytes += inputs[i].numBytes;
numInputs++;
}
}

if (numInputs == 0) {
// Return an empty string if there is no input, or all the inputs are null.
return fromBytes(new byte[0]);
}

// Allocate a new byte array, and copy the inputs one by one into it.
// The size of the new array is the size of all inputs, plus the separators.
final byte[] result = new byte[numInputBytes + (numInputs - 1) * separator.numBytes];
int offset = 0;

for (int i = 0, j = 0; i < inputs.length; i++) {
if (inputs[i] != null) {
int len = inputs[i].numBytes;
PlatformDependent.copyMemory(
inputs[i].base, inputs[i].offset,
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
len);
offset += len;

j++;
// Add separator if this is not the last input.
if (j < numInputs) {
PlatformDependent.copyMemory(
separator.base, separator.offset,
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
separator.numBytes);
offset += separator.numBytes;
}
}
}
return fromBytes(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,50 @@ public void upperAndLower() {

@Test
public void concatTest() {
assertEquals(concat(), fromString(""));
assertEquals(concat(null), fromString(""));
assertEquals(concat(fromString("")), fromString(""));
assertEquals(concat(fromString("ab")), fromString("ab"));
assertEquals(concat(fromString("a"), fromString("b")), fromString("ab"));
assertEquals(concat(fromString("a"), fromString("b"), fromString("c")), fromString("abc"));
assertEquals(concat(fromString("a"), null, fromString("c")), fromString("ac"));
assertEquals(concat(fromString("a"), null, null), fromString("a"));
assertEquals(concat(null, null, null), fromString(""));
assertEquals(concat(fromString("数据"), fromString("砖头")), fromString("数据砖头"));
assertEquals(fromString(""), concat());
assertEquals(null, concat((UTF8String) null));
assertEquals(fromString(""), concat(fromString("")));
assertEquals(fromString("ab"), concat(fromString("ab")));
assertEquals(fromString("ab"), concat(fromString("a"), fromString("b")));
assertEquals(fromString("abc"), concat(fromString("a"), fromString("b"), fromString("c")));
assertEquals(null, concat(fromString("a"), null, fromString("c")));
assertEquals(null, concat(fromString("a"), null, null));
assertEquals(null, concat(null, null, null));
assertEquals(fromString("数据砖头"), concat(fromString("数据"), fromString("砖头")));
}

@Test
public void concatWsTest() {
// Returns null if the separator is null
assertEquals(null, concatWs(null, (UTF8String)null));
assertEquals(null, concatWs(null, fromString("a")));

// If separator is null, concatWs should skip all null inputs and never return null.
UTF8String sep = fromString("哈哈");
assertEquals(
fromString(""),
concatWs(sep, fromString("")));
assertEquals(
fromString("ab"),
concatWs(sep, fromString("ab")));
assertEquals(
fromString("a哈哈b"),
concatWs(sep, fromString("a"), fromString("b")));
assertEquals(
fromString("a哈哈b哈哈c"),
concatWs(sep, fromString("a"), fromString("b"), fromString("c")));
assertEquals(
fromString("a哈哈c"),
concatWs(sep, fromString("a"), null, fromString("c")));
assertEquals(
fromString("a"),
concatWs(sep, fromString("a"), null, null));
assertEquals(
fromString(""),
concatWs(sep, null, null, null));
assertEquals(
fromString("数据哈哈砖头"),
concatWs(sep, fromString("数据"), fromString("砖头")));
}

@Test
Expand Down Expand Up @@ -215,14 +249,18 @@ public void pad() {
assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????")));
assertEquals(fromString("孙行数据砖头"), fromString("数据砖头").lpad(6, fromString("孙行者")));
assertEquals(fromString("孙行者数据砖头"), fromString("数据砖头").lpad(7, fromString("孙行者")));
assertEquals(fromString("孙行者孙行者孙行数据砖头"), fromString("数据砖头").lpad(12, fromString("孙行者")));
assertEquals(
fromString("孙行者孙行者孙行数据砖头"),
fromString("数据砖头").lpad(12, fromString("孙行者")));

assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, fromString("????")));
assertEquals(fromString("数据砖头?"), fromString("数据砖头").rpad(5, fromString("????")));
assertEquals(fromString("数据砖头??"), fromString("数据砖头").rpad(6, fromString("????")));
assertEquals(fromString("数据砖头孙行"), fromString("数据砖头").rpad(6, fromString("孙行者")));
assertEquals(fromString("数据砖头孙行者"), fromString("数据砖头").rpad(7, fromString("孙行者")));
assertEquals(fromString("数据砖头孙行者孙行者孙行"), fromString("数据砖头").rpad(12, fromString("孙行者")));
assertEquals(
fromString("数据砖头孙行者孙行者孙行"),
fromString("数据砖头").rpad(12, fromString("孙行者")));
}

@Test
Expand Down

0 comments on commit 2d51406

Please sign in to comment.