Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13917] [SQL] generate broadcast semi join #11742

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(
LeftSemi, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
joins.BroadcastLeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
Seq(joins.BroadcastHashJoin(
leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right)))
// Find left semi joins where at least some predicates can be evaluated by matching join keys
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
joins.LeftSemiJoinHash(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
Expand Down Expand Up @@ -92,6 +92,9 @@ case class BroadcastHashJoin(
rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
}

case LeftSemi =>
hashSemiJoin(streamedIter, hashTable, numOutputRows)

case x =>
throw new IllegalArgumentException(
s"BroadcastHashJoin should not take $x as the JoinType")
Expand All @@ -108,11 +111,13 @@ case class BroadcastHashJoin(
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
if (joinType == Inner) {
codegenInner(ctx, input)
} else {
// LeftOuter and RightOuter
codegenOuter(ctx, input)
joinType match {
case Inner => codegenInner(ctx, input)
case LeftOuter | RightOuter => codegenOuter(ctx, input)
case LeftSemi => codegenSemi(ctx, input)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashJoin should not take $x as the JoinType")
}
}

Expand Down Expand Up @@ -322,4 +327,70 @@ case class BroadcastHashJoin(
""".stripMargin
}
}

/**
* Generates the code for left semi join.
*/
private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val matched = ctx.freshName("matched")
val buildVars = genBuildSideVars(ctx, matched)
val numOutput = metricTerm(ctx, "numOutputRows")

val checkCondition = if (condition.isDefined) {
val expr = condition.get
// evaluate the variables from build side that used by condition
val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
// filter the output via condition
ctx.currentVars = input ++ buildVars
val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx)
s"""
|$eval
|${ev.code}
|if (${ev.isNull} || !${ev.value}) continue;
""".stripMargin
} else {
""
}

if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashedRelation
|UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
|if ($matched == null) continue;
|$checkCondition
|$numOutput.add(1);
|${consume(ctx, input)}
""".stripMargin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra new line

} else {
val matches = ctx.freshName("matches")
val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
val i = ctx.freshName("i")
val size = ctx.freshName("size")
val found = ctx.freshName("found")
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashRelation
|$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
|if ($matches == null) continue;
|int $size = $matches.size();
|boolean $found = false;
|for (int $i = 0; $i < $size; $i++) {
| UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
| $checkCondition
| $found = true;
| break;
|}
|if (!$found) continue;
|$numOutput.add(1);
|${consume(ctx, input)}
""".stripMargin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here

}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ trait HashJoin {
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
case LeftSemi =>
left.output
case x =>
throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType")
}
Expand Down Expand Up @@ -104,7 +104,7 @@ trait HashJoin {
keyExpr :: Nil
}

protected val canJoinKeyFitWithinLong: Boolean = {
protected lazy val canJoinKeyFitWithinLong: Boolean = {
val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)
val key = rewriteKeyExpr(buildKeys)
sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType]
Expand Down Expand Up @@ -258,4 +258,21 @@ trait HashJoin {
}
ret.iterator
}

protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
hashedRelation: HashedRelation,
numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
val joinKeys = streamSideKeyGenerator
val joinedRow = new JoinedRow
streamIter.filter { current =>
val key = joinKeys(current)
lazy val rowBuffer = hashedRelation.get(key)
val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists {
(row: InternalRow) => boundCondition(joinedRow(current, row))
})
if (r) numOutputRows += 1
r
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.LeftSemi
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
Expand All @@ -33,7 +34,10 @@ case class LeftSemiJoinHash(
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan,
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
condition: Option[Expression]) extends BinaryNode with HashJoin {

override val joinType = LeftSemi
override val buildSide = BuildRight

override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
Expand All @@ -47,7 +51,7 @@ case class LeftSemiJoinHash(
val numOutputRows = longMetric("numOutputRows")

right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
val hashRelation = HashedRelation(buildIter.map(_.copy()), rightKeyGenerator)
val hashRelation = HashedRelation(buildIter.map(_.copy()), buildSideKeyGenerator)
hashSemiJoin(streamIter, hashRelation, numOutputRows)
}
}
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
case j: BroadcastHashJoin => j
case j: CartesianProduct => j
case j: BroadcastNestedLoopJoin => j
case j: BroadcastLeftSemiJoinHash => j
case j: BroadcastHashJoin => j
case j: SortMergeJoin => j
case j: SortMergeOuterJoin => j
}
Expand Down Expand Up @@ -428,7 +428,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") {
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
classOf[BroadcastLeftSemiJoinHash])
classOf[BroadcastHashJoin])
).foreach {
case (query, joinClass) => assertJoin(query, joinClass)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.util.Benchmark
class BenchmarkWholeStageCodegen extends SparkFunSuite {
lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
.set("spark.sql.shuffle.partitions", "1")
.set("spark.sql.autoBroadcastJoinThreshold", "0")
.set("spark.sql.autoBroadcastJoinThreshold", "1")
lazy val sc = SparkContext.getOrCreate(conf)
lazy val sqlContext = SQLContext.getOrCreate(sc)

Expand Down Expand Up @@ -200,6 +200,18 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
outer join w long codegen=false 15280 / 16497 6.9 145.7 1.0X
outer join w long codegen=true 769 / 796 136.3 7.3 19.9X
*/

runBenchmark("semi join w long", N) {
sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "leftsemi").count()
}

/**
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
semi join w long codegen=false 5804 / 5969 18.1 55.3 1.0X
semi join w long codegen=true 814 / 934 128.8 7.8 7.1X
*/
}

ignore("sort merge join") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
}

test("unsafe broadcast left semi join updates peak execution memory") {
testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi")
testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast left semi join", "leftsemi")
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
}
}

test(s"$testName using BroadcastLeftSemiJoinHash") {
test(s"$testName using BroadcastHashJoin") {
extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
BroadcastHashJoin(
leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton {
// Using `sparkPlan` because for relevant patterns in HashJoin to be
// matched, other strategies need to be applied.
var bhj = df.queryExecution.sparkPlan.collect {
case j: BroadcastLeftSemiJoinHash => j
case j: BroadcastHashJoin => j
}
assert(bhj.size === 1,
s"actual query plans do not contain broadcast join: ${df.queryExecution}")
Expand All @@ -225,7 +225,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton {
sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1")
df = sql(leftSemiJoinQuery)
bhj = df.queryExecution.sparkPlan.collect {
case j: BroadcastLeftSemiJoinHash => j
case j: BroadcastHashJoin => j
}
assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off")

Expand Down