diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 5a1d680c99f66..b05fe49a6ac3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.storage.StorageLevel @@ -94,14 +94,13 @@ class CacheManager extends Logging { logWarning("Asked to cache already cached data.") } else { val sparkSession = query.sparkSession - cachedData.add(CachedData( - planToCache, - InMemoryRelation( - sparkSession.sessionState.conf.useCompression, - sparkSession.sessionState.conf.columnBatchSize, - storageLevel, - sparkSession.sessionState.executePlan(planToCache).executedPlan, - tableName))) + val inMemoryRelation = InMemoryRelation( + sparkSession.sessionState.conf.useCompression, + sparkSession.sessionState.conf.columnBatchSize, storageLevel, + sparkSession.sessionState.executePlan(planToCache).executedPlan, + tableName, + planToCache.stats) + cachedData.add(CachedData(planToCache, inMemoryRelation)) } } @@ -148,7 +147,8 @@ class CacheManager extends Logging { batchSize = cd.cachedRepresentation.batchSize, storageLevel = cd.cachedRepresentation.storageLevel, child = spark.sessionState.executePlan(cd.plan).executedPlan, - tableName = cd.cachedRepresentation.tableName) + tableName = cd.cachedRepresentation.tableName, + statsOfPlanToCache = cd.plan.stats) needToRecache += cd.copy(cachedRepresentation = newCache) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index a1c62a729900e..51928d914841e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -37,8 +37,10 @@ object InMemoryRelation { batchSize: Int, storageLevel: StorageLevel, child: SparkPlan, - tableName: Option[String]): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() + tableName: Option[String], + statsOfPlanToCache: Statistics): InMemoryRelation = + new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)( + statsOfPlanToCache = statsOfPlanToCache) } @@ -60,7 +62,8 @@ case class InMemoryRelation( @transient child: SparkPlan, tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator) + val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, + statsOfPlanToCache: Statistics = null) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -71,9 +74,8 @@ case class InMemoryRelation( override def computeStats(): Statistics = { if (batchStats.value == 0L) { - // Underlying columnar RDD hasn't been materialized, no useful statistics information - // available, return the default statistics. - Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) + // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache + statsOfPlanToCache } else { Statistics(sizeInBytes = batchStats.value.longValue) } @@ -142,7 +144,7 @@ case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, batchStats) + _cachedColumnBuffers, batchStats, statsOfPlanToCache) } override def newInstance(): this.type = { @@ -154,11 +156,12 @@ case class InMemoryRelation( child, tableName)( _cachedColumnBuffers, - batchStats).asInstanceOf[this.type] + batchStats, + statsOfPlanToCache).asInstanceOf[this.type] } def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, batchStats) + Seq(_cachedColumnBuffers, batchStats, statsOfPlanToCache) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index e662e294228db..ff7c5e58e9863 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel._ +import org.apache.spark.util.Utils class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -40,7 +41,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { data.createOrReplaceTempView(s"testData$dataType") val storageLevel = MEMORY_ONLY val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan - val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None) + val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None, + data.logicalPlan.stats) assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) inMemoryRelation.cachedColumnBuffers.collect().head match { @@ -116,7 +118,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("simple columnar query") { val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan - val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, + testData.logicalPlan.stats) checkAnswer(scan, testData.collect().toSeq) } @@ -132,8 +135,10 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("projection") { - val plan = spark.sessionState.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan - val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) + val logicalPlan = testData.select('value, 'key).logicalPlan + val plan = spark.sessionState.executePlan(logicalPlan).sparkPlan + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, + logicalPlan.stats) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -149,7 +154,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan - val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, + testData.logicalPlan.stats) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq) @@ -323,7 +329,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-17549: cached table size should be correctly calculated") { val data = spark.sparkContext.parallelize(1 to 10, 5).toDF() val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan - val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None) + val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan.stats) // Materialize the data. val expectedAnswer = data.collect() @@ -448,8 +454,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-22249: buildFilter should not throw exception when In contains an empty list") { val attribute = AttributeReference("a", IntegerType)() - val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, - LocalTableScanExec(Seq(attribute), Nil), None) + val localTableScanExec = LocalTableScanExec(Seq(attribute), Nil) + val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, null) val tableScanExec = InMemoryTableScanExec(Seq(attribute), Seq(In(attribute, Nil)), testRelation) assert(tableScanExec.partitionFilters.isEmpty) @@ -479,4 +485,43 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-22673: InMemoryRelation should utilize existing stats of the plan to be cached") { + withSQLConf("spark.sql.cbo.enabled" -> "true") { + withTempPath { workDir => + withTable("table1") { + val workDirPath = workDir.getAbsolutePath + val data = Seq(100, 200, 300, 400).toDF("count") + data.write.parquet(workDirPath) + val dfFromFile = spark.read.parquet(workDirPath).cache() + val inMemoryRelation = dfFromFile.queryExecution.optimizedPlan.collect { + case plan: InMemoryRelation => plan + }.head + // InMemoryRelation's stats is file size before the underlying RDD is materialized + assert(inMemoryRelation.computeStats().sizeInBytes === 740) + + // InMemoryRelation's stats is updated after materializing RDD + dfFromFile.collect() + assert(inMemoryRelation.computeStats().sizeInBytes === 16) + + // test of catalog table + val dfFromTable = spark.catalog.createTable("table1", workDirPath).cache() + val inMemoryRelation2 = dfFromTable.queryExecution.optimizedPlan. + collect { case plan: InMemoryRelation => plan }.head + + // Even CBO enabled, InMemoryRelation's stats keeps as the file size before table's stats + // is calculated + assert(inMemoryRelation2.computeStats().sizeInBytes === 740) + + // InMemoryRelation's stats should be updated after calculating stats of the table + // clear cache to simulate a fresh environment + dfFromTable.unpersist(blocking = true) + spark.sql("ANALYZE TABLE table1 COMPUTE STATISTICS") + val inMemoryRelation3 = spark.read.table("table1").cache().queryExecution.optimizedPlan. + collect { case plan: InMemoryRelation => plan }.head + assert(inMemoryRelation3.computeStats().sizeInBytes === 48) + } + } + } + } }