Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-25121][SQL] Supports multi-part table names for broadcast hint resolution #22198

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -146,7 +146,7 @@ class Analyzer(

lazy val batches: Seq[Batch] = Seq(
Batch("Hints", fixedPoint,
new ResolveHints.ResolveBroadcastHints(conf),
new ResolveHints.ResolveBroadcastHints(conf, catalog),
ResolveHints.ResolveCoalesceHints,
ResolveHints.RemoveAllHints),
Batch("Simple Sanity Check", Once,
Expand Down
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.analysis
import java.util.Locale

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.IdentifierWithDatabase
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.expressions.IntegerLiteral
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -47,20 +49,42 @@ object ResolveHints {
*
* This rule must happen before common table expressions.
*/
class ResolveBroadcastHints(conf: SQLConf) extends Rule[LogicalPlan] {
class ResolveBroadcastHints(conf: SQLConf, catalog: SessionCatalog) extends Rule[LogicalPlan] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accordingly, we can use String instead of SessionCatalog.

-  class ResolveBroadcastHints(conf: SQLConf, catalog: SessionCatalog) extends Rule[LogicalPlan] {
+  class ResolveBroadcastHints(conf: SQLConf, currentDatabase: String) extends Rule[LogicalPlan] {

Copy link
Member Author

@maropu maropu Aug 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can't use String there because currentDatabase might be updatable by others?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can instead use getCurrentDatabase: () => String?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya. Right, please ignore this. We need catalog to lookup global_temp, too.

private val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN")

def resolver: Resolver = conf.resolver

private def applyBroadcastHint(plan: LogicalPlan, toBroadcast: Set[String]): LogicalPlan = {
// Name resolution in hints follows three rules below:
//
// 1. table name matches if the hint table name only has one part
// 2. table name and database name both match if the hint table name has two parts
// 3. no match happens if the hint table name has more than three parts
//
// This means, `SELECT /* BROADCAST(t) */ * FROM db1.t JOIN db2.t` will match both tables, and
// `SELECT /* BROADCAST(default.t) */ * FROM t` match no table.
private def matchedTableIdentifier(
nameParts: Seq[String],
tableIdent: IdentifierWithDatabase): Boolean = nameParts match {
case Seq(tableName) =>
resolver(tableIdent.identifier, tableName)
case Seq(dbName, tableName) if tableIdent.database.isDefined =>
resolver(tableIdent.database.get, dbName) && resolver(tableIdent.identifier, tableName)
case _ =>
false
}

private def applyBroadcastHint(
plan: LogicalPlan,
toBroadcast: Set[Seq[String]]): LogicalPlan = {
// Whether to continue recursing down the tree
var recurse = true

val newNode = CurrentOrigin.withOrigin(plan.origin) {
plan match {
case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) =>
case u: UnresolvedRelation
if toBroadcast.exists(matchedTableIdentifier(_, u.tableIdentifier)) =>
ResolvedHint(plan, HintInfo(broadcast = true))
case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) =>
case r: SubqueryAlias if toBroadcast.exists(matchedTableIdentifier(_, r.name)) =>
ResolvedHint(plan, HintInfo(broadcast = true))

case _: ResolvedHint | _: View | _: With | _: SubqueryAlias =>
Expand Down Expand Up @@ -94,8 +118,8 @@ object ResolveHints {
} else {
// Otherwise, find within the subtree query plans that should be broadcasted.
applyBroadcastHint(h.child, h.parameters.map {
case tableName: String => tableName
case tableId: UnresolvedAttribute => tableId.name
case tableName: String => UnresolvedAttribute.parseAttributeName(tableName)
case tableId: UnresolvedAttribute => tableId.nameParts
case unsupported => throw new AnalysisException("Broadcast hint parameter should be " +
s"an identifier or string but was $unsupported (${unsupported.getClass}")
}.toSet)
Expand Down
Expand Up @@ -168,7 +168,7 @@ package object expressions {
// For example, consider an example where "db1" is the database name, "a" is the table name
// and "b" is the column name and "c" is the struct field name.
// If the name parts is db1.a.b.c, then Attribute will match
// Attribute(b, qualifier("db1,"a")) and List("c") will be the second element
// Attribute(b, qualifier("db1","a")) and List("c") will be the second element
var matches: (Seq[Attribute], Seq[String]) = nameParts match {
case dbPart +: tblPart +: name +: nestedFields =>
val key = (dbPart.toLowerCase(Locale.ROOT),
Expand Down
Expand Up @@ -41,6 +41,8 @@ trait AnalysisTest extends PlanTest {
catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true)
catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true)
catalog.createTempView("TaBlE3", TestRelations.testRelation3, overrideIfExists = true)
catalog.createGlobalTempView("TaBlE4", TestRelations.testRelation4, overrideIfExists = true)
catalog.createGlobalTempView("TaBlE5", TestRelations.testRelation5, overrideIfExists = true)
new Analyzer(catalog, conf) {
override val extendedResolutionRules = EliminateSubqueryAliases :: Nil
}
Expand Down
Expand Up @@ -155,4 +155,52 @@ class ResolveHintsSuite extends AnalysisTest {
UnresolvedHint("REPARTITION", Seq(Literal(true)), table("TaBlE")),
Seq(errMsgRepa))
}

test("supports multi-part table names for broadcast hint resolution") {
// local temp table
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("table", "table2"),
table("table").join(table("table2"))),
Join(
ResolvedHint(testRelation, HintInfo(broadcast = true)),
ResolvedHint(testRelation2, HintInfo(broadcast = true)),
Inner,
None,
JoinHint(None, None)),
caseSensitive = false)

checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("TaBlE", "table2"),
table("TaBlE").join(table("TaBlE2"))),
Join(
ResolvedHint(testRelation, HintInfo(broadcast = true)),
testRelation2,
Inner,
None,
JoinHint(None, None)),
caseSensitive = true)

// global temp table
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("global_temp.table4", "GlOBal_TeMP.table5"),
table("global_temp", "table4").join(table("global_temp", "table5"))),
Join(
ResolvedHint(testRelation4, HintInfo(broadcast = true)),
ResolvedHint(testRelation5, HintInfo(broadcast = true)),
Inner,
None,
JoinHint(None, None)),
caseSensitive = false)

checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("global_temp.TaBlE4", "table5"),
table("global_temp", "TaBlE4").join(table("global_temp", "TaBlE5"))),
Join(
ResolvedHint(testRelation4, HintInfo(broadcast = true)),
testRelation5,
Inner,
None,
JoinHint(None, None)),
caseSensitive = true)
}
}
Expand Up @@ -44,6 +44,8 @@ object TestRelations {
AttributeReference("g", StringType)(),
AttributeReference("h", MapType(IntegerType, IntegerType))())

val testRelation5 = LocalRelation(AttributeReference("i", StringType)())

val nestedRelation = LocalRelation(
AttributeReference("top", StructType(
StructField("duplicateField", StringType) ::
Expand Down
Expand Up @@ -17,8 +17,11 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -191,6 +194,83 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size == 1)
}

test("SPARK-25121 supports multi-part names for broadcast hint resolution") {
val (table1Name, table2Name) = ("t1", "t2")

withTempDatabase { dbName =>
withTable(table1Name, table2Name) {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
spark.range(50).write.saveAsTable(s"$dbName.$table1Name")
spark.range(100).write.saveAsTable(s"$dbName.$table2Name")

// First, makes sure a join is not broadcastable
val plan = sql(s"SELECT * FROM $dbName.$table1Name, $dbName.$table2Name " +
s"WHERE $table1Name.id = $table2Name.id")
.queryExecution.executedPlan
assert(plan.collect { case p: BroadcastHashJoinExec => p }.isEmpty)

def checkIfHintApplied(tableName: String, hintTableName: String): Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hintTableName is never used in this func?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, I'll fix.

val p = sql(s"SELECT /*+ BROADCASTJOIN($hintTableName) */ * " +
s"FROM $tableName, $dbName.$table2Name " +
s"WHERE $tableName.id = $table2Name.id")
.queryExecution.executedPlan
val broadcastHashJoins = p.collect { case p: BroadcastHashJoinExec => p }
assert(broadcastHashJoins.size == 1)
val broadcastExchanges = broadcastHashJoins.head.collect {
case p: BroadcastExchangeExec => p
}
assert(broadcastExchanges.size == 1)
val tables = broadcastExchanges.head.collect {
case FileSourceScanExec(_, _, _, _, _, _, Some(tableIdent)) => tableIdent
}
assert(tables.size == 1)
assert(tables.head === TableIdentifier(table1Name, Some(dbName)))
}

def checkIfHintNotApplied(tableName: String, hintTableName: String): Unit = {
val p = sql(s"SELECT /*+ BROADCASTJOIN($hintTableName) */ * " +
s"FROM $tableName, $dbName.$table2Name " +
s"WHERE $tableName.id = $table2Name.id")
.queryExecution.executedPlan
val broadcastHashJoins = p.collect { case p: BroadcastHashJoinExec => p }
assert(broadcastHashJoins.isEmpty)
}

sql(s"USE $dbName")
checkIfHintApplied(table1Name, table1Name)
checkIfHintApplied(s"$dbName.$table1Name", s"$dbName.$table1Name")
checkIfHintApplied(s"$dbName.$table1Name", table1Name)
checkIfHintNotApplied(table1Name, s"$dbName.$table1Name")
checkIfHintNotApplied(s"$dbName.$table1Name", s"$dbName.$table1Name.id")
}
}
}
}

test("SPARK-25121 the same table name exists in two databases for broadcast hint resolution") {
val (db1Name, db2Name) = ("db1", "db2")

withDatabase(db1Name, db2Name) {
withTable("t") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
sql(s"CREATE DATABASE $db1Name")
sql(s"CREATE DATABASE $db2Name")
spark.range(1).write.saveAsTable(s"$db1Name.t")
spark.range(1).write.saveAsTable(s"$db2Name.t")

// Checks if a broadcast hint applied in both sides
val statement = s"SELECT /*+ BROADCASTJOIN(t) */ * FROM $db1Name.t, $db2Name.t " +
s"WHERE $db1Name.t.id = $db2Name.t.id"
sql(statement).queryExecution.optimizedPlan match {
case Join(_, _, _, _, JoinHint(Some(leftHint), Some(rightHint))) =>
assert(leftHint.broadcast && rightHint.broadcast)
case _ => fail("broadcast hint not found in both tables")
}
}
}
}
}

