Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-19765][SPARK-18549][SQL] UNCACHE TABLE should un-cache all cached plans that refer to this table #17097

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution

import java.util.concurrent.locks.ReentrantReadWriteLock

import scala.collection.JavaConverters._

import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.internal.Logging
Expand All @@ -45,7 +47,7 @@ case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation)
class CacheManager extends Logging {

@transient
private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
private val cachedData = new java.util.LinkedList[CachedData]

@transient
private val cacheLock = new ReentrantReadWriteLock
Expand All @@ -70,7 +72,7 @@ class CacheManager extends Logging {

/** Clears all cached tables. */
def clearCache(): Unit = writeLock {
cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
cachedData.clear()
}

Expand All @@ -88,46 +90,93 @@ class CacheManager extends Logging {
query: Dataset[_],
tableName: Option[String] = None,
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
val planToCache = query.queryExecution.analyzed
val planToCache = query.logicalPlan
if (lookupCachedData(planToCache).nonEmpty) {
logWarning("Asked to cache already cached data.")
} else {
val sparkSession = query.sparkSession
cachedData +=
CachedData(
planToCache,
InMemoryRelation(
sparkSession.sessionState.conf.useCompression,
sparkSession.sessionState.conf.columnBatchSize,
storageLevel,
sparkSession.sessionState.executePlan(planToCache).executedPlan,
tableName))
cachedData.add(CachedData(
planToCache,
InMemoryRelation(
sparkSession.sessionState.conf.useCompression,
sparkSession.sessionState.conf.columnBatchSize,
storageLevel,
sparkSession.sessionState.executePlan(planToCache).executedPlan,
tableName)))
}
}

/**
* Tries to remove the data for the given [[Dataset]] from the cache.
* No operation, if it's already uncached.
* Tries to remove the cache entry of the given query, no operation, if it's already uncached.
* Note that all other caches that refer to this plan will be re-cached.
*
* @return true if a cache entry is found and removed, false otherwise.
*/
def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock {
val planToCache = query.queryExecution.analyzed
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
val found = dataIndex >= 0
uncacheQuery(query.sparkSession, query.logicalPlan, blocking)
}

/**
* Tries to remove the cache entry of the given plan, no operation, if it's already uncached.
* Note that all other caches that refer to this plan will be re-cached.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a behavior change.

Will uncache be a very expensive operation if we recache all the related cached plans?

*
* @return true if a cache entry is found and removed, false otherwise.
*/
def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Boolean = writeLock {
val it = cachedData.iterator()
var found = false
while (it.hasNext && !found) {
val cd = it.next()
if (cd.plan.sameResult(plan)) {
cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
it.remove()
found = true
}
}
if (found) {
cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
cachedData.remove(dataIndex)
recacheByPlan(spark, plan)
}
found
}

/**
* Tries to re-cache all the cache entries that refer to the given plan.
*/
def recacheByPlan(spark: SparkSession, plan: LogicalPlan): Unit = writeLock {
recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined)
}

private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = {
val it = cachedData.iterator()
val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData]
while (it.hasNext) {
val cd = it.next()
if (condition(cd.plan)) {
cd.cachedRepresentation.cachedColumnBuffers.unpersist()
// Remove the cache entry before we create a new one, so that we can have a different
// physical plan.
it.remove()
val newCache = InMemoryRelation(
useCompression = cd.cachedRepresentation.useCompression,
batchSize = cd.cachedRepresentation.batchSize,
storageLevel = cd.cachedRepresentation.storageLevel,
child = spark.sessionState.executePlan(cd.plan).executedPlan,
tableName = cd.cachedRepresentation.tableName)
needToRecache += cd.copy(cachedRepresentation = newCache)
}
}

needToRecache.foreach(cachedData.add)
}

/** Optionally returns cached data for the given [[Dataset]] */
def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock {
lookupCachedData(query.queryExecution.analyzed)
lookupCachedData(query.logicalPlan)
}

/** Optionally returns cached data for the given [[LogicalPlan]]. */
def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock {
cachedData.find(cd => plan.sameResult(cd.plan))
cachedData.asScala.find(cd => plan.sameResult(cd.plan))
}

/** Replaces segments of the given logical plan with cached versions where possible. */
Expand All @@ -145,40 +194,17 @@ class CacheManager extends Logging {
}

/**
* Invalidates the cache of any data that contains `plan`. Note that it is possible that this
* function will over invalidate.
*/
def invalidateCache(plan: LogicalPlan): Unit = writeLock {
cachedData.foreach {
case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty =>
data.cachedRepresentation.recache()
case _ =>
}
}

/**
* Invalidates the cache of any data that contains `resourcePath` in one or more
* Tries to re-cache all the cache entries that contain `resourcePath` in one or more
* `HadoopFsRelation` node(s) as part of its logical plan.
*/
def invalidateCachedPath(
sparkSession: SparkSession, resourcePath: String): Unit = writeLock {
def recacheByPath(spark: SparkSession, resourcePath: String): Unit = writeLock {
val (fs, qualifiedPath) = {
val path = new Path(resourcePath)
val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf())
(fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory))
val fs = path.getFileSystem(spark.sessionState.newHadoopConf())
(fs, fs.makeQualified(path))
}

cachedData.filter {
case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => true
case _ => false
}.foreach { data =>
val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan))
if (dataIndex >= 0) {
data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true)
cachedData.remove(dataIndex)
}
sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan))
}
recacheByCondition(spark, _.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,6 @@ case class InMemoryRelation(
buildBuffers()
}

def recache(): Unit = {
_cachedColumnBuffers.unpersist()
_cachedColumnBuffers = null
buildBuffers()
}

private def buildBuffers(): Unit = {
val output = child.output
val cached = child.execute().mapPartitionsInternal { rowIterator =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ case class DropTableCommand(
}
}
try {
sparkSession.sharedState.cacheManager.uncacheQuery(
sparkSession.table(tableName.quotedString))
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
} catch {
case _: NoSuchTableException if ifExists =>
case NonFatal(e) => log.warn(e.toString, e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ case class InsertIntoDataSourceCommand(
val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
relation.insert(df, overwrite)

// Invalidate the cache.
sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation)
// Re-cache all cached plans(including this relation itself, if it's cached) that refer to this
// data source relation.
sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current design, users need to re-cache the queries by themselves.

After this change, insertion could be super slow. Each insert could trigger the recache of many involved cached data, each of which could be very complex and expensive. That is a trade-off. Although we keep the data correctness/consistence, we could sacrifice the performance/user experience.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't change the behavior, just rename invalidateCache to recacheByPlan.

I'll open a new discussion about whether we should do recache after insertion or not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh... My fault... invalidateCache is misleading. Renaming is good!


Seq.empty[Row]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def dropTempView(viewName: String): Boolean = {
sparkSession.sessionState.catalog.getTempView(viewName).exists { tempView =>
sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView))
sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef =>
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
sessionCatalog.dropTempView(viewName)
}
}
Expand All @@ -357,7 +357,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
*/
override def dropGlobalTempView(viewName: String): Boolean = {
sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef =>
sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, viewDef))
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
sessionCatalog.dropGlobalTempView(viewName)
}
}
Expand Down Expand Up @@ -402,7 +402,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def uncacheTable(tableName: String): Unit = {
sparkSession.sharedState.cacheManager.uncacheQuery(query = sparkSession.table(tableName))
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
}

