Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24369][SQL] Correct handling for multiple distinct aggregations having the same argument set #21487

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -384,9 +384,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

val (functionsWithDistinct, functionsWithoutDistinct) =
aggregateExpressions.partition(_.isDistinct)
if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
if (functionsWithDistinct.map(_.aggregateFunction.children.toSet).distinct.length > 1) {
// This is a sanity check. We should not reach here when we have multiple distinct
// column sets. Our MultipleDistinctRewriter should take care this case.
// column sets. Our `RewriteDistinctAggregates` should take care this case.
sys.error("You hit a query analyzer bug. Please report your query to " +
"Spark user mailing list.")
}
Expand Down
6 changes: 5 additions & 1 deletion sql/core/src/test/resources/sql-tests/inputs/group-by.sql
Expand Up @@ -68,4 +68,8 @@ SELECT 1 from (
FROM (select 1 as x) a
WHERE false
) b
where b.z != b.z
where b.z != b.z;

-- SPARK-24369 multiple distinct aggregations having the same argument set
SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*)
FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y);
11 changes: 10 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 26
-- Number of queries: 27


-- !query 0
Expand Down Expand Up @@ -241,3 +241,12 @@ where b.z != b.z
struct<1:int>
-- !query 25 output



-- !query 26
SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*)
FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y)
-- !query 26 schema
struct<corr(DISTINCT CAST(x AS DOUBLE), CAST(y AS DOUBLE)):double,corr(DISTINCT CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,count(1):bigint>
-- !query 26 output
1.0 1.0 3
Expand Up @@ -69,6 +69,27 @@ class PlannerSuite extends SharedSQLContext {
testPartialAggregationPlan(query)
}

test("mixed aggregates with same distinct columns") {
def assertNoExpand(plan: SparkPlan): Unit = {
assert(plan.collect { case e: ExpandExec => e }.isEmpty)
}

withTempView("v") {
Seq((1, 1.0, 1.0), (1, 2.0, 2.0)).toDF("i", "j", "k").createTempView("v")
// one distinct column
val query1 = sql("SELECT sum(DISTINCT j), max(DISTINCT j) FROM v GROUP BY i")
assertNoExpand(query1.queryExecution.executedPlan)

// 2 distinct columns
val query2 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT j, k) FROM v GROUP BY i")
assertNoExpand(query2.queryExecution.executedPlan)

// 2 distinct columns with different order
val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i")
assertNoExpand(query3.queryExecution.executedPlan)
}
}

test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
def checkPlan(fieldTypes: Seq[DataType]): Unit = {
withTempView("testLimit") {
Expand Down