From db4720cce06e0da77c08c82b8ca305f40be7c9f0 Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Sat, 16 May 2015 00:28:15 +0800 Subject: [PATCH] [SPARK-7549] Support aggregating over nested fields.Add sum, avg, min, max and count. --- .../catalyst/expressions/aggregateCells.scala | 111 +++++++++++++++++ .../sql/catalyst/expressions/aggregates.scala | 22 ++-- .../expressions/codegen/CodeGenerator.scala | 112 ++++++++++++++++++ .../sql/execution/GeneratedAggregate.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 45 +++++++ .../scala/org/apache/spark/sql/TestData.scala | 8 ++ 6 files changed, 288 insertions(+), 12 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregateCells.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregateCells.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregateCells.scala new file mode 100644 index 0000000000000..ea9c90a9beba5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregateCells.scala @@ -0,0 +1,111 @@ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.types._ + +/** + * :: DeveloperApi :: + * Cells used to support aggregating over nested fields. + * @param child the input data source. + */ +case class SumCell(child: Expression) extends UnaryExpression{ + type EvaluatedType = Any + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + evalE match { + case seq: Seq[Any] => seq.reduce((a, b) => numeric.plus(a, b)) + case _ => evalE + } + } + + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + + lazy val numeric = dataType match { + case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] + case other => sys.error(s"Type $other does not support numeric operations") + } + + override def dataType: DataType = child.dataType match { + case DecimalType.Fixed(_, _) => + DecimalType.Unlimited + case ArrayType(dataType, _) => + dataType + case _ => + child.dataType + } +} + +case class CountCell(child: Expression) extends UnaryExpression{ + type EvaluatedType = Any + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + evalE match { + case seq: Seq[Any] => seq.size.toLong + case p if p != null => 1L + case _ => null + } + } + + override def nullable: Boolean = false + override def dataType: DataType = LongType +} + +case class MinCell(child: Expression) extends UnaryExpression{ + type EvaluatedType = Any + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + evalE match { + case seq: Seq[Any] => seq.reduce((a, b) => if (ordering.compare(a, b) < 0) a else b) + case _ => evalE + } + } + + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + + lazy val ordering = dataType match { + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] + case other => sys.error(s"Type $other does not support ordered operations") + } + + override def dataType: DataType = child.dataType match { + case DecimalType.Fixed(_, _) => + DecimalType.Unlimited + case ArrayType(dataType, _) => + dataType + case _ => + child.dataType + } +} + +case class MaxCell(child: Expression) extends UnaryExpression{ + type EvaluatedType = Any + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + evalE match { + case seq: Seq[Any] => seq.reduce((a, b) => if (ordering.compare(a, b) > 0) a else b) + case _ => evalE + } + } + + lazy val ordering = dataType match { + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] + case other => sys.error(s"Type $other does not support ordered operations") + } + + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + + override def dataType: DataType = child.dataType match { + case DecimalType.Fixed(_, _) => + DecimalType.Unlimited + case ArrayType(dataType, _) => + dataType + case _ => + child.dataType + } +} \ No newline at end of file diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index f3830c6d3bcf2..356d34c2bf2dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -97,7 +97,7 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def toString: String = s"MIN($child)" override def asPartial: SplitEvaluation = { - val partialMin = Alias(Min(child), "PartialMin")() + val partialMin = Alias(Min(MinCell(child)), "PartialMin")() SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil) } @@ -128,7 +128,7 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def toString: String = s"MAX($child)" override def asPartial: SplitEvaluation = { - val partialMax = Alias(Max(child), "PartialMax")() + val partialMax = Alias(Max(MaxCell(child)), "PartialMax")() SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil) } @@ -159,7 +159,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def toString: String = s"COUNT($child)" override def asPartial: SplitEvaluation = { - val partialCount = Alias(Count(child), "PartialCount")() + val partialCount = Alias(Count(CountCell(child)), "PartialCount")() SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil) } @@ -328,8 +328,8 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN child.dataType match { case DecimalType.Fixed(_, _) | DecimalType.Unlimited => // Turn the child to unlimited decimals for calculation, before going back to fixed - val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() + val partialSum = Alias(Sum(SumCell(Cast(child, DecimalType.Unlimited))), "PartialSum")() + val partialCount = Alias(Count(CountCell(child)), "PartialCount")() val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited) val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited) @@ -338,8 +338,8 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN partialCount :: partialSum :: Nil) case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() + val partialSum = Alias(Sum(SumCell(child)), "PartialSum")() + val partialCount = Alias(Count(CountCell(child)), "PartialCount")() val castedSum = Cast(Sum(partialSum.toAttribute), dataType) val castedCount = Cast(Sum(partialCount.toAttribute), dataType) @@ -370,13 +370,13 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) => - val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() + val partialSum = Alias(Sum(SumCell(Cast(child, DecimalType.Unlimited))), "PartialSum")() SplitEvaluation( Cast(CombineSum(partialSum.toAttribute), dataType), partialSum :: Nil) case _ => - val partialSum = Alias(Sum(child), "PartialSum")() + val partialSum = Alias(Sum(SumCell(child)), "PartialSum")() SplitEvaluation( CombineSum(partialSum.toAttribute), partialSum :: Nil) @@ -560,7 +560,7 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag override def update(input: Row): Unit = { val evaluatedExpr = expr.eval(input) if (evaluatedExpr != null) { - count += 1L + count += evaluatedExpr.asInstanceOf[Long] } } @@ -618,7 +618,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr private val sum = MutableLiteral(null, calcType) - private val addFunction = + private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) override def update(input: Row): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d17af0e7ff87e..d7946878c62a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -413,6 +413,118 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } """.children + case cell @ SumCell(e) if e.dataType.isInstanceOf[ArrayType] => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(cell.dataType)} = 0 + if (${eval.nullTerm}) { + $nullTerm = true + } else { + $primitiveTerm = ${eval.primitiveTerm} + .asInstanceOf[Seq[${termForType(cell.dataType)}]] + .reduce((a,b) => a + b) + } + """.children + + case cell @ SumCell(e @ NumericType()) => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(cell.dataType)} = 0 + if (${eval.nullTerm}) { + $nullTerm = true + } else { + $primitiveTerm = ${eval.primitiveTerm} + } + """.children + + case cell @ CountCell(e) if e.dataType.isInstanceOf[ArrayType] => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(cell.dataType)} = 0 + if (${eval.nullTerm}) { + $nullTerm = true + } else { + $primitiveTerm = ${eval.primitiveTerm} + .asInstanceOf[Seq[${termForType(cell.dataType)}]] + .size + } + """.children + + case cell @ CountCell(e @ NumericType()) => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(cell.dataType)} = 0 + if (${eval.nullTerm}) { + $nullTerm = true + } else { + $primitiveTerm = 1L + } + """.children + + case cell @ MinCell(e) if e.dataType.isInstanceOf[ArrayType] => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(cell.dataType)} = 0 + if (${eval.nullTerm}) { + $nullTerm = true + } else { + $primitiveTerm = ${eval.primitiveTerm} + .asInstanceOf[Seq[${termForType(cell.dataType)}]] + .min + } + """.children + + case cell @ MinCell(e @ NumericType()) => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(cell.dataType)} = 0 + if (${eval.nullTerm}) { + $nullTerm = true + } else { + $primitiveTerm = ${eval.primitiveTerm} + } + """.children + + case cell @ MaxCell(e) if e.dataType.isInstanceOf[ArrayType] => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(cell.dataType)} = 0 + if (${eval.nullTerm}) { + $nullTerm = true + } else { + $primitiveTerm = ${eval.primitiveTerm} + .asInstanceOf[Seq[${termForType(cell.dataType)}]] + .max + } + """.children + + case cell @ MaxCell(e @ NumericType()) => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(cell.dataType)} = 0 + if (${eval.nullTerm}) { + $nullTerm = true + } else { + $primitiveTerm = ${eval.primitiveTerm} + } + """.children + case IsNotNull(e) => val eval = expressionEvaluator(e) q""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 2ec7d4fbc92de..536b7fc5f154a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -83,7 +83,7 @@ case class GeneratedAggregate( } val currentCount = AttributeReference("currentCount", LongType, nullable = false)() val initialValue = Literal(0L) - val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) + val updateFunction = If(IsNotNull(toCount), Add(currentCount, toCount), currentCount) val result = currentCount AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ec0e76cde6f7c..2af3a0a33be21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1306,6 +1306,51 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) } + test("SPARK-7549 Support aggregating over nested fields") { + checkAnswer(sql("SELECT sum(a[0]) FROM complexData2"), Row(3)) + checkAnswer(sql("SELECT sum(s.key) FROM complexData2"), Row(6)) + checkAnswer(sql("SELECT sum(nestedData[0]) FROM arrayData"), Row(15)) + + checkAnswer(sql("SELECT count(a[0]) FROM complexData2"), Row(2)) + checkAnswer(sql("SELECT count(s.key) FROM complexData2"), Row(4)) + checkAnswer(sql("SELECT count(nestedData[0]) FROM arrayData"), Row(6)) + + checkAnswer(sql("SELECT min(a[0]) FROM complexData2"), Row(1)) + checkAnswer(sql("SELECT min(s.key) FROM complexData2"), Row(1)) + checkAnswer(sql("SELECT min(nestedData[0]) FROM arrayData"), Row(1)) + + checkAnswer(sql("SELECT max(a[0]) FROM complexData2"), Row(2)) + checkAnswer(sql("SELECT max(s.key) FROM complexData2"), Row(2)) + checkAnswer(sql("SELECT max(nestedData[0]) FROM arrayData"), Row(4)) + + checkAnswer(sql("SELECT avg(a[0]) FROM complexData2"), Row(1.5)) + checkAnswer(sql("SELECT avg(s.key) FROM complexData2"), Row(1.5)) + checkAnswer(sql("SELECT avg(nestedData[0]) FROM arrayData"), Row(2.5)) + + val originalValue = conf.codegenEnabled + setConf(SQLConf.CODEGEN_ENABLED, "true") + checkAnswer(sql("SELECT sum(a[0]) FROM complexData2"), Row(3)) + checkAnswer(sql("SELECT sum(s.key) FROM complexData2"), Row(6)) + checkAnswer(sql("SELECT sum(nestedData[0]) FROM arrayData"), Row(15)) + + checkAnswer(sql("SELECT count(a[0]) FROM complexData2"), Row(2)) + checkAnswer(sql("SELECT count(s.key) FROM complexData2"), Row(4)) + checkAnswer(sql("SELECT count(nestedData[0]) FROM arrayData"), Row(6)) + + checkAnswer(sql("SELECT min(a[0]) FROM complexData2"), Row(1)) + checkAnswer(sql("SELECT min(s.key) FROM complexData2"), Row(1)) + checkAnswer(sql("SELECT min(nestedData[0]) FROM arrayData"), Row(1)) + + checkAnswer(sql("SELECT max(a[0]) FROM complexData2"), Row(2)) + checkAnswer(sql("SELECT max(s.key) FROM complexData2"), Row(2)) + checkAnswer(sql("SELECT max(nestedData[0]) FROM arrayData"), Row(4)) + + checkAnswer(sql("SELECT avg(a[0]) FROM complexData2"), Row(1.5)) + checkAnswer(sql("SELECT avg(s.key) FROM complexData2"), Row(1.5)) + checkAnswer(sql("SELECT avg(nestedData[0]) FROM arrayData"), Row(2.5)) + setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + } + test("SPARK-6898: complete support for special chars in column names") { jsonRDD(sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 446771ab2a5a5..45f9475d56eaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -203,4 +203,12 @@ object TestData { :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false) :: Nil).toDF() complexData.registerTempTable("complexData") + + case class ComplexData2(s: Seq[TestData], a: Seq[Int], b: Boolean) + val complexData2 = + TestSQLContext.sparkContext.parallelize( + ComplexData2(Seq[TestData](TestData(1, "1"), TestData(1, "2")), Seq(1), true) + :: ComplexData2(Seq[TestData](TestData(2, "2"), TestData(2, "3")), Seq(2), false) + :: Nil).toDF() + complexData2.registerTempTable("complexData2") }