From f4b8dbb6a6b8dc563e8a68fe5713d2f33b1498f5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 4 Feb 2015 16:38:00 -0800 Subject: [PATCH] [SPARK-5605][SQL][DF] Allow using String to specify colum name in DSL aggregate functions. --- python/pyspark/sql.py | 10 ++--- .../org/apache/spark/sql/DataFrame.scala | 8 ++-- .../org/apache/spark/sql/DataFrameImpl.scala | 8 ++-- .../main/scala/org/apache/spark/sql/Dsl.scala | 37 +++++++++++++++++++ ...oupedDataFrame.scala => GroupedData.scala} | 2 +- .../apache/spark/sql/IncomputableColumn.scala | 4 +- 6 files changed, 53 insertions(+), 16 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/{GroupedDataFrame.scala => GroupedData.scala} (98%) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 5b56b36bdcdb7..20ca6284c4074 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -23,7 +23,7 @@ - L{DataFrame} A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In addition to normal RDD operations, DataFrames also support SQL. - - L{GroupedDataFrame} + - L{GroupedData} - L{Column} Column is a DataFrame with a single column. - L{Row} @@ -62,7 +62,7 @@ "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row", "Dsl", + "SQLContext", "HiveContext", "DataFrame", "GroupedData", "Column", "Row", "Dsl", "SchemaRDD"] @@ -2231,7 +2231,7 @@ def filter(self, condition): def groupBy(self, *cols): """ Group the :class:`DataFrame` using the specified columns, - so we can run aggregation on them. See :class:`GroupedDataFrame` + so we can run aggregation on them. See :class:`GroupedData` for all the available aggregate functions. >>> df.groupBy().avg().collect() @@ -2244,7 +2244,7 @@ def groupBy(self, *cols): jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) - return GroupedDataFrame(jdf, self.sql_ctx) + return GroupedData(jdf, self.sql_ctx) def agg(self, *exprs): """ Aggregate on the entire :class:`DataFrame` without groups @@ -2308,7 +2308,7 @@ def _api(self): return _api -class GroupedDataFrame(object): +class GroupedData(object): """ A set of methods for aggregations on a :class:`DataFrame`, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index a4997fb293781..92e04ce17c2e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -290,7 +290,7 @@ trait DataFrame extends RDDApi[Row] { /** * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. - * See [[GroupedDataFrame]] for all the available aggregate functions. + * See [[GroupedData]] for all the available aggregate functions. * * {{{ * // Compute the average for all numeric columns grouped by department. @@ -304,11 +304,11 @@ trait DataFrame extends RDDApi[Row] { * }}} */ @scala.annotation.varargs - def groupBy(cols: Column*): GroupedDataFrame + def groupBy(cols: Column*): GroupedData /** * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. - * See [[GroupedDataFrame]] for all the available aggregate functions. + * See [[GroupedData]] for all the available aggregate functions. * * This is a variant of groupBy that can only group by existing columns using column names * (i.e. cannot construct expressions). @@ -325,7 +325,7 @@ trait DataFrame extends RDDApi[Row] { * }}} */ @scala.annotation.varargs - def groupBy(col1: String, cols: String*): GroupedDataFrame + def groupBy(col1: String, cols: String*): GroupedData /** * (Scala-specific) Compute aggregates by specifying a map from column name to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index c702adcb65122..d6df927f9d42c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -201,13 +201,13 @@ private[sql] class DataFrameImpl protected[sql]( filter(condition) } - override def groupBy(cols: Column*): GroupedDataFrame = { - new GroupedDataFrame(this, cols.map(_.expr)) + override def groupBy(cols: Column*): GroupedData = { + new GroupedData(this, cols.map(_.expr)) } - override def groupBy(col1: String, cols: String*): GroupedDataFrame = { + override def groupBy(col1: String, cols: String*): GroupedData = { val colNames: Seq[String] = col1 +: cols - new GroupedDataFrame(this, colNames.map(colName => resolve(colName))) + new GroupedData(this, colNames.map(colName => resolve(colName))) } override def limit(n: Int): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala index 50f442dd87bf3..9afe496edc2be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala @@ -94,38 +94,75 @@ object Dsl { /** Aggregate function: returns the sum of all values in the expression. */ def sum(e: Column): Column = Sum(e.expr) + /** Aggregate function: returns the sum of all values in the given column. */ + def sum(columnName: String): Column = sum(Column(columnName)) + /** Aggregate function: returns the sum of distinct values in the expression. */ def sumDistinct(e: Column): Column = SumDistinct(e.expr) + /** Aggregate function: returns the sum of distinct values in the expression. */ + def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName)) + /** Aggregate function: returns the number of items in a group. */ def count(e: Column): Column = Count(e.expr) + /** Aggregate function: returns the number of items in a group. */ + def count(columnName: String): Column = count(Column(columnName)) + /** Aggregate function: returns the number of distinct items in a group. */ @scala.annotation.varargs def countDistinct(expr: Column, exprs: Column*): Column = CountDistinct((expr +: exprs).map(_.expr)) + /** Aggregate function: returns the number of distinct items in a group. */ + @scala.annotation.varargs + def countDistinct(columnName: String, columnNames: String*): Column = + countDistinct(Column(columnName), columnNames.map(Column.apply) :_*) + /** Aggregate function: returns the approximate number of distinct items in a group. */ def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr) + /** Aggregate function: returns the approximate number of distinct items in a group. */ + def approxCountDistinct(columnName: String): Column = approxCountDistinct(column(columnName)) + /** Aggregate function: returns the approximate number of distinct items in a group. */ def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd) + /** Aggregate function: returns the approximate number of distinct items in a group. */ + def approxCountDistinct(columnName: String, rsd: Double): Column = { + approxCountDistinct(Column(columnName), rsd) + } + /** Aggregate function: returns the average of the values in a group. */ def avg(e: Column): Column = Average(e.expr) + /** Aggregate function: returns the average of the values in a group. */ + def avg(columnName: String): Column = avg(Column(columnName)) + /** Aggregate function: returns the first value in a group. */ def first(e: Column): Column = First(e.expr) + /** Aggregate function: returns the first value of a column in a group. */ + def first(columnName: String): Column = first(Column(columnName)) + /** Aggregate function: returns the last value in a group. */ def last(e: Column): Column = Last(e.expr) + /** Aggregate function: returns the last value of the column in a group. */ + def last(columnName: String): Column = last(Column(columnName)) + /** Aggregate function: returns the minimum value of the expression in a group. */ def min(e: Column): Column = Min(e.expr) + /** Aggregate function: returns the minimum value of the column in a group. */ + def min(columnName: String): Column = min(Column(columnName)) + /** Aggregate function: returns the maximum value of the expression in a group. */ def max(e: Column): Column = Max(e.expr) + /** Aggregate function: returns the maximum value of the column in a group. */ + def max(columnName: String): Column = max(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala rename to sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 7963cb03126ba..3c20676355c9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Aggregate /** * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]]. */ -class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expression]) { +class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expression]) { private[this] implicit def toDataFrame(aggExprs: Seq[NamedExpression]): DataFrame = { val namedGroupingExprs = groupingExprs.map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala index 6b032d3d699a9..fedd7f06ef50a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala @@ -90,9 +90,9 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten override def apply(condition: Column): DataFrame = err() - override def groupBy(cols: Column*): GroupedDataFrame = err() + override def groupBy(cols: Column*): GroupedData = err() - override def groupBy(col1: String, cols: String*): GroupedDataFrame = err() + override def groupBy(col1: String, cols: String*): GroupedData = err() override def limit(n: Int): DataFrame = err()