From ff2e618780f4c06242e505bf7012fe518619e4ae Mon Sep 17 00:00:00 2001 From: wangxiaojing Date: Mon, 24 Nov 2014 20:09:52 -0800 Subject: [PATCH] add testsuite --- .../spark/sql/execution/SparkStrategies.scala | 6 ++ .../org/apache/spark/sql/JoinSuite.scala | 36 ++++++++++ .../spark/sql/hive/StatisticsSuite.scala | 69 ++++++++++++++++++- 3 files changed, 110 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2954d4ce7d2d8..0fd9206a435ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -33,6 +33,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object LeftSemiJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) + if sqlContext.autoBroadcastJoinThreshold > 0 && + right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => + val semiJoin = joins.BroadcastLeftSemiJoinHash( + leftKeys, rightKeys, planLater(left), planLater(right)) + condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => val semiJoin = joins.LeftSemiJoinHash( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 0378fd7e367f0..89f397150d92b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -48,6 +48,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j + case j: BroadcastLeftSemiJoinHash => j } assert(operators.size === 1) @@ -382,4 +383,39 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { """.stripMargin), (null, 10) :: Nil) } + test("broadcasted left semi join operator selection") { + clearCache() + sql("CACHE TABLE testData") + val tmp = autoBroadcastJoinThreshold + + sql( s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000""") + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastLeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } + + sql( s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") + + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } + + sql( s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-$tmp""") + sql("UNCACHE TABLE testData") + } + + test("left semi join") { + val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + checkAnswer(rdd, + (1, 1) :: + (1, 2) :: + (2, 1) :: + (2, 2) :: + (3, 1) :: + (3, 2) :: Nil) + + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index ff4071d8e2f10..b337dd1fdd850 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -22,7 +22,8 @@ import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag import org.apache.spark.sql.{SQLConf, QueryTest} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.catalyst.plans.logical.NativeCommand +import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.execution._ @@ -193,4 +194,70 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { ) } + test("auto converts to broadcast left semi join, by size estimate of a relation") { + def mkTest( + before: () => Unit, + after: () => Unit, + query: String, + expectedAnswer: Seq[Any], + ct: ClassTag[_]) = { + before() + + var rdd = sql(query) + + // Assert src has a size smaller than the threshold. + val sizes = rdd.queryExecution.analyzed.collect { + case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes + } + assert(sizes.size === 2 && sizes(1) <= autoBroadcastJoinThreshold + && sizes(0) <= autoBroadcastJoinThreshold, + s"query should contain two relations, each of which has size smaller than autoConvertSize") + + // Using `sparkPlan` because for relevant patterns in HashJoin to be + // matched, other strategies need to be applied. + var bhj = rdd.queryExecution.sparkPlan.collect { + case j: BroadcastLeftSemiJoinHash => j + } + assert(bhj.size === 1, + s"actual query plans do not contain broadcast join: ${rdd.queryExecution}") + + checkAnswer(rdd, expectedAnswer) // check correctness of output + + TestHive.settings.synchronized { + val tmp = autoBroadcastJoinThreshold + + sql( s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") + rdd = sql(query) + bhj = rdd.queryExecution.sparkPlan.collect { + case j: BroadcastLeftSemiJoinHash => j + } + assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") + + val shj = rdd.queryExecution.sparkPlan.collect { + case j: LeftSemiJoinHash => j + } + assert(shj.size === 1, + "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off") + + sql( s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") + } + + after() + } + + /** Tests for MetastoreRelation */ + val leftSemiJoinQuery = + """SELECT * FROM src a + |left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin + val Answer =(86, "val_86") ::Nil + + mkTest( + () => (), + () => (), + leftSemiJoinQuery, + Answer, + implicitly[ClassTag[MetastoreRelation]] + ) + + } }