diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ddfe80443d561..66d6045f54e43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -252,6 +252,7 @@ abstract class Optimizer(catalogManager: CatalogManager) // This batch must run after "Decimal Optimizations", as that one may change the // aggregate distinct column Batch("Distinct Aggregate Rewrite", Once, + RewriteCountDistinctConditional, RewriteDistinctAggregates, OptimizeExpand), Batch("Object Expressions Optimization", fixedPoint, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditional.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditional.scala new file mode 100644 index 0000000000000..18861ec0fb419 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditional.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE +import org.apache.spark.sql.internal.SQLConf + +/** + * Rewrites COUNT(DISTINCT IF(cond, base, NULL)) and + * COUNT(DISTINCT CASE WHEN cond THEN base END) into + * COUNT(DISTINCT base) FILTER (WHERE cond). + * + * This canonicalization reduces the number of distinct groups seen by + * RewriteDistinctAggregates from N (one per unique conditional expression) down to 1 + * (all share the same base column), collapsing the Expand factor from Nx to 1x. + * + * Correctness: COUNT DISTINCT ignores NULLs, so nulling out rows where !cond + * is semantically identical to filtering those rows out entirely. + */ +object RewriteCountDistinctConditional extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!SQLConf.get.rewriteCountDistinctConditionalEnabled) { + return plan + } + plan.transformUpWithPruning(_.containsPattern(AGGREGATE)) { + case agg: Aggregate => agg.transformExpressionsUp { + case ae @ AggregateExpression( + count: Count, + _, + true, // isDistinct + None, // no existing FILTER + _) + if count.children.size == 1 => + extractCondAndBase(count.children.head) match { + case Some((cond, base)) => + ae.copy( + aggregateFunction = count.withNewChildren(Seq(base)).asInstanceOf[Count], + filter = Some(cond)) + case None => ae + } + } + } + } + + /** + * Matches: + * IF(cond, base, null) + * CASE WHEN cond THEN base END + * CASE WHEN cond THEN base ELSE NULL END + * + * The analyzer may wrap the null branch in a Cast for type alignment, so the + * null check is done after unwrapping any surrounding Casts. + * + * Returns None for anything else, including IF(cond, base, fallback) where + * fallback is not null -- those change semantics and must not be rewritten. + */ + private def extractCondAndBase(expr: Expression): Option[(Expression, Expression)] = + expr match { + case If(cond, base, e) if isNullExpr(e) => Some((cond, base)) + case CaseWhen(Seq((cond, base)), None) => Some((cond, base)) + case CaseWhen(Seq((cond, base)), Some(e)) if isNullExpr(e) => Some((cond, base)) + case _ => None + } + + private def isNullExpr(e: Expression): Boolean = e match { + case Literal(null, _) => true + case Cast(child, _, _, _) => isNullExpr(child) + case _ => false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5ed831f20f394..2c16a5c4e876a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1366,6 +1366,18 @@ object SQLConf { .booleanConf .createWithDefault(true) + val REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED = + buildConf("spark.sql.optimizer.rewriteCountDistinctConditional.enabled") + .doc("When true, rewrites COUNT(DISTINCT IF(cond, base, NULL)) and " + + "COUNT(DISTINCT CASE WHEN cond THEN base END) into " + + "COUNT(DISTINCT base) FILTER (WHERE cond). This reduces the Expand factor " + + "in RewriteDistinctAggregates from Nx to 1x when multiple conditional distinct " + + "counts share the same base column.") + .version("4.2.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) + .booleanConf + .createWithDefault(false) + val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") .internal() .doc("When true, string literals (including regex patterns) remain escaped in our SQL " + @@ -8473,6 +8485,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def decorrelateInnerQueryEnabledForExistsIn: Boolean = !getConf(SQLConf.DECORRELATE_EXISTS_IN_SUBQUERY_LEGACY_INCORRECT_COUNT_HANDLING_ENABLED) + def rewriteCountDistinctConditionalEnabled: Boolean = + getConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED) + def maxConcurrentOutputFileWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS) def plannedWriteEnabled: Boolean = getConf(SQLConf.PLANNED_WRITE_ENABLED) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditionalSuite.scala new file mode 100644 index 0000000000000..88520a472e45b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditionalSuite.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Expression, If, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.IntegerType + +class RewriteCountDistinctConditionalSuite extends PlanTest { + + val testRelation = LocalRelation( + Symbol("a").int, Symbol("b").int, Symbol("c").int, Symbol("d").string) + + private def countDistinctIf(cond: Expression, base: Expression): Expression = { + Count(If(cond, base, Literal(null))).toAggregateExpression(isDistinct = true) + } + + private def countDistinctCaseWhen(cond: Expression, base: Expression): Expression = { + val caseWhen = org.apache.spark.sql.catalyst.expressions.CaseWhen( + Seq((cond, base)), + None) + Count(caseWhen).toAggregateExpression(isDistinct = true) + } + + private def countDistinctCaseWhenElseNull(cond: Expression, base: Expression): Expression = { + val caseWhen = org.apache.spark.sql.catalyst.expressions.CaseWhen( + Seq((cond, base)), + Some(Literal(null))) + Count(caseWhen).toAggregateExpression(isDistinct = true) + } + + test("disabled by default") { + val input = testRelation + .groupBy(Symbol("a"))( + countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + + val optimized = RewriteCountDistinctConditional(input) + comparePlans(optimized, input) + } + + test("rewrite COUNT(DISTINCT IF(cond, col, NULL)) to COUNT(DISTINCT col) FILTER") { + val input = testRelation + .groupBy(Symbol("a"))( + countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + + val expected = testRelation + .groupBy(Symbol("a"))( + countDistinctWithFilter(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + + comparePlans(optimized, expected) + } + } + + test("rewrite COUNT(DISTINCT CASE WHEN cond THEN col END) to COUNT(DISTINCT col) FILTER") { + val input = testRelation + .groupBy(Symbol("a"))( + countDistinctCaseWhen(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + + val expected = testRelation + .groupBy(Symbol("a"))( + countDistinctWithFilter(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + + comparePlans(optimized, expected) + } + } + + test("rewrite COUNT(DISTINCT CASE WHEN cond THEN col ELSE NULL END)") { + val input = testRelation + .groupBy(Symbol("a"))( + countDistinctCaseWhenElseNull(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + + val expected = testRelation + .groupBy(Symbol("a"))( + countDistinctWithFilter(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + + comparePlans(optimized, expected) + } + } + + test("multiple conditional distinct counts collapse to single distinct group") { + val input = testRelation + .groupBy(Symbol("a"))( + countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1"), + countDistinctIf(Symbol("b") > 2, Symbol("c")).as("cnt2"), + countDistinctIf(Symbol("b") > 3, Symbol("c")).as("cnt3")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + + val expected = testRelation + .groupBy(Symbol("a"))( + countDistinctWithFilter(Symbol("b") > 1, Symbol("c")).as("cnt1"), + countDistinctWithFilter(Symbol("b") > 2, Symbol("c")).as("cnt2"), + countDistinctWithFilter(Symbol("b") > 3, Symbol("c")).as("cnt3")) + .analyze + + comparePlans(optimized, expected) + + // Verify RewriteDistinctAggregates sees only 1 distinct group + val rewritten = RewriteDistinctAggregates(optimized) + // Should be rewritten (not same as input) because there are now multiple + // distinct expressions with the same base column + assert(rewritten != optimized) + } + } + + test("do not rewrite IF with non-null else branch") { + val input = testRelation + .groupBy(Symbol("a"))( + Count(If(Symbol("b") > 1, Symbol("c"), Literal(0, IntegerType))) + .toAggregateExpression(isDistinct = true) + .as("cnt1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + comparePlans(optimized, input) + } + } + + test("do not rewrite non-distinct COUNT") { + val input = testRelation + .groupBy(Symbol("a"))( + Count(If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType))) + .toAggregateExpression(isDistinct = false) + .as("cnt1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + comparePlans(optimized, input) + } + } + + test("do not rewrite when FILTER already exists") { + val input = testRelation + .groupBy(Symbol("a"))( + Count(If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType))) + .toAggregateExpression(isDistinct = true, filter = Some(Symbol("d") === "x")) + .as("cnt1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + comparePlans(optimized, input) + } + } + + test("do not rewrite multi-branch CASE WHEN") { + val caseWhen = new org.apache.spark.sql.catalyst.expressions.CaseWhen( + Seq( + (Symbol("b") > Literal(1), Symbol("c")), + (Symbol("b") > Literal(2), Symbol("a"))), + Some(Literal(null))) + val input = testRelation + .groupBy(Symbol("a"))( + Count(caseWhen).toAggregateExpression(isDistinct = true).as("cnt1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + comparePlans(optimized, input) + } + } + + test("do not rewrite SUM(DISTINCT IF(...))") { + val input = testRelation + .groupBy(Symbol("a"))( + org.apache.spark.sql.catalyst.expressions.aggregate.Sum( + If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType))) + .toAggregateExpression(isDistinct = true) + .as("sum1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + comparePlans(optimized, input) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RewriteCountDistinctConditionalQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RewriteCountDistinctConditionalQuerySuite.scala new file mode 100644 index 0000000000000..e83ce5b380c3a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/RewriteCountDistinctConditionalQuerySuite.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class RewriteCountDistinctConditionalQuerySuite extends QueryTest with SharedSparkSession { + + private def checkRewriteAndResult( + conditionalSql: String, + filterSql: String): Unit = { + withTempView("t") { + // Verify the rewrite produces the same result as the explicit FILTER form. + val withRewrite = withSQLConf( + SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + spark.sql(conditionalSql).collect() + } + val withoutRewrite = withSQLConf( + SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "false") { + spark.sql(conditionalSql).collect() + } + val explicitFilter = spark.sql(filterSql).collect() + + assert(withRewrite.sameElements(explicitFilter), + "Rewritten query should match explicit FILTER query") + assert(withoutRewrite.sameElements(explicitFilter), + "Non-rewritten query should also match explicit FILTER query") + } + } + + test("rewrite COUNT(DISTINCT IF(cond, col, NULL)) correctness") { + withTempView("t") { + spark.range(7) + .selectExpr( + "cast(id % 3 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "case when id % 4 = 0 then null else cast(id * 100 as int) end as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t GROUP BY key", + "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t GROUP BY key") + } + } + + test("rewrite COUNT(DISTINCT CASE WHEN cond THEN col END) correctness") { + withTempView("t") { + spark.range(7) + .selectExpr( + "cast(id % 3 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "case when id % 4 = 0 then null else cast(id * 100 as string) end as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + "SELECT key, COUNT(DISTINCT CASE WHEN col1 > 10 THEN col2 END) FROM t GROUP BY key", + "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t GROUP BY key") + } + } + + test("rewrite COUNT(DISTINCT CASE WHEN cond THEN col ELSE NULL END) correctness") { + withTempView("t") { + spark.range(6) + .selectExpr( + "cast(id % 2 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "case when id % 4 = 0 then null else cast(id * 1.0 as double) end as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + """SELECT key, COUNT(DISTINCT CASE WHEN col1 > 10 THEN col2 ELSE NULL END) + |FROM t GROUP BY key""".stripMargin, + "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t GROUP BY key") + } + } + + test("rewrite with no GROUP BY") { + withTempView("t") { + spark.range(5) + .selectExpr( + "cast(id * 10 as int) as col1", + "case when id % 3 = 0 then null else cast(id * 100 as int) end as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + "SELECT COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t", + "SELECT COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t") + } + } + + test("rewrite with all NULLs in conditional branch") { + withTempView("t") { + spark.range(3) + .selectExpr( + "cast(id % 2 + 1 as int) as key", + "cast(id * 5 as int) as col1", + "cast(id * 100 as int) as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t GROUP BY key", + "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t GROUP BY key") + } + } + + test("rewrite with duplicates in base column") { + withTempView("t") { + spark.range(6) + .selectExpr( + "cast(id % 2 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "case when id % 3 = 0 then 100 when id % 3 = 1 then 100 else 200 end as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t GROUP BY key", + "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t GROUP BY key") + } + } + + test("multiple conditional distinct counts collapse and produce correct results") { + withTempView("t") { + spark.range(5) + .selectExpr( + "cast(id % 2 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "case when id % 3 = 0 then null else cast(id * 100 as int) end as col2", + "case when id % 4 = 0 then null else cast(id * 10 as string) end as col3") + .createOrReplaceTempView("t") + + val conditionalSql = + """SELECT key, + | COUNT(DISTINCT IF(col1 > 10, col2, NULL)) as cnt1, + | COUNT(DISTINCT IF(col1 > 5, col3, NULL)) as cnt2 + |FROM t GROUP BY key""".stripMargin + + val filterSql = + """SELECT key, + | COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) as cnt1, + | COUNT(DISTINCT col3) FILTER (WHERE col1 > 5) as cnt2 + |FROM t GROUP BY key""".stripMargin + + checkRewriteAndResult(conditionalSql, filterSql) + } + } + + test("rewrite does not affect COUNT(DISTINCT IF(cond, col, non_null))") { + withTempView("t") { + spark.range(3) + .selectExpr( + "cast(id % 2 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "cast(id * 100 as int) as col2") + .createOrReplaceTempView("t") + + val sqlText = "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, 0)) FROM t GROUP BY key" + + val withRewrite = withSQLConf( + SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + spark.sql(sqlText).collect() + } + val withoutRewrite = withSQLConf( + SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "false") { + spark.sql(sqlText).collect() + } + + assert(withRewrite.sameElements(withoutRewrite), + "Non-null ELSE branch should not be rewritten") + } + } + + test("rewrite is present in optimized plan") { + withTempView("t") { + spark.range(2) + .selectExpr( + "cast(id + 1 as int) as key", + "cast(id * 10 as int) as col1", + "cast(id * 100 as int) as col2") + .createOrReplaceTempView("t") + + val planStr = withSQLConf( + SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val df = spark.sql( + "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t GROUP BY key") + df.queryExecution.optimizedPlan.toString + } + + assert(planStr.contains("FILTER"), + s"Optimized plan should contain FILTER clause. Plan:\n$planStr") + } + } +}