diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 432eb59d6fe57..d68aeb275afda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -169,14 +169,17 @@ class CacheManager extends Logging { /** Replaces segments of the given logical plan with cached versions where possible. */ def useCachedData(plan: LogicalPlan): LogicalPlan = { val newPlan = plan transformDown { + // Do not lookup the cache by hint node. Hint node is special, we should ignore it when + // canonicalizing plans, so that plans which are same except hint can hit the same cache. + // However, we also want to keep the hint info after cache lookup. Here we skip the hint + // node, so that the returned caching plan won't replace the hint node and drop the hint info + // from the original plan. + case hint: ResolvedHint => hint + case currentFragment => - lookupCachedData(currentFragment).map { cached => - val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output) - currentFragment match { - case hint: ResolvedHint => ResolvedHint(cachedPlan, hint.hints) - case _ => cachedPlan - } - }.getOrElse(currentFragment) + lookupCachedData(currentFragment) + .map(_.cachedRepresentation.withOutput(currentFragment.output)) + .getOrElse(currentFragment) } newPlan transformAllExpressions { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 51928d914841e..22e16913d4da9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator @@ -62,8 +62,8 @@ case class InMemoryRelation( @transient child: SparkPlan, tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics = null) + val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, + statsOfPlanToCache: Statistics) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -73,11 +73,16 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) override def computeStats(): Statistics = { - if (batchStats.value == 0L) { - // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache - statsOfPlanToCache + if (sizeInBytesStats.value == 0L) { + // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. + // Note that we should drop the hint info here. We may cache a plan whose root node is a hint + // node. When we lookup the cache with a semantically same plan without hint info, the plan + // returned by cache lookup should not have hint info. If we lookup the cache with a + // semantically same plan with a different hint info, `CacheManager.useCachedData` will take + // care of it and retain the hint info in the lookup input plan. + statsOfPlanToCache.copy(hints = HintInfo()) } else { - Statistics(sizeInBytes = batchStats.value.longValue) + Statistics(sizeInBytes = sizeInBytesStats.value.longValue) } } @@ -122,7 +127,7 @@ case class InMemoryRelation( rowCount += 1 } - batchStats.add(totalSize) + sizeInBytesStats.add(totalSize) val stats = InternalRow.fromSeq( columnBuilders.flatMap(_.columnStats.collectedStatistics)) @@ -144,7 +149,7 @@ case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, batchStats, statsOfPlanToCache) + _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) } override def newInstance(): this.type = { @@ -156,12 +161,12 @@ case class InMemoryRelation( child, tableName)( _cachedColumnBuffers, - batchStats, + sizeInBytesStats, statsOfPlanToCache).asInstanceOf[this.type] } def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, batchStats, statsOfPlanToCache) + Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1e52445f28fc1..72fe0f42801f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -368,12 +368,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val toBeCleanedAccIds = new HashSet[Long] val accId1 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.batchStats.id + case i: InMemoryRelation => i.sizeInBytesStats.id }.head toBeCleanedAccIds += accId1 val accId2 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.batchStats.id + case i: InMemoryRelation => i.sizeInBytesStats.id }.head toBeCleanedAccIds += accId2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 2280da927cf70..dc1766fb9a785 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -336,7 +336,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(cached, expectedAnswer) // Check that the right size was calculated. - assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize) + assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize) } test("access primitive-type columns in CachedBatch without whole stage codegen") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 1704bc8376f0d..bcdee792f4c70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,7 +22,8 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} -import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -70,8 +71,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { private def testBroadcastJoin[T: ClassTag]( joinType: String, forceBroadcast: Boolean = false): SparkPlan = { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") @@ -109,30 +110,58 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } - test("broadcast hint is retained after using the cached data") { + test("SPARK-23192: broadcast hint should be retained after using the cached data") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") - df2.cache() - val df3 = df1.join(broadcast(df2), Seq("key"), "inner") - val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { - case b: BroadcastHashJoinExec => b - }.size - assert(numBroadCastHashJoin === 1) + try { + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") + df2.cache() + val df3 = df1.join(broadcast(df2), Seq("key"), "inner") + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + assert(numBroadCastHashJoin === 1) + } finally { + spark.catalog.clearCache() + } + } + } + + test("SPARK-23214: cached data should not carry extra hint info") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + try { + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") + broadcast(df2).cache() + + val df3 = df1.join(df2, Seq("key"), "inner") + val numCachedPlan = df3.queryExecution.executedPlan.collect { + case i: InMemoryTableScanExec => i + }.size + // df2 should be cached. + assert(numCachedPlan === 1) + + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + // df2 should not be broadcasted. + assert(numBroadCastHashJoin === 0) + } finally { + spark.catalog.clearCache() + } } } test("broadcast hint isn't propagated after a join") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df3 = df1.join(broadcast(df2), Seq("key"), "inner").drop(df2("key")) - val df4 = spark.createDataFrame(Seq((1, "5"), (2, "5"))).toDF("key", "value") + val df4 = Seq((1, "5"), (2, "5")).toDF("key", "value") val df5 = df4.join(df3, Seq("key"), "inner") - val plan = - EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan) + val plan = EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan) assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) assert(plan.collect { case p: SortMergeJoinExec => p }.size === 1) @@ -140,30 +169,30 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } private def assertBroadcastJoin(df : Dataset[Row]) : Unit = { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") val joined = df1.join(df, Seq("key"), "inner") - val plan = - EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan) + val plan = EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan) assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) } test("broadcast hint programming API") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "2"))).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "2")).toDF("key", "value") val broadcasted = broadcast(df2) - val df3 = spark.createDataFrame(Seq((2, "2"), (3, "3"))).toDF("key", "value") - - val cases = Seq(broadcasted.limit(2), - broadcasted.filter("value < 10"), - broadcasted.sample(true, 0.5), - broadcasted.distinct(), - broadcasted.groupBy("value").agg(min($"key").as("key")), - // except and intersect are semi/anti-joins which won't return more data then - // their left argument, so the broadcast hint should be propagated here - broadcasted.except(df3), - broadcasted.intersect(df3)) + val df3 = Seq((2, "2"), (3, "3")).toDF("key", "value") + + val cases = Seq( + broadcasted.limit(2), + broadcasted.filter("value < 10"), + broadcasted.sample(true, 0.5), + broadcasted.distinct(), + broadcasted.groupBy("value").agg(min($"key").as("key")), + // except and intersect are semi/anti-joins which won't return more data then + // their left argument, so the broadcast hint should be propagated here + broadcasted.except(df3), + broadcasted.intersect(df3)) cases.foreach(assertBroadcastJoin) } @@ -240,9 +269,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { test("Shouldn't change broadcast join buildSide if user clearly specified") { withTempView("t1", "t2") { - spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") - spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") - .createTempView("t2") + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes @@ -292,9 +320,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { test("Shouldn't bias towards build right if user didn't specify") { withTempView("t1", "t2") { - spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") - spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") - .createTempView("t2") + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes