From 539782d93af14ed27c6f3a0fc659b13c4f92da41 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 11 Jul 2016 01:04:51 -0700 Subject: [PATCH 1/6] [SPARK-16475][SQL] Broadcast Hint for SQL Queries --- .../spark/sql/catalyst/parser/SqlBase.g4 | 18 ++- .../sql/catalyst/analysis/Analyzer.scala | 60 ++++++++- .../sql/catalyst/analysis/CheckAnalysis.scala | 4 + .../sql/catalyst/parser/AstBuilder.scala | 15 ++- .../plans/logical/basicLogicalOperators.scala | 9 ++ .../sql/catalyst/analysis/AnalysisTest.scala | 1 + .../analysis/SubstituteHintsSuite.scala | 125 ++++++++++++++++++ .../sql/catalyst/parser/PlanParserSuite.scala | 42 ++++++ .../spark/sql/catalyst/SQLBuilder.scala | 79 +++++++++-- .../execution/joins/BroadcastJoinSuite.scala | 96 ++++++++++++++ .../sqlgen/broadcast_hint_generator.sql | 4 + .../broadcast_hint_groupby_having_orderby.sql | 9 ++ .../sqlgen/broadcast_hint_groupingset.sql | 8 ++ .../broadcast_hint_multiple_table_1.sql | 4 + .../broadcast_hint_multiple_table_2.sql | 4 + .../sqlgen/broadcast_hint_rollup.sql | 7 + .../sqlgen/broadcast_hint_single_table_1.sql | 4 + .../sqlgen/broadcast_hint_single_table_2.sql | 4 + .../sqlgen/broadcast_hint_single_table_3.sql | 4 + .../sqlgen/broadcast_hint_window.sql | 6 + .../sqlgen/broadcast_hint_with_filter.sql | 4 + .../broadcast_hint_with_filter_limit.sql | 4 + .../sqlgen/broadcast_join_subquery.sql | 2 +- .../sqlgen/multiple_broadcast_hints.sql | 4 + .../sql/catalyst/LogicalPlanToSQLSuite.scala | 91 +++++++++++++ .../spark/sql/hive/BroadcastHintSuite.scala | 55 ++++++++ 26 files changed, 649 insertions(+), 14 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala create mode 100644 sql/hive/src/test/resources/sqlgen/broadcast_hint_generator.sql create mode 100644 sql/hive/src/test/resources/sqlgen/broadcast_hint_groupby_having_orderby.sql create mode 100644 sql/hive/src/test/resources/sqlgen/broadcast_hint_groupingset.sql create mode 100644 sql/hive/src/test/resources/sqlgen/broadcast_hint_multiple_table_1.sql create mode 100644 sql/hive/src/test/resources/sqlgen/broadcast_hint_multiple_table_2.sql create mode 100644 sql/hive/src/test/resources/sqlgen/broadcast_hint_rollup.sql create mode 100644 sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_1.sql create mode 100644 sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_2.sql create mode 100644 sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_3.sql create mode 100644 sql/hive/src/test/resources/sqlgen/broadcast_hint_window.sql create mode 100644 sql/hive/src/test/resources/sqlgen/broadcast_hint_with_filter.sql create mode 100644 sql/hive/src/test/resources/sqlgen/broadcast_hint_with_filter_limit.sql create mode 100644 sql/hive/src/test/resources/sqlgen/multiple_broadcast_hints.sql create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/BroadcastHintSuite.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index b599a884957a8..4e3dda4b7dd9e 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -364,7 +364,7 @@ querySpecification (RECORDREADER recordReader=STRING)? fromClause? (WHERE where=booleanExpression)?) - | ((kind=SELECT setQuantifier? namedExpressionSeq fromClause? + | ((kind=SELECT hint? setQuantifier? namedExpressionSeq fromClause? | fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?) lateralView* (WHERE where=booleanExpression)? @@ -373,6 +373,16 @@ querySpecification windows?) ; +hint + : '/*+' hintStatement '*/' + ; + +hintStatement + : hintName=identifier + | hintName=identifier '(' parameters+=identifier parameters+=identifier ')' + | hintName=identifier '(' parameters+=identifier (',' parameters+=identifier)* ')' + ; + fromClause : FROM relation (',' relation)* lateralView* ; @@ -996,8 +1006,12 @@ SIMPLE_COMMENT : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN) ; +BRACKETED_EMPTY_COMMENT + : '/**/' -> channel(HIDDEN) + ; + BRACKETED_COMMENT - : '/*' .*? '*/' -> channel(HIDDEN) + : '/*' ~[+] .*? '*/' -> channel(HIDDEN) ; WS diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5011f2fdbf9b7..b63bff942fd84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -78,7 +78,8 @@ class Analyzer( CTESubstitution, WindowsSubstitution, EliminateUnions, - new SubstituteUnresolvedOrdinals(conf)), + new SubstituteUnresolvedOrdinals(conf), + SubstituteHints), Batch("Resolution", fixedPoint, ResolveTableValuedFunctions :: ResolveRelations :: @@ -1795,6 +1796,63 @@ class Analyzer( } } + /** + * Substitute Hints. + * - BROADCAST/BROADCASTJOIN/MAPJOIN match the closest table with the given name parameters. + * + * This rule substitutes `UnresolvedRelation`s in `Substitute` batch before `ResolveRelations` + * rule is applied. Here are two reasons. + * - To support `MetastoreRelation` in Hive module. + * - To reduce the effect of `Hint` on the other rules. + * + * After this rule, it is guaranteed that there exists no unknown `Hint` in the plan. + * All new `Hint`s should be transformed into concrete Hint classes `BroadcastHint` here. + */ + object SubstituteHints extends Rule[LogicalPlan] { + val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN") + + import scala.collection.mutable.Set + private def appendAllDescendant(set: Set[LogicalPlan], plan: LogicalPlan): Unit = { + set += plan + plan.children.foreach { child => appendAllDescendant(set, child) } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case logical: LogicalPlan => logical transformDown { + case h @ Hint(name, parameters, child) if BROADCAST_HINT_NAMES.contains(name.toUpperCase) => + var resolvedChild = child + for (table <- parameters) { + var stop = false + val skipNodeSet = scala.collection.mutable.Set.empty[LogicalPlan] + resolvedChild = resolvedChild.transformDown { + case n if skipNodeSet.contains(n) => + skipNodeSet -= n + n + case p @ Project(_, _) if p != resolvedChild => + appendAllDescendant(skipNodeSet, p) + skipNodeSet -= p + p + case r @ BroadcastHint(UnresolvedRelation(t, _)) + if !stop && resolver(t.table, table) => + stop = true + r + case r @ UnresolvedRelation(t, alias) if !stop && resolver(t.table, table) => + stop = true + if (alias.isDefined) { + SubqueryAlias(alias.get, BroadcastHint(r.copy(alias = None)), None) + } else { + BroadcastHint(r) + } + } + } + resolvedChild + + // Remove unrecognized hints + case Hint(name, _, child) => child + } + } + } + /** * Check and add proper window frames for all window functions. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 3455a567b7786..622a4d9857981 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -404,6 +404,10 @@ trait CheckAnalysis extends PredicateHelper { |in operator ${operator.simpleString} """.stripMargin) + case Hint(_, _, _) => + throw new IllegalStateException( + "logical hint operator should have been removed by analyzer") + case _ => // Analysis successful! } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 4b151c81d8f8b..516a93be49002 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -387,7 +387,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } // Window - withDistinct.optionalMap(windows)(withWindows) + val withWindow = withDistinct.optionalMap(windows)(withWindows) + + // Hint + withWindow.optionalMap(ctx.hint)(withHints) } } @@ -527,6 +530,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } } + /** + * Add a Hint to a logical plan. + */ + private def withHints( + ctx: HintContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val stmt = ctx.hintStatement + Hint(stmt.hintName.getText, stmt.parameters.asScala.map(_.getText), query) + } + /** * Add a [[Generate]] (Lateral View) to a logical plan. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 65ceab2ce27b1..a31ab9576df3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -346,6 +346,15 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override lazy val statistics: Statistics = super.statistics.copy(isBroadcastable = true) } +/** + * A general hint for the child. + * A pair of (name, parameters). + */ +case class Hint(name: String, parameters: Seq[String], child: LogicalPlan) extends UnaryNode { + override lazy val resolved: Boolean = false + override def output: Seq[Attribute] = child.output +} + /** * Options for writing new data into a table. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 3acb261800c0e..0f059b9591460 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -32,6 +32,7 @@ trait AnalysisTest extends PlanTest { val conf = new SimpleCatalystConf(caseSensitive) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true) + catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true) new Analyzer(catalog, conf) { override val extendedResolutionRules = EliminateSubqueryAliases :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala new file mode 100644 index 0000000000000..64e85111c43df --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ + +class SubstituteHintsSuite extends AnalysisTest { + import org.apache.spark.sql.catalyst.analysis.TestRelations._ + + val a = testRelation.output(0) + val b = testRelation2.output(0) + + test("case-sensitive or insensitive parameters") { + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), + BroadcastHint(testRelation), + caseSensitive = false) + + checkAnalysis( + Hint("MAPJOIN", Seq("table"), table("TaBlE")), + BroadcastHint(testRelation), + caseSensitive = false) + + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), + BroadcastHint(testRelation)) + + checkAnalysis( + Hint("MAPJOIN", Seq("table"), table("TaBlE")), + testRelation) + } + + test("single hint") { + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE").select(a)), + BroadcastHint(testRelation).select(a)) + + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE").as("t").join(table("TaBlE2").as("u")).select(a)), + BroadcastHint(testRelation).join(testRelation2).select(a)) + + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE2"), + table("TaBlE").as("t").join(table("TaBlE2").as("u")).select(a)), + testRelation.join(BroadcastHint(testRelation2)).select(a)) + } + + test("single hint with multiple parameters") { + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE", "TaBlE"), + table("TaBlE").as("t").join(table("TaBlE2").as("u")).select(a)), + BroadcastHint(testRelation).join(testRelation2).select(a)) + + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE", "TaBlE2"), + table("TaBlE").as("t").join(table("TaBlE2").as("u")).select(a)), + BroadcastHint(testRelation).join(BroadcastHint(testRelation2)).select(a)) + } + + test("duplicated nested hints are transformed into one") { + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), + Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE").as("t").select('a)) + .join(table("TaBlE2").as("u")).select(a)), + BroadcastHint(testRelation).select(a).join(testRelation2).select(a)) + + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE2"), + table("TaBlE").as("t").select(a) + .join(Hint("MAPJOIN", Seq("TaBlE2"), table("TaBlE2").as("u").select(b))).select(a)), + testRelation.select(a).join(BroadcastHint(testRelation2).select(b)).select(a)) + } + + test("distinct nested two hints are handled separately") { + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE2"), + Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE").as("t").select(a)) + .join(table("TaBlE2").as("u")).select(a)), + BroadcastHint(testRelation).select(a).join(BroadcastHint(testRelation2)).select(a)) + + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), + table("TaBlE").as("t") + .join(Hint("MAPJOIN", Seq("TaBlE2"), table("TaBlE2").as("u").select(b))).select(a)), + BroadcastHint(testRelation).join(BroadcastHint(testRelation2).select(b)).select(a)) + } + + test("deep self join") { + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), + table("TaBlE").join(table("TaBlE")).join(table("TaBlE")).join(table("TaBlE")).select(a)), + BroadcastHint(testRelation).join(testRelation).join(testRelation).join(testRelation) + .select(a)) + } + + test("subquery should be ignored") { + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), + table("TaBlE").select(a).as("x").join(table("TaBlE")).select(a)), + testRelation.select(a).join(BroadcastHint(testRelation)).select(a)) + + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), + table("TaBlE").as("t").select(a).as("x") + .join(table("TaBlE2").as("t2")).select(a)), + testRelation.select(a).join(testRelation2).select(a)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 7400f3430e99c..134f792ff108d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -503,4 +503,46 @@ class PlanParserSuite extends PlanTest { assertEqual("select a, b from db.c where x !> 1", table("db", "c").where('x <= 1).select('a, 'b)) } + + test("select hint syntax") { + // Hive compatibility: Missing parameter raises ParseException. + val m = intercept[ParseException] { + parsePlan("SELECT /*+ HINT() */ * FROM t") + }.getMessage + assert(m.contains("no viable alternative at input")) + + // Hive compatibility: No database. + val m2 = intercept[ParseException] { + parsePlan("SELECT /*+ MAPJOIN(default.t) */ * from default.t") + }.getMessage + assert(m2.contains("no viable alternative at input")) + + comparePlans( + parsePlan("SELECT /*+ HINT */ * FROM t"), + Hint("HINT", Seq.empty, table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ BROADCASTJOIN(u) */ * FROM t"), + Hint("BROADCASTJOIN", Seq("u"), table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ MAPJOIN(u) */ * FROM t"), + Hint("MAPJOIN", Seq("u"), table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t"), + Hint("STREAMTABLE", Seq("a", "b", "c"), table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ INDEX(t emp_job_ix) */ * FROM t"), + Hint("INDEX", Seq("t", "emp_job_ix"), table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`"), + Hint("MAPJOIN", Seq("default.t"), table("default.t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"), + Hint("MAPJOIN", Seq("t"), table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala index 6f821f80cc4c5..f12214a206812 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala @@ -172,6 +172,10 @@ class SQLBuilder private ( toSQL(p.right), p.condition.map(" ON " + _.sql).getOrElse("")) + // Hint on aliased table should be matched directly. Otherwise, this Hint will be propagate up. + case h @ Hint(_, _, s @ SubqueryAlias(alias, p @ Project(_, _: SQLTable), _)) => + build("(" + toSQL(p.copy(child = h.copy(child = p.child))) + ")", "AS", s.alias) + case SQLTable(database, table, _, sample) => val qualifiedName = s"${quoteIdentifier(database)}.${quoteIdentifier(table)}" sample.map { case (lowerBound, upperBound) => @@ -214,6 +218,9 @@ class SQLBuilder private ( case OneRowRelation => "" + case Hint(_, _, child) => + toSQL(child) + case _ => throw new UnsupportedOperationException(s"unsupported plan $node") } @@ -226,14 +233,24 @@ class SQLBuilder private ( private def build(segments: String*): String = segments.map(_.trim).filter(_.nonEmpty).mkString(" ") - private def projectToSQL(plan: Project, isDistinct: Boolean): String = { - build( - "SELECT", - if (isDistinct) "DISTINCT" else "", - plan.projectList.map(_.sql).mkString(", "), - if (plan.child == OneRowRelation) "" else "FROM", - toSQL(plan.child) - ) + private def projectToSQL(plan: Project, isDistinct: Boolean): String = plan match { + case p @ Project(projectList, Hint("BROADCAST", tables, child)) => + build( + "SELECT", + if (tables.nonEmpty) s"/*+ MAPJOIN(${tables.mkString(", ")}) */" else "", + if (isDistinct) "DISTINCT" else "", + plan.projectList.map(_.sql).mkString(", "), + if (child == OneRowRelation) "" else "FROM", + toSQL(child) + ) + case _ => + build( + "SELECT", + if (isDistinct) "DISTINCT" else "", + plan.projectList.map(_.sql).mkString(", "), + if (plan.child == OneRowRelation) "" else "FROM", + toSQL(plan.child) + ) } private def scriptTransformationToSQL(plan: ScriptTransformation): String = { @@ -431,7 +448,9 @@ class SQLBuilder private ( // Insert sub queries on top of operators that need to appear after FROM clause. AddSubquery, // Reconstruct subquery expressions. - ConstructSubqueryExpressions + ConstructSubqueryExpressions, + // Normalize BroadcastHints to reconstruct hint comments. + NormalizeBroadcastHint ) ) @@ -444,6 +463,46 @@ class SQLBuilder private ( } } + /** + * Merge and move upward to the nearest Project. + * A broadcast hint comment is scattered into multiple nodes inside the plan, and the + * information of BroadcastHint resides its current position inside the plan. In order to + * reconstruct broadcast hint comment, we need to pack the information of BroadcastHint into + * Hint("BROADCAST", _, _) and collect them up by moving upward to the nearest Project node. + */ + object NormalizeBroadcastHint extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + // Capture the broadcasted information and store it in Hint. + case BroadcastHint(child @ SubqueryAlias(_, Project(_, SQLTable(_, table, _, _)), _)) => + Hint("BROADCAST", Seq(table), child) + + // Nearest Project is found. + case p @ Project(_, Hint(_, _, _)) => p + + // Merge BROADCAST hints up to the nearest Project. + case Hint("BROADCAST", params1, h @ Hint("BROADCAST", params2, _)) => + h.copy(parameters = params1 ++ params2) + case j @ Join(h1 @ Hint("BROADCAST", p1, left), h2 @ Hint("BROADCAST", p2, right), _, _) => + h1.copy(parameters = p1 ++ p2, child = j.copy(left = left, right = right)) + + // Bubble up BROADCAST hints to the nearest Project. + case j @ Join(h @ Hint("BROADCAST", _, hintChild), _, _, _) => + h.copy(child = j.copy(left = hintChild)) + case j @ Join(_, h @ Hint("BROADCAST", _, hintChild), _, _) => + h.copy(child = j.copy(right = hintChild)) + + // Other UnaryNodes are bypassed. + case u: UnaryNode + if u.child.isInstanceOf[Hint] && u.child.asInstanceOf[Hint].name.equals("BROADCAST") => + val hint = u.child.asInstanceOf[Hint] + hint.copy(child = u.withNewChildren(Seq(hint.child))) + + // Other binary(CoGroup/Intersect/Except) and Union are ignored. + // - CoGroup is not used in SQL. + // - Intersect/Except/Union have Project nodes inside. + } + } + object RemoveSubqueriesAboveSQLTable extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case SubqueryAlias(_, t @ ExtractSQLTable(_), _) => t @@ -575,6 +634,8 @@ class SQLBuilder private ( case _: SQLTable => plan case _: Generate => plan case OneRowRelation => plan + case _: BroadcastHint => plan + case _: Hint => plan case _ => addSubquery(plan) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 83db81ea3f1c2..ce4bcd66e2440 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,6 +22,8 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.functions._ @@ -156,6 +158,100 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } + test("Broadcast Hint") { + import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, Join} + + spark.range(10).createOrReplaceTempView("t") + spark.range(10).createOrReplaceTempView("u") + + for (name <- Seq("BROADCAST", "BROADCASTJOIN", "MAPJOIN")) { + val plan1 = sql(s"SELECT /*+ $name(t) */ * FROM t JOIN u ON t.id = u.id").queryExecution + .optimizedPlan + val plan2 = sql(s"SELECT /*+ $name(u) */ * FROM t JOIN u ON t.id = u.id").queryExecution + .optimizedPlan + val plan3 = sql(s"SELECT /*+ $name(v) */ * FROM t JOIN u ON t.id = u.id").queryExecution + .optimizedPlan + + assert(plan1.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) + assert(!plan1.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) + assert(!plan2.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) + assert(plan2.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) + assert(!plan3.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) + assert(!plan3.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) + } + } + + test("Broadcast Hint matches the nearest one") { + val tbl_a = spark.range(10) + val tbl_b = spark.range(20) + val tbl_c = spark.range(30) + + tbl_a.createOrReplaceTempView("tbl_a") + tbl_b.createOrReplaceTempView("tbl_b") + tbl_c.createOrReplaceTempView("tbl_c") + + val plan = sql( + """SELECT /*+ MAPJOIN(tbl_b) */ + | * + |FROM tbl_a A + | JOIN tbl_b B + | ON B.id = A.id + | JOIN (SELECT XA.id + | FROM tbl_b XA + | LEFT SEMI JOIN tbl_c XB + | ON XB.id = XA.id) C + | ON C.id = A.id + """.stripMargin).queryExecution.analyzed + + val correct_answer = + SubqueryAlias("A", tbl_a.logicalPlan, Some(TableIdentifier("tbl_a"))) + .join(SubqueryAlias("B", broadcast(SubqueryAlias("tbl_b", tbl_b.logicalPlan, + Some(TableIdentifier("tbl_b")))).logicalPlan, None), $"B.id" === $"A.id", "inner") + .join(SubqueryAlias("XA", tbl_b.logicalPlan, Some(TableIdentifier("tbl_b"))) + .join(SubqueryAlias("XB", tbl_c.logicalPlan, Some(TableIdentifier("tbl_c"))), + $"XB.id" === $"XA.id", "leftsemi") + .select("XA.id").as("C"), $"C.id" === $"A.id", "inner") + .select(col("*")).logicalPlan + + comparePlans(plan, correct_answer) + } + + test("Nested Broadcast Hint") { + val tbl_a = spark.range(10) + val tbl_b = spark.range(20) + val tbl_c = spark.range(30) + + tbl_a.createOrReplaceTempView("tbl_a") + tbl_b.createOrReplaceTempView("tbl_b") + tbl_c.createOrReplaceTempView("tbl_c") + + val plan = sql( + """SELECT /*+ MAPJOIN(tbl_a, tbl_a) */ + | * + |FROM tbl_a A + | JOIN tbl_b B + | ON B.id = A.id + | JOIN (SELECT /*+ MAPJOIN(tbl_c) */ + | XA.id + | FROM tbl_b XA + | LEFT SEMI JOIN tbl_c XB + | ON XB.id = XA.id) C + | ON C.id = A.id + """.stripMargin).queryExecution.analyzed + + val correct_answer = + broadcast(SubqueryAlias("tbl_a", tbl_a.logicalPlan, Some(TableIdentifier("tbl_a")))).as("A") + .join(SubqueryAlias("B", tbl_b.logicalPlan, Some(TableIdentifier("tbl_b"))), + $"B.id" === $"A.id", "inner") + .join(SubqueryAlias("XA", tbl_b.logicalPlan, Some(TableIdentifier("tbl_b"))) + .join(broadcast(SubqueryAlias("tbl_c", tbl_c.logicalPlan, Some(TableIdentifier("tbl_c")))) + .as("XB"), $"XB.id" === $"XA.id", "leftsemi") + .select("XA.id").as("C"), $"C.id" === $"A.id", "inner") + .select(col("*")).logicalPlan + + comparePlans(plan, correct_answer) + } + test("join key rewritten") { val l = Literal(1L) val i = Literal(2) diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_generator.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_generator.sql new file mode 100644 index 0000000000000..dbf8ff55dae92 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_generator.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT * FROM (SELECT /*+ MAPJOIN(parquet_t0) */ EXPLODE(ARRAY(1,2,3)) FROM parquet_t0) T +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `col` FROM (SELECT `gen_attr_0` FROM (SELECT /*+ MAPJOIN(parquet_t0) */ `gen_attr_0` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`parquet_t0`) AS gen_subquery_0 LATERAL VIEW explode(array(1, 2, 3)) gen_subquery_1 AS `gen_attr_0`) AS T) AS T diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_groupby_having_orderby.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_groupby_having_orderby.sql new file mode 100644 index 0000000000000..8cbd399f9ee6d --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_groupby_having_orderby.sql @@ -0,0 +1,9 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t0) */ * +FROM parquet_t0 +WHERE id > 0 +GROUP BY id +HAVING count(*) > 0 +ORDER BY id +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT /*+ MAPJOIN(parquet_t0) */ `gen_attr_0` FROM (SELECT `gen_attr_0`, count(1) AS `gen_attr_1` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0 WHERE (`gen_attr_0` > CAST(0 AS BIGINT)) GROUP BY `gen_attr_0` HAVING (`gen_attr_1` > CAST(0 AS BIGINT))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC NULLS FIRST) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_groupingset.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_groupingset.sql new file mode 100644 index 0000000000000..9d670dd2b1691 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_groupingset.sql @@ -0,0 +1,8 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t1) */ + count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id() AS k3 +FROM parquet_t1 +GROUP BY key % 5, key - 5 +GROUPING SETS (key % 5, key - 5) +-------------------------------------------------------------------------------- +SELECT `gen_attr_3` AS `cnt`, `gen_attr_4` AS `k1`, `gen_attr_5` AS `k2`, `gen_attr_6` AS `k3` FROM (SELECT count(1) AS `gen_attr_3`, (`gen_attr_7` % CAST(5 AS BIGINT)) AS `gen_attr_4`, (`gen_attr_7` - CAST(5 AS BIGINT)) AS `gen_attr_5`, grouping_id() AS `gen_attr_6` FROM (SELECT /*+ MAPJOIN(parquet_t1) */ `key` AS `gen_attr_7`, `value` AS `gen_attr_8` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY (`gen_attr_7` % CAST(5 AS BIGINT)), (`gen_attr_7` - CAST(5 AS BIGINT)) GROUPING SETS(((`gen_attr_7` % CAST(5 AS BIGINT))), ((`gen_attr_7` - CAST(5 AS BIGINT))))) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_multiple_table_1.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_multiple_table_1.sql new file mode 100644 index 0000000000000..889fef83522d2 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_multiple_table_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0, parquet_t1 WHERE id < key +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id`, `gen_attr_1` AS `key`, `gen_attr_2` AS `value` FROM (SELECT /*+ MAPJOIN(parquet_t0) */ `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_1`, `value` AS `gen_attr_2` FROM `default`.`parquet_t1`) AS gen_subquery_1 WHERE (`gen_attr_0` < `gen_attr_1`)) AS gen_subquery_2 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_multiple_table_2.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_multiple_table_2.sql new file mode 100644 index 0000000000000..15911769678cf --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_multiple_table_2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t1) */ * FROM parquet_t0, parquet_t1 WHERE id < key +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id`, `gen_attr_1` AS `key`, `gen_attr_2` AS `value` FROM (SELECT /*+ MAPJOIN(parquet_t1) */ `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_1`, `value` AS `gen_attr_2` FROM `default`.`parquet_t1`) AS gen_subquery_1 WHERE (`gen_attr_0` < `gen_attr_1`)) AS gen_subquery_2 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_rollup.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_rollup.sql new file mode 100644 index 0000000000000..f40fcb6731b11 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_rollup.sql @@ -0,0 +1,7 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t1) */ + count(*) as cnt, key%5, grouping_id() +FROM parquet_t1 +GROUP BY key % 5 WITH ROLLUP +-------------------------------------------------------------------------------- +SELECT `gen_attr_2` AS `cnt`, `gen_attr_3` AS `(key % CAST(5 AS BIGINT))`, `gen_attr_4` AS `grouping_id()` FROM (SELECT count(1) AS `gen_attr_2`, (`gen_attr_5` % CAST(5 AS BIGINT)) AS `gen_attr_3`, grouping_id() AS `gen_attr_4` FROM (SELECT /*+ MAPJOIN(parquet_t1) */ `key` AS `gen_attr_5`, `value` AS `gen_attr_6` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY (`gen_attr_5` % CAST(5 AS BIGINT)) GROUPING SETS(((`gen_attr_5` % CAST(5 AS BIGINT))), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_1.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_1.sql new file mode 100644 index 0000000000000..6a4c16470d1f5 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT /*+ MAPJOIN(parquet_t0) */ `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_2.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_2.sql new file mode 100644 index 0000000000000..8ef91e82b518e --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t0, parquet_t0) */ * FROM parquet_t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT /*+ MAPJOIN(parquet_t0) */ `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_3.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_3.sql new file mode 100644 index 0000000000000..9cb48ff62f7d0 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_3.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0 as a +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT /*+ MAPJOIN(parquet_t0) */ `gen_attr_0` FROM ((SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS a) AS a diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_window.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_window.sql new file mode 100644 index 0000000000000..640045110bda8 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_window.sql @@ -0,0 +1,6 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t1) */ + x.key, MAX(y.key) OVER (PARTITION BY x.key % 5 ORDER BY x.key) +FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `max(key) OVER (PARTITION BY (key % CAST(5 AS BIGINT)) ORDER BY key ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT gen_subquery_2.`gen_attr_0`, gen_subquery_2.`gen_attr_2`, gen_subquery_2.`gen_attr_3`, max(`gen_attr_2`) OVER (PARTITION BY `gen_attr_3` ORDER BY `gen_attr_0` ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_1` FROM (SELECT /*+ MAPJOIN(parquet_t1) */ `gen_attr_0`, `gen_attr_2`, (`gen_attr_0` % CAST(5 AS BIGINT)) AS `gen_attr_3` FROM ((SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_4` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS x INNER JOIN (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_5` FROM `default`.`parquet_t1`) AS gen_subquery_1 ON (`gen_attr_0` = `gen_attr_2`)) AS gen_subquery_2) AS gen_subquery_3) AS x diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_with_filter.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_with_filter.sql new file mode 100644 index 0000000000000..a413f25b27d0a --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_with_filter.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0 WHERE id < 10 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT /*+ MAPJOIN(parquet_t0) */ `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0 WHERE (`gen_attr_0` < CAST(10 AS BIGINT))) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_with_filter_limit.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_with_filter_limit.sql new file mode 100644 index 0000000000000..2671b60c60e7f --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_with_filter_limit.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0 WHERE id < 10 LIMIT 10 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT /*+ MAPJOIN(parquet_t0) */ `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0 WHERE (`gen_attr_0` < CAST(10 AS BIGINT)) LIMIT 10) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql b/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql index 3de4f8a059965..743b48f8e142f 100644 --- a/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql +++ b/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql @@ -5,4 +5,4 @@ FROM (SELECT x.key as key1, x.value as value1, y.key as key2, y.value as value2 JOIN srcpart z ON (subq.key1 = z.key and z.ds='2008-04-08' and z.hr=11) ORDER BY subq.key1, z.value -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key1`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_7` AS `gen_attr_6`, `gen_attr_9` AS `gen_attr_8`, `gen_attr_11` AS `gen_attr_10` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_7` FROM `default`.`src1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_9`, `value` AS `gen_attr_11` FROM `default`.`src`) AS gen_subquery_1 ON (`gen_attr_5` = `gen_attr_9`)) AS subq INNER JOIN (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_1`, `ds` AS `gen_attr_3`, `hr` AS `gen_attr_4` FROM `default`.`srcpart`) AS gen_subquery_2 ON (((`gen_attr_0` = `gen_attr_2`) AND (`gen_attr_3` = '2008-04-08')) AND (CAST(`gen_attr_4` AS DOUBLE) = CAST(11 AS DOUBLE))) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST) AS gen_subquery_3 +SELECT `gen_attr_0` AS `key1`, `gen_attr_1` AS `value` FROM (SELECT /*+ MAPJOIN(srcpart) */ `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_7` AS `gen_attr_6`, `gen_attr_9` AS `gen_attr_8`, `gen_attr_11` AS `gen_attr_10` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_7` FROM `default`.`src1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_9`, `value` AS `gen_attr_11` FROM `default`.`src`) AS gen_subquery_1 ON (`gen_attr_5` = `gen_attr_9`)) AS subq INNER JOIN ((SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_1`, `ds` AS `gen_attr_3`, `hr` AS `gen_attr_4` FROM `default`.`srcpart`) AS gen_subquery_2) AS z ON (((`gen_attr_0` = `gen_attr_2`) AND (`gen_attr_3` = '2008-04-08')) AND (CAST(`gen_attr_4` AS DOUBLE) = CAST(11 AS DOUBLE))) ORDER BY `gen_attr_0` ASC NULLS FIRST, `gen_attr_1` ASC NULLS FIRST) AS gen_subquery_3 diff --git a/sql/hive/src/test/resources/sqlgen/multiple_broadcast_hints.sql b/sql/hive/src/test/resources/sqlgen/multiple_broadcast_hints.sql new file mode 100644 index 0000000000000..537bb4ffd9689 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/multiple_broadcast_hints.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t0, parquet_t1) */ * FROM parquet_t0, parquet_t1 WHERE id < key +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id`, `gen_attr_1` AS `key`, `gen_attr_2` AS `value` FROM (SELECT /*+ MAPJOIN(parquet_t0, parquet_t1) */ `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_1`, `value` AS `gen_attr_2` FROM `default`.`parquet_t1`) AS gen_subquery_1 WHERE (`gen_attr_0` < `gen_attr_1`)) AS gen_subquery_2 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index 12d18dc87ceb4..722028605a5b1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -1150,6 +1150,97 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { "inline_tables") } + test("broadcast hint on single table") { + checkSQL("SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0", + "broadcast_hint_single_table_1") + + checkSQL("SELECT /*+ MAPJOIN(parquet_t0, parquet_t0) */ * FROM parquet_t0", + "broadcast_hint_single_table_2") + + checkSQL( + "SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0 as a", + "broadcast_hint_single_table_3") + } + + test("broadcast hint on multiple tables") { + checkSQL( + "SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0, parquet_t1 WHERE id < key", + "broadcast_hint_multiple_table_1") + checkSQL( + "SELECT /*+ MAPJOIN(parquet_t1) */ * FROM parquet_t0, parquet_t1 WHERE id < key", + "broadcast_hint_multiple_table_2") + } + + test("multiple broadcast hints on multiple tables") { + checkSQL( + "SELECT /*+ MAPJOIN(parquet_t0, parquet_t1) */ * FROM parquet_t0, parquet_t1 WHERE id < key", + "multiple_broadcast_hints") + } + + test("broadcast hint with filter") { + checkSQL( + "SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0 WHERE id < 10", + "broadcast_hint_with_filter") + } + + test("broadcast hint with filter/limit") { + checkSQL( + "SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0 WHERE id < 10 LIMIT 10", + "broadcast_hint_with_filter_limit") + } + + test("broadcast hint with generator") { + checkSQL( + "SELECT * FROM (SELECT /*+ MAPJOIN(parquet_t0) */ EXPLODE(ARRAY(1,2,3)) FROM parquet_t0) T", + "broadcast_hint_generator") + } + + test("broadcast hint with groupby/having/orderby") { + checkSQL( + """ + |SELECT /*+ MAPJOIN(parquet_t0) */ * + |FROM parquet_t0 + |WHERE id > 0 + |GROUP BY id + |HAVING count(*) > 0 + |ORDER BY id + """.stripMargin, + "broadcast_hint_groupby_having_orderby") + } + + test("broadcast hint with window") { + checkSQL( + """ + |SELECT /*+ MAPJOIN(parquet_t1) */ + | x.key, MAX(y.key) OVER (PARTITION BY x.key % 5 ORDER BY x.key) + |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key + """.stripMargin, + "broadcast_hint_window") + } + + test("broadcast hint with rollup") { + checkSQL( + """ + |SELECT /*+ MAPJOIN(parquet_t1) */ + | count(*) as cnt, key%5, grouping_id() + |FROM parquet_t1 + |GROUP BY key % 5 WITH ROLLUP + """.stripMargin, + "broadcast_hint_rollup") + } + + test("broadcast hint with grouping sets") { + checkSQL( + s""" + |SELECT /*+ MAPJOIN(parquet_t1) */ + | count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id() AS k3 + |FROM parquet_t1 + |GROUP BY key % 5, key - 5 + |GROUPING SETS (key % 5, key - 5) + """.stripMargin, + "broadcast_hint_groupingset") + } + test("SPARK-17750 - interval arithmetic") { withTable("dates") { sql("create table dates (ts timestamp)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/BroadcastHintSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/BroadcastHintSuite.scala new file mode 100644 index 0000000000000..928064a95fec3 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/BroadcastHintSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, Join} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils + +class BroadcastHintSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + test("broadcast hint on Hive table") { + withTable("hive_t", "hive_u") { + spark.sql("CREATE TABLE hive_t(a int)") + spark.sql("CREATE TABLE hive_u(b int)") + + val hive_t = spark.table("hive_t").queryExecution.analyzed + val hive_u = spark.table("hive_u").queryExecution.analyzed + + val plan = spark.sql("SELECT /*+ MAPJOIN(hive_t) */ * FROM hive_t, hive_u") + .queryExecution.analyzed + + assert(plan.collectFirst { + case BroadcastHint(MetastoreRelation(_, "hive_t")) => true + }.isDefined) + assert(plan.collectFirst { + case Join(_, MetastoreRelation(_, "hive_u"), _, _) => true + }.isDefined) + + val plan2 = spark.sql("SELECT /*+ MAPJOIN(hive_u) */ a FROM hive_t, hive_u") + .queryExecution.analyzed + + assert(plan2.collectFirst { + case BroadcastHint(MetastoreRelation(_, "hive_u")) => true + }.isDefined) + assert(plan2.collectFirst { + case Join(MetastoreRelation(_, "hive_t"), _, _, _) => true + }.isDefined) + } + } +} From c702e3ee8ee41ede216eab1209fe1e9374a2c301 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 14 Feb 2017 13:28:35 +0100 Subject: [PATCH 2/6] Get rid of the merge leftover. --- .../plans/logical/basicLogicalOperators.scala | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 5fc11f1c4f29c..4d696c0a3fba6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -371,20 +371,6 @@ case class Hint(name: String, parameters: Seq[String], child: LogicalPlan) exten override def output: Seq[Attribute] = child.output } -/** - * Options for writing new data into a table. - * - * @param enabled whether to overwrite existing data in the table. - * @param specificPartition only data in the specified partition will be overwritten. - */ -case class OverwriteOptions( - enabled: Boolean, - specificPartition: Option[CatalogTypes.TablePartitionSpec] = None) { - if (specificPartition.isDefined) { - assert(enabled, "Overwrite must be enabled when specifying a partition to overwrite.") - } -} - /** * Insert some data into a table. Note that this plan is unresolved and has to be replaced by the * concrete implementations during analysis. From a095df370aedbfd9a14e7761a9da18365cfd2bbf Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 14 Feb 2017 14:01:19 +0100 Subject: [PATCH 3/6] Move rule out to its own file. --- .../sql/catalyst/analysis/Analyzer.scala | 59 +------------ .../catalyst/analysis/SubstituteHints.scala | 83 +++++++++++++++++++ 2 files changed, 84 insertions(+), 58 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 90bea5684eb44..d37f16ca470d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -119,7 +119,7 @@ class Analyzer( WindowsSubstitution, EliminateUnions, new SubstituteUnresolvedOrdinals(conf), - SubstituteHints), + new SubstituteHints(conf)), Batch("Resolution", fixedPoint, ResolveTableValuedFunctions :: ResolveRelations :: @@ -2088,63 +2088,6 @@ class Analyzer( } } - /** - * Substitute Hints. - * - BROADCAST/BROADCASTJOIN/MAPJOIN match the closest table with the given name parameters. - * - * This rule substitutes `UnresolvedRelation`s in `Substitute` batch before `ResolveRelations` - * rule is applied. Here are two reasons. - * - To support `MetastoreRelation` in Hive module. - * - To reduce the effect of `Hint` on the other rules. - * - * After this rule, it is guaranteed that there exists no unknown `Hint` in the plan. - * All new `Hint`s should be transformed into concrete Hint classes `BroadcastHint` here. - */ - object SubstituteHints extends Rule[LogicalPlan] { - val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN") - - import scala.collection.mutable.Set - private def appendAllDescendant(set: Set[LogicalPlan], plan: LogicalPlan): Unit = { - set += plan - plan.children.foreach { child => appendAllDescendant(set, child) } - } - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case logical: LogicalPlan => logical transformDown { - case h @ Hint(name, parameters, child) if BROADCAST_HINT_NAMES.contains(name.toUpperCase) => - var resolvedChild = child - for (table <- parameters) { - var stop = false - val skipNodeSet = scala.collection.mutable.Set.empty[LogicalPlan] - resolvedChild = resolvedChild.transformDown { - case n if skipNodeSet.contains(n) => - skipNodeSet -= n - n - case p @ Project(_, _) if p != resolvedChild => - appendAllDescendant(skipNodeSet, p) - skipNodeSet -= p - p - case r @ BroadcastHint(UnresolvedRelation(t, _)) - if !stop && resolver(t.table, table) => - stop = true - r - case r @ UnresolvedRelation(t, alias) if !stop && resolver(t.table, table) => - stop = true - if (alias.isDefined) { - SubqueryAlias(alias.get, BroadcastHint(r.copy(alias = None)), None) - } else { - BroadcastHint(r) - } - } - } - resolvedChild - - // Remove unrecognized hints - case Hint(name, _, child) => child - } - } - } - /** * Check and add proper window frames for all window functions. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala new file mode 100644 index 0000000000000..2fe5fe17c3430 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import scala.collection.{immutable, mutable} + +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule + + +/** + * Substitute Hints. + * - BROADCAST/BROADCASTJOIN/MAPJOIN match the closest table with the given name parameters. + * + * This rule substitutes `UnresolvedRelation`s in `Substitute` batch before `ResolveRelations` + * rule is applied. Here are two reasons. + * - To support `MetastoreRelation` in Hive module. + * - To reduce the effect of `Hint` on the other rules. + * + * After this rule, it is guaranteed that there exists no unknown `Hint` in the plan. + * All new `Hint`s should be transformed into concrete Hint classes `BroadcastHint` here. + */ +class SubstituteHints(conf: CatalystConf) extends Rule[LogicalPlan] { + private val BROADCAST_HINT_NAMES = immutable.Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN") + + def resolver: Resolver = conf.resolver + + private def appendAllDescendant(set: mutable.Set[LogicalPlan], plan: LogicalPlan): Unit = { + set += plan + plan.children.foreach { child => appendAllDescendant(set, child) } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case logical: LogicalPlan => logical transformDown { + case h @ Hint(name, parameters, child) if BROADCAST_HINT_NAMES.contains(name.toUpperCase) => + var resolvedChild = child + for (table <- parameters) { + var stop = false + val skipNodeSet = scala.collection.mutable.Set.empty[LogicalPlan] + resolvedChild = resolvedChild.transformDown { + case n if skipNodeSet.contains(n) => + skipNodeSet -= n + n + case p @ Project(_, _) if p != resolvedChild => + appendAllDescendant(skipNodeSet, p) + skipNodeSet -= p + p + case r @ BroadcastHint(UnresolvedRelation(t, _)) + if !stop && resolver(t.table, table) => + stop = true + r + case r @ UnresolvedRelation(t, alias) if !stop && resolver(t.table, table) => + stop = true + if (alias.isDefined) { + SubqueryAlias(alias.get, BroadcastHint(r.copy(alias = None)), None) + } else { + BroadcastHint(r) + } + } + } + resolvedChild + + // Remove unrecognized hints + case Hint(name, _, child) => child + } + } +} From 617f8a2b9d6069472729609c262ff4e4e2dde5f7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 14 Feb 2017 17:59:08 +0100 Subject: [PATCH 4/6] Rewrote the PR. --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 4 +- .../catalyst/analysis/SubstituteHints.scala | 76 ++++++------ .../analysis/SubstituteHintsSuite.scala | 109 ++++++++---------- .../execution/joins/BroadcastJoinSuite.scala | 75 +----------- .../spark/sql/hive/BroadcastHintSuite.scala | 55 --------- 6 files changed, 95 insertions(+), 228 deletions(-) delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/BroadcastHintSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d37f16ca470d3..300d60596eb57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -115,11 +115,11 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, + new SubstituteHints(conf), CTESubstitution, WindowsSubstitution, EliminateUnions, - new SubstituteUnresolvedOrdinals(conf), - new SubstituteHints(conf)), + new SubstituteUnresolvedOrdinals(conf)), Batch("Resolution", fixedPoint, ResolveTableValuedFunctions :: ResolveRelations :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index df3217c133769..36ab8b8527b44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -387,9 +387,9 @@ trait CheckAnalysis extends PredicateHelper { |in operator ${operator.simpleString} """.stripMargin) - case Hint(_, _, _) => + case _: Hint => throw new IllegalStateException( - "logical hint operator should have been removed by analyzer") + "Internal error: logical hint operator should have been removed during analysis") case _ => // Analysis successful! } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala index 2fe5fe17c3430..ae5bd6f69af6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.catalyst.analysis -import scala.collection.{immutable, mutable} - import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.CurrentOrigin /** * Substitute Hints. * - BROADCAST/BROADCASTJOIN/MAPJOIN match the closest table with the given name parameters. * + * In the case of broadcast hint, we find the frontier of + * * This rule substitutes `UnresolvedRelation`s in `Substitute` batch before `ResolveRelations` * rule is applied. Here are two reasons. * - To support `MetastoreRelation` in Hive module. @@ -37,47 +38,48 @@ import org.apache.spark.sql.catalyst.rules.Rule * All new `Hint`s should be transformed into concrete Hint classes `BroadcastHint` here. */ class SubstituteHints(conf: CatalystConf) extends Rule[LogicalPlan] { - private val BROADCAST_HINT_NAMES = immutable.Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN") + private val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN") def resolver: Resolver = conf.resolver - private def appendAllDescendant(set: mutable.Set[LogicalPlan], plan: LogicalPlan): Unit = { - set += plan - plan.children.foreach { child => appendAllDescendant(set, child) } - } + private def applyBroadcastHint(plan: LogicalPlan, toBroadcast: Set[String]): LogicalPlan = { + // Whether to continue recursing down the tree + var recurse = true - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case logical: LogicalPlan => logical transformDown { - case h @ Hint(name, parameters, child) if BROADCAST_HINT_NAMES.contains(name.toUpperCase) => - var resolvedChild = child - for (table <- parameters) { - var stop = false - val skipNodeSet = scala.collection.mutable.Set.empty[LogicalPlan] - resolvedChild = resolvedChild.transformDown { - case n if skipNodeSet.contains(n) => - skipNodeSet -= n - n - case p @ Project(_, _) if p != resolvedChild => - appendAllDescendant(skipNodeSet, p) - skipNodeSet -= p - p - case r @ BroadcastHint(UnresolvedRelation(t, _)) - if !stop && resolver(t.table, table) => - stop = true - r - case r @ UnresolvedRelation(t, alias) if !stop && resolver(t.table, table) => - stop = true - if (alias.isDefined) { - SubqueryAlias(alias.get, BroadcastHint(r.copy(alias = None)), None) - } else { - BroadcastHint(r) - } + val newNode = CurrentOrigin.withOrigin(plan.origin) { + plan match { + case r: UnresolvedRelation => + val alias = r.alias.getOrElse(r.tableIdentifier.table) + if (toBroadcast.exists(resolver(_, alias))) BroadcastHint(plan) else plan + case r: SubqueryAlias => + if (toBroadcast.exists(resolver(_, r.alias))) { + BroadcastHint(plan) + } else { + // Don't recurse down subquery aliases if there are no match. + recurse = false + plan } - } - resolvedChild + case _: BroadcastHint => + // Found a broadcast hint; don't change the plan but also don't recurse down. + recurse = false + plan + case _ => + plan + } + } - // Remove unrecognized hints - case Hint(name, _, child) => child + if ((plan fastEquals newNode) && recurse) { + newNode.mapChildren(child => applyBroadcastHint(child, toBroadcast)) + } else { + newNode } } + + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase) => + applyBroadcastHint(h.child, h.parameters.toSet) + + // Remove unrecognized hints + case h: Hint => h.child + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala index 64e85111c43df..2bcd5101408e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ class SubstituteHintsSuite extends AnalysisTest { import org.apache.spark.sql.catalyst.analysis.TestRelations._ - val a = testRelation.output(0) - val b = testRelation2.output(0) - test("case-sensitive or insensitive parameters") { checkAnalysis( Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), @@ -40,86 +40,77 @@ class SubstituteHintsSuite extends AnalysisTest { checkAnalysis( Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), - BroadcastHint(testRelation)) + BroadcastHint(testRelation), + caseSensitive = true) checkAnalysis( Hint("MAPJOIN", Seq("table"), table("TaBlE")), - testRelation) + testRelation, + caseSensitive = true) } - test("single hint") { - checkAnalysis( - Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE").select(a)), - BroadcastHint(testRelation).select(a)) - - checkAnalysis( - Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE").as("t").join(table("TaBlE2").as("u")).select(a)), - BroadcastHint(testRelation).join(testRelation2).select(a)) - + test("multiple broadcast hint aliases") { checkAnalysis( - Hint("MAPJOIN", Seq("TaBlE2"), - table("TaBlE").as("t").join(table("TaBlE2").as("u")).select(a)), - testRelation.join(BroadcastHint(testRelation2)).select(a)) + Hint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))), + Join(BroadcastHint(testRelation), BroadcastHint(testRelation2), Inner, None), + caseSensitive = false) } - test("single hint with multiple parameters") { - checkAnalysis( - Hint("MAPJOIN", Seq("TaBlE", "TaBlE"), - table("TaBlE").as("t").join(table("TaBlE2").as("u")).select(a)), - BroadcastHint(testRelation).join(testRelation2).select(a)) - + test("do not traverse past existing broadcast hints") { checkAnalysis( - Hint("MAPJOIN", Seq("TaBlE", "TaBlE2"), - table("TaBlE").as("t").join(table("TaBlE2").as("u")).select(a)), - BroadcastHint(testRelation).join(BroadcastHint(testRelation2)).select(a)) + Hint("MAPJOIN", Seq("table"), BroadcastHint(table("table").where('a > 1))), + BroadcastHint(testRelation.where('a > 1)).analyze, + caseSensitive = false) } - test("duplicated nested hints are transformed into one") { - checkAnalysis( - Hint("MAPJOIN", Seq("TaBlE"), - Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE").as("t").select('a)) - .join(table("TaBlE2").as("u")).select(a)), - BroadcastHint(testRelation).select(a).join(testRelation2).select(a)) + test("should work for subqueries") { + val relation = UnresolvedRelation(TableIdentifier("table"), Some("tableAlias")) checkAnalysis( - Hint("MAPJOIN", Seq("TaBlE2"), - table("TaBlE").as("t").select(a) - .join(Hint("MAPJOIN", Seq("TaBlE2"), table("TaBlE2").as("u").select(b))).select(a)), - testRelation.select(a).join(BroadcastHint(testRelation2).select(b)).select(a)) - } + Hint("MAPJOIN", Seq("tableAlias"), relation), + BroadcastHint(testRelation), + caseSensitive = false) - test("distinct nested two hints are handled separately") { checkAnalysis( - Hint("MAPJOIN", Seq("TaBlE2"), - Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE").as("t").select(a)) - .join(table("TaBlE2").as("u")).select(a)), - BroadcastHint(testRelation).select(a).join(BroadcastHint(testRelation2)).select(a)) + Hint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)), + BroadcastHint(testRelation), + caseSensitive = false) + // Negative case: if the alias doesn't match, don't match the original table name. checkAnalysis( - Hint("MAPJOIN", Seq("TaBlE"), - table("TaBlE").as("t") - .join(Hint("MAPJOIN", Seq("TaBlE2"), table("TaBlE2").as("u").select(b))).select(a)), - BroadcastHint(testRelation).join(BroadcastHint(testRelation2).select(b)).select(a)) + Hint("MAPJOIN", Seq("table"), relation), + testRelation, + caseSensitive = false) } - test("deep self join") { + test("do not traverse past subquery alias") { checkAnalysis( - Hint("MAPJOIN", Seq("TaBlE"), - table("TaBlE").join(table("TaBlE")).join(table("TaBlE")).join(table("TaBlE")).select(a)), - BroadcastHint(testRelation).join(testRelation).join(testRelation).join(testRelation) - .select(a)) + Hint("MAPJOIN", Seq("table"), table("table").where('a > 1).subquery('tableAlias)), + testRelation.where('a > 1).analyze, + caseSensitive = false) } - test("subquery should be ignored") { + test("should work for CTE") { checkAnalysis( - Hint("MAPJOIN", Seq("TaBlE"), - table("TaBlE").select(a).as("x").join(table("TaBlE")).select(a)), - testRelation.select(a).join(BroadcastHint(testRelation)).select(a)) + CatalystSqlParser.parsePlan( + """ + |WITH ctetable AS (SELECT * FROM table WHERE a > 1) + |SELECT /*+ BROADCAST(ctetable) */ * FROM ctetable + """.stripMargin + ), + BroadcastHint(testRelation.where('a > 1).select('a)).select('a).analyze, + caseSensitive = false) + } + test("should not traverse down CTE") { checkAnalysis( - Hint("MAPJOIN", Seq("TaBlE"), - table("TaBlE").as("t").select(a).as("x") - .join(table("TaBlE2").as("t2")).select(a)), - testRelation.select(a).join(testRelation2).select(a)) + CatalystSqlParser.parsePlan( + """ + |WITH ctetable AS (SELECT * FROM table WHERE a > 1) + |SELECT /*+ BROADCAST(table) */ * FROM ctetable + """.stripMargin + ), + testRelation.where('a > 1).select('a).select('a).analyze, + caseSensitive = false) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 85b1b777b0f7c..9c55357ab9bc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -139,7 +139,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) } - test("broadcast hint is propagated correctly") { + test("broadcast hint programming API") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "2"))).toDF("key", "value") val broadcasted = broadcast(df2) @@ -159,7 +159,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } - test("Broadcast Hint") { + test("broadcast hint in SQL") { import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, Join} spark.range(10).createOrReplaceTempView("t") @@ -182,77 +182,6 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } - test("Broadcast Hint matches the nearest one") { - val tbl_a = spark.range(10) - val tbl_b = spark.range(20) - val tbl_c = spark.range(30) - - tbl_a.createOrReplaceTempView("tbl_a") - tbl_b.createOrReplaceTempView("tbl_b") - tbl_c.createOrReplaceTempView("tbl_c") - - val plan = sql( - """SELECT /*+ MAPJOIN(tbl_b) */ - | * - |FROM tbl_a A - | JOIN tbl_b B - | ON B.id = A.id - | JOIN (SELECT XA.id - | FROM tbl_b XA - | LEFT SEMI JOIN tbl_c XB - | ON XB.id = XA.id) C - | ON C.id = A.id - """.stripMargin).queryExecution.analyzed - - val correct_answer = - SubqueryAlias("A", tbl_a.logicalPlan, Some(TableIdentifier("tbl_a"))) - .join(SubqueryAlias("B", broadcast(SubqueryAlias("tbl_b", tbl_b.logicalPlan, - Some(TableIdentifier("tbl_b")))).logicalPlan, None), $"B.id" === $"A.id", "inner") - .join(SubqueryAlias("XA", tbl_b.logicalPlan, Some(TableIdentifier("tbl_b"))) - .join(SubqueryAlias("XB", tbl_c.logicalPlan, Some(TableIdentifier("tbl_c"))), - $"XB.id" === $"XA.id", "leftsemi") - .select("XA.id").as("C"), $"C.id" === $"A.id", "inner") - .select(col("*")).logicalPlan - - comparePlans(plan, correct_answer) - } - - test("Nested Broadcast Hint") { - val tbl_a = spark.range(10) - val tbl_b = spark.range(20) - val tbl_c = spark.range(30) - - tbl_a.createOrReplaceTempView("tbl_a") - tbl_b.createOrReplaceTempView("tbl_b") - tbl_c.createOrReplaceTempView("tbl_c") - - val plan = sql( - """SELECT /*+ MAPJOIN(tbl_a, tbl_a) */ - | * - |FROM tbl_a A - | JOIN tbl_b B - | ON B.id = A.id - | JOIN (SELECT /*+ MAPJOIN(tbl_c) */ - | XA.id - | FROM tbl_b XA - | LEFT SEMI JOIN tbl_c XB - | ON XB.id = XA.id) C - | ON C.id = A.id - """.stripMargin).queryExecution.analyzed - - val correct_answer = - broadcast(SubqueryAlias("tbl_a", tbl_a.logicalPlan, Some(TableIdentifier("tbl_a")))).as("A") - .join(SubqueryAlias("B", tbl_b.logicalPlan, Some(TableIdentifier("tbl_b"))), - $"B.id" === $"A.id", "inner") - .join(SubqueryAlias("XA", tbl_b.logicalPlan, Some(TableIdentifier("tbl_b"))) - .join(broadcast(SubqueryAlias("tbl_c", tbl_c.logicalPlan, Some(TableIdentifier("tbl_c")))) - .as("XB"), $"XB.id" === $"XA.id", "leftsemi") - .select("XA.id").as("C"), $"C.id" === $"A.id", "inner") - .select(col("*")).logicalPlan - - comparePlans(plan, correct_answer) - } - test("join key rewritten") { val l = Literal(1L) val i = Literal(2) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/BroadcastHintSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/BroadcastHintSuite.scala deleted file mode 100644 index 928064a95fec3..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/BroadcastHintSuite.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, Join} -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils - -class BroadcastHintSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { - test("broadcast hint on Hive table") { - withTable("hive_t", "hive_u") { - spark.sql("CREATE TABLE hive_t(a int)") - spark.sql("CREATE TABLE hive_u(b int)") - - val hive_t = spark.table("hive_t").queryExecution.analyzed - val hive_u = spark.table("hive_u").queryExecution.analyzed - - val plan = spark.sql("SELECT /*+ MAPJOIN(hive_t) */ * FROM hive_t, hive_u") - .queryExecution.analyzed - - assert(plan.collectFirst { - case BroadcastHint(MetastoreRelation(_, "hive_t")) => true - }.isDefined) - assert(plan.collectFirst { - case Join(_, MetastoreRelation(_, "hive_u"), _, _) => true - }.isDefined) - - val plan2 = spark.sql("SELECT /*+ MAPJOIN(hive_u) */ a FROM hive_t, hive_u") - .queryExecution.analyzed - - assert(plan2.collectFirst { - case BroadcastHint(MetastoreRelation(_, "hive_u")) => true - }.isDefined) - assert(plan2.collectFirst { - case Join(MetastoreRelation(_, "hive_t"), _, _, _) => true - }.isDefined) - } - } -} From 51a73d510a5faf3bbed8bb76ffe5590c68a67e2c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 14 Feb 2017 19:51:20 +0100 Subject: [PATCH 5/6] Separate rule. --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../catalyst/analysis/SubstituteHints.scala | 106 ++++++++++-------- .../analysis/SubstituteHintsSuite.scala | 7 ++ 3 files changed, 72 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 300d60596eb57..8348cb50129cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -114,8 +114,10 @@ class Analyzer( val postHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil lazy val batches: Seq[Batch] = Seq( + Batch("Hints", fixedPoint, + new SubstituteHints.SubstituteBroadcastHints(conf), + SubstituteHints.RemoveAllHints), Batch("Substitution", fixedPoint, - new SubstituteHints(conf), CTESubstitution, WindowsSubstitution, EliminateUnions, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala index ae5bd6f69af6e..b103d1e8eeff5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala @@ -24,62 +24,80 @@ import org.apache.spark.sql.catalyst.trees.CurrentOrigin /** - * Substitute Hints. - * - BROADCAST/BROADCASTJOIN/MAPJOIN match the closest table with the given name parameters. + * Collection of rules related to hints. The only hint currently available is broadcast join hint. * - * In the case of broadcast hint, we find the frontier of - * - * This rule substitutes `UnresolvedRelation`s in `Substitute` batch before `ResolveRelations` - * rule is applied. Here are two reasons. - * - To support `MetastoreRelation` in Hive module. - * - To reduce the effect of `Hint` on the other rules. - * - * After this rule, it is guaranteed that there exists no unknown `Hint` in the plan. - * All new `Hint`s should be transformed into concrete Hint classes `BroadcastHint` here. + * Note that this is separatedly into two rules because in the future we might introduce new hint + * rules that have different ordering requirements from broadcast. */ -class SubstituteHints(conf: CatalystConf) extends Rule[LogicalPlan] { - private val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN") +object SubstituteHints { + + /** + * Substitute Hints. + * + * The only hint currently available is broadcast join hint. + * + * For broadcast hint, we accept "BROADCAST", "BROADCASTJOIN", and "MAPJOIN", and a sequence of + * relation aliases can be specified in the hint. A broadcast hint plan node will be inserted + * on top of any relation (that is not aliased differently), subquery, or common table expression + * that match the specified name. + * + * The hint resolution works by recursively traversing down the query plan to find a relation or + * subquery that matches one of the specified broadcast aliases. The traversal does not go past + * beyond any existing broadcast hints, subquery aliases. + * + * This rule must happen before common table expressions. + */ + class SubstituteBroadcastHints(conf: CatalystConf) extends Rule[LogicalPlan] { + private val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN") - def resolver: Resolver = conf.resolver + def resolver: Resolver = conf.resolver - private def applyBroadcastHint(plan: LogicalPlan, toBroadcast: Set[String]): LogicalPlan = { - // Whether to continue recursing down the tree - var recurse = true + private def applyBroadcastHint(plan: LogicalPlan, toBroadcast: Set[String]): LogicalPlan = { + // Whether to continue recursing down the tree + var recurse = true - val newNode = CurrentOrigin.withOrigin(plan.origin) { - plan match { - case r: UnresolvedRelation => - val alias = r.alias.getOrElse(r.tableIdentifier.table) - if (toBroadcast.exists(resolver(_, alias))) BroadcastHint(plan) else plan - case r: SubqueryAlias => - if (toBroadcast.exists(resolver(_, r.alias))) { - BroadcastHint(plan) - } else { - // Don't recurse down subquery aliases if there are no match. + val newNode = CurrentOrigin.withOrigin(plan.origin) { + plan match { + case r: UnresolvedRelation => + val alias = r.alias.getOrElse(r.tableIdentifier.table) + if (toBroadcast.exists(resolver(_, alias))) BroadcastHint(plan) else plan + case r: SubqueryAlias => + if (toBroadcast.exists(resolver(_, r.alias))) { + BroadcastHint(plan) + } else { + // Don't recurse down subquery aliases if there are no match. + recurse = false + plan + } + case _: BroadcastHint => + // Found a broadcast hint; don't change the plan but also don't recurse down. recurse = false plan - } - case _: BroadcastHint => - // Found a broadcast hint; don't change the plan but also don't recurse down. - recurse = false - plan - case _ => - plan + case _ => + plan + } + } + + if ((plan fastEquals newNode) && recurse) { + newNode.mapChildren(child => applyBroadcastHint(child, toBroadcast)) + } else { + newNode } } - if ((plan fastEquals newNode) && recurse) { - newNode.mapChildren(child => applyBroadcastHint(child, toBroadcast)) - } else { - newNode + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase) => + applyBroadcastHint(h.child, h.parameters.toSet) } } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase) => - applyBroadcastHint(h.child, h.parameters.toSet) - - // Remove unrecognized hints - case h: Hint => h.child + /** + * Removes all the hints. This must be executed after all the other hint rules are executed. + */ + object RemoveAllHints extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case h: Hint => h.child + } } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala index 2bcd5101408e6..10fe3ebdc42db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala @@ -27,6 +27,13 @@ import org.apache.spark.sql.catalyst.plans.logical._ class SubstituteHintsSuite extends AnalysisTest { import org.apache.spark.sql.catalyst.analysis.TestRelations._ + test("invalid hints should be ignored") { + checkAnalysis( + Hint("some_random_hint_that_does_not_exist", Seq("TaBlE"), table("TaBlE")), + testRelation, + caseSensitive = false) + } + test("case-sensitive or insensitive parameters") { checkAnalysis( Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), From 0d429782b38cf19c0d4d6a5102102e39d872fee7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 14 Feb 2017 21:38:53 +0100 Subject: [PATCH 6/6] CR --- .../spark/sql/catalyst/analysis/SubstituteHints.scala | 3 ++- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 2 +- .../spark/sql/catalyst/analysis/SubstituteHintsSuite.scala | 6 ++---- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala index b103d1e8eeff5..fda4d1b61212c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala @@ -92,7 +92,8 @@ object SubstituteHints { } /** - * Removes all the hints. This must be executed after all the other hint rules are executed. + * Removes all the hints, used to remove invalid hints provided by the user. + * This must be executed after all the other hint rules are executed. */ object RemoveAllHints extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 60c41fb1bad71..bbb9922c187de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -383,7 +383,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val withWindow = withDistinct.optionalMap(windows)(withWindows) // Hint - withWindow.optionalMap(ctx.hint)(withHints) + withWindow.optionalMap(hint)(withHints) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala index 10fe3ebdc42db..9d671f31213ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala @@ -71,10 +71,8 @@ class SubstituteHintsSuite extends AnalysisTest { } test("should work for subqueries") { - val relation = UnresolvedRelation(TableIdentifier("table"), Some("tableAlias")) - checkAnalysis( - Hint("MAPJOIN", Seq("tableAlias"), relation), + Hint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")), BroadcastHint(testRelation), caseSensitive = false) @@ -85,7 +83,7 @@ class SubstituteHintsSuite extends AnalysisTest { // Negative case: if the alias doesn't match, don't match the original table name. checkAnalysis( - Hint("MAPJOIN", Seq("table"), relation), + Hint("MAPJOIN", Seq("table"), table("table").as("tableAlias")), testRelation, caseSensitive = false) }