Skip to content

Commit

Permalink
[SPARK-32332][SQL][3.0] Support columnar exchanges
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Backports SPARK-32332 to 3.0 branch.

### Why are the changes needed?
Plugins cannot replace exchanges with columnar versions when AQE is enabled without this patch.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Tests included.

Closes #29310 from andygrove/backport-SPARK-32332.

Authored-by: Andy Grove <andygrove@nvidia.com>
Signed-off-by: Thomas Graves <tgraves@apache.org>
  • Loading branch information
andygrove authored and tgravescs committed Jul 31, 2020
1 parent 2a38090 commit 7c91b15
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ case class AdaptiveSparkPlanExec(
// The following two rules need to make use of 'CustomShuffleReaderExec.partitionSpecs'
// added by `CoalesceShufflePartitions`. So they must be executed after it.
OptimizeSkewedJoin(conf),
OptimizeLocalShuffleReader(conf),
OptimizeLocalShuffleReader(conf)
)

// A list of physical optimizer rules to be applied right after a new stage is created. The input
// plan to these rules has exchange as its root node.
@transient private val postStageCreationRules = Seq(
ApplyColumnarRulesAndInsertTransitions(conf, context.session.sessionState.columnarRules),
CollapseCodegenStages(conf)
)
Expand Down Expand Up @@ -227,7 +232,8 @@ case class AdaptiveSparkPlanExec(
}

