Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-40193][SQL] Merge subquery plans with different filters #37630

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -4434,6 +4434,24 @@ object SQLConf {
.booleanConf
.createOptional

val PLAN_MERGE_FILTER_PROPAGATION_ENABLED =
buildConf("spark.sql.planMerge.filterPropagation.enabled")
.internal()
.doc(s"When set to true different filters can be propagated up to aggregates.")
.version("4.0.0")
.booleanConf
.createWithDefault(true)

val PLAN_MERGE_FILTER_PROPAGATION_MAX_COST =
buildConf("spark.sql.planMerge.filterPropagation.maxCost")
.internal()
.doc("The maximum allowed additional cost of merging. By setting this value to -1 filter " +
"propagation is always allowed.")
.version("4.0.0")
.doubleConf
.checkValue(c => c >= 0 || c == -1, "The maximum allowed cost must not be negative")
.createWithDefault(100)

val ERROR_MESSAGE_FORMAT = buildConf("spark.sql.error.messageFormat")
.doc("When PRETTY, the error message consists of textual representation of error class, " +
"message and query context. The MINIMAL and STANDARD formats are pretty JSON formats where " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectList, Collect
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf

class MergeScalarSubqueriesSuite extends PlanTest {

Expand Down Expand Up @@ -597,4 +598,279 @@ class MergeScalarSubqueriesSuite extends PlanTest {

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}

test("Merging subqueries with different filters") {
Seq(true, false).foreach { mergeEnabled =>
withSQLConf(SQLConf.PLAN_MERGE_FILTER_PROPAGATION_ENABLED.key -> s"$mergeEnabled") {
val subquery1 =
ScalarSubquery(testRelation.where($"b" > 0).groupBy()(max($"a").as("max_a")))
val subquery2 =
ScalarSubquery(testRelation.where($"b" < 0).groupBy()(sum($"a").as("sum_a")))
val subquery3 =
ScalarSubquery(testRelation.where($"b" === 0).groupBy()(avg($"a").as("avg_a")))
val originalQuery = testRelation
.select(
subquery1,
subquery2,
subquery3)

val correctAnswer = if (mergeEnabled) {
val mergedSubquery = testRelation
.where($"b" > 0 || $"b" < 0 || $"b" === 0)
.groupBy()(
max($"a", Some($"b" > 0)).as("max_a"),
sum($"a", Some($"b" < 0)).as("sum_a"),
avg($"a", Some($"b" === 0)).as("avg_a"))
.select(CreateNamedStruct(Seq(
Literal("max_a"), $"max_a",
Literal("sum_a"), $"sum_a",
Literal("avg_a"), $"avg_a"
)).as("mergedValue"))
val analyzedMergedSubquery = mergedSubquery.analyze
WithCTE(
testRelation
.select(
extractorExpression(0, analyzedMergedSubquery.output, 0),
extractorExpression(0, analyzedMergedSubquery.output, 1),
extractorExpression(0, analyzedMergedSubquery.output, 2)),
Seq(definitionNode(analyzedMergedSubquery, 0)))
} else {
originalQuery
}

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}
}
}

test("Merging subqueries with different filters - cost limit") {
val subquery1 = ScalarSubquery(testRelation.where($"b" > 0).groupBy()(max($"a").as("max_a")))
val subquery2 = ScalarSubquery(testRelation.where($"b" < 0).groupBy()(sum($"a").as("sum_a")))
val originalQuery = testRelation
.select(
subquery1,
subquery2)
Seq(0, 7, 8, -1).foreach { maxCost =>
CTERelationDef.curId.set(0)
withSQLConf(SQLConf.PLAN_MERGE_FILTER_PROPAGATION_MAX_COST.key -> s"$maxCost") {
// Extra cost of mergedSubquery to subquery1 is 4: `b < 0` and `||` in the merged
// `Filter`, and `b > 0`, `b < 0` in the merged `Aggregate` nodes.
// Extra cost of mergedSubquery to subquery2 is 4: `b > 0` and `||` in the merged
// `Filter`, and `b > 0`, `b < 0` in the merged `Aggregate` nodes.
val correctAnswer = if (maxCost < 0 || maxCost >= 8) {
val mergedSubquery = testRelation
.where($"b" > 0 || $"b" < 0)
.groupBy()(
max($"a", Some($"b" > 0)).as("max_a"),
sum($"a", Some($"b" < 0)).as("sum_a"))
.select(CreateNamedStruct(Seq(
Literal("max_a"), $"max_a",
Literal("sum_a"), $"sum_a"
)).as("mergedValue"))
val analyzedMergedSubquery = mergedSubquery.analyze
WithCTE(
testRelation
.select(
extractorExpression(0, analyzedMergedSubquery.output, 0),
extractorExpression(0, analyzedMergedSubquery.output, 1)),
Seq(definitionNode(analyzedMergedSubquery, 0)))
} else {
originalQuery
}
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}
}
}

test("Merging subqueries with and without filters") {
val subquery1 = ScalarSubquery(testRelation.where($"b" > 0).groupBy()(max($"a").as("max_a")))
val subquery2 = ScalarSubquery(testRelation.groupBy()(count($"a").as("cnt_a")))
val originalQuery = testRelation
.select(
subquery1,
subquery2)
val mergedSubquery = testRelation
.groupBy()(
max($"a", Some($"b" > 0)).as("max_a"),
count($"a").as("cnt_a"))
.select(CreateNamedStruct(Seq(
Literal("max_a"), $"max_a",
Literal("cnt_a"), $"cnt_a"
)).as("mergedValue"))
val analyzedMergedSubquery = mergedSubquery.analyze
val correctAnswer = WithCTE(
testRelation
.select(
extractorExpression(0, analyzedMergedSubquery.output, 0),
extractorExpression(0, analyzedMergedSubquery.output, 1)),
Seq(definitionNode(analyzedMergedSubquery, 0)))

// If one side doesn't contain a filter then physical scans surely overlap so merging is cost
// independent.
Seq(0, -1).foreach { maxCost =>
CTERelationDef.curId.set(0)
withSQLConf(SQLConf.PLAN_MERGE_FILTER_PROPAGATION_MAX_COST.key -> s"$maxCost") {
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}
}
}

test("Merging subqueries without and with filters") {
val subquery1 = ScalarSubquery(testRelation.groupBy()(count($"a").as("cnt_a")))
val subquery2 = ScalarSubquery(testRelation.where($"b" > 0).groupBy()(max($"a").as("max_a")))
val originalQuery = testRelation
.select(
subquery1,
subquery2)

val mergedSubquery = testRelation
.groupBy()(
count($"a").as("cnt_a"),
max($"a", Some($"b" > 0)).as("max_a"))
.select(CreateNamedStruct(Seq(
Literal("cnt_a"), $"cnt_a",
Literal("max_a"), $"max_a"
)).as("mergedValue"))
val analyzedMergedSubquery = mergedSubquery.analyze
val correctAnswer = WithCTE(
testRelation
.select(
extractorExpression(0, analyzedMergedSubquery.output, 0),
extractorExpression(0, analyzedMergedSubquery.output, 1)),
Seq(definitionNode(analyzedMergedSubquery, 0)))

// If one side doesn't contain a filter then physical scans surely overlap so merging is cost
// independent.
Seq(0, -1).foreach { maxCost =>
CTERelationDef.curId.set(0)
withSQLConf(SQLConf.PLAN_MERGE_FILTER_PROPAGATION_MAX_COST.key -> s"$maxCost") {
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}
}
}

test("Merging subqueries with same condition in filter and in having") {
val subquery1 = ScalarSubquery(testRelation.where($"b" > 0).groupBy()(max($"a").as("max_a")))
val subquery2 = ScalarSubquery(testRelation.groupBy()(max($"a", Some($"b" > 0)).as("max_a_2")))
val originalQuery = testRelation
.select(
subquery1,
subquery2)

val mergedSubquery = testRelation
.groupBy()(
max($"a", Some($"b" > 0)).as("max_a"))
.select(CreateNamedStruct(Seq(
Literal("max_a"), $"max_a")).as("mergedValue"))
val analyzedMergedSubquery = mergedSubquery.analyze
val correctAnswer = WithCTE(
testRelation
.select(
extractorExpression(0, analyzedMergedSubquery.output, 0),
extractorExpression(0, analyzedMergedSubquery.output, 0)),
Seq(definitionNode(analyzedMergedSubquery, 0)))

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}

test("Merging subqueries with same condition in having and in filter") {
val subquery1 = ScalarSubquery(testRelation.groupBy()(max($"a", Some($"b" > 0)).as("max_a")))
val subquery2 = ScalarSubquery(testRelation.where($"b" > 0).groupBy()(max($"a").as("max_a_2")))
val originalQuery = testRelation
.select(
subquery1,
subquery2)

val mergedSubquery = testRelation
.groupBy()(
max($"a", Some($"b" > 0)).as("max_a"))
.select(CreateNamedStruct(Seq(
Literal("max_a"), $"max_a")).as("mergedValue"))
val analyzedMergedSubquery = mergedSubquery.analyze
val correctAnswer = WithCTE(
testRelation
.select(
extractorExpression(0, analyzedMergedSubquery.output, 0),
extractorExpression(0, analyzedMergedSubquery.output, 0)),
Seq(definitionNode(analyzedMergedSubquery, 0)))

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}

test("Merging subqueries with different filters, multiple filters propagated") {
val subquery1 = ScalarSubquery(
testRelation.where($"b" > 0).where($"c" === "a").groupBy()(max($"a").as("max_a")))
val subquery2 = ScalarSubquery(
testRelation.where($"b" > 0).where($"c" === "b").groupBy()(avg($"a").as("avg_a")))
val subquery3 = ScalarSubquery(
testRelation.where($"b" < 0).where($"c" === "c").groupBy()(count($"a").as("cnt_a")))
val originalQuery = testRelation
.select(
subquery1,
subquery2,
subquery3)

val mergedSubquery = testRelation
.where($"b" > 0 || $"b" < 0)
.where($"b" > 0 && ($"c" === "a" || $"c" === "b") || $"b" < 0 && $"c" === "c")
.groupBy()(
max($"a", Some($"b" > 0 && $"c" === "a")).as("max_a"),
avg($"a", Some($"b" > 0 && $"c" === "b")).as("avg_a"),
count($"a", Some($"b" < 0 && $"c" === "c")).as("cnt_a"))
.select(CreateNamedStruct(Seq(
Literal("max_a"), $"max_a",
Literal("avg_a"), $"avg_a",
Literal("cnt_a"), $"cnt_a"
)).as("mergedValue"))
val analyzedMergedSubquery = mergedSubquery.analyze
val correctAnswer = WithCTE(
testRelation
.select(
extractorExpression(0, analyzedMergedSubquery.output, 0),
extractorExpression(0, analyzedMergedSubquery.output, 1),
extractorExpression(0, analyzedMergedSubquery.output, 2)),
Seq(definitionNode(analyzedMergedSubquery, 0)))

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}

test("Merging subqueries with different filters, multiple filters propagated 2") {
val subquery1 = ScalarSubquery(
testRelation.where($"c" === "a").where($"b" > 0).groupBy()(max($"a").as("max_a")))
val subquery2 = ScalarSubquery(
testRelation.where($"c" === "b").where($"b" > 0).groupBy()(avg($"a").as("avg_a")))
val subquery3 = ScalarSubquery(
testRelation.where($"c" === "c").where($"b" < 0).groupBy()(count($"a").as("cnt_a")))
val originalQuery = testRelation
.select(
subquery1,
subquery2,
subquery3)

val mergedSubquery = testRelation
.where($"c" === "a" || $"c" === "b" || $"c" === "c")
.where(($"c" === "a" || $"c" === "b") && $"b" > 0 || $"c" === "c" && $"b" < 0)
.groupBy()(
// Note: `b` related conditions are evaluated first despite `c` related ones are lower in
// the original plans. This is because `$"b" > 0` is the same in the first 2 original plans
// so it isn't propagated when we merge them. But later when we merge the 3rd plan, it is.
max($"a", Some($"b" > 0 && $"c" === "a")).as("max_a"),
avg($"a", Some($"b" > 0 && $"c" === "b")).as("avg_a"),
count($"a", Some($"c" === "c" && $"b" < 0)).as("cnt_a"))
.select(CreateNamedStruct(Seq(
Literal("max_a"), $"max_a",
Literal("avg_a"), $"avg_a",
Literal("cnt_a"), $"cnt_a"
)).as("mergedValue"))
val analyzedMergedSubquery = mergedSubquery.analyze
val correctAnswer = WithCTE(
testRelation
.select(
extractorExpression(0, analyzedMergedSubquery.output, 0),
extractorExpression(0, analyzedMergedSubquery.output, 1),
extractorExpression(0, analyzedMergedSubquery.output, 2)),
Seq(definitionNode(analyzedMergedSubquery, 0)))

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class SparkOptimizer(
RewriteDistinctAggregates) :+
Batch("Pushdown Filters from PartitionPruning", fixedPoint,
PushDownPredicates) :+
Batch("Cleanup filters that cannot be pushed down", Once,
Batch("Cleanup filters that cannot be pushed down", FixedPoint(1),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because BooleanSimplification is not idempotent.

CleanupDynamicPruningFilters,
// cleanup the unnecessary TrueLiteral predicates
BooleanSimplification,
Expand Down