Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-7462] By default retain group by columns in aggregate #5996

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 1 addition & 3 deletions R/pkg/R/group.R
Expand Up @@ -102,9 +102,7 @@ setMethod("agg",
}
}
jcols <- lapply(cols, function(c) { c@jc })
# the GroupedData.agg(col, cols*) API does not contain grouping Column
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "aggWithGrouping",
x@sgd, listToSeq(jcols))
sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1]))
} else {
stop("agg can only support Column or character")
}
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/dataframe.py
Expand Up @@ -1069,7 +1069,7 @@ def agg(self, *exprs):

>>> from pyspark.sql import functions as F
>>> gdf.agg(F.min(df.age)).collect()
[Row(MIN(age)=2), Row(MIN(age)=5)]
[Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)]
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
Expand Down
15 changes: 12 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
Expand Up @@ -135,8 +135,9 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
}

/**
* Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this
* class, the resulting [[DataFrame]] won't automatically include the grouping columns.
* Compute aggregates by specifying a series of aggregate columns. Note that this function by
* default retains the grouping columns in its output. To not retain grouping columns, set
* `spark.sql.retainGroupColumns` to false.
*
* The available aggregate methods are defined in [[org.apache.spark.sql.functions]].
*
Expand All @@ -158,7 +159,15 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
}
DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
val retainedExprs = groupingExprs.map {
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
}
DataFrame(df.sqlContext, Aggregate(groupingExprs, retainedExprs ++ aggExprs, df.logicalPlan))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about retainedExprs and aggExprs have the same expressions, should we distinct them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure if we should do that, since in dataframe you can technically duplicate columns.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, then i think this is ok, distinct them may make users confused.

} else {
DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
}
}

/**
Expand Down
6 changes: 6 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Expand Up @@ -74,6 +74,9 @@ private[spark] object SQLConf {
// See SPARK-6231.
val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = "spark.sql.selfJoinAutoResolveAmbiguity"

// Whether to retain group by columns or not in GroupedData.agg.
val DATAFRAME_RETAIN_GROUP_COLUMNS = "spark.sql.retainGroupColumns"

val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2"

val USE_JACKSON_STREAMING_API = "spark.sql.json.useJacksonStreamingAPI"
Expand Down Expand Up @@ -242,6 +245,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf {

private[spark] def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY, "true").toBoolean

private[spark] def dataFrameRetainGroupColumns: Boolean =
getConf(DATAFRAME_RETAIN_GROUP_COLUMNS, "true").toBoolean
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Increasingly wondering if dataframe flags should be scoped (eager analysis affects sql(...) too and not just dataframe DSL functions).,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's talk more about this. if we want to do it, we should do it in 1.4.


/** ********************** SQLConf functionality methods ************ */

Expand Down
11 changes: 0 additions & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
Expand Up @@ -72,17 +72,6 @@ private[r] object SQLUtils {
sqlContext.createDataFrame(rowRDD, schema)
}

// A helper to include grouping columns in Agg()
def aggWithGrouping(gd: GroupedData, exprs: Column*): DataFrame = {
val aggExprs = exprs.map { col =>
col.expr match {
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.simpleString)()
}
}
gd.toDF(aggExprs)
}