// Run the final plan when there's no more unfinished stages.
currentPhysicalPlan = applyPhysicalRules(result.newPlan, queryStageOptimizerRules)
currentPhysicalPlan = applyPhysicalRules(
result.newPlan, queryStageOptimizerRules ++ postStageCreationRules)
isFinalPlan = true
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
currentPhysicalPlan
Expand Down Expand Up @@ -375,10 +381,22 @@ case class AdaptiveSparkPlanExec(
private def newQueryStage(e: Exchange): QueryStageExec = {
val optimizedPlan = applyPhysicalRules(e.child, queryStageOptimizerRules)
val queryStage = e match {
case s: ShuffleExchangeExec =>
ShuffleQueryStageExec(currentStageId, s.copy(child = optimizedPlan))
case b: BroadcastExchangeExec =>
BroadcastQueryStageExec(currentStageId, b.copy(child = optimizedPlan))
case s: ShuffleExchangeLike =>
val newShuffle = applyPhysicalRules(
s.withNewChildren(Seq(optimizedPlan)), postStageCreationRules)
if (!newShuffle.isInstanceOf[ShuffleExchangeLike]) {
throw new IllegalStateException(
"Custom columnar rules cannot transform shuffle node to something else.")
}
ShuffleQueryStageExec(currentStageId, newShuffle)
case b: BroadcastExchangeLike =>
val newBroadcast = applyPhysicalRules(
b.withNewChildren(Seq(optimizedPlan)), postStageCreationRules)
if (!newBroadcast.isInstanceOf[BroadcastExchangeLike]) {
throw new IllegalStateException(
"Custom columnar rules cannot transform broadcast node to something else.")
}
BroadcastQueryStageExec(currentStageId, newBroadcast)
}
currentStageId += 1
setLogicalLinkForNewQueryStage(queryStage, e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.vectorized.ColumnarBatch


/**
Expand All @@ -38,6 +39,8 @@ case class CustomShuffleReaderExec private(
partitionSpecs: Seq[ShufflePartitionSpec],
description: String) extends UnaryExecNode {

override def supportsColumnar: Boolean = child.supportsColumnar

override def output: Seq[Attribute] = child.output
override lazy val outputPartitioning: Partitioning = {
// If it is a local shuffle reader with one mapper per task, then the output partitioning is
Expand All @@ -47,9 +50,9 @@ case class CustomShuffleReaderExec private(
partitionSpecs.map(_.asInstanceOf[PartialMapperPartitionSpec].mapIndex).toSet.size ==
partitionSpecs.length) {
child match {
case ShuffleQueryStageExec(_, s: ShuffleExchangeExec) =>
case ShuffleQueryStageExec(_, s: ShuffleExchangeLike) =>
s.child.outputPartitioning
case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeExec)) =>
case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeLike)) =>
s.child.outputPartitioning match {
case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning]
case other => other
Expand All @@ -64,18 +67,24 @@ case class CustomShuffleReaderExec private(

override def stringArgs: Iterator[Any] = Iterator(description)

private var cachedShuffleRDD: RDD[InternalRow] = null
private def shuffleStage = child match {
case stage: ShuffleQueryStageExec => Some(stage)
case _ => None
}

override protected def doExecute(): RDD[InternalRow] = {
if (cachedShuffleRDD == null) {
cachedShuffleRDD = child match {
case stage: ShuffleQueryStageExec =>
new ShuffledRowRDD(
stage.shuffle.shuffleDependency, stage.shuffle.readMetrics, partitionSpecs.toArray)
case _ =>
throw new IllegalStateException("operating on canonicalization plan")
}
private lazy val shuffleRDD: RDD[_] = {
shuffleStage.map { stage =>
stage.shuffle.getShuffleRDD(partitionSpecs.toArray)
}.getOrElse {
throw new IllegalStateException("operating on canonicalized plan")
}
cachedShuffleRDD
}

override protected def doExecute(): RDD[InternalRow] = {
shuffleRDD.asInstanceOf[RDD[InternalRow]]
}

override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
shuffleRDD.asInstanceOf[RDD[ColumnarBatch]]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,9 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {
private def getPartitionSpecs(
shuffleStage: ShuffleQueryStageExec,
advisoryParallelism: Option[Int]): Seq[ShufflePartitionSpec] = {
val shuffleDep = shuffleStage.shuffle.shuffleDependency
val numReducers = shuffleDep.partitioner.numPartitions
val numMappers = shuffleStage.shuffle.numMappers
val numReducers = shuffleStage.shuffle.numPartitions
val expectedParallelism = advisoryParallelism.getOrElse(numReducers)
val numMappers = shuffleDep.rdd.getNumPartitions
val splitPoints = if (numMappers == 0) {
Seq.empty
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable

import org.apache.commons.io.FileUtils

import org.apache.spark.{MapOutputTrackerMaster, SparkEnv}
import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
Expand Down Expand Up @@ -197,7 +197,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
val leftParts = if (isLeftSkew && !isLeftCoalesced) {
val reducerId = leftPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize)
left.mapStats.shuffleId, reducerId, leftTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Left side partition $partitionIndex is skewed, split it into " +
s"${skewSpecs.get.length} parts.")
Expand All @@ -212,7 +212,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
val rightParts = if (isRightSkew && !isRightCoalesced) {
val reducerId = rightPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize)
right.mapStats.shuffleId, reducerId, rightTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Right side partition $partitionIndex is skewed, split it into " +
s"${skewSpecs.get.length} parts.")
Expand Down Expand Up @@ -287,15 +287,17 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
private object ShuffleStage {
def unapply(plan: SparkPlan): Option[ShuffleStageInfo] = plan match {
case s: ShuffleQueryStageExec if s.mapStats.isDefined =>
val sizes = s.mapStats.get.bytesByPartitionId
val mapStats = s.mapStats.get
val sizes = mapStats.bytesByPartitionId
val partitions = sizes.zipWithIndex.map {
case (size, i) => CoalescedPartitionSpec(i, i + 1) -> size
}
Some(ShuffleStageInfo(s, partitions))
Some(ShuffleStageInfo(s, mapStats, partitions))

case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs, _)
if s.mapStats.isDefined && partitionSpecs.nonEmpty =>
val sizes = s.mapStats.get.bytesByPartitionId
val mapStats = s.mapStats.get
val sizes = mapStats.bytesByPartitionId
val partitions = partitionSpecs.map {
case spec @ CoalescedPartitionSpec(start, end) =>
var sum = 0L
Expand All @@ -308,14 +310,15 @@ private object ShuffleStage {
case other => throw new IllegalArgumentException(
s"Expect CoalescedPartitionSpec but got $other")
}
Some(ShuffleStageInfo(s, partitions))
Some(ShuffleStageInfo(s, mapStats, partitions))

case _ => None
}
}

private case class ShuffleStageInfo(
shuffleStage: ShuffleQueryStageExec,
mapStats: MapOutputStatistics,
partitionsWithSizes: Seq[(CoalescedPartitionSpec, Long)])

private class SkewDesc {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ThreadUtils

/**
Expand Down Expand Up @@ -80,6 +81,11 @@ abstract class QueryStageExec extends LeafExecNode {

def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec

/**
* Returns the runtime statistics after stage materialization.
*/
def getRuntimeStatistics: Statistics

/**
* Compute the statistics of the query stage if executed, otherwise None.
*/
Expand Down Expand Up @@ -107,6 +113,8 @@ abstract class QueryStageExec extends LeafExecNode {

protected override def doPrepare(): Unit = plan.prepare()
protected override def doExecute(): RDD[InternalRow] = plan.execute()
override def supportsColumnar: Boolean = plan.supportsColumnar
protected override def doExecuteColumnar(): RDD[ColumnarBatch] = plan.executeColumnar()
override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast()
override def doCanonicalize(): SparkPlan = plan.canonicalized

Expand Down Expand Up @@ -135,15 +143,15 @@ abstract class QueryStageExec extends LeafExecNode {
}

/**
* A shuffle query stage whose child is a [[ShuffleExchangeExec]] or [[ReusedExchangeExec]].
* A shuffle query stage whose child is a [[ShuffleExchangeLike]] or [[ReusedExchangeExec]].
*/
case class ShuffleQueryStageExec(
override val id: Int,
override val plan: SparkPlan) extends QueryStageExec {

@transient val shuffle = plan match {
case s: ShuffleExchangeExec => s
case ReusedExchangeExec(_, s: ShuffleExchangeExec) => s
case s: ShuffleExchangeLike => s
case ReusedExchangeExec(_, s: ShuffleExchangeLike) => s
case _ =>
throw new IllegalStateException("wrong plan for shuffle stage:\n " + plan.treeString)
}
Expand Down Expand Up @@ -176,18 +184,20 @@ case class ShuffleQueryStageExec(
val stats = resultOption.get.asInstanceOf[MapOutputStatistics]
Option(stats)
}

override def getRuntimeStatistics: Statistics = shuffle.runtimeStatistics
}

/**
* A broadcast query stage whose child is a [[BroadcastExchangeExec]] or [[ReusedExchangeExec]].
* A broadcast query stage whose child is a [[BroadcastExchangeLike]] or [[ReusedExchangeExec]].
*/
case class BroadcastQueryStageExec(
override val id: Int,
override val plan: SparkPlan) extends QueryStageExec {

@transient val broadcast = plan match {
case b: BroadcastExchangeExec => b
case ReusedExchangeExec(_, b: BroadcastExchangeExec) => b
case b: BroadcastExchangeLike => b
case ReusedExchangeExec(_, b: BroadcastExchangeLike) => b
case _ =>
throw new IllegalStateException("wrong plan for broadcast stage:\n " + plan.treeString)
}
Expand Down Expand Up @@ -224,6 +234,8 @@ case class BroadcastQueryStageExec(
broadcast.relationFuture.cancel(true)
}
}

override def getRuntimeStatistics: Statistics = broadcast.runtimeStatistics
}

object BroadcastQueryStageExec {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.adaptive

import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}

/**
* A simple implementation of [[Cost]], which takes a number of [[Long]] as the cost value.
Expand All @@ -35,13 +35,13 @@ case class SimpleCost(value: Long) extends Cost {

/**
* A simple implementation of [[CostEvaluator]], which counts the number of
* [[ShuffleExchangeExec]] nodes in the plan.
* [[ShuffleExchangeLike]] nodes in the plan.
*/
object SimpleCostEvaluator extends CostEvaluator {

override def evaluateCost(plan: SparkPlan): Cost = {
val cost = plan.collect {
case s: ShuffleExchangeExec => s
case s: ShuffleExchangeLike => s
}.size
SimpleCost(cost)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.HashedRelation
Expand All @@ -37,16 +38,43 @@ import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.util.{SparkFatalException, ThreadUtils}

/**
* Common trait for all broadcast exchange implementations to facilitate pattern matching.
*/
trait BroadcastExchangeLike extends Exchange {

/**
* The broadcast job group ID
*/
def runId: UUID = UUID.randomUUID

/**
* The asynchronous job that prepares the broadcast relation.
*/
def relationFuture: Future[broadcast.Broadcast[Any]]

/**
* For registering callbacks on `relationFuture`.
* Note that calling this method may not start the execution of broadcast job.
*/
def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]]

/**
* Returns the runtime statistics after broadcast materialization.
*/
def runtimeStatistics: Statistics
}

/**
* A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of
* a transformed SparkPlan.
*/
case class BroadcastExchangeExec(
mode: BroadcastMode,
child: SparkPlan) extends Exchange {
child: SparkPlan) extends BroadcastExchangeLike {
import BroadcastExchangeExec._

private[sql] val runId: UUID = UUID.randomUUID
override val runId: UUID = UUID.randomUUID

override lazy val metrics = Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
Expand All @@ -60,6 +88,11 @@ case class BroadcastExchangeExec(
BroadcastExchangeExec(mode.canonicalized, child.canonicalized)
}

override def runtimeStatistics: Statistics = {
val dataSize = metrics("dataSize").value
Statistics(dataSize)
}

@transient
private lazy val promise = Promise[broadcast.Broadcast[Any]]()

Expand All @@ -68,13 +101,14 @@ case class BroadcastExchangeExec(
* Note that calling this field will not start the execution of broadcast job.
*/
@transient
lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] = promise.future
override lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] =
promise.future

@transient
private val timeout: Long = SQLConf.get.broadcastTimeout

@transient
private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
sqlContext.sparkSession, BroadcastExchangeExec.executionContext) {
try {
Expand Down
Loading

0 comments on commit 7c91b15

Please sign in to comment.