test("join - outer join conversion") {
val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a")
val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b")
Expand Down
Expand Up @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalog.Table
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -157,6 +158,27 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext {
}
}

test("SPARK-25121 broadcast hint on global temp view") {
withGlobalTempView("v1") {
spark.range(10).createGlobalTempView("v1")
withTempView("v2") {
spark.range(10).createTempView("v2")

withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
Seq(
"SELECT /*+ MAPJOIN(v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id",
"SELECT /*+ MAPJOIN(global_temp.v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id"
).foreach { statement =>
sql(statement).queryExecution.optimizedPlan match {
case Join(_, _, _, _, JoinHint(Some(leftHint), None)) => assert(leftHint.broadcast)
case _ => fail("broadcast hint not found in a left-side table")
}
}
}
}
}
}

test("public Catalog should recognize global temp view") {
withGlobalTempView("src") {
sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 2")
Expand Down
Expand Up @@ -20,6 +20,9 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.plans.logical.{ResolvedHint, SubqueryAlias}
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}

class SimpleSQLViewSuite extends SQLViewSuite with SharedSQLContext
Expand Down Expand Up @@ -706,4 +709,39 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
}
}
}

test("SPARK-25121 broadcast hint on temp view") {
withTable("t") {
spark.range(10).write.saveAsTable("t")
withTempView("tv") {
sql("CREATE TEMPORARY VIEW tv AS SELECT * FROM t")

withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
// First, makes sure a join is not broadcastable
val plan1 = sql("SELECT * FROM t, tv WHERE t.id = tv.id")
.queryExecution.executedPlan
assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size == 0)

// `MAPJOIN(default.tv)` cannot match the temporary table `tv`
val plan2 = sql("SELECT /*+ MAPJOIN(default.tv) */ * FROM t, tv WHERE t.id = tv.id")
.queryExecution.analyzed
assert(plan2.collect { case h: ResolvedHint => h }.size == 0)

// `MAPJOIN(tv)` can match the temporary table `tv`
val df = sql("SELECT /*+ MAPJOIN(tv) */ * FROM t, tv WHERE t.id = tv.id")
val logicalPlan = df.queryExecution.analyzed
val broadcastData = logicalPlan.collect {
case ResolvedHint(SubqueryAlias(name, _), _) => name
}
assert(broadcastData.size == 1)
assert(broadcastData.head.database === None)
assert(broadcastData.head.identifier === "tv")

val sparkPlan = df.queryExecution.executedPlan
val broadcastHashJoins = sparkPlan.collect { case p: BroadcastHashJoinExec => p }
assert(broadcastHashJoins.size == 1)
}
}
}
}
}