Skip to content

Commit

Permalink
add testsuite
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxiaojing committed Dec 23, 2014
1 parent 1a8da2a commit ff2e618
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

object LeftSemiJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right)
if sqlContext.autoBroadcastJoinThreshold > 0 &&
right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
val semiJoin = joins.BroadcastLeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
// Find left semi joins where at least some predicates can be evaluated by matching join keys
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
val semiJoin = joins.LeftSemiJoinHash(
Expand Down
36 changes: 36 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
case j: LeftSemiJoinBNL => j
case j: CartesianProduct => j
case j: BroadcastNestedLoopJoin => j
case j: BroadcastLeftSemiJoinHash => j
}

assert(operators.size === 1)
Expand Down Expand Up @@ -382,4 +383,39 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
""".stripMargin),
(null, 10) :: Nil)
}
test("broadcasted left semi join operator selection") {
clearCache()
sql("CACHE TABLE testData")
val tmp = autoBroadcastJoinThreshold

sql( s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000""")
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastLeftSemiJoinHash])
).foreach {
case (query, joinClass) => assertJoin(query, joinClass)
}

sql( s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""")

Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash])
).foreach {
case (query, joinClass) => assertJoin(query, joinClass)
}

sql( s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-$tmp""")
sql("UNCACHE TABLE testData")
}

test("left semi join") {
val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
checkAnswer(rdd,
(1, 1) ::
(1, 2) ::
(2, 1) ::
(2, 2) ::
(3, 1) ::
(3, 2) :: Nil)

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import org.scalatest.BeforeAndAfterAll
import scala.reflect.ClassTag

import org.apache.spark.sql.{SQLConf, QueryTest}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.catalyst.plans.logical.NativeCommand
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.execution._
Expand Down Expand Up @@ -193,4 +194,70 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
)
}

test("auto converts to broadcast left semi join, by size estimate of a relation") {
def mkTest(
before: () => Unit,
after: () => Unit,
query: String,
expectedAnswer: Seq[Any],
ct: ClassTag[_]) = {
before()

var rdd = sql(query)

// Assert src has a size smaller than the threshold.
val sizes = rdd.queryExecution.analyzed.collect {
case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes
}
assert(sizes.size === 2 && sizes(1) <= autoBroadcastJoinThreshold
&& sizes(0) <= autoBroadcastJoinThreshold,
s"query should contain two relations, each of which has size smaller than autoConvertSize")

// Using `sparkPlan` because for relevant patterns in HashJoin to be
// matched, other strategies need to be applied.
var bhj = rdd.queryExecution.sparkPlan.collect {
case j: BroadcastLeftSemiJoinHash => j
}
assert(bhj.size === 1,
s"actual query plans do not contain broadcast join: ${rdd.queryExecution}")

checkAnswer(rdd, expectedAnswer) // check correctness of output

TestHive.settings.synchronized {
val tmp = autoBroadcastJoinThreshold

sql( s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""")
rdd = sql(query)
bhj = rdd.queryExecution.sparkPlan.collect {
case j: BroadcastLeftSemiJoinHash => j
}
assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off")

val shj = rdd.queryExecution.sparkPlan.collect {
case j: LeftSemiJoinHash => j
}
assert(shj.size === 1,
"LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off")

sql( s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""")
}

after()
}

/** Tests for MetastoreRelation */
val leftSemiJoinQuery =
"""SELECT * FROM src a
|left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin
val Answer =(86, "val_86") ::Nil

mkTest(
() => (),
() => (),
leftSemiJoinQuery,
Answer,
implicitly[ClassTag[MetastoreRelation]]
)

}
}

0 comments on commit ff2e618

Please sign in to comment.