From 31c23f60b430fd397b4049d1c65338dd32fab50c Mon Sep 17 00:00:00 2001 From: Mingming Ge <7mming7@gmail.com> Date: Wed, 30 Aug 2023 14:08:49 +0800 Subject: [PATCH] KE-42353 [SQL] Optimize join if maximum number of rows on one side is 1 (#669) --- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../sql/catalyst/optimizer/subquery.scala | 84 +++++++++++- .../sql/catalyst/rules/RuleIdCollection.scala | 1 + .../optimizer/JoinEliminationSuite.scala | 127 ++++++++++++++++++ .../execution/BroadcastExchangeSuite.scala | 4 +- .../spark/sql/streaming/StreamSuite.scala | 3 +- .../sql/hive/thriftserver/CliSuite.scala | 4 +- 7 files changed, 218 insertions(+), 6 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f6a47fc93b022..40c781921d011 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -239,6 +239,7 @@ abstract class Optimizer(catalogManager: CatalogManager) CollapseProject, RemoveRedundantAliases, RemoveNoopOperators) :+ + Batch("OptimizeOneRowJoin", Once, OptimizeOneRowJoin) :+ // This batch must be executed after the `RewriteSubquery` batch, which creates joins. Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+ Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 7ef5ef55fabda..74484fc22cd81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -24,10 +24,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, FILTER, IN_SUBQUERY, LATERAL_JOIN, LIST_SUBQUERY, PLAN_EXPRESSION, SCALAR_SUBQUERY} +import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, FILTER, IN_SUBQUERY, JOIN, LATERAL_JOIN, LIST_SUBQUERY, PLAN_EXPRESSION, SCALAR_SUBQUERY} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -756,3 +757,84 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { } } } + +/** + * 1. Rewrite join to filter if one side max row number is 1 + * + * {{{ + * SELECT t1.* FROM t1 INNER JOIN (SELECT max(c1) AS c1 FROM t) t2 ON t1.c1 = t2.c1 ==> + * SELECT t1.* FROM t1 WHERE t1.c = (SELECT max(c1) AS c1 FROM t) + * }}} + * + * 2. Removes outer join if streamed side max row number is 1 + * {{{ + * SELECT t1.* FROM t1 LEFT JOIN (SELECT max(c1) AS c1 FROM t) t2 ON t1.c1 = t2.c1 ==> + * SELECT t1.* FROM t1 + * }}} + * + * {{{ + * SELECT t1.* FROM t1 FULL JOIN (SELECT max(c1) AS c1 FROM t) t2 ==> + * SELECT t1.* FROM t1 + * }}} + * + * This rule should be executed before OptimizeSubqueries + */ +object OptimizeOneRowJoin extends Rule[LogicalPlan] with PredicateHelper with JoinSelectionHelper { + private def eliminateRightSide(j: Join): Option[LogicalPlan] = { + j.joinType match { + case _: InnerLike | LeftSemi if j.condition.nonEmpty => + ExtractEquiJoinKeys.unapply(j) match { + case Some((_, leftKeys, rightKeys, _, _, left, right, _)) => + val conditions = leftKeys.zipWithIndex.map { case (exp, index) => + val projectList = Seq(Alias(rightKeys(index), "_joinkey")()) + EqualTo(exp, ScalarSubquery(Project(projectList, right))) + } + Some(Filter(conditions.reduceLeft(And), left)) + case _ => + None + } + case LeftOuter => + Some(j.left) + case FullOuter if j.condition.isEmpty => + Some(j.left) + case _ => + None + } + } + + private def eliminateLeftSide(j: Join): Option[LogicalPlan] = { + j.joinType match { + case _: InnerLike if j.condition.nonEmpty => + ExtractEquiJoinKeys.unapply(j) match { + case Some((_, leftKeys, rightKeys, _, _, left, right, _)) => + val conditions = rightKeys.zipWithIndex.map { case (exp, index) => + val projectList = Seq(Alias(leftKeys(index), "_joinkey")()) + EqualTo(exp, ScalarSubquery(Project(projectList, left))) + } + Some(Filter(conditions.reduceLeft(And), right)) + case _ => + None + } + case RightOuter => + Some(j.right) + case FullOuter if j.condition.isEmpty => + Some(j.right) + case _ => + None + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(JOIN), ruleId) { + case p @ Project(_, j: Join) if j.right.maxRows.contains(1) && + p.references.subsetOf(j.left.outputSet) => + eliminateRightSide(j).map(c => p.copy(child = c)).getOrElse(p) + + case p @ Project(_, j: Join) if j.left.maxRows.contains(1) && + p.references.subsetOf(j.right.outputSet) => + eliminateLeftSide(j).map(c => p.copy(child = c)).getOrElse(p) + + case j: Join if j.right.maxRows.contains(1) && j.joinType == LeftSemi => + eliminateRightSide(j).getOrElse(j) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 883712b549edb..e3a3ae4d50b81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -123,6 +123,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.optimizer.OptimizeCsvJsonExprs" :: "org.apache.spark.sql.catalyst.optimizer.OptimizeIn" :: "org.apache.spark.sql.catalyst.optimizer.OptimizeOneRowPlan" :: + "org.apache.spark.sql.catalyst.optimizer.OptimizeOneRowJoin" :: "org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" :: "org.apache.spark.sql.catalyst.optimizer.OptimizeRepartition" :: "org.apache.spark.sql.catalyst.optimizer.OptimizeWindowFunctions" :: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala new file mode 100644 index 0000000000000..091eca10d42c6 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, ScalarSubquery} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.StringType + +class JoinEliminationSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueryAliases) :: + Batch("Outer Join Elimination", Once, + OptimizeOneRowJoin, + PushPredicateThroughJoin) :: Nil + } + + private val testRelation = LocalRelation($"a".int, $"b".int, $"c".int) + private val testRelation1 = LocalRelation($"d".int, $"e".int, $"f".int) + + test("right side max row number is 1 and join condition is exist") { + Seq(Inner, Cross, LeftSemi, LeftOuter, LeftAnti, RightOuter, FullOuter).foreach { joinType => + val y = testRelation1.groupBy()(max($"d").as("d")) + val originalQuery = + testRelation.as("x").join(y.subquery('y), joinType, Option("x.a".attr === "y.d".attr)) + .select($"b") + + val correctAnswer = joinType match { + case Inner | Cross | LeftSemi => + testRelation.as("x") + .where($"a" === ScalarSubquery(Project(Seq(Alias($"d", "_joinkey")()), y))) + .select($"b") + case LeftOuter => + testRelation.as("x").select($"b") + case _ => + originalQuery + } + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("right side max row number is 1 and join condition is not exist") { + Seq(Inner, Cross, LeftSemi, LeftOuter, LeftAnti, RightOuter, FullOuter).foreach { joinType => + val y = testRelation1.groupBy()(max($"d").as("d")) + val originalQuery = + testRelation.as("x").join(y.subquery('y), joinType, None) + .select($"b") + + val correctAnswer = joinType match { + case LeftOuter | FullOuter => + testRelation.as("x").select($"b") + case _ => + originalQuery + } + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("right side max row number is 1 and join condition is complex expressions") { + Seq(Inner, Cross, LeftSemi, LeftOuter, LeftAnti, RightOuter, FullOuter).foreach { joinType => + val y = testRelation1.groupBy()(max($"d").as("d")) + val originalQuery = + testRelation.as("x").join(y.subquery('y), joinType, + Option("x.a".attr.cast(StringType) === "y.d".attr.cast(StringType))) + .select($"b") + + val correctAnswer = joinType match { + case Inner | Cross | LeftSemi => + testRelation.as("x") + .where($"a".cast(StringType) === + ScalarSubquery(Project(Seq(Alias($"d".cast(StringType), "_joinkey")()), y))) + .select($"b") + case LeftOuter => + testRelation.as("x").select($"b") + case _ => + originalQuery + } + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("left side max row number is 1 and join condition is exist") { + Seq(Inner, Cross, LeftOuter, RightOuter, FullOuter).foreach { joinType => + val x = testRelation.groupBy()(max($"a").as("a")) + val originalQuery = + x.as("x").join(testRelation1.as("y"), joinType, Option("x.a".attr === "y.d".attr)) + .select($"e") + + val correctAnswer = joinType match { + case Inner | Cross => + testRelation1.as("y") + .where($"d" === ScalarSubquery(Project(Seq(Alias($"a", "_joinkey")()), x))) + .select($"e").analyze + case RightOuter => + testRelation1.as("y").select($"e") + case _ => + originalQuery + } + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala index 7d6306b65ff47..258eca935b0a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala @@ -85,12 +85,12 @@ class BroadcastExchangeSuite extends SparkPlanTest test("set broadcastTimeout to -1") { withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> "-1") { - val df = spark.range(1).toDF() + val df = spark.range(2).toDF() val joinDF = df.join(broadcast(df), "id") val broadcastExchangeExec = collect( joinDF.queryExecution.executedPlan) { case p: BroadcastExchangeExec => p } assert(broadcastExchangeExec.size == 1, "one and only BroadcastExchangeExec") - assert(joinDF.collect().length == 1) + assert(joinDF.collect().length == 2) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index f2031b94231b7..1e471d3038363 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -1097,7 +1097,8 @@ class StreamSuite extends StreamTest { ) require(execPlan != null) - val localLimits = execPlan.collect { + // Use collectWithSubqueries because EliminateJoin may rewrite inner joins to filter + val localLimits = execPlan.collectWithSubqueries { case l: LocalLimitExec => l case l: StreamingLocalLimitExec => l } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index e1840d8622b54..a1be9ad8c49b8 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -594,8 +594,8 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { test("SPARK-33100: test sql statements with hint in bracketed comment") { runCliWithin(2.minute)( - "CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES(1, 2) AS t1(k, v);" -> "", - "CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES(2, 1) AS t2(k, v);" -> "", + "CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES(1, 2), (3, 4) AS t1(k, v);" -> "", + "CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES(2, 1), (3, 4) AS t2(k, v);" -> "", "EXPLAIN SELECT /*+ MERGEJOIN(t1) */ t1.* FROM t1 JOIN t2 ON t1.k = t2.v;" -> "SortMergeJoin", "EXPLAIN SELECT /* + MERGEJOIN(t1) */ t1.* FROM t1 JOIN t2 ON t1.k = t2.v;" -> "BroadcastHashJoin"