From 7f82e38f8edca611cdf55266ff07ac4523799603 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 22 Jul 2017 17:17:26 -0700 Subject: [PATCH 1/6] fix --- .../scala/org/apache/spark/sql/Dataset.scala | 12 ++++++++ .../apache/spark/sql/DatasetCacheSuite.scala | 29 ++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 71ab0ddf2d6f4..554200b8992d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2704,6 +2704,18 @@ class Dataset[T] private[sql]( this } + /** + * Persist this Dataset with the default storage level (`MEMORY_AND_DISK`). + * @param eager If true, persist the Dataset eagerly. + * @group basic + * @since 2.3.0 + */ + def persist(eager: Boolean): this.type = { + persist() + if (eager) queryExecution.toRdd.count() + this + } + /** * Persist this Dataset with the default storage level (`MEMORY_AND_DISK`). * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index e0561ee2797a5..85cbcd6bdf078 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -17,14 +17,41 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{RDDBlockId, StorageLevel} class DatasetCacheSuite extends QueryTest with SharedSQLContext { import testImplicits._ + def isMaterialized(rddId: Int): Boolean = { + val maybeBlock = sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)) + maybeBlock.foreach(_ => sparkContext.env.blockManager.releaseLock(RDDBlockId(rddId, 0))) + maybeBlock.nonEmpty + } + + def rddIdOf(ds: DataFrame): Int = { + val plan = ds.queryExecution.sparkPlan + plan.collect { + case InMemoryTableScanExec(_, _, relation) => + relation.cachedColumnBuffers.id + case _ => + fail(s"Table $tableName is not cached\n" + plan) + }.head + } + + test("eager persist") { + val ds = Seq("1", "2").toDF() + ds.persist(eager = false) + val rddId = rddIdOf(ds) + assert(!isMaterialized(rddId)) + ds.persist(eager = true) + ds.collect() + assert(isMaterialized(rddId)) + } + test("get storage level") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") From 80e7123c3905c44776233e0a56a8d8315f1313a0 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 22 Jul 2017 21:17:56 -0700 Subject: [PATCH 2/6] fix --- .../scala/org/apache/spark/sql/Dataset.scala | 31 +++++++- .../apache/spark/sql/catalog/Catalog.scala | 2 + .../apache/spark/sql/CachedTableSuite.scala | 72 +++++++------------ .../apache/spark/sql/DatasetCacheSuite.scala | 24 +------ .../org/apache/spark/sql/QueryTest.scala | 27 ++++--- .../spark/sql/hive/CachedTableSuite.scala | 45 ++---------- 6 files changed, 84 insertions(+), 117 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 554200b8992d5..b725cb38d995a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningC import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython @@ -54,7 +55,7 @@ import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils @@ -2750,6 +2751,34 @@ class Dataset[T] private[sql]( }.getOrElse(StorageLevel.NONE) } + /** + * Returns true when the Dataset is cached and materialized. + * + * @group basic + * @since 2.3.0 + */ + def isMaterialized(): Boolean = { + queryExecution.sparkPlan match { + case i: InMemoryTableScanExec => + val blockManager = sparkSession.sparkContext.env.blockManager + + val rdd = i.relation.cachedColumnBuffers + val blockIDs = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)) + + var foundNonexistentBlocks = false + blockIDs.foreach { bid => + if (blockManager.get(bid).isEmpty) { + foundNonexistentBlocks = true + } else { + blockManager.releaseLock(bid) + } + if (foundNonexistentBlocks) return false + } + true + case _ => false + } + } + /** * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index ab81725def3f4..23c74db27d134 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -462,6 +462,8 @@ abstract class Catalog { * @param tableName is either a qualified or unqualified name that designates a table/view. * If no database identifier is provided, it refers to a temporary view or * a table/view in the current database. + * @return true if the table is cached. Even if it is cached, the table might not be materialized, + * until the first time it is used. * @since 2.0.0 */ def isCached(tableName: String): Boolean diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 3e4f619431599..571a599208513 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -21,8 +21,6 @@ import scala.collection.mutable.HashSet import scala.concurrent.duration._ import scala.language.postfixOps -import org.scalatest.concurrent.Eventually._ - import org.apache.spark.CleanerListener import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression @@ -31,7 +29,7 @@ import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} -import org.apache.spark.storage.{RDDBlockId, StorageLevel} +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{AccumulatorContext, Utils} private case class BigData(s: String) @@ -49,22 +47,6 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } } - def rddIdOf(tableName: String): Int = { - val plan = spark.table(tableName).queryExecution.sparkPlan - plan.collect { - case InMemoryTableScanExec(_, _, relation) => - relation.cachedColumnBuffers.id - case _ => - fail(s"Table $tableName is not cached\n" + plan) - }.head - } - - def isMaterialized(rddId: Int): Boolean = { - val maybeBlock = sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)) - maybeBlock.foreach(_ => sparkContext.env.blockManager.releaseLock(RDDBlockId(rddId, 0))) - maybeBlock.nonEmpty - } - private def getNumInMemoryRelations(ds: Dataset[_]): Int = { val plan = ds.queryExecution.withCachedData var sum = plan.collect { case _: InMemoryRelation => 1 }.sum @@ -240,17 +222,29 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { sql("CACHE TABLE testData") assertCached(spark.table("testData")) - - val rddId = rddIdOf("testData") - assert( - isMaterialized(rddId), - "Eagerly cached in-memory table should have already been materialized") + assertMaterialized(spark.table("testData")) sql("UNCACHE TABLE testData") assert(!spark.catalog.isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + assertNotMaterialized(spark.table("testData")) + } + } + + test("'CACHE PARTITIONED TABLE' and 'UNCACHE PARTITIONED TABLE' SQL statement") { + withTempView("t1") { + testData.repartition(6, $"value").createOrReplaceTempView("t1") + sql("CACHE TABLE t1") + assertCached(spark.table("t1")) + assertMaterialized(spark.table("t1")) + + sql("UNCACHE TABLE t1") + assert(!spark.catalog.isCached("t1"), "Table 't1' should not be cached") + + eventually(timeout(10 seconds)) { + assertNotMaterialized(spark.table("t1")) + } } } @@ -258,15 +252,11 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext withTempView("testCacheTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") assertCached(spark.table("testCacheTable")) - - val rddId = rddIdOf("testCacheTable") - assert( - isMaterialized(rddId), - "Eagerly cached in-memory table should have already been materialized") + assertMaterialized(spark.table("testCacheTable")) spark.catalog.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + assertNotMaterialized(spark.table("testCacheTable")) } } } @@ -275,15 +265,11 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext withTempView("testCacheTable") { sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") assertCached(spark.table("testCacheTable")) - - val rddId = rddIdOf("testCacheTable") - assert( - isMaterialized(rddId), - "Eagerly cached in-memory table should have already been materialized") + assertMaterialized(spark.table("testCacheTable")) spark.catalog.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + assertNotMaterialized(spark.table("testCacheTable")) } } } @@ -291,20 +277,14 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE testData") assertCached(spark.table("testData")) - - val rddId = rddIdOf("testData") - assert( - !isMaterialized(rddId), - "Lazily cached in-memory table shouldn't be materialized eagerly") + assertNotMaterialized(spark.table("testData")) sql("SELECT COUNT(*) FROM testData").collect() - assert( - isMaterialized(rddId), - "Lazily cached in-memory table should have been materialized") + assertMaterialized(spark.table("testData")) spark.catalog.uncacheTable("testData") eventually(timeout(10 seconds)) { - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + assertNotMaterialized(spark.table("testData")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 85cbcd6bdf078..87748e721210b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -17,39 +17,21 @@ package org.apache.spark.sql -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.storage.{RDDBlockId, StorageLevel} +import org.apache.spark.storage.StorageLevel class DatasetCacheSuite extends QueryTest with SharedSQLContext { import testImplicits._ - def isMaterialized(rddId: Int): Boolean = { - val maybeBlock = sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)) - maybeBlock.foreach(_ => sparkContext.env.blockManager.releaseLock(RDDBlockId(rddId, 0))) - maybeBlock.nonEmpty - } - - def rddIdOf(ds: DataFrame): Int = { - val plan = ds.queryExecution.sparkPlan - plan.collect { - case InMemoryTableScanExec(_, _, relation) => - relation.cachedColumnBuffers.id - case _ => - fail(s"Table $tableName is not cached\n" + plan) - }.head - } - test("eager persist") { val ds = Seq("1", "2").toDF() ds.persist(eager = false) - val rddId = rddIdOf(ds) - assert(!isMaterialized(rddId)) + assert(!ds.isMaterialized()) ds.persist(eager = true) ds.collect() - assert(isMaterialized(rddId)) + assert(ds.isMaterialized()) } test("get storage level") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index f9808834df4a5..1e18878adea14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,23 +17,13 @@ package org.apache.spark.sql -import java.util.{ArrayDeque, Locale, TimeZone} +import java.util.{Locale, TimeZone} import scala.collection.JavaConverters._ -import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.streaming.MemoryPlan -import org.apache.spark.sql.types.{Metadata, ObjectType} abstract class QueryTest extends PlanTest { @@ -225,6 +215,21 @@ abstract class QueryTest extends PlanTest { planWithCaching) } + /** + * Asserts that a given [[Dataset]] has been materialized. + */ + def assertMaterialized(query: Dataset[_]): Unit = { + assert(query.isMaterialized(), + "Eagerly cached cached Dataset should have already been materialized") + } + + /** + * Asserts that a given [[Dataset]] has not been materialized. + */ + def assertNotMaterialized(query: Dataset[_]): Unit = { + assert(!query.isMaterialized(), "Dataset should not be materialized") + } + /** * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans. */ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index d3cbf898e2439..07a6d54e167a2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -28,28 +28,11 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StructType -import org.apache.spark.storage.RDDBlockId import org.apache.spark.util.Utils class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import hiveContext._ - def rddIdOf(tableName: String): Int = { - val plan = table(tableName).queryExecution.sparkPlan - plan.collect { - case InMemoryTableScanExec(_, _, relation) => - relation.cachedColumnBuffers.id - case _ => - fail(s"Table $tableName is not cached\n" + plan) - }.head - } - - def isMaterialized(rddId: Int): Boolean = { - val maybeBlock = sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)) - maybeBlock.foreach(_ => sparkContext.env.blockManager.releaseLock(RDDBlockId(rddId, 0))) - maybeBlock.nonEmpty - } - test("cache table") { val preCacheResults = sql("SELECT * FROM src").collect().toSeq @@ -138,14 +121,10 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto withTempView("testCacheTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM src") assertCached(table("testCacheTable")) - - val rddId = rddIdOf("testCacheTable") - assert( - isMaterialized(rddId), - "Eagerly cached in-memory table should have already been materialized") + assertMaterialized(table("testCacheTable")) uncacheTable("testCacheTable") - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + assertNotMaterialized(table("testCacheTable")) } } @@ -153,33 +132,23 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto withTempView("testCacheTable") { sql("CACHE TABLE testCacheTable AS SELECT key FROM src LIMIT 10") assertCached(table("testCacheTable")) - - val rddId = rddIdOf("testCacheTable") - assert( - isMaterialized(rddId), - "Eagerly cached in-memory table should have already been materialized") + assertMaterialized(table("testCacheTable")) uncacheTable("testCacheTable") - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + assertNotMaterialized(table("testCacheTable")) } } test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE src") assertCached(table("src")) - - val rddId = rddIdOf("src") - assert( - !isMaterialized(rddId), - "Lazily cached in-memory table shouldn't be materialized eagerly") + assertNotMaterialized(table("src")) sql("SELECT COUNT(*) FROM src").collect() - assert( - isMaterialized(rddId), - "Lazily cached in-memory table should have been materialized") + assertMaterialized(table("src")) uncacheTable("src") - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + assertNotMaterialized(table("src")) } test("CACHE TABLE with Hive UDF") { From e63efabe5ac48133aab877914adaaadd1fc82ed7 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 23 Jul 2017 09:28:35 -0700 Subject: [PATCH 3/6] clean --- .../main/scala/org/apache/spark/sql/Dataset.scala | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b725cb38d995a..a351fd74ff572 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2713,7 +2713,7 @@ class Dataset[T] private[sql]( */ def persist(eager: Boolean): this.type = { persist() - if (eager) queryExecution.toRdd.count() + if (eager) queryExecution.toRdd.foreachPartition(_ => {}) this } @@ -2761,18 +2761,10 @@ class Dataset[T] private[sql]( queryExecution.sparkPlan match { case i: InMemoryTableScanExec => val blockManager = sparkSession.sparkContext.env.blockManager - val rdd = i.relation.cachedColumnBuffers val blockIDs = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)) - - var foundNonexistentBlocks = false blockIDs.foreach { bid => - if (blockManager.get(bid).isEmpty) { - foundNonexistentBlocks = true - } else { - blockManager.releaseLock(bid) - } - if (foundNonexistentBlocks) return false + if (blockManager.get(bid).nonEmpty) blockManager.releaseLock(bid) else return false } true case _ => false From 9077258af86d7403e28723d9d6d457bd7bf1cc76 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 23 Jul 2017 12:02:13 -0700 Subject: [PATCH 4/6] address comments. --- .../scala/org/apache/spark/sql/execution/command/cache.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 47952f2f227a3..b14915aa9d70a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -41,7 +41,7 @@ case class CacheTableCommand( if (!isLazy) { // Performs eager caching - sparkSession.table(tableIdent).count() + sparkSession.table(tableIdent).persist(eager = true) } Seq.empty[Row] From b0f0e80162212431d2fa34866cf9b1e932243f0d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 20 Aug 2017 00:09:28 -0700 Subject: [PATCH 5/6] fix. --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 5 +---- .../test/scala/org/apache/spark/sql/DatasetCacheSuite.scala | 1 - 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 637d0fa501cac..34f4875487d99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2796,10 +2796,7 @@ class Dataset[T] private[sql]( val blockManager = sparkSession.sparkContext.env.blockManager val rdd = i.relation.cachedColumnBuffers val blockIDs = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)) - blockIDs.foreach { bid => - if (blockManager.get(bid).nonEmpty) blockManager.releaseLock(bid) else return false - } - true + blockIDs.forall { bid => blockManager.getStatus(bid).exists(_.isCached) } case _ => false } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 6a5ecd1be619c..4e0b05a93771f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -39,7 +39,6 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { ds.persist(eager = false) assert(!ds.isMaterialized()) ds.persist(eager = true) - ds.collect() assert(ds.isMaterialized()) } From 31bf79742a219ca13cca2a6774da783b061864d1 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 20 Aug 2017 01:13:41 -0700 Subject: [PATCH 6/6] fix. --- .../scala/org/apache/spark/sql/Dataset.scala | 44 +++++++++++++------ .../apache/spark/sql/DatasetCacheSuite.scala | 15 +++++++ .../org/apache/spark/sql/QueryTest.scala | 2 +- 3 files changed, 47 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 34f4875487d99..28066bb7b39d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2740,35 +2740,51 @@ class Dataset[T] private[sql]( /** * Persist this Dataset with the default storage level (`MEMORY_AND_DISK`). - * @param eager If true, persist the Dataset eagerly. + * * @group basic - * @since 2.3.0 + * @since 1.6.0 */ - def persist(eager: Boolean): this.type = { - persist() - if (eager) queryExecution.toRdd.foreachPartition(_ => {}) + def cache(): this.type = persist() + + /** + * Persist this Dataset with the given storage level. + * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, + * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, + * `MEMORY_AND_DISK_2`, etc. + * + * @group basic + * @since 1.6.0 + */ + def persist(newLevel: StorageLevel): this.type = { + sparkSession.sharedState.cacheManager.cacheQuery(this, None, newLevel) this } /** * Persist this Dataset with the default storage level (`MEMORY_AND_DISK`). - * + * @param eager If true, persist the Dataset eagerly. * @group basic - * @since 1.6.0 + * @since 2.3.0 */ - def cache(): this.type = persist() + def persist(eager: Boolean): this.type = { + persist() + if (eager) queryExecution.toRdd.foreachPartition(_ => {}) + this + } /** * Persist this Dataset with the given storage level. + * @param eager If true, persist the Dataset eagerly. * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, * `MEMORY_AND_DISK_2`, etc. * * @group basic - * @since 1.6.0 + * @since 2.3.0 */ - def persist(newLevel: StorageLevel): this.type = { - sparkSession.sharedState.cacheManager.cacheQuery(this, None, newLevel) + def persist(eager: Boolean, newLevel: StorageLevel): this.type = { + persist(newLevel) + if (eager) queryExecution.toRdd.foreachPartition(_ => {}) this } @@ -2795,8 +2811,10 @@ class Dataset[T] private[sql]( case i: InMemoryTableScanExec => val blockManager = sparkSession.sparkContext.env.blockManager val rdd = i.relation.cachedColumnBuffers - val blockIDs = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)) - blockIDs.forall { bid => blockManager.getStatus(bid).exists(_.isCached) } + sparkSession.sparkContext.persistentRdds.contains(rdd.id) && + rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).forall { + bid => blockManager.getStatus(bid).exists(_.isCached) + } case _ => false } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 4e0b05a93771f..edeaf38b6aabd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -40,6 +40,21 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { assert(!ds.isMaterialized()) ds.persist(eager = true) assert(ds.isMaterialized()) + ds.unpersist() + assert(!ds.isMaterialized()) + } + + test("eager persist with storagelevel") { + val ds = Seq("1", "2").toDF() + ds.persist(eager = false, StorageLevel.MEMORY_ONLY_2) + assert(!ds.isMaterialized()) + assert(ds.storageLevel == StorageLevel.MEMORY_ONLY_2) + ds.persist(eager = true, StorageLevel.MEMORY_ONLY_2) + assert(ds.isMaterialized()) + assert(ds.storageLevel == StorageLevel.MEMORY_ONLY_2) + ds.unpersist() + assert(ds.storageLevel == StorageLevel.NONE) + assert(!ds.isMaterialized()) } test("get storage level") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 1e18878adea14..a83d4c3cd14dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -220,7 +220,7 @@ abstract class QueryTest extends PlanTest { */ def assertMaterialized(query: Dataset[_]): Unit = { assert(query.isMaterialized(), - "Eagerly cached cached Dataset should have already been materialized") + "Eagerly cached Dataset should have already been materialized") } /**