From b9cd3472bb81ccf5b9b772103342771891a60f49 Mon Sep 17 00:00:00 2001 From: Nick Buroojy Date: Thu, 3 Sep 2015 21:40:10 +0000 Subject: [PATCH 1/2] [SPARK-9301] [sql] Add collect_set aggregate function --- .../apache/spark/sql/catalyst/SqlParser.scala | 2 + .../sql/catalyst/expressions/aggregates.scala | 65 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 16 +++++ 3 files changed, 83 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index f2498861c9573..e844bcb14197d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -68,6 +68,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected val BY = Keyword("BY") protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") + protected val COLLECT_SET = Keyword("COLLECT_SET") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") protected val ELSE = Keyword("ELSE") @@ -298,6 +299,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { throw new AnalysisException(s"invalid function approximate($s) $udfName") } } + | COLLECT_SET ~> "(" ~> expression <~ ")" ^^ { case exp => CollectSet(exp) } | CASE ~> whenThenElse ^^ CaseWhen | CASE ~> expression ~ whenThenElse ^^ { case keyPart ~ branches => CaseKeyWhen(keyPart, branches) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5e8298aaaa9cb..ab68e9ed4cfb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -256,6 +256,71 @@ case class CollectHashSetFunction( } } +case class CollectSet(child: Expression) extends UnaryExpression with PartialAggregate1 { + + override def nullable: Boolean = false + override def dataType: DataType = new ArrayType(child.dataType, false) + override def newInstance(): CollectSetFunction = new CollectSetFunction(child, this) + + override def asPartial: SplitEvaluation = { + val partialSet = Alias(CollectHashSet(child :: Nil), "PartialSets")() + SplitEvaluation( + CombineSetsArr(partialSet.toAttribute, this), + partialSet :: Nil) + } +} + +case class CollectSetFunction(expr: Expression, base: AggregateExpression1) + extends AggregateFunction1 { + def this() = this(null, null) // Required for serialization. + + val seen = new OpenHashSet[Any]() + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + seen.add(evaluatedExpr) + } + } + + override def eval(input: InternalRow): Any = seen +} + +case class CombineSetsArr(inputSet: Expression, base: Expression) extends AggregateExpression1 { + def this() = this(null, null) + + override def children: Seq[Expression] = inputSet :: Nil + override def nullable: Boolean = false + override def dataType: DataType = base.dataType + override def toString: String = s"Combine($inputSet)" + override def newInstance(): CombineSetsArrFunction = { + new CombineSetsArrFunction(inputSet, this) + } +} + +case class CombineSetsArrFunction( + @transient inputSet: Expression, + @transient base: AggregateExpression1) + extends AggregateFunction1 { + + def this() = this(null, null) // Required for serialization. + + val seen = new OpenHashSet[Any]() + + override def update(input: InternalRow): Unit = { + val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] + val inputIterator = inputSetEval.iterator + while (inputIterator.hasNext) { + seen.add(inputIterator.next) + } + } + + override def eval(input: InternalRow): Any = { + val casted = seen.asInstanceOf[OpenHashSet[InternalRow]] + Literal.create(casted.iterator.map(f => f.get(0, null)).toSeq, base.dataType).eval(null) + } +} + case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 { def this() = this(null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0ef25fe0faef0..8c10fcc29fc26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -55,6 +55,22 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(queryCoalesce, Row("1") :: Nil) } + test("collect set") { + val df = Seq((1, "a"), (1, "b"), (2, "c")).toDF("key", "value") + df.registerTempTable("src") + val query = sql("select key, collect_set(value) from src group by key") + + checkAnswer(query, Row(1, "a" :: "b" :: Nil) :: Row(2, "c" :: Nil) :: Nil) + } + + test("collect set with nulls") { + val df = Seq((1, "a"), (1, null), (2, null)).toDF("key", "value") + df.registerTempTable("src") + val query = sql("select key, collect_set(value) from src group by key") + + checkAnswer(query, Row(1, "a" :: Nil) :: Row(2, Nil) :: Nil) + } + test("show functions") { checkAnswer(sql("SHOW functions"), FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) From 5f5abc11e5c117194cce39e603de2b29d8a44baf Mon Sep 17 00:00:00 2001 From: Nick Buroojy Date: Mon, 14 Sep 2015 23:06:10 +0000 Subject: [PATCH 2/2] Use AggregateFunction2 interface for collect_* agg fns --- .../apache/spark/sql/catalyst/SqlParser.scala | 2 ++ .../spark/sql/catalyst/dsl/package.scala | 2 ++ .../expressions/aggregate/functions.scala | 32 +++++++++++++++++++ .../expressions/aggregate/utils.scala | 12 +++++++ .../sql/catalyst/expressions/aggregates.scala | 6 ++++ .../expressions/complexTypeExtractors.scala | 32 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 22 ++++++++----- 7 files changed, 100 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index e844bcb14197d..f806827f1d919 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -68,6 +68,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected val BY = Keyword("BY") protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") + protected val COLLECT_LIST = Keyword("COLLECT_LIST") protected val COLLECT_SET = Keyword("COLLECT_SET") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") @@ -300,6 +301,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } | COLLECT_SET ~> "(" ~> expression <~ ")" ^^ { case exp => CollectSet(exp) } + | COLLECT_LIST ~> "(" ~> expression <~ ")" ^^ { case exp => CollectList(exp) } | CASE ~> whenThenElse ^^ CaseWhen | CASE ~> expression ~ whenThenElse ^^ { case keyPart ~ branches => CaseKeyWhen(keyPart, branches) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 699c4cc63d09a..93caa505a786a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -162,6 +162,8 @@ package object dsl { def stddev(e: Expression): Expression = Stddev(e) def stddev_pop(e: Expression): Expression = StddevPop(e) def stddev_samp(e: Expression): Expression = StddevSamp(e) + def collect_list(e: Expression): Expression = CollectList(e) + def collect_set(e: Expression): Expression = CollectSet(e) implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 02cd0ac0db118..15a0c6e5af866 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -445,3 +445,35 @@ case class Sum(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = Cast(currentSum, resultType) } + +case class CollectList(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = false + + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + // Return data type. + override def dataType: DataType = new ArrayType(child.dataType, false) + + private val currentArr = AttributeReference("currentArr", dataType)() + + override val bufferAttributes = currentArr :: Nil + + override val initialValues = Seq( + /* currentArr = */ CreateArray(Nil) + ) + + override val updateExpressions = Seq( + /* currentArr = */ + If(IsNull(child), currentArr, ArrayUnion(currentArr, CreateArray(child :: Nil))) + ) + + override val mergeExpressions = Seq( + /* currentArr = */ ArrayUnion(currentArr.left, currentArr.right) + ) + + override val evaluateExpression = currentArr +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index ce3dddad87f55..434568c2c515b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -61,6 +61,18 @@ object Utils { mode = aggregate.Complete, isDistinct = true) + case expressions.CollectList(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.CollectList(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.CollectSet(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.CollectList(child), + mode = aggregate.Complete, + isDistinct = true) + case expressions.First(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.First(child), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index a5f8e72bdd292..f5cd412d9b343 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -1001,3 +1001,9 @@ case class StddevFunction( } } } + +case class CollectList(child: Expression) extends UnaryExpression with AggregateExpression { + override def nullable: Boolean = false + override def dataType: DataType = new ArrayType(child.dataType, false) +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9927da21b052e..5b9f261b24915 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -226,6 +226,38 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } } +/** + * Combines two Arrays into one Array. + */ +case class ArrayUnion(left: Expression, right: Expression) extends BinaryOperator { + + override def inputType: AbstractDataType = ArrayType + + override def symbol: String = "++" + + private def inputArrType = left.dataType.asInstanceOf[ArrayType] + override def dataType: DataType = inputArrType + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arrayClass = classOf[GenericArrayData].getName + val elementType = inputArrType.elementType + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + final int n1 = $eval1.numElements(); + final int n2 = $eval2.numElements(); + final Object[] unionValues = new Object[n1 + n2]; + for (int j = 0; j < n1; j++) { + unionValues[j] = ${ctx.getValue(eval1, elementType, "j")}; + } + for (int j = 0; j < n2; j++) { + unionValues[n1 + j] = ${ctx.getValue(eval2, elementType, "j")}; + } + ${ev.primitive} = new $arrayClass(unionValues); + """ + }) + } +} + /** * Returns the value of key `key` in Map `child`. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d7ce322269782..d0fb3f7a12ae6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -55,18 +55,18 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(queryCoalesce, Row("1") :: Nil) } - test("collect set") { - val df = Seq((1, "a"), (1, "b"), (2, "c")).toDF("key", "value") + test("collect list") { + val df = Seq((1, "a"), (1, "b"), (2, "c"), (2, "c")).toDF("key", "value") df.registerTempTable("src") - val query = sql("select key, collect_set(value) from src group by key") + val query = sql("select key, collect_list(value) from src group by key") - checkAnswer(query, Row(1, "a" :: "b" :: Nil) :: Row(2, "c" :: Nil) :: Nil) + checkAnswer(query, Row(1, "a" :: "b" :: Nil) :: Row(2, "c" :: "c" :: Nil) :: Nil) } - test("collect set with nulls") { + test("collect list with nulls") { val df = Seq((1, "a"), (1, null), (2, null)).toDF("key", "value") df.registerTempTable("src") - val query = sql("select key, collect_set(value) from src group by key") + val query = sql("select key, collect_list(value) from src group by key") checkAnswer(query, Row(1, "a" :: Nil) :: Row(2, Nil) :: Nil) } @@ -275,12 +275,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val df = sql(sqlText) // First, check if we have GeneratedAggregate. val hasGeneratedAgg = df.queryExecution.executedPlan - .collect { case _: aggregate.TungstenAggregate => true } + .collect { + case _: aggregate.TungstenAggregate => true + case _: aggregate.SortBasedAggregate => true + } .nonEmpty if (!hasGeneratedAgg) { fail( s""" - |Codegen is enabled, but query $sqlText does not have TungstenAggregate in the plan. + |Codegen is enabled, but query $sqlText does not have an Aggregate in the plan. |${df.queryExecution.simpleString} """.stripMargin) } @@ -373,6 +376,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { testCodeGen( "SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData", Row(null, null, null, 0) :: Nil) + testCodeGen( + "SELECT value, collect_list(key), collect_set(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i :: i :: i :: Nil, i :: Nil))) } finally { sqlContext.dropTempTable("testData3x") sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue)