From 58205e2e4c66e83d6cc1761daa3ca490da23097d Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Fri, 19 Feb 2016 04:49:49 -0500 Subject: [PATCH] Support Aggregator in DataFrame --- .../org/apache/spark/sql/GroupedData.scala | 86 +++++++++++++- .../org/apache/spark/sql/GroupedDataset.scala | 5 +- .../spark/sql/expressions/Aggregator.scala | 21 +++- .../org/apache/spark/sql/AggregatorTest.scala | 106 ++++++++++++++++++ .../spark/sql/DataFrameAggregatorSuite.scala | 86 ++++++++++++++ .../spark/sql/DatasetAggregatorSuite.scala | 89 --------------- 6 files changed, 297 insertions(+), 96 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/AggregatorTest.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregatorSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index c74ef2c03541e..d231071a48535 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot, Project} +import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.types.NumericType /** @@ -212,6 +215,87 @@ class GroupedData protected[sql]( toDF((expr +: exprs).map(_.expr)) } + /** + * Computes the given aggregation, returning a [[DataFrame]] for each unique key + * and the result of computing this aggregation over all elements in the group. + * + * @since 2.0.0 + */ + def agg[I: TypeTag, _](col1: TypedColumn[I, _]): DataFrame = { + aggInternal(col1) + } + + /** + * Computes the given aggregation, returning a [[DataFrame]] for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 2.0.0 + */ + def agg[I: TypeTag, _]( + col1: TypedColumn[I, _], + col2: TypedColumn[I, _]): DataFrame = { + aggInternal(col1, col2) + } + + /** + * Computes the given aggregation, returning a [[DataFrame]] for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 2.0.0 + */ + def agg[I: TypeTag, _]( + col1: TypedColumn[I, _], + col2: TypedColumn[I, _], + col3: TypedColumn[I, _]): DataFrame = { + aggInternal(col1, col2, col3) + } + + /** + * Computes the given aggregation, returning a [[DataFrame]] for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 2.0.0 + */ + def agg[I: TypeTag, _]( + col1: TypedColumn[I, _], + col2: TypedColumn[I, _], + col3: TypedColumn[I, _], + col4: TypedColumn[I, _]): DataFrame = { + aggInternal(col1, col2, col3, col4) + } + + /** + * Internal helper function for building typed aggregations that return a single value. It is not + * until typed aggregations are used that encoders are initialized. This is because + * [[GroupedData]] does not know if typed aggregations are used in advance. + * TODO: does not handle aggregations that return nonflat results. + */ + private def aggInternal[I: TypeTag, _](columns: TypedColumn[I, _]*): DataFrame = { + // Since Aggregator takes all of the elements of a group, we split an input schema into key + // and value here. + val (keyAttributes, dataAttributes) = { + val withKeyColumns = df.logicalPlan.output ++ groupingExprs.map(UnresolvedAlias(_)) + val withKey = Project(withKeyColumns, df.logicalPlan) + val outputAttributes = df.sqlContext.executePlan(withKey).analyzed.output + (outputAttributes.takeRight(groupingExprs.size), + outputAttributes.dropRight(groupingExprs.size)) + } + + val resolvedVEncoder = encoderFor(ExpressionEncoder[I]()) + .resolve(dataAttributes, OuterScopes.outerScopes) + + val namedColumns = columns.map(_.withInputType(resolvedVEncoder, dataAttributes).named) + val keyColumn = if (keyAttributes.length == 1) { + keyAttributes.head + } else { + Alias(CreateStruct(keyAttributes), "key")() + } + val aggregate = Aggregate(groupingExprs, keyColumn +: namedColumns, df.logicalPlan) + val execution = new QueryExecution(df.sqlContext, aggregate) + + new DataFrame(df.sqlContext, execution) + } + /** * Count the number of rows for each group. * The resulting [[DataFrame]] will also contain the grouping columns. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 53cb8eb524947..6b9aea019725c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -224,15 +224,14 @@ class GroupedDataset[K, V] private[sql]( * Internal helper function for building typed aggregations that return tuples. For simplicity * and code reuse, we do this without the help of the type system and then use helper functions * that cast appropriately for the user facing interface. - * TODO: does not handle aggrecations that return nonflat results, + * TODO: does not handle aggregations that return nonflat results. */ protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = columns.map( _.withInputType(resolvedVEncoder, dataAttributes).named) - val keyColumn = if (resolvedKEncoder.flat) { - assert(groupingAttributes.length == 1) + val keyColumn = if (groupingAttributes.length == 1) { groupingAttributes.head } else { Alias(CreateStruct(groupingAttributes), "key")() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 6eea92451734e..984cf1ca3138b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -26,21 +26,36 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression * A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]] * operations to take all of the elements of a group and reduce them to a single value. * - * For example, the following aggregator extracts an `int` from a specific class and adds them up: + * For example, the following user-defined aggregator extracts an `int` from a specific class and + * adds them up: + * * {{{ - * case class Data(i: Int) + * case class Data(k: String, v: Int) * * val customSummer = new Aggregator[Data, Int, Int] { * def zero: Int = 0 - * def reduce(b: Int, a: Data): Int = b + a.i + * def reduce(b: Int, a: Data): Int = b + a.v * def merge(b1: Int, b2: Int): Int = b1 + b2 * def finish(r: Int): Int = r * }.toColumn() + * }}} * + * == [[Dataset]]== + * Sums up all the values of input rows in [[Dataset]]. + * + * {{{ * val ds: Dataset[Data] = ... * val aggregated = ds.select(customSummer) * }}} * + * == [[DataFrame]]== + * Sums up the values of input rows for each group in [[DataFrame]]. + * + * {{{ + * val df: DataFrame = ... + * val aggregated = df.groupBy($"k").agg(customSummer) + * }}} + * * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird * * @tparam I The input type for the aggregation. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/AggregatorTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/AggregatorTest.scala new file mode 100644 index 0000000000000..9b34aa3c66f9b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/AggregatorTest.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.expressions.Aggregator + +/** An `Aggregator` that adds up any numeric type returned by the given function. */ +class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] { + val numeric = implicitly[Numeric[N]] + + override def zero: N = numeric.zero + + override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) + + override def merge(b1: N, b2: N): N = numeric.plus(b1, b2) + + override def finish(reduction: N): N = reduction +} + +object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] { + override def zero: (Long, Long) = (0, 0) + + override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { + (countAndSum._1 + 1, countAndSum._2 + input._2) + } + + override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } + + override def finish(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1 +} + +object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { + + override def zero: (Long, Long) = (0, 0) + + override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { + (countAndSum._1 + 1, countAndSum._2 + input._2) + } + + override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } + + override def finish(reduction: (Long, Long)): (Long, Long) = reduction +} + +case class AggData(a: Int, b: String) +object ClassInputAgg extends Aggregator[AggData, Int, Int] { + /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + override def zero: Int = 0 + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + */ + override def reduce(b: Int, a: AggData): Int = b + a.a + + /** + * Transform the output of the reduction. + */ + override def finish(reduction: Int): Int = reduction + + /** + * Merge two intermediate values + */ + override def merge(b1: Int, b2: Int): Int = b1 + b2 +} + +object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { + /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + override def zero: (Int, AggData) = 0 -> AggData(0, "0") + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + */ + override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a) + + /** + * Transform the output of the reduction. + */ + override def finish(reduction: (Int, AggData)): Int = reduction._1 + + /** + * Merge two intermediate values + */ + override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) = + (b1._1 + b2._1, b1._2) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregatorSuite.scala new file mode 100644 index 0000000000000..86e4f8c31f61d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregatorSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +class DataFrameAggregatorSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = + new SumOf(f).toColumn + + test("typed aggregation: TypedAggregator") { + val df = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDF("k", "v") + val sumFunc = sum[(String, Int), Int] _ + + checkAnswer( + df.groupBy($"k").agg(sumFunc(_._2)), + Seq(Row("a", 30), Row("b", 3), Row("c", 1))) + } + + test("typed aggregation: TypedAggregator, expr, expr") { + val df = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDF("k", "v") + val sumFunc = sum[(String, Int), Int] _ + + checkAnswer( + df.groupBy($"k").agg( + sumFunc(_._2), + expr("sum(v)").as[Long], + count("*")), + Seq(Row("a", 30, 30L, 2L), Row("b", 3, 3L, 2L), Row("c", 1, 1L, 1L))) + } + + test("typed aggregation: complex case") { + val df = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDF("k", "v") + + checkAnswer( + df.groupBy($"k").agg( + expr("avg(v)").as[Double], + TypedAverage.toColumn), + Seq(Row("a", 2.0, 2.0), Row("b", 3.0, 3.0))) + } + + test("typed aggregation: complex result type") { + val df = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDF("k", "v") + + checkAnswer( + df.groupBy($"k").agg( + expr("avg(v)").as[Double], + ComplexResultAgg.toColumn), + Seq(Row("a", 2.0, Row(2L, 4L)), Row("b", 3.0, Row(1L, 3L)))) + } + + test("typed aggregation: class input with reordering") { + val df = Seq(AggData(1, "one")).toDF + + checkAnswer( + df.groupBy($"b").agg(ClassInputAgg.toColumn), + Seq(Row("one", 1))) + } + + test("typed aggregation: complex input") { + val df = Seq(AggData(1, "one"), AggData(2, "two")).toDF + + checkAnswer( + df.groupBy($"b").agg(ComplexBufferAgg.toColumn), + Seq(Row("one", 1), Row("two", 1))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 3258f3782d8cc..fb315be0a6c8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -17,98 +17,9 @@ package org.apache.spark.sql -import scala.language.postfixOps - -import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -/** An `Aggregator` that adds up any numeric type returned by the given function. */ -class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] { - val numeric = implicitly[Numeric[N]] - - override def zero: N = numeric.zero - - override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) - - override def merge(b1: N, b2: N): N = numeric.plus(b1, b2) - - override def finish(reduction: N): N = reduction -} - -object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] { - override def zero: (Long, Long) = (0, 0) - - override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { - (countAndSum._1 + 1, countAndSum._2 + input._2) - } - - override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { - (b1._1 + b2._1, b1._2 + b2._2) - } - - override def finish(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1 -} - -object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { - - override def zero: (Long, Long) = (0, 0) - - override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { - (countAndSum._1 + 1, countAndSum._2 + input._2) - } - - override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { - (b1._1 + b2._1, b1._2 + b2._2) - } - - override def finish(reduction: (Long, Long)): (Long, Long) = reduction -} - -case class AggData(a: Int, b: String) -object ClassInputAgg extends Aggregator[AggData, Int, Int] { - /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ - override def zero: Int = 0 - - /** - * Combine two values to produce a new value. For performance, the function may modify `b` and - * return it instead of constructing new object for b. - */ - override def reduce(b: Int, a: AggData): Int = b + a.a - - /** - * Transform the output of the reduction. - */ - override def finish(reduction: Int): Int = reduction - - /** - * Merge two intermediate values - */ - override def merge(b1: Int, b2: Int): Int = b1 + b2 -} - -object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { - /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ - override def zero: (Int, AggData) = 0 -> AggData(0, "0") - - /** - * Combine two values to produce a new value. For performance, the function may modify `b` and - * return it instead of constructing new object for b. - */ - override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a) - - /** - * Transform the output of the reduction. - */ - override def finish(reduction: (Int, AggData)): Int = reduction._1 - - /** - * Merge two intermediate values - */ - override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) = - (b1._1 + b2._1, b1._2) -} - class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._