def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = {
df.map(r => rowToRBytes(r))
}
Expand Down
Expand Up @@ -104,7 +104,7 @@ private[sql] object StatFunctions extends Logging {
/** Generate a table of frequencies for the elements of two columns. */
private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
val tableName = s"${col1}_$col2"
val counts = df.groupBy(col1, col2).agg(col(col1), col(col2), count("*")).take(1e6.toInt)
val counts = df.groupBy(col1, col2).agg(count("*")).take(1e6.toInt)
if (counts.length == 1e6.toInt) {
logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " +
"the pairs. Please try reducing the amount of distinct items in your columns.")
Expand Down
@@ -0,0 +1,193 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types.DecimalType


class DataFrameAggregateSuite extends QueryTest {

test("groupBy") {
checkAnswer(
testData2.groupBy("a").agg(sum($"b")),
Seq(Row(1, 3), Row(2, 3), Row(3, 3))
)
checkAnswer(
testData2.groupBy("a").agg(sum($"b").as("totB")).agg(sum('totB)),
Row(9)
)
checkAnswer(
testData2.groupBy("a").agg(count("*")),
Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
)
checkAnswer(
testData2.groupBy("a").agg(Map("*" -> "count")),
Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
)
checkAnswer(
testData2.groupBy("a").agg(Map("b" -> "sum")),
Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
)

val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d"))
.toDF("key", "value1", "value2", "rest")

checkAnswer(
df1.groupBy("key").min(),
df1.groupBy("key").min("value1", "value2").collect()
)
checkAnswer(
df1.groupBy("key").min("value2"),
Seq(Row("a", 0), Row("b", 4))
)
}

test("spark.sql.retainGroupColumns config") {
checkAnswer(
testData2.groupBy("a").agg(sum($"b")),
Seq(Row(1, 3), Row(2, 3), Row(3, 3))
)

TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "false")
checkAnswer(
testData2.groupBy("a").agg(sum($"b")),
Seq(Row(3), Row(3), Row(3))
)
TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "true")
}

test("agg without groups") {
checkAnswer(
testData2.agg(sum('b)),
Row(9)
)
}

test("average") {
checkAnswer(
testData2.agg(avg('a)),
Row(2.0))

// Also check mean
checkAnswer(
testData2.agg(mean('a)),
Row(2.0))

checkAnswer(
testData2.agg(avg('a), sumDistinct('a)), // non-partial
Row(2.0, 6.0) :: Nil)

checkAnswer(
decimalData.agg(avg('a)),
Row(new java.math.BigDecimal(2.0)))
checkAnswer(
decimalData.agg(avg('a), sumDistinct('a)), // non-partial
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)

checkAnswer(
decimalData.agg(avg('a cast DecimalType(10, 2))),
Row(new java.math.BigDecimal(2.0)))
// non-partial
checkAnswer(
decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))),
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
}

test("null average") {
checkAnswer(
testData3.agg(avg('b)),
Row(2.0))

checkAnswer(
testData3.agg(avg('b), countDistinct('b)),
Row(2.0, 1))

checkAnswer(
testData3.agg(avg('b), sumDistinct('b)), // non-partial
Row(2.0, 2.0))
}

test("zero average") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
emptyTableData.agg(avg('a)),
Row(null))

checkAnswer(
emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial
Row(null, null))
}

test("count") {
assert(testData2.count() === testData2.map(_ => 1).count())

checkAnswer(
testData2.agg(count('a), sumDistinct('a)), // non-partial
Row(6, 6.0))
}

test("null count") {
checkAnswer(
testData3.groupBy('a).agg(count('b)),
Seq(Row(1,0), Row(2, 1))
)

checkAnswer(
testData3.groupBy('a).agg(count('a + 'b)),
Seq(Row(1,0), Row(2, 1))
)

checkAnswer(
testData3.agg(count('a), count('b), count(lit(1)), countDistinct('a), countDistinct('b)),
Row(2, 1, 2, 2, 1)
)

checkAnswer(
testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial
Row(1, 1, 2)
)
}

test("zero count") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
assert(emptyTableData.count() === 0)

checkAnswer(
emptyTableData.agg(count('a), sumDistinct('a)), // non-partial
Row(0, null))
}

test("zero sum") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
emptyTableData.agg(sum('a)),
Row(null))
}

test("zero sum distinct") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
emptyTableData.agg(sumDistinct('a)),
Row(null))
}

}
Expand Up @@ -77,8 +77,8 @@ class DataFrameJoinSuite extends QueryTest {
df.join(df, df("key") === df("key") && df("value") === 1),
Row(1, "1", 1, "1") :: Nil)

val left = df.groupBy("key").agg($"key", count("*"))
val right = df.groupBy("key").agg($"key", sum("key"))
val left = df.groupBy("key").agg(count("*"))
val right = df.groupBy("key").agg(sum("key"))
checkAnswer(
left.join(right, left("key") === right("key")),
Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil)
Expand Down