From 21e5321d072c312e243407af08eeb9c1a796ab4d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Jan 2018 12:40:04 -0800 Subject: [PATCH] fix --- .../execution/columnar/InMemoryRelation.scala | 4 ++-- .../sql/execution/joins/BroadcastJoinSuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 51928d914841e..5945808c4abfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -63,7 +63,7 @@ case class InMemoryRelation( tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics = null) + statsOfPlanToCache: Statistics) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -77,7 +77,7 @@ case class InMemoryRelation( // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache statsOfPlanToCache } else { - Statistics(sizeInBytes = batchStats.value.longValue) + Statistics(sizeInBytes = batchStats.value.longValue, hints = statsOfPlanToCache.hints) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 0bcd54e1fceab..8b1a9acec1acc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -126,6 +126,22 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } + test("broadcast hint is retained in a cached plan") { + Seq(true, false).foreach { materialized => + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + broadcast(df2).cache() + if (materialized) df2.collect() + val df3 = df1.join(df2, Seq("key"), "inner") + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + assert(numBroadCastHashJoin === 1) + } + } + } + private def assertBroadcastJoin(df : Dataset[Row]) : Unit = { val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") val joined = df1.join(df, Seq("key"), "inner")