From 0d855e53d030f9c36243a3dbb10d4a28a0cb4aa2 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 11 Jun 2018 15:48:37 -0400 Subject: [PATCH 1/5] Fix cache test --- .../apache/spark/sql/CachedTableSuite.scala | 25 ++++------------ .../apache/spark/sql/DatasetCacheSuite.scala | 30 ++++++++++++++++++- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 81b7e18773f81..ff6a4d406847e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -50,7 +50,11 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } def rddIdOf(tableName: String): Int = { - val plan = spark.table(tableName).queryExecution.sparkPlan + rddIdOf(spark.table(tableName), tableName) + } + + def rddIdOf(ds: Dataset[_], tableName: String = "unnamedTable"): Int = { + val plan = ds.queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => relation.cacheBuilder.cachedColumnBuffers.id @@ -83,25 +87,6 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext }.sum } - test("withColumn doesn't invalidate cached dataframe") { - var evalCount = 0 - val myUDF = udf((x: String) => { evalCount += 1; "result" }) - val df = Seq(("test", 1)).toDF("s", "i").select(myUDF($"s")) - df.cache() - - df.collect() - assert(evalCount === 1) - - df.collect() - assert(evalCount === 1) - - val df2 = df.withColumn("newColumn", lit(1)) - df2.collect() - - // We should not reevaluate the cached dataframe - assert(evalCount === 1) - } - test("cache temp table") { withTempView("tempTable") { testData.select('key).createOrReplaceTempView("tempTable") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index e0561ee2797a5..48f7bce4fb7d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql +import org.scalatest.concurrent.TimeLimits +import org.scalatest.time.{Seconds, Span} + import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.storage.StorageLevel -class DatasetCacheSuite extends QueryTest with SharedSQLContext { +class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits { import testImplicits._ test("get storage level") { @@ -96,4 +99,29 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { agged.unpersist() assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.") } + + test("persist and then withColumn") { + val df = Seq(("test", 1)).toDF("s", "i") + // Cache first because rddIdOf only works with cached DataFrame + df.cache() + assertCached(df) + df.count() + + // We should not invalidate the cached DataFrame + val df2 = df.withColumn("newColumn", lit(1)) + assertCached(df2) + } + + test("cache UDF correctly") { + val expensiveUDF = udf({x: Int => Thread.sleep(10000); x}) + val df = spark.range(0, 10).toDF("a").withColumn("b", expensiveUDF($"a")) + val df2 = df.agg(sum(df("b"))) + + df.cache() + df.count() + + failAfter(Span(5, Seconds)) { + df2.collect() + } + } } From 377e049f23674277726b9271b2affbbd17839cbd Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 11 Jun 2018 16:00:33 -0400 Subject: [PATCH 2/5] Add clean up code --- .../apache/spark/sql/DatasetCacheSuite.scala | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 48f7bce4fb7d9..f2a1fc6a82702 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.scalatest.concurrent.TimeLimits -import org.scalatest.time.{Seconds, Span} +import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -102,17 +102,21 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits test("persist and then withColumn") { val df = Seq(("test", 1)).toDF("s", "i") - // Cache first because rddIdOf only works with cached DataFrame + // We should not invalidate the cached DataFrame + val df2 = df.withColumn("newColumn", lit(1)) + df.cache() assertCached(df) - df.count() + assertCached(df2) - // We should not invalidate the cached DataFrame - val df2 = df.withColumn("newColumn", lit(1)) + df.count() assertCached(df2) + + df.unpersist() + assert(df.storageLevel == StorageLevel.NONE) } - test("cache UDF correctly") { + test("cache UDF result correctly") { val expensiveUDF = udf({x: Int => Thread.sleep(10000); x}) val df = spark.range(0, 10).toDF("a").withColumn("b", expensiveUDF($"a")) val df2 = df.agg(sum(df("b"))) @@ -120,8 +124,13 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits df.cache() df.count() - failAfter(Span(5, Seconds)) { + assertCached(df2) + + failAfter(5 seconds) { df2.collect() } + + df.unpersist() + assert(df.storageLevel == StorageLevel.NONE) } } From e01e603e1dbd0dd96371b31c90b6c1767b1324cf Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 11 Jun 2018 16:09:28 -0400 Subject: [PATCH 3/5] Revert unneeded changes --- .../test/scala/org/apache/spark/sql/CachedTableSuite.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index ff6a4d406847e..81fe65a22cb80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -50,10 +50,6 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } def rddIdOf(tableName: String): Int = { - rddIdOf(spark.table(tableName), tableName) - } - - def rddIdOf(ds: Dataset[_], tableName: String = "unnamedTable"): Int = { val plan = ds.queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => From 658539a71971bdea62e89afd11775bf6f51d766d Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 11 Jun 2018 16:10:04 -0400 Subject: [PATCH 4/5] Fix unneeded changes --- .../src/test/scala/org/apache/spark/sql/CachedTableSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 81fe65a22cb80..6982c22f4771d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -50,7 +50,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } def rddIdOf(tableName: String): Int = { - val plan = ds.queryExecution.sparkPlan + val plan = spark.table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => relation.cacheBuilder.cachedColumnBuffers.id From c9db68d87a6f34f1849aadc3eaf58ed183cc2419 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 13 Jun 2018 18:44:49 -0400 Subject: [PATCH 5/5] Address comments --- .../test/scala/org/apache/spark/sql/DatasetCacheSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index f2a1fc6a82702..82a93f74dd76c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -102,7 +102,6 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits test("persist and then withColumn") { val df = Seq(("test", 1)).toDF("s", "i") - // We should not invalidate the cached DataFrame val df2 = df.withColumn("newColumn", lit(1)) df.cache() @@ -123,9 +122,9 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits df.cache() df.count() - assertCached(df2) + // udf has been evaluated during caching, and thus should not be re-evaluated here failAfter(5 seconds) { df2.collect() }