Skip to content

Commit

Permalink
KE-42353 [SQL] Optimize join if maximum number of rows on one side is…
Browse files Browse the repository at this point in the history
… 1 (apache#669)
  • Loading branch information
7mming7 authored and frearb committed Sep 19, 2023
1 parent 06d6db4 commit 31c23f6
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
CollapseProject,
RemoveRedundantAliases,
RemoveNoopOperators) :+
Batch("OptimizeOneRowJoin", Once, OptimizeOneRowJoin) :+
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, FILTER, IN_SUBQUERY, LATERAL_JOIN, LIST_SUBQUERY, PLAN_EXPRESSION, SCALAR_SUBQUERY}
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, FILTER, IN_SUBQUERY, JOIN, LATERAL_JOIN, LIST_SUBQUERY, PLAN_EXPRESSION, SCALAR_SUBQUERY}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -756,3 +757,84 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] {
}
}
}

/**
* 1. Rewrite join to filter if one side max row number is 1
*
* {{{
* SELECT t1.* FROM t1 INNER JOIN (SELECT max(c1) AS c1 FROM t) t2 ON t1.c1 = t2.c1 ==>
* SELECT t1.* FROM t1 WHERE t1.c = (SELECT max(c1) AS c1 FROM t)
* }}}
*
* 2. Removes outer join if streamed side max row number is 1
* {{{
* SELECT t1.* FROM t1 LEFT JOIN (SELECT max(c1) AS c1 FROM t) t2 ON t1.c1 = t2.c1 ==>
* SELECT t1.* FROM t1
* }}}
*
* {{{
* SELECT t1.* FROM t1 FULL JOIN (SELECT max(c1) AS c1 FROM t) t2 ==>
* SELECT t1.* FROM t1
* }}}
*
* This rule should be executed before OptimizeSubqueries
*/
object OptimizeOneRowJoin extends Rule[LogicalPlan] with PredicateHelper with JoinSelectionHelper {
private def eliminateRightSide(j: Join): Option[LogicalPlan] = {
j.joinType match {
case _: InnerLike | LeftSemi if j.condition.nonEmpty =>
ExtractEquiJoinKeys.unapply(j) match {
case Some((_, leftKeys, rightKeys, _, _, left, right, _)) =>
val conditions = leftKeys.zipWithIndex.map { case (exp, index) =>
val projectList = Seq(Alias(rightKeys(index), "_joinkey")())
EqualTo(exp, ScalarSubquery(Project(projectList, right)))
}
Some(Filter(conditions.reduceLeft(And), left))
case _ =>
None
}
case LeftOuter =>
Some(j.left)
case FullOuter if j.condition.isEmpty =>
Some(j.left)
case _ =>
None
}
}

private def eliminateLeftSide(j: Join): Option[LogicalPlan] = {
j.joinType match {
case _: InnerLike if j.condition.nonEmpty =>
ExtractEquiJoinKeys.unapply(j) match {
case Some((_, leftKeys, rightKeys, _, _, left, right, _)) =>
val conditions = rightKeys.zipWithIndex.map { case (exp, index) =>
val projectList = Seq(Alias(leftKeys(index), "_joinkey")())
EqualTo(exp, ScalarSubquery(Project(projectList, left)))
}
Some(Filter(conditions.reduceLeft(And), right))
case _ =>
None
}
case RightOuter =>
Some(j.right)
case FullOuter if j.condition.isEmpty =>
Some(j.right)
case _ =>
None
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(JOIN), ruleId) {
case p @ Project(_, j: Join) if j.right.maxRows.contains(1) &&
p.references.subsetOf(j.left.outputSet) =>
eliminateRightSide(j).map(c => p.copy(child = c)).getOrElse(p)

case p @ Project(_, j: Join) if j.left.maxRows.contains(1) &&
p.references.subsetOf(j.right.outputSet) =>
eliminateLeftSide(j).map(c => p.copy(child = c)).getOrElse(p)

case j: Join if j.right.maxRows.contains(1) && j.joinType == LeftSemi =>
eliminateRightSide(j).getOrElse(j)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.optimizer.OptimizeCsvJsonExprs" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeIn" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeOneRowPlan" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeOneRowJoin" ::
"org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeRepartition" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeWindowFunctions" ::
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, ScalarSubquery}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.StringType

class JoinEliminationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", Once,
EliminateSubqueryAliases) ::
Batch("Outer Join Elimination", Once,
OptimizeOneRowJoin,
PushPredicateThroughJoin) :: Nil
}

private val testRelation = LocalRelation($"a".int, $"b".int, $"c".int)
private val testRelation1 = LocalRelation($"d".int, $"e".int, $"f".int)