/**
Expand Down Expand Up @@ -440,17 +440,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {

// If this table is cached as an InMemoryRelation, drop the original
// cached version and make the new version cached lazily.
val logicalPlan = sparkSession.table(tableIdent).queryExecution.analyzed
// Use lookupCachedData directly since RefreshTable also takes databaseName.
val isCached = sparkSession.sharedState.cacheManager.lookupCachedData(logicalPlan).nonEmpty
if (isCached) {
// Create a data frame to represent the table.
// TODO: Use uncacheTable once it supports database name.
val df = Dataset.ofRows(sparkSession, logicalPlan)
val table = sparkSession.table(tableIdent)
if (isCached(table)) {
// Uncache the logicalPlan.
sparkSession.sharedState.cacheManager.uncacheQuery(df, blocking = true)
sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true)
// Cache it again.
sparkSession.sharedState.cacheManager.cacheQuery(df, Some(tableIdent.table))
sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table))
}
}

Expand All @@ -462,7 +457,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def refreshByPath(resourcePath: String): Unit = {
sparkSession.sharedState.cacheManager.invalidateCachedPath(sparkSession, resourcePath)
sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, resourcePath)
}
}

Expand Down
40 changes: 28 additions & 12 deletions sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
maybeBlock.nonEmpty
}

private def getNumInMemoryRelations(plan: LogicalPlan): Int = {
private def getNumInMemoryRelations(ds: Dataset[_]): Int = {
val plan = ds.queryExecution.withCachedData
var sum = plan.collect { case _: InMemoryRelation => 1 }.sum
plan.transformAllExpressions {
case e: SubqueryExpression =>
Expand Down Expand Up @@ -187,7 +188,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
assertCached(spark.table("testData"))

assertResult(1, "InMemoryRelation not found, testData should have been cached") {
getNumInMemoryRelations(spark.table("testData").queryExecution.withCachedData)
getNumInMemoryRelations(spark.table("testData"))
}

spark.catalog.cacheTable("testData")
Expand Down Expand Up @@ -580,21 +581,21 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
localRelation.createOrReplaceTempView("localRelation")

spark.catalog.cacheTable("localRelation")
assert(getNumInMemoryRelations(localRelation.queryExecution.withCachedData) == 1)
assert(getNumInMemoryRelations(localRelation) == 1)
}

test("SPARK-19093 Caching in side subquery") {
withTempView("t1") {
Seq(1).toDF("c1").createOrReplaceTempView("t1")
spark.catalog.cacheTable("t1")
val cachedPlan =
val ds =
sql(
"""
|SELECT * FROM t1
|WHERE
|NOT EXISTS (SELECT * FROM t1)
""".stripMargin).queryExecution.optimizedPlan
assert(getNumInMemoryRelations(cachedPlan) == 2)
""".stripMargin)
assert(getNumInMemoryRelations(ds) == 2)
}
}

Expand All @@ -610,17 +611,17 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
spark.catalog.cacheTable("t4")

// Nested predicate subquery
val cachedPlan =
val ds =
sql(
"""
|SELECT * FROM t1
|WHERE
|c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1))
""".stripMargin).queryExecution.optimizedPlan
assert(getNumInMemoryRelations(cachedPlan) == 3)
""".stripMargin)
assert(getNumInMemoryRelations(ds) == 3)

// Scalar subquery and predicate subquery
val cachedPlan2 =
val ds2 =
sql(
"""
|SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1)
Expand All @@ -630,8 +631,23 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
|EXISTS (SELECT c1 FROM t3)
|OR
|c1 IN (SELECT c1 FROM t4)
""".stripMargin).queryExecution.optimizedPlan
assert(getNumInMemoryRelations(cachedPlan2) == 4)
""".stripMargin)
assert(getNumInMemoryRelations(ds2) == 4)
}
}

test("SPARK-19765: UNCACHE TABLE should re-cache all cached plans that refer to this table") {
withTable("t") {
Seq(1 -> "a").toDF("i", "j").write.saveAsTable("t")
spark.catalog.cacheTable("t")
spark.table("t").select($"i").cache()
checkAnswer(spark.table("t").select($"i"), Row(1))
assertCached(spark.table("t").select($"i"))

sql("INSERT OVERWRITE TABLE t SELECT 2, 'b'")
spark.catalog.uncacheTable("t")
checkAnswer(spark.table("t").select($"i"), Row(2))
assert(getNumInMemoryRelations(spark.table("t").select($"i")) == 1)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,8 @@ case class InsertIntoHiveTable(
logWarning(s"Unable to delete staging directory: $stagingDir.\n" + e)
}

// Invalidate the cache.
sparkSession.catalog.uncacheTable(table.qualifiedName)
// un-cache this table.
sparkSession.catalog.uncacheTable(table.identifier.quotedString)
sparkSession.sessionState.catalog.refreshTable(table.identifier)

// It would be nice to just return the childRdd unchanged so insert operations could be chained,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
tempPath.delete()
table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString)
sql("DROP TABLE IF EXISTS refreshTable")
sparkSession.catalog.createExternalTable("refreshTable", tempPath.toString, "parquet")
checkAnswer(
table("refreshTable"),
table("src").collect())
sparkSession.catalog.createTable("refreshTable", tempPath.toString, "parquet")
checkAnswer(table("refreshTable"), table("src"))
// Cache the table.
sql("CACHE TABLE refreshTable")
assertCached(table("refreshTable"))
Expand Down
Loading