diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 113cf9ae2f222..5e78ff7106024 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -65,8 +65,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys( LeftSemi, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - joins.BroadcastLeftSemiJoinHash( - leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right))) // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => joins.LeftSemiJoinHash( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 4c8f8080a98d7..f84ed41f1d2c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} -import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -92,6 +92,9 @@ case class BroadcastHashJoin( rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) } + case LeftSemi => + hashSemiJoin(streamedIter, hashTable, numOutputRows) + case x => throw new IllegalArgumentException( s"BroadcastHashJoin should not take $x as the JoinType") @@ -108,11 +111,13 @@ case class BroadcastHashJoin( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { - if (joinType == Inner) { - codegenInner(ctx, input) - } else { - // LeftOuter and RightOuter - codegenOuter(ctx, input) + joinType match { + case Inner => codegenInner(ctx, input) + case LeftOuter | RightOuter => codegenOuter(ctx, input) + case LeftSemi => codegenSemi(ctx, input) + case x => + throw new IllegalArgumentException( + s"BroadcastHashJoin should not take $x as the JoinType") } } @@ -322,4 +327,68 @@ case class BroadcastHashJoin( """.stripMargin } } + + /** + * Generates the code for left semi join. + */ + private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) + val numOutput = metricTerm(ctx, "numOutputRows") + + val checkCondition = if (condition.isDefined) { + val expr = condition.get + // evaluate the variables from build side that used by condition + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) + // filter the output via condition + ctx.currentVars = input ++ buildVars + val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx) + s""" + |$eval + |${ev.code} + |if (${ev.isNull} || !${ev.value}) continue; + """.stripMargin + } else { + "" + } + + if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |if ($matched == null) continue; + |$checkCondition + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin + } else { + val matches = ctx.freshName("matches") + val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val i = ctx.freshName("i") + val size = ctx.freshName("size") + val found = ctx.freshName("found") + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); + |if ($matches == null) continue; + |int $size = $matches.size(); + |boolean $found = false; + |for (int $i = 0; $i < $size; $i++) { + | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); + | $checkCondition + | $found = true; + | break; + |} + |if (!$found) continue; + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala deleted file mode 100644 index d3bcfad7c3de0..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.TaskContext -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Build the right table's join keys into a HashedRelation, and iteratively go through the left - * table, to find if the join keys are in the HashedRelation. - */ -case class BroadcastLeftSemiJoinHash( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - left: SparkPlan, - right: SparkPlan, - condition: Option[Expression]) extends BinaryNode with HashSemiJoin { - - override private[sql] lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def requiredChildDistribution: Seq[Distribution] = { - val mode = HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output) - UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil - } - - protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - val broadcastedRelation = right.executeBroadcast[HashedRelation]() - left.execute().mapPartitionsInternal { streamIter => - val hashedRelation = broadcastedRelation.value - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) - hashSemiJoin(streamIter, hashedRelation, numOutputRows) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 2fe9c06cc9537..5f42d07273e44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -46,8 +46,8 @@ trait HashJoin { left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case LeftSemi => + left.output case x => throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType") } @@ -104,7 +104,7 @@ trait HashJoin { keyExpr :: Nil } - protected val canJoinKeyFitWithinLong: Boolean = { + protected lazy val canJoinKeyFitWithinLong: Boolean = { val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType) val key = rewriteKeyExpr(buildKeys) sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType] @@ -258,4 +258,21 @@ trait HashJoin { } ret.iterator } + + protected def hashSemiJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { + val joinKeys = streamSideKeyGenerator + val joinedRow = new JoinedRow + streamIter.filter { current => + val key = joinKeys(current) + lazy val rowBuffer = hashedRelation.get(key) + val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists { + (row: InternalRow) => boundCondition(joinedRow(current, row)) + }) + if (r) numOutputRows += 1 + r + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala deleted file mode 100644 index 813ec024250c2..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.LongSQLMetric - - -trait HashSemiJoin { - self: SparkPlan => - val leftKeys: Seq[Expression] - val rightKeys: Seq[Expression] - val left: SparkPlan - val right: SparkPlan - val condition: Option[Expression] - - override def output: Seq[Attribute] = left.output - - protected def leftKeyGenerator: Projection = - UnsafeProjection.create(leftKeys, left.output) - - protected def rightKeyGenerator: Projection = - UnsafeProjection.create(rightKeys, right.output) - - @transient private lazy val boundCondition = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - - protected def hashSemiJoin( - streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val joinKeys = leftKeyGenerator - val joinedRow = new JoinedRow - streamIter.filter { current => - val key = joinKeys(current) - lazy val rowBuffer = hashedRelation.get(key) - val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists { - (row: InternalRow) => boundCondition(joinedRow(current, row)) - }) - if (r) numOutputRows += 1 - r - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 14389e45babed..fa549b4d51336 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -33,7 +34,10 @@ case class LeftSemiJoinHash( rightKeys: Seq[Expression], left: SparkPlan, right: SparkPlan, - condition: Option[Expression]) extends BinaryNode with HashSemiJoin { + condition: Option[Expression]) extends BinaryNode with HashJoin { + + override val joinType = LeftSemi + override val buildSide = BuildRight override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) @@ -47,7 +51,7 @@ case class LeftSemiJoinHash( val numOutputRows = longMetric("numOutputRows") right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => - val hashRelation = HashedRelation(buildIter.map(_.copy()), rightKeyGenerator) + val hashRelation = HashedRelation(buildIter.map(_.copy()), buildSideKeyGenerator) hashSemiJoin(streamIter, hashRelation, numOutputRows) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 50647c28402eb..2f27896925de1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -49,7 +49,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { case j: BroadcastHashJoin => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j - case j: BroadcastLeftSemiJoinHash => j + case j: BroadcastHashJoin => j case j: SortMergeJoin => j case j: SortMergeOuterJoin => j } @@ -428,7 +428,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[BroadcastLeftSemiJoinHash]) + classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 9f33e4ab62298..cb672643f1c24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -38,7 +38,7 @@ import org.apache.spark.util.Benchmark class BenchmarkWholeStageCodegen extends SparkFunSuite { lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") .set("spark.sql.shuffle.partitions", "1") - .set("spark.sql.autoBroadcastJoinThreshold", "0") + .set("spark.sql.autoBroadcastJoinThreshold", "1") lazy val sc = SparkContext.getOrCreate(conf) lazy val sqlContext = SQLContext.getOrCreate(sc) @@ -200,6 +200,18 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { outer join w long codegen=false 15280 / 16497 6.9 145.7 1.0X outer join w long codegen=true 769 / 796 136.3 7.3 19.9X */ + + runBenchmark("semi join w long", N) { + sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "leftsemi").count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + semi join w long codegen=false 5804 / 5969 18.1 55.3 1.0X + semi join w long codegen=true 814 / 934 128.8 7.8 7.1X + */ } ignore("sort merge join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 6d5b777733f41..babe7ef70f99d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -79,7 +79,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { } test("unsafe broadcast left semi join updates peak execution memory") { - testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi") + testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast left semi join", "leftsemi") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index d8c9564f1e4fb..5eb6a745239ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -84,11 +84,12 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { } } - test(s"$testName using BroadcastLeftSemiJoinHash") { + test(s"$testName using BroadcastHashJoin") { extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), + BroadcastHashJoin( + leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 1d8c293d43c05..1468be4670f26 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -212,7 +212,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { // Using `sparkPlan` because for relevant patterns in HashJoin to be // matched, other strategies need to be applied. var bhj = df.queryExecution.sparkPlan.collect { - case j: BroadcastLeftSemiJoinHash => j + case j: BroadcastHashJoin => j } assert(bhj.size === 1, s"actual query plans do not contain broadcast join: ${df.queryExecution}") @@ -225,7 +225,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") df = sql(leftSemiJoinQuery) bhj = df.queryExecution.sparkPlan.collect { - case j: BroadcastLeftSemiJoinHash => j + case j: BroadcastHashJoin => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off")