test("right side max row number is 1 and join condition is exist") {
Seq(Inner, Cross, LeftSemi, LeftOuter, LeftAnti, RightOuter, FullOuter).foreach { joinType =>
val y = testRelation1.groupBy()(max($"d").as("d"))
val originalQuery =
testRelation.as("x").join(y.subquery('y), joinType, Option("x.a".attr === "y.d".attr))
.select($"b")

val correctAnswer = joinType match {
case Inner | Cross | LeftSemi =>
testRelation.as("x")
.where($"a" === ScalarSubquery(Project(Seq(Alias($"d", "_joinkey")()), y)))
.select($"b")
case LeftOuter =>
testRelation.as("x").select($"b")
case _ =>
originalQuery
}

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}
}

test("right side max row number is 1 and join condition is not exist") {
Seq(Inner, Cross, LeftSemi, LeftOuter, LeftAnti, RightOuter, FullOuter).foreach { joinType =>
val y = testRelation1.groupBy()(max($"d").as("d"))
val originalQuery =
testRelation.as("x").join(y.subquery('y), joinType, None)
.select($"b")

val correctAnswer = joinType match {
case LeftOuter | FullOuter =>
testRelation.as("x").select($"b")
case _ =>
originalQuery
}

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}
}

test("right side max row number is 1 and join condition is complex expressions") {
Seq(Inner, Cross, LeftSemi, LeftOuter, LeftAnti, RightOuter, FullOuter).foreach { joinType =>
val y = testRelation1.groupBy()(max($"d").as("d"))
val originalQuery =
testRelation.as("x").join(y.subquery('y), joinType,
Option("x.a".attr.cast(StringType) === "y.d".attr.cast(StringType)))
.select($"b")

val correctAnswer = joinType match {
case Inner | Cross | LeftSemi =>
testRelation.as("x")
.where($"a".cast(StringType) ===
ScalarSubquery(Project(Seq(Alias($"d".cast(StringType), "_joinkey")()), y)))
.select($"b")
case LeftOuter =>
testRelation.as("x").select($"b")
case _ =>
originalQuery
}

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}
}

test("left side max row number is 1 and join condition is exist") {
Seq(Inner, Cross, LeftOuter, RightOuter, FullOuter).foreach { joinType =>
val x = testRelation.groupBy()(max($"a").as("a"))
val originalQuery =
x.as("x").join(testRelation1.as("y"), joinType, Option("x.a".attr === "y.d".attr))
.select($"e")

val correctAnswer = joinType match {
case Inner | Cross =>
testRelation1.as("y")
.where($"d" === ScalarSubquery(Project(Seq(Alias($"a", "_joinkey")()), x)))
.select($"e").analyze
case RightOuter =>
testRelation1.as("y").select($"e")
case _ =>
originalQuery
}

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ class BroadcastExchangeSuite extends SparkPlanTest

test("set broadcastTimeout to -1") {
withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> "-1") {
val df = spark.range(1).toDF()
val df = spark.range(2).toDF()
val joinDF = df.join(broadcast(df), "id")
val broadcastExchangeExec = collect(
joinDF.queryExecution.executedPlan) { case p: BroadcastExchangeExec => p }
assert(broadcastExchangeExec.size == 1, "one and only BroadcastExchangeExec")
assert(joinDF.collect().length == 1)
assert(joinDF.collect().length == 2)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1097,7 +1097,8 @@ class StreamSuite extends StreamTest {
)
require(execPlan != null)

val localLimits = execPlan.collect {
// Use collectWithSubqueries because EliminateJoin may rewrite inner joins to filter
val localLimits = execPlan.collectWithSubqueries {
case l: LocalLimitExec => l
case l: StreamingLocalLimitExec => l
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,8 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging {

test("SPARK-33100: test sql statements with hint in bracketed comment") {
runCliWithin(2.minute)(
"CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES(1, 2) AS t1(k, v);" -> "",
"CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES(2, 1) AS t2(k, v);" -> "",
"CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES(1, 2), (3, 4) AS t1(k, v);" -> "",
"CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES(2, 1), (3, 4) AS t2(k, v);" -> "",
"EXPLAIN SELECT /*+ MERGEJOIN(t1) */ t1.* FROM t1 JOIN t2 ON t1.k = t2.v;" -> "SortMergeJoin",
"EXPLAIN SELECT /* + MERGEJOIN(t1) */ t1.* FROM t1 JOIN t2 ON t1.k = t2.v;"
-> "BroadcastHashJoin"
Expand Down

0 comments on commit 31c23f6

Please sign in to comment.