From 0a063768046e82281d369ad9236f6c2ade34f391 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 13 Jan 2015 02:05:31 -0800 Subject: [PATCH] add concat support in spark sql --- .../apache/spark/sql/catalyst/SqlParser.scala | 2 ++ .../sql/catalyst/expressions/arithmetic.scala | 24 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 12 ++++++++++ 3 files changed, 38 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 5d974df98b699..e788cfb5b37c0 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -51,6 +51,7 @@ class SqlParser extends AbstractSparkSQLParser { protected val CACHE = Keyword("CACHE") protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") + protected val CONCAT = Keyword("CONCAT") protected val COUNT = Keyword("COUNT") protected val DECIMAL = Keyword("DECIMAL") protected val DESC = Keyword("DESC") @@ -308,6 +309,7 @@ class SqlParser extends AbstractSparkSQLParser { { case s ~ p ~ l => Substring(s, p, l) } | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } + | CONCAT ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Concat(exprs) } | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 168a963e29c90..873cf83740bb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -261,3 +261,27 @@ case class Abs(child: Expression) extends UnaryExpression { override def eval(input: Row): Any = n1(child, input, _.abs(_)) } + +/** + * A function that concat two strings. + */ +case class Concat(childSeq: Seq[Expression]) extends Expression { + type EvaluatedType = Any + + def dataType = StringType + def nullable = childSeq.forall(_.nullable) + override def foldable = childSeq.forall(_.foldable) + override def children = childSeq + override def toString = s"CONCAT($childSeq)" + + override def eval(input: Row): Any = { + childSeq.foldLeft("")((r, c) => { + val e = c.eval(input) + if (e == null) { + r + } else { + r + e.asInstanceOf[String] + } + }) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d9de5686dce48..2825524a58145 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -78,6 +78,18 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { 2.5) } + test("Add string concat in Spark SQL") { + checkAnswer( + sql("""SELECT CONCAT("ab", "bcd")"""), + "abbcd") + checkAnswer( + sql("""SELECT CONCAT("ab", null, "ccc", "b")"""), + "abcccb") + checkAnswer( + sql("""SELECT CONCAT(null, "bcd")"""), + "bcd") + } + test("aggregation with codegen") { val originalValue = codegenEnabled setConf(SQLConf.CODEGEN_ENABLED, "true")