Skip to content

Commit

Permalink
Support coalesce table cache stage partitions
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you committed Mar 29, 2023
1 parent 8982cee commit 229a57c
Show file tree
Hide file tree
Showing 19 changed files with 633 additions and 156 deletions.
6 changes: 4 additions & 2 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,7 +1421,8 @@ def cache(self) -> "DataFrame":
>>> df.explain()
== Physical Plan ==
InMemoryTableScan ...
AdaptiveSparkPlan isFinalPlan=false
+- InMemoryTableScan ...
"""
self.is_cached = True
self._jdf.cache()
Expand Down Expand Up @@ -1463,7 +1464,8 @@ def persist(
>>> df.explain()
== Physical Plan ==
InMemoryTableScan ...
AdaptiveSparkPlan isFinalPlan=false
+- InMemoryTableScan ...
Persists the data in the disk by specifying the storage level.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,15 @@ object SQLConf {
.stringConf
.createOptional

val COALESCE_CACHE_PARTITIONS_ENABLED =
buildConf("spark.sql.adaptive.coalesceCachePartitions.enabled")
.doc(s"When true and '${ADAPTIVE_EXECUTION_ENABLED.key}' is true, Spark will coalesce " +
"contiguous table cache partitions according to the target size (specified by " +
s"'${ADVISORY_PARTITION_SIZE_IN_BYTES.key}'), to avoid too many small tasks.")
.version("3.5.0")
.booleanConf
.createWithDefault(true)

val SUBEXPRESSION_ELIMINATION_ENABLED =
buildConf("spark.sql.subexpressionElimination.enabled")
.internal()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import scala.reflect.ClassTag

import org.apache.spark.{Dependency, NarrowDependency, Partition, TaskContext}
import org.apache.spark.rdd.RDD

/**
* The [[Partition]] used by [[CachedRDD]].
*/
case class CachedRDDPartition(
index: Int,
originalPartitions: Array[Partition],
@transient originalPreferredLocations: Seq[String]) extends Partition

/**
* It wraps the real cached RDD with coalesced partitions.
*
* @param prev The real cached RDD
* @param partitionSpecs the coalesced partitions
*/
class CachedRDD[T: ClassTag](
@transient var prev: RDD[T],
partitionSpecs: Seq[CoalescedPartitionSpec])
extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies

override protected def getPartitions: Array[Partition] = {
Array.tabulate[Partition](partitionSpecs.length) { i =>
val spec = partitionSpecs(i)
val originalPartitions = spec.startReducerIndex.until(spec.endReducerIndex)
.map(prev.partitions).toArray
val originalPreferredLocations = originalPartitions.flatMap(prev.preferredLocations)
.distinct.toSeq
CachedRDDPartition(i, originalPartitions, originalPreferredLocations)
}
}

override protected def getPreferredLocations(split: Partition): Seq[String] = {
split.asInstanceOf[CachedRDDPartition].originalPreferredLocations
}

override def compute(split: Partition, context: TaskContext): Iterator[T] = {
split.asInstanceOf[CachedRDDPartition].originalPartitions.iterator.flatMap { partition =>
firstParent[T].iterator(partition, context)
}
}

override def getDependencies: Seq[Dependency[_]] = {
Seq(new NarrowDependency(prev) {
def getParents(id: Int): Seq[Int] =
partitions(id).asInstanceOf[CachedRDDPartition].originalPartitions.map(_.index).toSeq
})
}

override def clearDependencies(): Unit = {
super.clearDependencies()
prev = null
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.adaptive

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{CachedRDD, CoalescedPartitionSpec, ShufflePartitionSpec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
* A wrapper of table cache query stage, which follows the given partition arrangement.
* The RDD cache block is based on partition level, so we can not split the partition if it's
* skewed. When [[AQECacheReadExec]] happen that means there are some partitions can be coalesced.
*
* @param child It should always be [[TableCacheQueryStageExec]].
* @param partitionSpecs The partition specs that defines the arrangement, requires at least one
* partition.
*/
case class AQECacheReadExec(
child: SparkPlan,
partitionSpecs: Seq[ShufflePartitionSpec]) extends AQERead {
assert(partitionSpecs.forall(_.isInstanceOf[CoalescedPartitionSpec]))

override def outputPartitioning: Partitioning = {
outputPartitionWithCoalesced(partitionSpecs.length)
}

override lazy val metrics: Map[String, SQLMetric] =
Map("numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions"))

private def updateMetrics(): Unit = {
metrics("numPartitions") += partitionSpecs.length

val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
}

override def stringArgs: Iterator[Any] = Iterator("coalesced")

override protected def doExecute(): RDD[InternalRow] = {
updateMetrics()
val rdd = child.execute()
new CachedRDD(rdd, partitionSpecs.asInstanceOf[Seq[CoalescedPartitionSpec]])
}

override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
updateMetrics()
val rdd = child.executeColumnar()
new CachedRDD(rdd, partitionSpecs.asInstanceOf[Seq[CoalescedPartitionSpec]])
}

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.adaptive

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition, UnknownPartitioning}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan, UnaryExecNode}

abstract class AQERead extends UnaryExecNode {
def child: SparkPlan
def partitionSpecs: Seq[ShufflePartitionSpec]

assert(partitionSpecs.nonEmpty, s"${getClass.getSimpleName} requires at least one partition")

override final def output: Seq[Attribute] = child.output
override final def supportsColumnar: Boolean = child.supportsColumnar
override final def supportsRowBased: Boolean = child.supportsRowBased

def outputPartitionWithCoalesced(numPartitions: Int): Partitioning = {
// For coalesced shuffle read, the data distribution is not changed, only the number of
// partitions is changed.
child.outputPartitioning match {
case h: HashPartitioning =>
CurrentOrigin.withOrigin(h.origin)(h.copy(numPartitions = numPartitions))
case r: RangePartitioning =>
CurrentOrigin.withOrigin(r.origin)(r.copy(numPartitions = numPartitions))
// This can only happen for `REBALANCE_PARTITIONS_BY_NONE`, which uses
// `RoundRobinPartitioning` but we don't need to retain the number of partitions.
case r: RoundRobinPartitioning =>
r.copy(numPartitions = numPartitions)
case other@SinglePartition =>
throw new IllegalStateException(
"Unexpected partitioning for coalesced shuffle read: " + other)
case _ =>
// Spark plugins may have custom partitioning and may replace this operator
// during the postStageOptimization phase, so return UnknownPartitioning here
// rather than throw an exception
UnknownPartitioning(numPartitions)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition, UnknownPartitioning}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
Expand All @@ -39,19 +38,13 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
*/
case class AQEShuffleReadExec private(
child: SparkPlan,
partitionSpecs: Seq[ShufflePartitionSpec]) extends UnaryExecNode {
assert(partitionSpecs.nonEmpty, s"${getClass.getSimpleName} requires at least one partition")

partitionSpecs: Seq[ShufflePartitionSpec]) extends AQERead {
// If this is to read shuffle files locally, then all partition specs should be
// `PartialMapperPartitionSpec`.
if (partitionSpecs.exists(_.isInstanceOf[PartialMapperPartitionSpec])) {
assert(partitionSpecs.forall(_.isInstanceOf[PartialMapperPartitionSpec]))
}

override def supportsColumnar: Boolean = child.supportsColumnar

override def output: Seq[Attribute] = child.output

override lazy val outputPartitioning: Partitioning = {
// If it is a local shuffle read with one mapper per task, then the output partitioning is
// the same as the plan before shuffle.
Expand All @@ -71,26 +64,7 @@ case class AQEShuffleReadExec private(
throw new IllegalStateException("operating on canonicalization plan")
}
} else if (isCoalescedRead) {
// For coalesced shuffle read, the data distribution is not changed, only the number of
// partitions is changed.
child.outputPartitioning match {
case h: HashPartitioning =>
CurrentOrigin.withOrigin(h.origin)(h.copy(numPartitions = partitionSpecs.length))
case r: RangePartitioning =>
CurrentOrigin.withOrigin(r.origin)(r.copy(numPartitions = partitionSpecs.length))
// This can only happen for `REBALANCE_PARTITIONS_BY_NONE`, which uses
// `RoundRobinPartitioning` but we don't need to retain the number of partitions.
case r: RoundRobinPartitioning =>
r.copy(numPartitions = partitionSpecs.length)
case other @ SinglePartition =>
throw new IllegalStateException(
"Unexpected partitioning for coalesced shuffle read: " + other)
case _ =>
// Spark plugins may have custom partitioning and may replace this operator
// during the postStageOptimization phase, so return UnknownPartitioning here
// rather than throw an exception
UnknownPartitioning(partitionSpecs.length)
}
outputPartitionWithCoalesced(partitionSpecs.length)
} else {
UnknownPartitioning(partitionSpecs.length)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ case class AdaptiveSparkPlanExec(
CoalesceShufflePartitions(context.session),
// `OptimizeShuffleWithLocalRead` needs to make use of 'AQEShuffleReadExec.partitionSpecs'
// added by `CoalesceShufflePartitions`, and must be executed after it.
OptimizeShuffleWithLocalRead
OptimizeShuffleWithLocalRead,
CoalesceCachePartitions(context.session)
)

// This rule is stateful as it maintains the codegen stage ID. We can't create a fresh one every
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.adaptive

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan, UnaryExecNode, UnionExec}
import org.apache.spark.sql.internal.SQLConf

/**
* A rule to coalesce the cache partitions based on the statistics, which can
* avoid many small reduce tasks that hurt performance.
*/
case class CoalesceCachePartitions(session: SparkSession) extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = {
if (!conf.getConf(SQLConf.COALESCE_CACHE_PARTITIONS_ENABLED)) {
return plan
}

val coalesceGroups = collectCoalesceGroups(plan)
val groups = coalesceGroups.map { tableCacheStages =>
val stageIds = tableCacheStages.map(_.id)
val bytesByPartitionIds = tableCacheStages.map(_.outputStats().map(_.bytesByPartitionId))
val inputPartitionSpecs = Seq.fill(bytesByPartitionIds.length)(None)
(tableCacheStages.map(_.id),
conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES),
bytesByPartitionIds,
inputPartitionSpecs,
s"For table cache stage(${stageIds.mkString(", ")})")
}
val specsMap = ShufflePartitionsUtil.coalescePartitionsByGroup(
groups, session.sparkContext.defaultParallelism)
if (specsMap.nonEmpty) {
updateCacheReads(plan, specsMap)
} else {
plan
}
}

private def updateCacheReads(
plan: SparkPlan,
specsMap: Map[Int, Seq[ShufflePartitionSpec]]): SparkPlan = plan match {
case stage: TableCacheQueryStageExec if specsMap.contains(stage.id) =>
AQECacheReadExec(stage, specsMap(stage.id))
case other => other.mapChildren(updateCacheReads(_, specsMap))
}

private def collectCoalesceGroups(
plan: SparkPlan): Seq[Seq[TableCacheQueryStageExec]] = plan match {
case unary: UnaryExecNode => collectCoalesceGroups(unary.child)
case union: UnionExec => union.children.flatMap(collectCoalesceGroups)
case p if p.collectLeaves().forall(_.isInstanceOf[TableCacheQueryStageExec]) =>
collectTableCacheStages(p) :: Nil
case _ => Seq.empty
}

private def collectTableCacheStages(plan: SparkPlan): Seq[TableCacheQueryStageExec] = plan match {
case tableCacheStage: TableCacheQueryStageExec => Seq(tableCacheStage)
case _ => plan.children.flatMap(collectTableCacheStages)
}
}
Loading

0 comments on commit 229a57c

Please sign in to comment.