Skip to content

Commit

Permalink
[SPARK-19765][SPARK-18549][SPARK-19093][SPARK-19736][BACKPORT-2.1][SQ…
Browse files Browse the repository at this point in the history
…L] Backport Three Cache-related PRs to Spark 2.1

### What changes were proposed in this pull request?

Backport a few cache related PRs:

---
[[SPARK-19093][SQL] Cached tables are not used in SubqueryExpression](#16493)

Consider the plans inside subquery expressions while looking up cache manager to make
use of cached data. Currently CacheManager.useCachedData does not consider the
subquery expressions in the plan.

---
[[SPARK-19736][SQL] refreshByPath should clear all cached plans with the specified path](#17064)

Catalog.refreshByPath can refresh the cache entry and the associated metadata for all dataframes (if any), that contain the given data source path.

However, CacheManager.invalidateCachedPath doesn't clear all cached plans with the specified path. It causes some strange behaviors reported in SPARK-15678.

---
[[SPARK-19765][SPARK-18549][SQL] UNCACHE TABLE should un-cache all cached plans that refer to this table](#17097)

When un-cache a table, we should not only remove the cache entry for this table, but also un-cache any other cached plans that refer to this table. The following commands trigger the table uncache: `DropTableCommand`, `TruncateTableCommand`, `AlterTableRenameCommand`, `UncacheTableCommand`, `RefreshTable` and `InsertIntoHiveTable`

This PR also includes some refactors:
- use java.util.LinkedList to store the cache entries, so that it's safer to remove elements while iterating
- rename invalidateCache to recacheByPlan, which is more obvious about what it does.

### How was this patch tested?
N/A

Author: Xiao Li <gatorsmile@gmail.com>

Closes #17319 from gatorsmile/backport-17097.
  • Loading branch information
gatorsmile authored and cloud-fan committed Mar 17, 2017
1 parent 9d032d0 commit 4b977ff
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 88 deletions.
Expand Up @@ -19,9 +19,12 @@ package org.apache.spark.sql.execution

import java.util.concurrent.locks.ReentrantReadWriteLock

import scala.collection.JavaConverters._

import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.execution.columnar.InMemoryRelation
Expand All @@ -44,7 +47,7 @@ case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation)
class CacheManager extends Logging {

@transient
private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
private val cachedData = new java.util.LinkedList[CachedData]

@transient
private val cacheLock = new ReentrantReadWriteLock
Expand All @@ -69,7 +72,7 @@ class CacheManager extends Logging {

/** Clears all cached tables. */
def clearCache(): Unit = writeLock {
cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
cachedData.clear()
}

Expand All @@ -87,92 +90,109 @@ class CacheManager extends Logging {
query: Dataset[_],
tableName: Option[String] = None,
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
val planToCache = query.queryExecution.analyzed
val planToCache = query.logicalPlan
if (lookupCachedData(planToCache).nonEmpty) {
logWarning("Asked to cache already cached data.")
} else {
val sparkSession = query.sparkSession
cachedData +=
CachedData(
planToCache,
InMemoryRelation(
sparkSession.sessionState.conf.useCompression,
sparkSession.sessionState.conf.columnBatchSize,
storageLevel,
sparkSession.sessionState.executePlan(planToCache).executedPlan,
tableName))
cachedData.add(CachedData(
planToCache,
InMemoryRelation(
sparkSession.sessionState.conf.useCompression,
sparkSession.sessionState.conf.columnBatchSize,
storageLevel,
sparkSession.sessionState.executePlan(planToCache).executedPlan,
tableName)))
}
}

/**
* Tries to remove the data for the given [[Dataset]] from the cache.
* No operation, if it's already uncached.
* Un-cache all the cache entries that refer to the given plan.
*/
def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock {
uncacheQuery(query.sparkSession, query.logicalPlan, blocking)
}

/**
* Un-cache all the cache entries that refer to the given plan.
*/
def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock {
val planToCache = query.queryExecution.analyzed
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
val found = dataIndex >= 0
if (found) {
cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
cachedData.remove(dataIndex)
def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock {
val it = cachedData.iterator()
while (it.hasNext) {
val cd = it.next()
if (cd.plan.find(_.sameResult(plan)).isDefined) {
cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
it.remove()
}
}
found
}

/**
* Tries to re-cache all the cache entries that refer to the given plan.
*/
def recacheByPlan(spark: SparkSession, plan: LogicalPlan): Unit = writeLock {
recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined)
}

private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = {
val it = cachedData.iterator()
val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData]
while (it.hasNext) {
val cd = it.next()
if (condition(cd.plan)) {
cd.cachedRepresentation.cachedColumnBuffers.unpersist()
// Remove the cache entry before we create a new one, so that we can have a different
// physical plan.
it.remove()
val newCache = InMemoryRelation(
useCompression = cd.cachedRepresentation.useCompression,
batchSize = cd.cachedRepresentation.batchSize,
storageLevel = cd.cachedRepresentation.storageLevel,
child = spark.sessionState.executePlan(cd.plan).executedPlan,
tableName = cd.cachedRepresentation.tableName)
needToRecache += cd.copy(cachedRepresentation = newCache)
}
}

needToRecache.foreach(cachedData.add)
}

/** Optionally returns cached data for the given [[Dataset]] */
def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock {
lookupCachedData(query.queryExecution.analyzed)
lookupCachedData(query.logicalPlan)
}

/** Optionally returns cached data for the given [[LogicalPlan]]. */
def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock {
cachedData.find(cd => plan.sameResult(cd.plan))
cachedData.asScala.find(cd => plan.sameResult(cd.plan))
}

/** Replaces segments of the given logical plan with cached versions where possible. */
def useCachedData(plan: LogicalPlan): LogicalPlan = {
plan transformDown {
val newPlan = plan transformDown {
case currentFragment =>
lookupCachedData(currentFragment)
.map(_.cachedRepresentation.withOutput(currentFragment.output))
.getOrElse(currentFragment)
}
}

/**
* Invalidates the cache of any data that contains `plan`. Note that it is possible that this
* function will over invalidate.
*/
def invalidateCache(plan: LogicalPlan): Unit = writeLock {
cachedData.foreach {
case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty =>
data.cachedRepresentation.recache()
case _ =>
newPlan transformAllExpressions {
case s: SubqueryExpression => s.withNewPlan(useCachedData(s.plan))
}
}

/**
* Invalidates the cache of any data that contains `resourcePath` in one or more
* Tries to re-cache all the cache entries that contain `resourcePath` in one or more
* `HadoopFsRelation` node(s) as part of its logical plan.
*/
def invalidateCachedPath(
sparkSession: SparkSession, resourcePath: String): Unit = writeLock {
def recacheByPath(spark: SparkSession, resourcePath: String): Unit = writeLock {
val (fs, qualifiedPath) = {
val path = new Path(resourcePath)
val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf())
(fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory))
val fs = path.getFileSystem(spark.sessionState.newHadoopConf())
(fs, fs.makeQualified(path))
}

cachedData.foreach {
case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined =>
val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan))
if (dataIndex >= 0) {
data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true)
cachedData.remove(dataIndex)
}
sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan))
case _ => // Do Nothing
}
recacheByCondition(spark, _.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined)
}

