Skip to content

Commit

Permalink
[SPARK-23880][SQL] Do not trigger any jobs for caching data
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
This pr fixed code so that `cache` could prevent any jobs from being triggered.
For example, in the current master, an operation below triggers a actual job;
```
val df = spark.range(10000000000L)
  .filter('id > 1000)
  .orderBy('id.desc)
  .cache()
```
This triggers a job while the cache should be lazy. The problem is that, when creating `InMemoryRelation`, we build the RDD, which calls `SparkPlan.execute` and may trigger jobs, like sampling job for range partitioner, or broadcast job.

This pr removed the code to build a cached `RDD` in the constructor of `InMemoryRelation` and added `CachedRDDBuilder` to lazily build the `RDD` in `InMemoryRelation`. Then, the first call of `CachedRDDBuilder.cachedColumnBuffers` triggers a job to materialize the cache in  `InMemoryTableScanExec` .

## How was this patch tested?
Added tests in `CachedTableSuite`.

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #21018 from maropu/SPARK-23880.
  • Loading branch information
maropu authored and cloud-fan committed Apr 25, 2018
1 parent 64e8408 commit 20ca208
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 94 deletions.
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Expand Up @@ -2933,7 +2933,7 @@ class Dataset[T] private[sql](
*/
def storageLevel: StorageLevel = {
sparkSession.sharedState.cacheManager.lookupCachedData(this).map { cachedData =>
cachedData.cachedRepresentation.storageLevel
cachedData.cachedRepresentation.cacheBuilder.storageLevel
}.getOrElse(StorageLevel.NONE)
}

Expand Down
Expand Up @@ -71,7 +71,7 @@ class CacheManager extends Logging {

/** Clears all cached tables. */
def clearCache(): Unit = writeLock {
cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
cachedData.asScala.foreach(_.cachedRepresentation.cacheBuilder.clearCache())
cachedData.clear()
}

Expand Down Expand Up @@ -119,7 +119,7 @@ class CacheManager extends Logging {
while (it.hasNext) {
val cd = it.next()
if (cd.plan.find(_.sameResult(plan)).isDefined) {
cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
cd.cachedRepresentation.cacheBuilder.clearCache(blocking)
it.remove()
}
}
Expand All @@ -138,16 +138,14 @@ class CacheManager extends Logging {
while (it.hasNext) {
val cd = it.next()
if (condition(cd.plan)) {
cd.cachedRepresentation.cachedColumnBuffers.unpersist()
cd.cachedRepresentation.cacheBuilder.clearCache()
// Remove the cache entry before we create a new one, so that we can have a different
// physical plan.
it.remove()
val plan = spark.sessionState.executePlan(cd.plan).executedPlan
val newCache = InMemoryRelation(
useCompression = cd.cachedRepresentation.useCompression,
batchSize = cd.cachedRepresentation.batchSize,
storageLevel = cd.cachedRepresentation.storageLevel,
child = spark.sessionState.executePlan(cd.plan).executedPlan,
tableName = cd.cachedRepresentation.tableName,
cacheBuilder = cd.cachedRepresentation
.cacheBuilder.copy(cachedPlan = plan)(_cachedColumnBuffers = null),
logicalPlan = cd.plan)
needToRecache += cd.copy(cachedRepresentation = newCache)
}
Expand Down
Expand Up @@ -32,19 +32,6 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.LongAccumulator


object InMemoryRelation {
def apply(
useCompression: Boolean,
batchSize: Int,
storageLevel: StorageLevel,
child: SparkPlan,
tableName: Option[String],
logicalPlan: LogicalPlan): InMemoryRelation =
new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)(
statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering)
}


/**
* CachedBatch is a cached batch of rows.
*
Expand All @@ -55,58 +42,41 @@ object InMemoryRelation {
private[columnar]
case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow)

case class InMemoryRelation(
output: Seq[Attribute],
case class CachedRDDBuilder(
useCompression: Boolean,
batchSize: Int,
storageLevel: StorageLevel,
@transient child: SparkPlan,
@transient cachedPlan: SparkPlan,
tableName: Option[String])(
@transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
statsOfPlanToCache: Statistics,
override val outputOrdering: Seq[SortOrder])
extends logical.LeafNode with MultiInstanceRelation {

override protected def innerChildren: Seq[SparkPlan] = Seq(child)

override def doCanonicalize(): logical.LogicalPlan =
copy(output = output.map(QueryPlan.normalizeExprId(_, child.output)),
storageLevel = StorageLevel.NONE,
child = child.canonicalized,
tableName = None)(
_cachedColumnBuffers,
sizeInBytesStats,
statsOfPlanToCache,
outputOrdering)
@transient private var _cachedColumnBuffers: RDD[CachedBatch] = null) {

override def producedAttributes: AttributeSet = outputSet

@transient val partitionStatistics = new PartitionStatistics(output)
val sizeInBytesStats: LongAccumulator = cachedPlan.sqlContext.sparkContext.longAccumulator

override def computeStats(): Statistics = {
if (sizeInBytesStats.value == 0L) {
// Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache.
// Note that we should drop the hint info here. We may cache a plan whose root node is a hint
// node. When we lookup the cache with a semantically same plan without hint info, the plan
// returned by cache lookup should not have hint info. If we lookup the cache with a
// semantically same plan with a different hint info, `CacheManager.useCachedData` will take
// care of it and retain the hint info in the lookup input plan.
statsOfPlanToCache.copy(hints = HintInfo())
} else {
Statistics(sizeInBytes = sizeInBytesStats.value.longValue)
def cachedColumnBuffers: RDD[CachedBatch] = {
if (_cachedColumnBuffers == null) {
synchronized {
if (_cachedColumnBuffers == null) {
_cachedColumnBuffers = buildBuffers()
}
}
}
_cachedColumnBuffers
}

// If the cached column buffers were not passed in, we calculate them in the constructor.
// As in Spark, the actual work of caching is lazy.
if (_cachedColumnBuffers == null) {
buildBuffers()
def clearCache(blocking: Boolean = true): Unit = {
if (_cachedColumnBuffers != null) {
synchronized {
if (_cachedColumnBuffers != null) {
_cachedColumnBuffers.unpersist(blocking)
_cachedColumnBuffers = null
}
}
}
}

private def buildBuffers(): Unit = {
val output = child.output
val cached = child.execute().mapPartitionsInternal { rowIterator =>
private def buildBuffers(): RDD[CachedBatch] = {
val output = cachedPlan.output
val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator =>
new Iterator[CachedBatch] {
def next(): CachedBatch = {
val columnBuilders = output.map { attribute =>
Expand Down Expand Up @@ -154,32 +124,77 @@ case class InMemoryRelation(

cached.setName(
tableName.map(n => s"In-memory table $n")
.getOrElse(StringUtils.abbreviate(child.toString, 1024)))
_cachedColumnBuffers = cached
.getOrElse(StringUtils.abbreviate(cachedPlan.toString, 1024)))
cached
}
}

object InMemoryRelation {

def apply(
useCompression: Boolean,
batchSize: Int,
storageLevel: StorageLevel,
child: SparkPlan,
tableName: Option[String],
logicalPlan: LogicalPlan): InMemoryRelation = {
val cacheBuilder = CachedRDDBuilder(useCompression, batchSize, storageLevel, child, tableName)()
new InMemoryRelation(child.output, cacheBuilder)(
statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering)
}

def apply(cacheBuilder: CachedRDDBuilder, logicalPlan: LogicalPlan): InMemoryRelation = {
new InMemoryRelation(cacheBuilder.cachedPlan.output, cacheBuilder)(
statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering)
}
}

case class InMemoryRelation(
output: Seq[Attribute],
@transient cacheBuilder: CachedRDDBuilder)(
statsOfPlanToCache: Statistics,
override val outputOrdering: Seq[SortOrder])
extends logical.LeafNode with MultiInstanceRelation {

override protected def innerChildren: Seq[SparkPlan] = Seq(cachedPlan)

override def doCanonicalize(): logical.LogicalPlan =
copy(output = output.map(QueryPlan.normalizeExprId(_, cachedPlan.output)),
cacheBuilder)(
statsOfPlanToCache,
outputOrdering)

override def producedAttributes: AttributeSet = outputSet

@transient val partitionStatistics = new PartitionStatistics(output)

def cachedPlan: SparkPlan = cacheBuilder.cachedPlan

override def computeStats(): Statistics = {
if (cacheBuilder.sizeInBytesStats.value == 0L) {
// Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache.
// Note that we should drop the hint info here. We may cache a plan whose root node is a hint
// node. When we lookup the cache with a semantically same plan without hint info, the plan
// returned by cache lookup should not have hint info. If we lookup the cache with a
// semantically same plan with a different hint info, `CacheManager.useCachedData` will take
// care of it and retain the hint info in the lookup input plan.
statsOfPlanToCache.copy(hints = HintInfo())
} else {
Statistics(sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue)
}
}

def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
InMemoryRelation(
newOutput, useCompression, batchSize, storageLevel, child, tableName)(
_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache, outputOrdering)
InMemoryRelation(newOutput, cacheBuilder)(statsOfPlanToCache, outputOrdering)
}

override def newInstance(): this.type = {
new InMemoryRelation(
output.map(_.newInstance()),
useCompression,
batchSize,
storageLevel,
child,
tableName)(
_cachedColumnBuffers,
sizeInBytesStats,
cacheBuilder)(
statsOfPlanToCache,
outputOrdering).asInstanceOf[this.type]
}

def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers

override protected def otherCopyArgs: Seq[AnyRef] =
Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache)
}
Expand Up @@ -154,7 +154,7 @@ case class InMemoryTableScanExec(
private def updateAttribute(expr: Expression): Expression = {
// attributes can be pruned so using relation's output.
// E.g., relation.output is [id, item] but this scan's output can be [item] only.
val attrMap = AttributeMap(relation.child.output.zip(relation.output))
val attrMap = AttributeMap(relation.cachedPlan.output.zip(relation.output))
expr.transform {
case attr: Attribute => attrMap.getOrElse(attr, attr)
}
Expand All @@ -163,16 +163,16 @@ case class InMemoryTableScanExec(
// The cached version does not change the outputPartitioning of the original SparkPlan.
// But the cached version could alias output, so we need to replace output.
override def outputPartitioning: Partitioning = {
relation.child.outputPartitioning match {
relation.cachedPlan.outputPartitioning match {
case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning]
case _ => relation.child.outputPartitioning
case _ => relation.cachedPlan.outputPartitioning
}
}

// The cached version does not change the outputOrdering of the original SparkPlan.
// But the cached version could alias output, so we need to replace output.
override def outputOrdering: Seq[SortOrder] =
relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder])
relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder])

// Keeps relation's partition statistics because we don't serialize relation.
private val stats = relation.partitionStatistics
Expand Down Expand Up @@ -252,7 +252,7 @@ case class InMemoryTableScanExec(
// within the map Partitions closure.
val schema = stats.schema
val schemaIndex = schema.zipWithIndex
val buffers = relation.cachedColumnBuffers
val buffers = relation.cacheBuilder.cachedColumnBuffers

buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) =>
val partitionFilter = newPredicate(
Expand Down
Expand Up @@ -22,6 +22,7 @@ import scala.concurrent.duration._
import scala.language.postfixOps

import org.apache.spark.CleanerListener
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
Expand Down Expand Up @@ -52,7 +53,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
val plan = spark.table(tableName).queryExecution.sparkPlan
plan.collect {
case InMemoryTableScanExec(_, _, relation) =>
relation.cachedColumnBuffers.id
relation.cacheBuilder.cachedColumnBuffers.id
case _ =>
fail(s"Table $tableName is not cached\n" + plan)
}.head
Expand All @@ -78,7 +79,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = {
plan.collect {
case InMemoryTableScanExec(_, _, relation) =>
getNumInMemoryTablesRecursively(relation.child) + 1
getNumInMemoryTablesRecursively(relation.cachedPlan) + 1
}.sum
}

Expand Down Expand Up @@ -200,7 +201,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
spark.catalog.cacheTable("testData")
assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") {
spark.table("testData").queryExecution.withCachedData.collect {
case r @ InMemoryRelation(_, _, _, _, _: InMemoryTableScanExec, _) => r
case r: InMemoryRelation if r.cachedPlan.isInstanceOf[InMemoryTableScanExec] => r
}.size
}

Expand Down Expand Up @@ -367,12 +368,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
val toBeCleanedAccIds = new HashSet[Long]

val accId1 = spark.table("t1").queryExecution.withCachedData.collect {
case i: InMemoryRelation => i.sizeInBytesStats.id
case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id
}.head
toBeCleanedAccIds += accId1

val accId2 = spark.table("t1").queryExecution.withCachedData.collect {
case i: InMemoryRelation => i.sizeInBytesStats.id
case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id
}.head
toBeCleanedAccIds += accId2

Expand Down Expand Up @@ -794,4 +795,29 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
}
}
}

private def checkIfNoJobTriggered[T](f: => T): T = {
var numJobTrigered = 0
val jobListener = new SparkListener {
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
numJobTrigered += 1
}
}
sparkContext.addSparkListener(jobListener)
try {
val result = f
sparkContext.listenerBus.waitUntilEmpty(10000L)
assert(numJobTrigered === 0)
result
} finally {
sparkContext.removeSparkListener(jobListener)
}
}

test("SPARK-23880 table cache should be lazy and don't trigger any jobs") {
val cachedData = checkIfNoJobTriggered {
spark.range(1002).filter('id > 1000).orderBy('id.desc).cache()
}
assert(cachedData.collect === Seq(1001))
}
}
Expand Up @@ -194,7 +194,7 @@ class PlannerSuite extends SharedSQLContext {
test("CollectLimit can appear in the middle of a plan when caching is used") {
val query = testData.select('key, 'value).limit(2).cache()
val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation]
assert(planned.child.isInstanceOf[CollectLimitExec])
assert(planned.cachedPlan.isInstanceOf[CollectLimitExec])
}

test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") {
Expand Down
Expand Up @@ -45,8 +45,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None,
data.logicalPlan)

assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel)
inMemoryRelation.cachedColumnBuffers.collect().head match {
assert(inMemoryRelation.cacheBuilder.cachedColumnBuffers.getStorageLevel == storageLevel)
inMemoryRelation.cacheBuilder.cachedColumnBuffers.collect().head match {
case _: CachedBatch =>
case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}")
}
Expand Down Expand Up @@ -337,7 +337,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(cached, expectedAnswer)

// Check that the right size was calculated.
assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize)
assert(cached.cacheBuilder.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize)
}

test("access primitive-type columns in CachedBatch without whole stage codegen") {
Expand Down

0 comments on commit 20ca208

Please sign in to comment.