Skip to content

Commit

Permalink
[SPARK-5605][SQL][DF] Allow using String to specify colum name in DSL…
Browse files Browse the repository at this point in the history
… aggregate functions.
  • Loading branch information
rxin committed Feb 5, 2015
1 parent dc101b0 commit f4b8dbb
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 16 deletions.
10 changes: 5 additions & 5 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
- L{DataFrame}
A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In
addition to normal RDD operations, DataFrames also support SQL.
- L{GroupedDataFrame}
- L{GroupedData}
- L{Column}
Column is a DataFrame with a single column.
- L{Row}
Expand Down Expand Up @@ -62,7 +62,7 @@
"StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType",
"DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
"ShortType", "ArrayType", "MapType", "StructField", "StructType",
"SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row", "Dsl",
"SQLContext", "HiveContext", "DataFrame", "GroupedData", "Column", "Row", "Dsl",
"SchemaRDD"]


Expand Down Expand Up @@ -2231,7 +2231,7 @@ def filter(self, condition):

def groupBy(self, *cols):
""" Group the :class:`DataFrame` using the specified columns,
so we can run aggregation on them. See :class:`GroupedDataFrame`
so we can run aggregation on them. See :class:`GroupedData`
for all the available aggregate functions.
>>> df.groupBy().avg().collect()
Expand All @@ -2244,7 +2244,7 @@ def groupBy(self, *cols):
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
return GroupedDataFrame(jdf, self.sql_ctx)
return GroupedData(jdf, self.sql_ctx)

def agg(self, *exprs):
""" Aggregate on the entire :class:`DataFrame` without groups
Expand Down Expand Up @@ -2308,7 +2308,7 @@ def _api(self):
return _api


class GroupedDataFrame(object):
class GroupedData(object):

"""
A set of methods for aggregations on a :class:`DataFrame`,
Expand Down
8 changes: 4 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ trait DataFrame extends RDDApi[Row] {

/**
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
* See [[GroupedDataFrame]] for all the available aggregate functions.
* See [[GroupedData]] for all the available aggregate functions.
*
* {{{
* // Compute the average for all numeric columns grouped by department.
Expand All @@ -304,11 +304,11 @@ trait DataFrame extends RDDApi[Row] {
* }}}
*/
@scala.annotation.varargs
def groupBy(cols: Column*): GroupedDataFrame
def groupBy(cols: Column*): GroupedData

/**
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
* See [[GroupedDataFrame]] for all the available aggregate functions.
* See [[GroupedData]] for all the available aggregate functions.
*
* This is a variant of groupBy that can only group by existing columns using column names
* (i.e. cannot construct expressions).
Expand All @@ -325,7 +325,7 @@ trait DataFrame extends RDDApi[Row] {
* }}}
*/
@scala.annotation.varargs
def groupBy(col1: String, cols: String*): GroupedDataFrame
def groupBy(col1: String, cols: String*): GroupedData

/**
* (Scala-specific) Compute aggregates by specifying a map from column name to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,13 @@ private[sql] class DataFrameImpl protected[sql](
filter(condition)
}

override def groupBy(cols: Column*): GroupedDataFrame = {
new GroupedDataFrame(this, cols.map(_.expr))
override def groupBy(cols: Column*): GroupedData = {
new GroupedData(this, cols.map(_.expr))
}

override def groupBy(col1: String, cols: String*): GroupedDataFrame = {
override def groupBy(col1: String, cols: String*): GroupedData = {
val colNames: Seq[String] = col1 +: cols
new GroupedDataFrame(this, colNames.map(colName => resolve(colName)))
new GroupedData(this, colNames.map(colName => resolve(colName)))
}

override def limit(n: Int): DataFrame = {
Expand Down
37 changes: 37 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,38 +94,75 @@ object Dsl {
/** Aggregate function: returns the sum of all values in the expression. */
def sum(e: Column): Column = Sum(e.expr)

/** Aggregate function: returns the sum of all values in the given column. */
def sum(columnName: String): Column = sum(Column(columnName))

/** Aggregate function: returns the sum of distinct values in the expression. */
def sumDistinct(e: Column): Column = SumDistinct(e.expr)

/** Aggregate function: returns the sum of distinct values in the expression. */
def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName))

/** Aggregate function: returns the number of items in a group. */
def count(e: Column): Column = Count(e.expr)

/** Aggregate function: returns the number of items in a group. */
def count(columnName: String): Column = count(Column(columnName))

/** Aggregate function: returns the number of distinct items in a group. */
@scala.annotation.varargs
def countDistinct(expr: Column, exprs: Column*): Column =
CountDistinct((expr +: exprs).map(_.expr))

/** Aggregate function: returns the number of distinct items in a group. */
@scala.annotation.varargs
def countDistinct(columnName: String, columnNames: String*): Column =
countDistinct(Column(columnName), columnNames.map(Column.apply) :_*)

/** Aggregate function: returns the approximate number of distinct items in a group. */
def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr)

/** Aggregate function: returns the approximate number of distinct items in a group. */
def approxCountDistinct(columnName: String): Column = approxCountDistinct(column(columnName))

/** Aggregate function: returns the approximate number of distinct items in a group. */
def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd)

/** Aggregate function: returns the approximate number of distinct items in a group. */
def approxCountDistinct(columnName: String, rsd: Double): Column = {
approxCountDistinct(Column(columnName), rsd)
}

/** Aggregate function: returns the average of the values in a group. */
def avg(e: Column): Column = Average(e.expr)

/** Aggregate function: returns the average of the values in a group. */
def avg(columnName: String): Column = avg(Column(columnName))

/** Aggregate function: returns the first value in a group. */
def first(e: Column): Column = First(e.expr)

/** Aggregate function: returns the first value of a column in a group. */
def first(columnName: String): Column = first(Column(columnName))

/** Aggregate function: returns the last value in a group. */
def last(e: Column): Column = Last(e.expr)

/** Aggregate function: returns the last value of the column in a group. */
def last(columnName: String): Column = last(Column(columnName))

/** Aggregate function: returns the minimum value of the expression in a group. */
def min(e: Column): Column = Min(e.expr)

/** Aggregate function: returns the minimum value of the column in a group. */
def min(columnName: String): Column = min(Column(columnName))

/** Aggregate function: returns the maximum value of the expression in a group. */
def max(e: Column): Column = Max(e.expr)

/** Aggregate function: returns the maximum value of the column in a group. */
def max(columnName: String): Column = max(Column(columnName))

//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Aggregate
/**
* A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]].
*/
class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expression]) {
class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expression]) {

private[this] implicit def toDataFrame(aggExprs: Seq[NamedExpression]): DataFrame = {
val namedGroupingExprs = groupingExprs.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten

override def apply(condition: Column): DataFrame = err()

override def groupBy(cols: Column*): GroupedDataFrame = err()
override def groupBy(cols: Column*): GroupedData = err()

override def groupBy(col1: String, cols: String*): GroupedDataFrame = err()
override def groupBy(col1: String, cols: String*): GroupedData = err()

override def limit(n: Int): DataFrame = err()

Expand Down

0 comments on commit f4b8dbb

Please sign in to comment.