Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,25 @@ import org.apache.spark.util.collection.Utils
* techniques.
*/
object RewriteDistinctAggregates extends Rule[LogicalPlan] {
private def mustRewrite(
distinctAggs: Seq[AggregateExpression],
groupingExpressions: Seq[Expression]): Boolean = {
// If there are any distinct AggregateExpressions with filter, we need to rewrite the query.
// Also, if there are no grouping expressions and all distinct aggregate expressions are
// foldable, we need to rewrite the query, e.g. SELECT COUNT(DISTINCT 1). Without this case,
// non-grouping aggregation queries with distinct aggregate expressions will be incorrectly
// handled by the aggregation strategy, causing wrong results when working with empty tables.
distinctAggs.exists(_.filter.isDefined) || (groupingExpressions.isEmpty &&
distinctAggs.exists(_.aggregateFunction.children.forall(_.foldable)))
}

private def mayNeedtoRewrite(a: Aggregate): Boolean = {
val aggExpressions = collectAggregateExprs(a)
val distinctAggs = aggExpressions.filter(_.isDistinct)
// We need at least two distinct aggregates or the single distinct aggregate group exists filter
// clause for this rule because aggregation strategy can handle a single distinct aggregate
// group without filter clause.
// This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a).
distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined)
distinctAggs.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions)
}

def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
Expand Down Expand Up @@ -236,7 +246,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}

// Aggregation strategy can handle queries with a single distinct group without filter clause.
if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) {
if (distinctAggGroups.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions)) {
// Create the attributes for the grouping id and the group by clause.
val gid = AttributeReference("gid", IntegerType, nullable = false)()
val groupByMap = a.groupingExpressions.collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.util.Random
import org.scalatest.matchers.must.Matchers.the

import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.sql.catalyst.plans.logical.Expand
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
Expand Down Expand Up @@ -2150,6 +2151,116 @@ class DataFrameAggregateSuite extends QueryTest
checkAnswer(df, Row(1, 2, 2) :: Row(3, 1, 1) :: Nil)
}
}

test("aggregating with various distinct expressions") {
abstract class AggregateTestCaseBase(
val query: String,
val resultSeq: Seq[Seq[Row]],
val hasExpandNodeInPlan: Boolean)
case class AggregateTestCase(
override val query: String,
override val resultSeq: Seq[Seq[Row]],
override val hasExpandNodeInPlan: Boolean)
extends AggregateTestCaseBase(query, resultSeq, hasExpandNodeInPlan)
case class AggregateTestCaseDefault(
override val query: String)
extends AggregateTestCaseBase(
query,
Seq(Seq(Row(0)), Seq(Row(1)), Seq(Row(1))),
hasExpandNodeInPlan = true)

val t = "t"
val testCases: Seq[AggregateTestCaseBase] = Seq(
AggregateTestCaseDefault(
s"""SELECT COUNT(DISTINCT "col") FROM $t"""
),
AggregateTestCaseDefault(
s"SELECT COUNT(DISTINCT 1) FROM $t"
),
AggregateTestCaseDefault(
s"SELECT COUNT(DISTINCT 1 + 2) FROM $t"
),
AggregateTestCaseDefault(
s"SELECT COUNT(DISTINCT 1, 2, 1 + 2) FROM $t"
),
AggregateTestCase(
s"SELECT COUNT(1), COUNT(DISTINCT 1) FROM $t",
Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(2, 1))),
hasExpandNodeInPlan = true
),
AggregateTestCaseDefault(
s"""SELECT COUNT(DISTINCT 1, "col") FROM $t"""
),
AggregateTestCaseDefault(
s"""SELECT COUNT(DISTINCT current_date()) FROM $t"""
),
AggregateTestCaseDefault(
s"""SELECT COUNT(DISTINCT array(1, 2)[1]) FROM $t"""
),
AggregateTestCaseDefault(
s"""SELECT COUNT(DISTINCT map(1, 2)[1]) FROM $t"""
),
AggregateTestCaseDefault(
s"""SELECT COUNT(DISTINCT struct(1, 2).col1) FROM $t"""
),
AggregateTestCase(
s"SELECT COUNT(DISTINCT 1) FROM $t GROUP BY col",
Seq(Seq(), Seq(Row(1)), Seq(Row(1), Row(1))),
hasExpandNodeInPlan = false
),
AggregateTestCaseDefault(
s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 1"
),
AggregateTestCase(
s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 0",
Seq(Seq(Row(0)), Seq(Row(0)), Seq(Row(0))),
hasExpandNodeInPlan = false
),
AggregateTestCase(
s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)",
Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))),
hasExpandNodeInPlan = false
),
AggregateTestCase(
s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(1) FROM $t)",
Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))),
hasExpandNodeInPlan = false
),
AggregateTestCase(
s"SELECT SUM(1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)",
Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))),
hasExpandNodeInPlan = false
),
AggregateTestCaseDefault(
s"SELECT SUM(x) FROM (SELECT COUNT(DISTINCT 1) AS x FROM $t)"),
AggregateTestCase(
s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT "col") FROM $t""",
Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 1))),
hasExpandNodeInPlan = true
),
AggregateTestCase(
s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT col) FROM $t""",
Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 2))),
hasExpandNodeInPlan = true
)
)
withTable(t) {
sql(s"create table $t(col int) using parquet")
Seq(0, 1, 2).foreach(columnValue => {
if (columnValue != 0) {
sql(s"insert into $t(col) values($columnValue)")
}
testCases.foreach(testCase => {
val query = sql(testCase.query)
checkAnswer(query, testCase.resultSeq(columnValue))
val hasExpandNodeInPlan = query.queryExecution.optimizedPlan.collectFirst {
case _: Expand => true
}.nonEmpty
assert(hasExpandNodeInPlan == testCase.hasExpandNodeInPlan)
})
})
}
}
}

case class B(c: Option[Double])
Expand Down