Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-17170] [SQL] InMemoryTableScanExec driver-side partition pruning #14733

Closed
wants to merge 4 commits into from

Conversation

pwoody
Copy link

@pwoody pwoody commented Aug 20, 2016

What changes were proposed in this pull request?

After caching data, we have statistics that enable us to eagerly prune entire partitions before launching a query. This modifies the InMemoryTableScanExec to prune partitions before launching the tasks.

How was this patch tested?

Existing test suite with slight modification to scan over the data once as setup.

if (validPartitions.isEmpty) {
new EmptyRDD[CachedBatch](sparkContext)
} else {
new PartitionPruningRDD[CachedBatch](relation.cachedColumnBuffers,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's PartitionPruningRDD.create which will make this slightly cleaner and if you skip logging it's just
PartitionPruningRDD.create(relation.cachedColumnBuffers, validPartitions.contains).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool yeah used this (and removed the EmptyRDD constructor as well). I'd prefer to keep the logging in the function though.

@pwoody
Copy link
Author

pwoody commented Aug 23, 2016

@rxin @davies @dongjoon-hyun mind taking a look?

if (validPartitions.contains(index)) {
true
} else {
logInfo(s"Skipping partition $index because all cached batches will be pruned")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Log at debug?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current executor-side pruning logging is done at INFO. I have no strong opinion either way, but this can get noisy with many partitions getting pruned.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on logging at debug.

@ash211
Copy link
Contributor

ash211 commented Aug 30, 2016

Jenkins, this is ok to test.

@@ -106,7 +106,7 @@ case class InMemoryRelation(

private def buildBuffers(): Unit = {
val output = child.output
val cached = child.execute().mapPartitionsInternal { rowIterator =>
val cached = child.execute().mapPartitionsWithIndex { (i, rowIterator) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I'd prefer to use TaskContext.get.getPartitionId to get the partition id. The problem with this change is that it's re-introducing closure cleaning overhead, which mapPartitionsInternal avoids.

@JoshRosen
Copy link
Contributor

This looks pretty cool! I'll come back later tonight / tomorrow to test this out and do a more detailed review pass.

@SparkQA
Copy link

SparkQA commented Aug 30, 2016

Test build #64611 has finished for PR 14733 at commit 1631973.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Aug 30, 2016

Test build #64654 has finished for PR 14733 at commit 610d85d.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Aug 30, 2016

Test build #64659 has finished for PR 14733 at commit 7bf5bb9.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@JoshRosen
Copy link
Contributor

This looks like a legitimate test failure:

[info] - A cached table preserves the partitioning and ordering of its cached SparkPlan *** FAILED *** (1 second, 305 milliseconds)
[info]   Exception thrown while executing query:
[info]   == Parsed Logical Plan ==
[info]   'Project [*]
[info]   +- 'Join Inner, ('t1.key = 't2.a)
[info]      :- 'UnresolvedRelation `t1`, t1
[info]      +- 'UnresolvedRelation `t2`, t2
[info]   
[info]   == Analyzed Logical Plan ==
[info]   key: int, value: string, a: int, b: int
[info]   Project [key#17868, value#17869, a#19094, b#19095]
[info]   +- Join Inner, (key#17868 = a#19094)
[info]      :- SubqueryAlias t1, `t1`
[info]      :  +- RepartitionByExpression [key#17868], 5
[info]      :     +- SerializeFromObject [assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData, true], top level non-flat input object).key AS key#17868, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData, true], top level non-flat input object).value, true) AS value#17869]
[info]      :        +- ExternalRDD [obj#17867]
[info]      +- SubqueryAlias t2, `t2`
[info]         +- RepartitionByExpression [a#19094], 5
[info]            +- SerializeFromObject [assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true], top level non-flat input object).a AS a#19094, assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true], top level non-flat input object).b AS b#19095]
[info]               +- ExternalRDD [obj#19093]
[info]   
[info]   == Optimized Logical Plan ==
[info]   Join Inner, (key#17868 = a#19094)
[info]   :- Filter isnotnull(key#17868)
[info]   :  +- InMemoryRelation [key#17868, value#17869], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas), t1
[info]   :        +- Exchange hashpartitioning(key#17868, 5)
[info]   :           +- *SerializeFromObject [assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData, true], top level non-flat input object).key AS key#17868, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData, true], top level non-flat input object).value, true) AS value#17869]
[info]   :              +- Scan ExternalRDDScan[obj#17867]
[info]   +- Filter isnotnull(a#19094)
[info]      +- InMemoryRelation [a#19094, b#19095], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas), t2
[info]            +- Exchange hashpartitioning(a#19094, 5)
[info]               +- *SerializeFromObject [assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true], top level non-flat input object).a AS a#19094, assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true], top level non-flat input object).b AS b#19095]
[info]                  +- Scan ExternalRDDScan[obj#19093]
[info]   
[info]   == Physical Plan ==
[info]   *SortMergeJoin [key#17868], [a#19094], Inner
[info]   :- *Sort [key#17868 ASC], false, 0
[info]   :  +- *Filter isnotnull(key#17868)
[info]   :     +- InMemoryTableScan [key#17868, value#17869], [isnotnull(key#17868)]
[info]   :           +- InMemoryRelation [key#17868, value#17869], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas), t1
[info]   :                 +- Exchange hashpartitioning(key#17868, 5)
[info]   :                    +- *SerializeFromObject [assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData, true], top level non-flat input object).key AS key#17868, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData, true], top level non-flat input object).value, true) AS value#17869]
[info]   :                       +- Scan ExternalRDDScan[obj#17867]
[info]   +- *Sort [a#19094 ASC], false, 0
[info]      +- *Filter isnotnull(a#19094)
[info]         +- InMemoryTableScan [a#19094, b#19095], [isnotnull(a#19094)]
[info]               +- InMemoryRelation [a#19094, b#19095], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas), t2
[info]                     +- Exchange hashpartitioning(a#19094, 5)
[info]                        +- *SerializeFromObject [assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true], top level non-flat input object).a AS a#19094, assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true], top level non-flat input object).b AS b#19095]
[info]                           +- Scan ExternalRDDScan[obj#19093]
[info]   == Exception ==
[info]   java.lang.IllegalArgumentException: Can't zip RDDs with unequal numbers of partitions: List(5, 3)
[info]   java.lang.IllegalArgumentException: Can't zip RDDs with unequal numbers of partitions: List(5, 3)
[info]      at org.apache.spark.rdd.ZippedPartitionsBaseRDD.getPartitions(ZippedPartitionsRDD.scala:57)
[info]      at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:248)
[info]      at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:246)
[info]      at scala.Option.getOrElse(Option.scala:121)
[info]      at org.apache.spark.rdd.RDD.partitions(RDD.scala:246)
[info]      at org.apache.spark.rdd.MapPartitionsRDD.getPartitions(MapPartitionsRDD.scala:35)
[info]      at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:248)
[info]      at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:246)
[info]      at scala.Option.getOrElse(Option.scala:121)
[info]      at org.apache.spark.rdd.RDD.partitions(RDD.scala:246)
[info]      at org.apache.spark.SparkContext.runJob(SparkContext.scala:1924)
[info]      at org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:912)
[info]      at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
[info]      at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
[info]      at org.apache.spark.rdd.RDD.withScope(RDD.scala:358)
[info]      at org.apache.spark.rdd.RDD.collect(RDD.scala:911)
[info]      at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:276)
[info]      at org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$execute$1$1.apply(Dataset.scala:2226)
[info]      at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:57)
[info]      at org.apache.spark.sql.Dataset.withNewExecutionId(Dataset.scala:2576)
[info]      at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$execute$1(Dataset.scala:2225)
[info]      at org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$collect$1.apply(Dataset.scala:2230)
[info]      at org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$collect$1.apply(Dataset.scala:2230)
[info]      at org.apache.spark.sql.Dataset.withCallback(Dataset.scala:2589)
[info]      at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collect(Dataset.scala:2230)
[info]      at org.apache.spark.sql.Dataset.collect(Dataset.scala:2206)
[info]      at org.apache.spark.sql.QueryTest$.checkAnswer(QueryTest.scala:389)
[info]      at org.apache.spark.sql.QueryTest.checkAnswer(QueryTest.scala:175)
[info]      at org.apache.spark.sql.QueryTest.checkAnswer(QueryTest.scala:186)
[info]      at org.apache.spark.sql.CachedTableSuite$$anonfun$25$$anonfun$apply$mcV$sp$10$$anonfun$apply$mcVI$sp$1.apply$mcV$sp(CachedTableSuite.scala:424)
[info]      at org.apache.spark.sql.test.SQLTestUtils$class.withTempView(SQLTestUtils.scala:155)

@@ -142,13 +171,16 @@ case class InMemoryTableScanExec(
val cachedBatchesToScan =
if (inMemoryPartitionPruningEnabled) {
cachedBatchIterator.filter { cachedBatch =>
val partitionFilter = newPredicate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you choose to nest this more deeply inside of the filter rather than leaving it where it was in mapPartitionsInternal? By moving it here, we'll wind up calling newPredicate once per batch rather than once per partition, thereby harming performance.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was to avoid the call if inMemoryPartitionPruning wasn't enabled. This reasoning is kind of dumb though given that it is the default and if disabled you will pay extra cost elsewhere. I'll move it back.

@JoshRosen
Copy link
Contributor

I believe that the test failure is happening because InMemoryTableScanExec's outputPartitioning says that it has n partitions yet the produced RDD may yield fewer than n partitions:

   // The cached version does not change the outputPartitioning of the original SparkPlan.
   override def outputPartitioning: Partitioning = relation.child.outputPartitioning

I think that this is breaking SortMergeJoinExec because it's assuming that both sides of the join have the same number of partitions:

left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
      val boundCondition: (InternalRow) => Boolean = {
        condition.map { cond =>
          newPredicate(cond, left.output ++ right.output)
        }.getOrElse {
          (r: InternalRow) => true
        }
      }

One possible fix would be to have InMemoryTableScanExec.outputPartitioning copy the parent partitioning and just adjust the number of partitions, but doing that will possibly introduce performance regressions for certain types of joins joins: for example, in hash joins: even though both sides of a hash join may be partitioned according to the same hash function and modulus Spark won't realize this because the declared number of partitions are different. In other words, you might get lucky and wind up joining two co-partitioned cached datasets and then have the co-partitioning get de-optimized because the partition pruning alters the partitioning.

It might be possible to work around this via a custom, SQL-specific version of ZippedPartitionsRDD which understands how to deal with missing partitions and zips according to partition ids (dealing with gaps), but this will be tricky to get right for outer joins (where you still need to produce output for partitions which match to pruned ones).

@pwoody
Copy link
Author

pwoody commented Nov 12, 2016

Closing stale PR

@pwoody pwoody closed this Nov 12, 2016
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants