From 72d50adf03c107d77a67a69ad9836f3b8a86ab2a Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 16 Jun 2015 23:23:40 -0700 Subject: [PATCH 1/3] string function: concat/concat_ws --- .../catalyst/analysis/FunctionRegistry.scala | 4 +- .../expressions/stringOperations.scala | 79 +++++++++++++++++++ .../expressions/StringFunctionsSuite.scala | 10 ++- .../org/apache/spark/sql/functions.scala | 38 +++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 47 +++++++++++ 5 files changed, 176 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 04e306da23e4c..934a1cbf945ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -88,6 +88,8 @@ object FunctionRegistry { expression[Abs]("abs"), expression[CreateArray]("array"), expression[Coalesce]("coalesce"), + expression[Concat]("concat"), + expression[ConcatWS]("concat_ws"), expression[Explode]("explode"), expression[If]("if"), expression[IsNull]("isnull"), @@ -172,7 +174,7 @@ object FunctionRegistry { case Success(e) => e case Failure(e) => - throw new AnalysisException(s"Invalid number of arguments for function $name") + throw new AnalysisException(s"Invalid number of arguments for function $name, $params") } f.newInstance(expressions : _*).asInstanceOf[Expression] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 315c63e63c635..e2c8f7579dace 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -19,8 +19,12 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern +<<<<<<< HEAD import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.Substring +======= +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} +>>>>>>> string function: concat/concat_ws import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -313,3 +317,78 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI defineCodeGen(ctx, ev, c => s"($c).length()") } } + +/** + * Like Concat below, but with custom separator SEP. + */ +case class ConcatWS(children: Expression*) extends Expression { + // return type is always String + override def dataType: DataType = StringType + override def nullable: Boolean = true + override def foldable: Boolean = children.forall(_.foldable) + override def toString: String = s"""CONCAT_WS($children)""" + private def sep = children.head + private def exprs = children.tail + + override def checkInputDataTypes(): TypeCheckResult = { + def supportedType(dt: DataType): Boolean = dt match { + case ArrayType(StringType, _) => true + case ArrayType(NullType, _) => true + case StringType => true + case NullType => true + case _ => false + } + if (sep.dataType != StringType && sep.dataType != NullType) { + TypeCheckResult.TypeCheckFailure( + s"type of separator expression in ConcatWS should be string, not ${sep.dataType}") + } else if (children.size < 2) { + TypeCheckResult.TypeCheckFailure( + s"ConcatWS takes at least two arguments") + } else if (exprs.exists(expr => !supportedType(expr.dataType))) { + TypeCheckResult.TypeCheckFailure( + "type of exprs expressions in ConcatWS should be array(string) or string, not" + + s" ${exprs.map(_.dataType)}") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def eval(input: InternalRow): Any = { + val sepEval = sep.eval(input) + if (sepEval != null) { + val childrenArr = exprs.map(expr => (expr.eval(input), expr.dataType)) + val separator = sepEval.asInstanceOf[UTF8String].toString + val validSeq = childrenArr.filter(_._1 != null).map(child => child._2 match { + case StringType => child._1.asInstanceOf[UTF8String].toString + case ArrayType(StringType, _) => child._1.asInstanceOf[Seq[UTF8String]].mkString(separator) + case ArrayType(NullType, _) => child._1.asInstanceOf[Seq[UTF8String]].mkString(separator) + }) + UTF8String.fromString(validSeq.mkString(separator)) + } else { + null + } + } +} + +/** + * A function that returns the string or bytes resulting from concatenating the strings or bytes + * passed in as parameters in order. For example, concat('foo', 'bar') results in 'foobar'. Note + * that this function can take any number of input strings. + */ +case class Concat(children: Expression*) + extends Expression with ExpectsInputTypes { + override def dataType: DataType = StringType + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + override def expectedChildTypes: Seq[DataType] = Seq.fill(children.size)(StringType) + override def toString: String = s"""CONCAT($children)""" + + override def eval(input: InternalRow): Any = { + val validSeq = children.map(_.eval(input)) + if (!validSeq.contains(null)) { + UTF8String.fromString(validSeq.mkString("")) + } else { + null + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index d363e631540d8..125289d9a4b3e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -226,5 +226,13 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { // checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) } - + test("Concat/ConcatWS Expression") { + checkEvaluation(Concat("b", "c"), "bc") + checkEvaluation(Concat(Literal(null), Literal(null)), null) + checkEvaluation(ConcatWS(",", CreateArray(Seq("b", "c"))), "b,c") + checkEvaluation(ConcatWS(Literal(null), CreateArray(Seq("b", "c"))), null) + checkEvaluation(ConcatWS("", CreateArray(Seq("b", Literal.create(null, StringType), "c")), + CreateArray(Seq(Literal.create(null, StringType)))), "bnullcnull") + checkEvaluation(ConcatWS(",", CreateArray(Seq()), "a"), ",a") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c5b77724aae17..007f027bbd7f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -913,6 +913,44 @@ object functions { */ def ceil(columnName: String): Column = ceil(Column(columnName)) + /** + * Computes the concat of the given values. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat(exprs: Column*): Column = Concat(exprs.map(_.expr): _*) + + /** + * Computes the concat of the given column names. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat(columnName: String, columnNames: String*): Column = + concat((columnName +: columnNames).map(Column.apply): _*) + + /** + * Computes the concat_ws of the given columns with given separator. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat_ws(sep: Column, exprs: Column*): Column = ConcatWS((sep +: exprs).map(_.expr): _*) + + /** + * Computes the concat_ws of the given columns with given separator. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat_ws(sepName: String, exprNames: String*): Column = + concat_ws(Column(sepName), exprNames.map(Column.apply): _*) + /** * Computes the cosine of the given value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index cfd23867a9bba..718a451cd2a93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -85,6 +86,52 @@ class DataFrameFunctionsSuite extends QueryTest { } } + test("concat/concat_ws test") { + checkAnswer( + testData2.select(concat()).limit(1), + Row("") + ) + checkAnswer( + testData2.select( + concat_ws(lit(","), col("a").cast("String"), col("b").cast("String"))).limit(1), + Row("1,1") + ) + checkAnswer( + testData2.select( + concat_ws(lit("=="), array(col("a").cast("String"), col("b").cast("String")), lit("x")) + ).limit(1), + Row("1==1==x") + ) + checkAnswer( + testData2.select(concat_ws(lit(""), lit("x"), lit("y"))).limit(1), + Row("xy") + ) + checkAnswer( + testData2.select(concat_ws(lit("=="), array(), lit("x"))).limit(1), + Row("==x") + ) + checkAnswer( + ctx.sql("""SELECT CONCAT(null, null)"""), + Row(null) + ) + checkAnswer( + ctx.sql("""SELECT CONCAT("a", null)"""), + Row(null) + ) + checkAnswer( + ctx.sql("""SELECT CONCAT("a", b, 1) from testData2 limit 1"""), + Row("a11") + ) + checkAnswer( + ctx.sql("""SELECT CONCAT_WS("==", array("a", "b"), array(null, null), array())"""), + Row("a==b==null==null==") + ) + checkAnswer( + ctx.sql("""SELECT CONCAT_WS(",", "a", array())"""), + Row("a,") + ) + } + test("constant functions") { checkAnswer( testData2.select(e()).limit(1), From 282e9581b4f1f885bf82af64fedee4dad9a169c0 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 16 Jun 2015 23:39:50 -0700 Subject: [PATCH 2/3] fix rebase --- .../spark/sql/catalyst/expressions/stringOperations.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index e2c8f7579dace..e68208e117f02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -19,12 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern -<<<<<<< HEAD -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.expressions.Substring -======= import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} ->>>>>>> string function: concat/concat_ws import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String From ee7ebae144d806f7b9e20958a328f9879fca5f13 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 17 Jun 2015 18:07:36 -0700 Subject: [PATCH 3/3] set nullable --- .../spark/sql/catalyst/expressions/stringOperations.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index e68208e117f02..3b271b649ee11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -319,7 +319,7 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI case class ConcatWS(children: Expression*) extends Expression { // return type is always String override def dataType: DataType = StringType - override def nullable: Boolean = true + override def nullable: Boolean = sep.nullable override def foldable: Boolean = children.forall(_.foldable) override def toString: String = s"""CONCAT_WS($children)""" private def sep = children.head