From efb8cd36de0e0b96a8efe855b6243c1f553ce195 Mon Sep 17 00:00:00 2001 From: jinxing Date: Sat, 22 Jul 2017 12:22:36 +0800 Subject: [PATCH] Add a config to enable adaptive query execution only for the last query execution. --- .../apache/spark/sql/internal/SQLConf.scala | 7 ++++ .../exchange/EnsureRequirements.scala | 31 ++++++++++---- .../exchange/ExchangeCoordinator.scala | 9 +++++ .../execution/exchange/ShuffleExchange.scala | 2 +- .../execution/ExchangeCoordinatorSuite.scala | 40 ++++++++++++++++++- 5 files changed, 80 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 55558ca9f700c..01a029cc7cc74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -204,6 +204,11 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ADAPTIVE_ONLY_FOR_LAST_SHUFFLE = buildConf("spark.sql.adaptiveOnlyForLastShuffle") + .doc("When true, adaptive query execution is enabled only for the last shuffle.") + .booleanConf + .createWithDefault(false) + val SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS = buildConf("spark.sql.adaptive.minNumPostShufflePartitions") .internal() @@ -970,6 +975,8 @@ class SQLConf extends Serializable with Logging { def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) + def adaptiveOnlyForLastShuffle: Boolean = getConf(ADAPTIVE_ONLY_FOR_LAST_SHUFFLE) + def minNumPostShufflePartitions: Int = getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index b91d077442557..d6fd480d5b10f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -37,6 +37,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { private def adaptiveExecutionEnabled: Boolean = conf.adaptiveExecutionEnabled + private def adaptiveOnlyForLastShuffle: Boolean = conf.adaptiveOnlyForLastShuffle + private def minNumPostShufflePartitions: Option[Int] = { val minNumPostShufflePartitions = conf.minNumPostShufflePartitions if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None @@ -258,13 +260,28 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { operator.withNewChildren(children) } - def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case operator @ ShuffleExchange(partitioning, child, _) => - child.children match { - case ShuffleExchange(childPartitioning, baseChild, _)::Nil => - if (childPartitioning.guarantees(partitioning)) child else operator - case _ => operator + def apply(plan: SparkPlan): SparkPlan = { + var ret = plan.transformUp { + case operator @ ShuffleExchange(partitioning, child, _) => + child.children match { + case ShuffleExchange(childPartitioning, baseChild, _)::Nil => + if (childPartitioning.guarantees(partitioning)) child else operator + case _ => operator + } + case operator: SparkPlan => ensureDistributionAndOrdering(operator) + } + if (adaptiveOnlyForLastShuffle) { + var rootCoordinator: Option[ExchangeCoordinator] = None + ret = ret transformDown { + case operator @ ShuffleExchange(_, _, Some(coordinator)) => + if (rootCoordinator.isEmpty) { + rootCoordinator = Some(coordinator) + } else if (coordinator != rootCoordinator.get) { + coordinator.deactivate + } + operator } - case operator: SparkPlan => ensureDistributionAndOrdering(operator) + } + ret } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index deb2c24d0f16e..c13a504b906b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -100,6 +100,9 @@ class ExchangeCoordinator( // synchronized. @volatile private[this] var estimated: Boolean = false + // A boolean that indicates if this coordinator is active for adaptive query execution. + @volatile private[this] var active: Boolean = true + /** * Registers a [[ShuffleExchange]] operator to this coordinator. This method is only allowed to * be called in the `doPrepare` method of a [[ShuffleExchange]] operator. @@ -111,6 +114,12 @@ class ExchangeCoordinator( def isEstimated: Boolean = estimated + def deactivate: Unit = { + active = false + } + + def isActive: Boolean = active + /** * Estimates partition start indices for post-shuffle partitions based on * mapOutputStatistics provided by all pre-shuffle stages. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index eebe6ad2e7944..86d045f44cc5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -116,7 +116,7 @@ case class ShuffleExchange( // Returns the same ShuffleRowRDD if this plan is used by multiple plans. if (cachedShuffleRDD == null) { cachedShuffleRDD = coordinator match { - case Some(exchangeCoordinator) => + case Some(exchangeCoordinator) if exchangeCoordinator.isActive => val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) shuffleRDD diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 06bce9a2400e7..12a79d3b4303d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -252,7 +252,8 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { def withSparkSession( f: SparkSession => Unit, targetNumPostShufflePartitions: Int, - minNumPostShufflePartitions: Option[Int]): Unit = { + minNumPostShufflePartitions: Option[Int], + adaptiveOnlyForLastShuffle: Boolean = false): Unit = { val sparkConf = new SparkConf(false) .setMaster("local[*]") @@ -265,6 +266,8 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { .set( SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, targetNumPostShufflePartitions.toString) + .set("spark.sql.adaptiveOnlyForLastShuffle", + if (adaptiveOnlyForLastShuffle) "true" else "false") minNumPostShufflePartitions match { case Some(numPartitions) => sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, numPartitions.toString) @@ -480,4 +483,39 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { withSparkSession(test, 6144, minNumPostShufflePartitions) } } + + test("Enable adaptive query execution only for last shuffle.") { + val test = { + spark: SparkSession => + val df = spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 20 as key", "id as value") + .groupBy("key").sum("value").toDF("newId", "newValue") + .selectExpr("newId % 2 as key", "newValue") + .groupBy("key").sum("newValue") + + // Check the answer first + val expected = spark.range(0, 1000) + .selectExpr("id % 2 as key", "id as value") + .groupBy("key").sum("value").collect() + checkAnswer(df, expected) + + // Check the number of active coordinator + var activeCoordinatorOpt: Option[ExchangeCoordinator] = None + var nonActiveCoordinatorOpt: Option[ExchangeCoordinator] = None + df.queryExecution.executedPlan transformDown { + case operator @ ShuffleExchange(_, _, Some(coordinator)) => + if (coordinator.isActive) { + assert(activeCoordinatorOpt.isEmpty && nonActiveCoordinatorOpt.isEmpty) + activeCoordinatorOpt = Some(coordinator) + } else { + assert(activeCoordinatorOpt.isDefined && nonActiveCoordinatorOpt.isEmpty) + nonActiveCoordinatorOpt = Some(coordinator) + } + operator + } + assert(activeCoordinatorOpt.isDefined && nonActiveCoordinatorOpt.isDefined) + } + withSparkSession(test, 2000, None, true) + } }