diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 5a1d680c99f66..d1d9861db5c65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -94,14 +94,16 @@ class CacheManager extends Logging { logWarning("Asked to cache already cached data.") } else { val sparkSession = query.sparkSession - cachedData.add(CachedData( - planToCache, - InMemoryRelation( - sparkSession.sessionState.conf.useCompression, - sparkSession.sessionState.conf.columnBatchSize, - storageLevel, - sparkSession.sessionState.executePlan(planToCache).executedPlan, - tableName))) + val inMemoryRelation = InMemoryRelation( + sparkSession.sessionState.conf.useCompression, + sparkSession.sessionState.conf.columnBatchSize, + storageLevel, + sparkSession.sessionState.executePlan(planToCache).executedPlan, + tableName) + if (planToCache.conf.cboEnabled && planToCache.stats.rowCount.isDefined) { + inMemoryRelation.setStatsFromCachedPlan(planToCache) + } + cachedData.add(CachedData(planToCache, inMemoryRelation)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 761d00cb82647..25cda8027824d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -25,15 +25,15 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator object InMemoryRelation { + def apply( useCompression: Boolean, batchSize: Int, @@ -73,22 +73,20 @@ case class InMemoryRelation( override def computeStats(): Statistics = { if (batchStats.value == 0L) { - children.filter(_.isInstanceOf[LogicalRelation]) match { - case Seq(c @ LogicalRelation(_, _, _, _), _) if c.conf.cboEnabled => - val stats = c.computeStats() - if (stats.rowCount.isDefined) { - stats - } else { - Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) - } - case _ => - Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) - } + inheritedStats.getOrElse(Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes)) } else { Statistics(sizeInBytes = batchStats.value.longValue) } } + private var inheritedStats: Option[Statistics] = _ + + private[execution] def setStatsFromCachedPlan(planToCache: LogicalPlan): Unit = { + require(planToCache.conf.cboEnabled, "you cannot use the stats of cached plan in" + + " InMemoryRelation without cbo enabled") + inheritedStats = Some(planToCache.stats) + } + // If the cached column buffers were not passed in, we calculate them in the constructor. // As in Spark, the actual work of caching is lazy. if (_cachedColumnBuffers == null) {