Skip to content

Commit

Permalink
[SPARK-23214][SQL] cached data should not carry extra hint info
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This is a regression introduced by #19864

When we lookup cache, we should not carry the hint info, as this cache entry might be added by a plan having hint info, while the input plan for this lookup may not have hint info, or have different hint info.

## How was this patch tested?

a new test.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #20394 from cloud-fan/cache.
  • Loading branch information
cloud-fan authored and gatorsmile committed Jan 27, 2018
1 parent 0737449 commit 5b5447c
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,17 @@ class CacheManager extends Logging {
/** Replaces segments of the given logical plan with cached versions where possible. */
def useCachedData(plan: LogicalPlan): LogicalPlan = {
val newPlan = plan transformDown {
// Do not lookup the cache by hint node. Hint node is special, we should ignore it when
// canonicalizing plans, so that plans which are same except hint can hit the same cache.
// However, we also want to keep the hint info after cache lookup. Here we skip the hint
// node, so that the returned caching plan won't replace the hint node and drop the hint info
// from the original plan.
case hint: ResolvedHint => hint

case currentFragment =>
lookupCachedData(currentFragment).map { cached =>
val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output)
currentFragment match {
case hint: ResolvedHint => ResolvedHint(cachedPlan, hint.hints)
case _ => cachedPlan
}
}.getOrElse(currentFragment)
lookupCachedData(currentFragment)
.map(_.cachedRepresentation.withOutput(currentFragment.output))
.getOrElse(currentFragment)
}

newPlan transformAllExpressions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.LongAccumulator
Expand Down Expand Up @@ -62,8 +62,8 @@ case class InMemoryRelation(
@transient child: SparkPlan,
tableName: Option[String])(
@transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
statsOfPlanToCache: Statistics = null)
val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
statsOfPlanToCache: Statistics)
extends logical.LeafNode with MultiInstanceRelation {

override protected def innerChildren: Seq[SparkPlan] = Seq(child)
Expand All @@ -73,11 +73,16 @@ case class InMemoryRelation(
@transient val partitionStatistics = new PartitionStatistics(output)

override def computeStats(): Statistics = {
if (batchStats.value == 0L) {
// Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache
statsOfPlanToCache
if (sizeInBytesStats.value == 0L) {
// Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache.
// Note that we should drop the hint info here. We may cache a plan whose root node is a hint
// node. When we lookup the cache with a semantically same plan without hint info, the plan
// returned by cache lookup should not have hint info. If we lookup the cache with a
// semantically same plan with a different hint info, `CacheManager.useCachedData` will take
// care of it and retain the hint info in the lookup input plan.
statsOfPlanToCache.copy(hints = HintInfo())
} else {
Statistics(sizeInBytes = batchStats.value.longValue)
Statistics(sizeInBytes = sizeInBytesStats.value.longValue)
}
}

Expand Down Expand Up @@ -122,7 +127,7 @@ case class InMemoryRelation(
rowCount += 1
}

batchStats.add(totalSize)
sizeInBytesStats.add(totalSize)

val stats = InternalRow.fromSeq(
columnBuilders.flatMap(_.columnStats.collectedStatistics))
Expand All @@ -144,7 +149,7 @@ case class InMemoryRelation(
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
InMemoryRelation(
newOutput, useCompression, batchSize, storageLevel, child, tableName)(
_cachedColumnBuffers, batchStats, statsOfPlanToCache)
_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
}

override def newInstance(): this.type = {
Expand All @@ -156,12 +161,12 @@ case class InMemoryRelation(
child,
tableName)(
_cachedColumnBuffers,
batchStats,
sizeInBytesStats,
statsOfPlanToCache).asInstanceOf[this.type]
}

def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers

override protected def otherCopyArgs: Seq[AnyRef] =
Seq(_cachedColumnBuffers, batchStats, statsOfPlanToCache)
Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
}
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
val toBeCleanedAccIds = new HashSet[Long]

val accId1 = spark.table("t1").queryExecution.withCachedData.collect {
case i: InMemoryRelation => i.batchStats.id
case i: InMemoryRelation => i.sizeInBytesStats.id
}.head
toBeCleanedAccIds += accId1

val accId2 = spark.table("t1").queryExecution.withCachedData.collect {
case i: InMemoryRelation => i.batchStats.id
case i: InMemoryRelation => i.sizeInBytesStats.id
}.head
toBeCleanedAccIds += accId2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(cached, expectedAnswer)

// Check that the right size was calculated.
assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize)
assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize)
}

