Skip to content

Commit

Permalink
groupby -> groupBy.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Jan 27, 2015
1 parent 9662c9e commit e8aa3d3
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 37 deletions.
16 changes: 8 additions & 8 deletions python/pyspark/sql.py
Expand Up @@ -1838,7 +1838,7 @@ class DataFrame(object):
department = sqlContext.parquetFile("...")
people.filter(people.age > 30).join(department, people.deptId == department.id)) \
.groupby(department.name, "gender").agg({"salary": "avg", "age": "max"})
.groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"})
"""

def __init__(self, jdf, sql_ctx):
Expand Down Expand Up @@ -2178,13 +2178,13 @@ def filter(self, condition):

where = filter

def groupby(self, *cols):
def groupBy(self, *cols):
""" Group the [[DataFrame]] using the specified columns,
so we can run aggregation on them. See :class:`GroupedDataFrame`
for all the available aggregate functions::
df.groupby(df.department).avg()
df.groupby("department", "gender").agg({
df.groupBy(df.department).avg()
df.groupBy("department", "gender").agg({
"salary": "avg",
"age": "max",
})
Expand All @@ -2194,16 +2194,16 @@ def groupby(self, *cols):
else:
cols = [c._jc for c in cols]
jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
jdf = self._jdf.groupby(self._jdf.toColumnArray(jcols))
jdf = self._jdf.groupBy(self._jdf.toColumnArray(jcols))
return GroupedDataFrame(jdf, self.sql_ctx)

def agg(self, *exprs):
""" Aggregate on the entire [[DataFrame]] without groups
(shorthand for df.groupby.agg())::
(shorthand for df.groupBy.agg())::
df.agg({"age": "max", "salary": "avg"})
"""
return self.groupby().agg(*exprs)
return self.groupBy().agg(*exprs)

def unionAll(self, other):
""" Return a new DataFrame containing union of rows in this
Expand Down Expand Up @@ -2266,7 +2266,7 @@ class GroupedDataFrame(object):

"""
A set of methods for aggregations on a :class:`DataFrame`,
created by DataFrame.groupby().
created by DataFrame.groupBy().
"""

def __init__(self, jdf, sql_ctx):
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/tests.py
Expand Up @@ -986,7 +986,7 @@ def test_column_select(self):
def test_aggregator(self):
from pyspark.sql import Aggregator as Agg
df = self.df
g = df.groupby()
g = df.groupBy()
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
# self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
Expand Down
37 changes: 22 additions & 15 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Expand Up @@ -77,7 +77,7 @@ import org.apache.spark.util.Utils
*
* people.filter("age" > 30)
* .join(department, people("deptId") === department("id"))
* .groupby(department("name"), "gender")
* .groupBy(department("name"), "gender")
* .agg(avg(people("salary")), max(people("age")))
* }}}
*/
Expand Down Expand Up @@ -331,64 +331,64 @@ class DataFrame protected[sql](
*
* {{{
* // Compute the average for all numeric columns grouped by department.
* df.groupby($"department").avg()
* df.groupBy($"department").avg()
*
* // Compute the max age and average salary, grouped by department and gender.
* df.groupby($"department", $"gender").agg(Map(
* df.groupBy($"department", $"gender").agg(Map(
* "salary" -> "avg",
* "age" -> "max"
* ))
* }}}
*/
@scala.annotation.varargs
override def groupby(cols: Column*): GroupedDataFrame = {
override def groupBy(cols: Column*): GroupedDataFrame = {
new GroupedDataFrame(this, cols.map(_.expr))
}

/**
* Group the [[DataFrame]] using the specified columns, so we can run aggregation on them.
* See [[GroupedDataFrame]] for all the available aggregate functions.
*
* This is a variant of groupby that can only group by existing columns using column names
* This is a variant of groupBy that can only group by existing columns using column names
* (i.e. cannot construct expressions).
*
* {{{
* // Compute the average for all numeric columns grouped by department.
* df.groupby("department").avg()
* df.groupBy("department").avg()
*
* // Compute the max age and average salary, grouped by department and gender.
* df.groupby($"department", $"gender").agg(Map(
* df.groupBy($"department", $"gender").agg(Map(
* "salary" -> "avg",
* "age" -> "max"
* ))
* }}}
*/
@scala.annotation.varargs
override def groupby(col1: String, cols: String*): GroupedDataFrame = {
override def groupBy(col1: String, cols: String*): GroupedDataFrame = {
val colNames: Seq[String] = col1 +: cols
new GroupedDataFrame(this, colNames.map(colName => resolve(colName)))
}

/**
* Aggregate on the entire [[DataFrame]] without groups.
* {{
* // df.agg(...) is a shorthand for df.groupby().agg(...)
* // df.agg(...) is a shorthand for df.groupBy().agg(...)
* df.agg(Map("age" -> "max", "salary" -> "avg"))
* df.groupby().agg(Map("age" -> "max", "salary" -> "avg"))
* df.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
* }}
*/
override def agg(exprs: Map[String, String]): DataFrame = groupby().agg(exprs)
override def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs)

/**
* Aggregate on the entire [[DataFrame]] without groups.
* {{
* // df.agg(...) is a shorthand for df.groupby().agg(...)
* // df.agg(...) is a shorthand for df.groupBy().agg(...)
* df.agg(max($"age"), avg($"salary"))
* df.groupby().agg(max($"age"), avg($"salary"))
* df.groupBy().agg(max($"age"), avg($"salary"))
* }}
*/
@scala.annotation.varargs
override def agg(expr: Column, exprs: Column*): DataFrame = groupby().agg(expr, exprs :_*)
override def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*)

/**
* Return a new [[DataFrame]] by taking the first `n` rows. The difference between this function
Expand Down Expand Up @@ -484,7 +484,14 @@ class DataFrame protected[sql](
/**
* Return the number of rows in the [[DataFrame]].
*/
override def count(): Long = groupby().count().rdd.collect().head.getLong(0)
override def count(): Long = groupBy().count().rdd.collect().head.getLong(0)

/**
* Return a new [[DataFrame]] that has exactly `numPartitions` partitions.
*/
override def repartition(numPartitions: Int): DataFrame = {
sqlContext.applySchema(rdd.repartition(numPartitions), schema)
}

override def persist(): this.type = {
sqlContext.cacheQuery(this)
Expand Down
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Aggregate


/**
* A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupby]].
* A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]].
*/
class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
extends GroupedDataFrameApi {
Expand Down Expand Up @@ -62,7 +62,7 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
* df.groupby("department").agg(Map(
* df.groupBy("department").agg(Map(
* "age" -> "max"
* "sum" -> "expense"
* ))
Expand All @@ -80,7 +80,7 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
* df.groupby("department").agg(Map(
* df.groupBy("department").agg(Map(
* "age" -> "max"
* "sum" -> "expense"
* ))
Expand All @@ -96,7 +96,7 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
* import org.apache.spark.sql.dsl._
* df.groupby("department").agg(max($"age"), sum($"expense"))
* df.groupBy("department").agg(max($"age"), sum($"expense"))
* }}}
*/
@scala.annotation.varargs
Expand Down
5 changes: 3 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/api.scala
Expand Up @@ -54,6 +54,7 @@ trait RDDApi[T] {

def count(): Long

def repartition(numPartitions: Int): DataFrame
}


Expand Down Expand Up @@ -97,10 +98,10 @@ trait DataFrameSpecificApi {
def where(condition: Column): DataFrame

@scala.annotation.varargs
def groupby(cols: Column*): GroupedDataFrame
def groupBy(cols: Column*): GroupedDataFrame

@scala.annotation.varargs
def groupby(col1: String, cols: String*): GroupedDataFrame
def groupBy(col1: String, cols: String*): GroupedDataFrame

def agg(exprs: Map[String, String]): DataFrame

Expand Down
Expand Up @@ -42,11 +42,11 @@ class DslQuerySuite extends QueryTest {

test("agg") {
checkAnswer(
testData2.groupby("a").agg($"a", sum($"b")),
testData2.groupBy("a").agg($"a", sum($"b")),
Seq(Row(1,3), Row(2,3), Row(3,3))
)
checkAnswer(
testData2.groupby("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)),
testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)),
Row(9)
)
checkAnswer(
Expand Down Expand Up @@ -205,12 +205,12 @@ class DslQuerySuite extends QueryTest {

test("null count") {
checkAnswer(
testData3.groupby('a).agg('a, count('b)),
testData3.groupBy('a).agg('a, count('b)),
Seq(Row(1,0), Row(2, 1))
)

checkAnswer(
testData3.groupby('a).agg('a, count('a + 'b)),
testData3.groupBy('a).agg('a, count('a + 'b)),
Seq(Row(1,0), Row(2, 1))
)

Expand Down
Expand Up @@ -42,22 +42,22 @@ class PlannerSuite extends FunSuite {
}

test("count is partially aggregated") {
val query = testData.groupby('value).agg(count('key)).queryExecution.analyzed
val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
val planned = HashAggregation(query).head
val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }

assert(aggregations.size === 2)
}

test("count distinct is partially aggregated") {
val query = testData.groupby('value).agg(countDistinct('key)).queryExecution.analyzed
val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
val planned = HashAggregation(query)
assert(planned.nonEmpty)
}

test("mixed aggregates are partially aggregated") {
val query =
testData.groupby('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
val planned = HashAggregation(query)
assert(planned.nonEmpty)
}
Expand Down

0 comments on commit e8aa3d3

Please sign in to comment.