/**
Expand Down
Expand Up @@ -85,12 +85,6 @@ case class InMemoryRelation(
buildBuffers()
}

def recache(): Unit = {
_cachedColumnBuffers.unpersist()
_cachedColumnBuffers = null
buildBuffers()
}

private def buildBuffers(): Unit = {
val output = child.output
val cached = child.execute().mapPartitionsInternal { rowIterator =>
Expand Down
Expand Up @@ -199,8 +199,7 @@ case class DropTableCommand(
}
}
try {
sparkSession.sharedState.cacheManager.uncacheQuery(
sparkSession.table(tableName.quotedString))
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
} catch {
case _: NoSuchTableException if ifExists =>
case NonFatal(e) => log.warn(e.toString, e)
Expand Down
Expand Up @@ -42,8 +42,9 @@ case class InsertIntoDataSourceCommand(
val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
relation.insert(df, overwrite.enabled)

// Invalidate the cache.
sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation)
// Re-cache all cached plans(including this relation itself, if it's cached) that refer to this
// data source relation.
sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation)

Seq.empty[Row]
}
Expand Down
Expand Up @@ -373,8 +373,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def dropTempView(viewName: String): Boolean = {
sparkSession.sessionState.catalog.getTempView(viewName).exists { tempView =>
sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView))
sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef =>
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
sessionCatalog.dropTempView(viewName)
}
}
Expand All @@ -389,7 +389,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
*/
override def dropGlobalTempView(viewName: String): Boolean = {
sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef =>
sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, viewDef))
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
sessionCatalog.dropGlobalTempView(viewName)
}
}
Expand Down Expand Up @@ -434,7 +434,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def uncacheTable(tableName: String): Unit = {
sparkSession.sharedState.cacheManager.uncacheQuery(query = sparkSession.table(tableName))
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
}

/**
Expand Down Expand Up @@ -472,17 +472,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {

// If this table is cached as an InMemoryRelation, drop the original
// cached version and make the new version cached lazily.
val logicalPlan = sparkSession.sessionState.catalog.lookupRelation(tableIdent)
// Use lookupCachedData directly since RefreshTable also takes databaseName.
val isCached = sparkSession.sharedState.cacheManager.lookupCachedData(logicalPlan).nonEmpty
if (isCached) {
// Create a data frame to represent the table.
// TODO: Use uncacheTable once it supports database name.
val df = Dataset.ofRows(sparkSession, logicalPlan)
val table = sparkSession.table(tableIdent)
if (isCached(table)) {
// Uncache the logicalPlan.
sparkSession.sharedState.cacheManager.uncacheQuery(df, blocking = true)
sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true)
// Cache it again.
sparkSession.sharedState.cacheManager.cacheQuery(df, Some(tableIdent.table))
sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table))
}
}

Expand All @@ -494,7 +489,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def refreshByPath(resourcePath: String): Unit = {
sparkSession.sharedState.cacheManager.invalidateCachedPath(sparkSession, resourcePath)
sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, resourcePath)
}
}

Expand Down

0 comments on commit 4b977ff

Please sign in to comment.