test("access primitive-type columns in CachedBatch without whole stage codegen") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import scala.reflect.ClassTag
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft}
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -70,8 +71,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
private def testBroadcastJoin[T: ClassTag](
joinType: String,
forceBroadcast: Boolean = false): SparkPlan = {
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")

// Comparison at the end is for broadcast left semi join
val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
Expand Down Expand Up @@ -109,61 +110,89 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
}
}

test("broadcast hint is retained after using the cached data") {
test("SPARK-23192: broadcast hint should be retained after using the cached data") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
df2.cache()
val df3 = df1.join(broadcast(df2), Seq("key"), "inner")
val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
case b: BroadcastHashJoinExec => b
}.size
assert(numBroadCastHashJoin === 1)
try {
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
df2.cache()
val df3 = df1.join(broadcast(df2), Seq("key"), "inner")
val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
case b: BroadcastHashJoinExec => b
}.size
assert(numBroadCastHashJoin === 1)
} finally {
spark.catalog.clearCache()
}
}
}

test("SPARK-23214: cached data should not carry extra hint info") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
try {
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
broadcast(df2).cache()

val df3 = df1.join(df2, Seq("key"), "inner")
val numCachedPlan = df3.queryExecution.executedPlan.collect {
case i: InMemoryTableScanExec => i
}.size
// df2 should be cached.
assert(numCachedPlan === 1)

val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect {
case b: BroadcastHashJoinExec => b
}.size
// df2 should not be broadcasted.
assert(numBroadCastHashJoin === 0)
} finally {
spark.catalog.clearCache()
}
}
}

test("broadcast hint isn't propagated after a join") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
val df3 = df1.join(broadcast(df2), Seq("key"), "inner").drop(df2("key"))

val df4 = spark.createDataFrame(Seq((1, "5"), (2, "5"))).toDF("key", "value")
val df4 = Seq((1, "5"), (2, "5")).toDF("key", "value")
val df5 = df4.join(df3, Seq("key"), "inner")

val plan =
EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan)
val plan = EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan)

assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1)
assert(plan.collect { case p: SortMergeJoinExec => p }.size === 1)
}
}

private def assertBroadcastJoin(df : Dataset[Row]) : Unit = {
val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value")
val joined = df1.join(df, Seq("key"), "inner")

val plan =
EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan)
val plan = EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan)

assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1)
}

test("broadcast hint programming API") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "2"))).toDF("key", "value")
val df2 = Seq((1, "1"), (2, "2"), (3, "2")).toDF("key", "value")
val broadcasted = broadcast(df2)
val df3 = spark.createDataFrame(Seq((2, "2"), (3, "3"))).toDF("key", "value")

val cases = Seq(broadcasted.limit(2),
broadcasted.filter("value < 10"),
broadcasted.sample(true, 0.5),
broadcasted.distinct(),
broadcasted.groupBy("value").agg(min($"key").as("key")),
// except and intersect are semi/anti-joins which won't return more data then
// their left argument, so the broadcast hint should be propagated here
broadcasted.except(df3),
broadcasted.intersect(df3))
val df3 = Seq((2, "2"), (3, "3")).toDF("key", "value")

val cases = Seq(
broadcasted.limit(2),
broadcasted.filter("value < 10"),
broadcasted.sample(true, 0.5),
broadcasted.distinct(),
broadcasted.groupBy("value").agg(min($"key").as("key")),
// except and intersect are semi/anti-joins which won't return more data then
// their left argument, so the broadcast hint should be propagated here
broadcasted.except(df3),
broadcasted.intersect(df3))

cases.foreach(assertBroadcastJoin)
}
Expand Down Expand Up @@ -240,9 +269,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
test("Shouldn't change broadcast join buildSide if user clearly specified") {

withTempView("t1", "t2") {
spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1")
spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value")
.createTempView("t2")
Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2")

val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
Expand Down Expand Up @@ -292,9 +320,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
test("Shouldn't bias towards build right if user didn't specify") {

withTempView("t1", "t2") {
spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1")
spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value")
.createTempView("t2")
Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2")

val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
Expand Down

0 comments on commit 5b5447c

Please sign in to comment.