From 57866c2d0476fbc5cb7a0aa5ed73116e8efb2e66 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 26 Feb 2016 13:36:31 -0800 Subject: [PATCH 1/5] WIP: reuse exchanges in a query --- .../spark/sql/catalyst/plans/QueryPlan.scala | 63 ++++++++++++- .../catalyst/plans/logical/LogicalPlan.scala | 55 +---------- .../plans/physical/broadcastMode.scala | 9 ++ .../spark/sql/execution/SparkPlanInfo.scala | 20 +++- .../aggregate/TungstenAggregate.scala | 4 + .../spark/sql/execution/basicOperators.scala | 3 + .../exchange/BroadcastExchange.scala | 10 +- .../sql/execution/exchange/Exchange.scala | 91 +++++++++++++++++++ .../execution/exchange/ShuffleExchange.scala | 29 +++--- .../sql/execution/joins/HashedRelation.scala | 15 ++- .../sql/execution/ui/SparkPlanGraph.scala | 21 +++-- .../apache/spark/sql/internal/SQLConf.scala | 6 ++ .../spark/sql/internal/SessionState.scala | 5 +- .../org/apache/spark/sql/DataFrameSuite.scala | 32 ++++++- .../spark/sql/execution/ExchangeSuite.scala | 72 ++++++++++++++- .../spark/sql/execution/PlannerSuite.scala | 52 ++++++++++- 16 files changed, 397 insertions(+), 90 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0e0453b517d92..e5888d5864702 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} -abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] { +abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] { self: PlanType => def output: Seq[Attribute] @@ -237,4 +237,65 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy } override def innerChildren: Seq[PlanType] = subqueries + + /** + * Cleaned copy of this query plan. + */ + protected lazy val cleaned: PlanType = this + + /** + * Returns true when the given query plan will return the same results as this query plan. + * + * Since its likely undecidable to generally determine if two given plans will produce the same + * results, it is okay for this function to return false, even if the results are actually + * the same. Such behavior will not affect correctness, only the application of performance + * enhancements like caching. However, it is not acceptable to return true if the results could + * possibly be different. + * + * By default this function performs a modified version of equality that is tolerant of cosmetic + * differences like attribute naming and or expression id differences. Operators that + * can do better should override this function. + */ + def sameResult(plan: PlanType): Boolean = { + val cleanLeft = this.cleaned + val cleanRight = plan.cleaned + cleanLeft.getClass == cleanRight.getClass && + cleanLeft.children.size == cleanRight.children.size && + cleanLeft.cleanArgs == cleanRight.cleanArgs && + (cleanLeft.children, cleanRight.children).zipped.forall(_ sameResult _) + } + + /** + * All the attributes that are used for this plan. + */ + lazy val allAttributes: Seq[Attribute] = children.flatMap(_.output) + + private def cleanExpression(e: Expression): Expression = e match { + case a: Alias => + // As the root of the expression, Alias will always take an arbitrary exprId, we need + // to erase that for equality testing. + val cleanedExprId = + Alias(a.child, a.name)(ExprId(-1), a.qualifiers, isGenerated = a.isGenerated) + BindReferences.bindReference(cleanedExprId.canonicalized, allAttributes, allowFailures = true) + case other => + BindReferences.bindReference(other.canonicalized, allAttributes, allowFailures = true) + } + + /** Args that have cleaned such that differences in expression id should not affect equality */ + protected lazy val cleanArgs: Seq[Any] = { + def cleanArg(arg: Any): Any = arg match { + case e: Expression => cleanExpression(e) + case other => other + } + + productIterator.map { + // Children are checked using sameResult above. + case tn: TreeNode[_] if containsChild(tn) => null + case e: Expression => cleanExpression(e) + case s: Option[_] => s.map(cleanArg) + case s: Seq[_] => s.map(cleanArg) + case m: Map[_, _] => m.mapValues(cleanArg) + case other => other + }.toSeq + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 31e775d60f950..3f851cf6ada73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -114,60 +114,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def childrenResolved: Boolean = children.forall(_.resolved) - /** - * Returns true when the given logical plan will return the same results as this logical plan. - * - * Since its likely undecidable to generally determine if two given plans will produce the same - * results, it is okay for this function to return false, even if the results are actually - * the same. Such behavior will not affect correctness, only the application of performance - * enhancements like caching. However, it is not acceptable to return true if the results could - * possibly be different. - * - * By default this function performs a modified version of equality that is tolerant of cosmetic - * differences like attribute naming and or expression id differences. Logical operators that - * can do better should override this function. - */ - def sameResult(plan: LogicalPlan): Boolean = { - val cleanLeft = EliminateSubqueryAliases(this) - val cleanRight = EliminateSubqueryAliases(plan) - - cleanLeft.getClass == cleanRight.getClass && - cleanLeft.children.size == cleanRight.children.size && { - logDebug( - s"[${cleanRight.cleanArgs.mkString(", ")}] == [${cleanLeft.cleanArgs.mkString(", ")}]") - cleanRight.cleanArgs == cleanLeft.cleanArgs - } && - (cleanLeft.children, cleanRight.children).zipped.forall(_ sameResult _) - } - - /** Args that have cleaned such that differences in expression id should not affect equality */ - protected lazy val cleanArgs: Seq[Any] = { - val input = children.flatMap(_.output) - def cleanExpression(e: Expression) = e match { - case a: Alias => - // As the root of the expression, Alias will always take an arbitrary exprId, we need - // to erase that for equality testing. - val cleanedExprId = - Alias(a.child, a.name)(ExprId(-1), a.qualifiers, isGenerated = a.isGenerated) - BindReferences.bindReference(cleanedExprId, input, allowFailures = true) - case other => BindReferences.bindReference(other, input, allowFailures = true) - } - - productIterator.map { - // Children are checked using sameResult above. - case tn: TreeNode[_] if containsChild(tn) => null - case e: Expression => cleanExpression(e) - case s: Option[_] => s.map { - case e: Expression => cleanExpression(e) - case other => other - } - case s: Seq[_] => s.map { - case e: Expression => cleanExpression(e) - case other => other - } - case other => other - }.toSeq - } + override lazy val cleaned: LogicalPlan = EliminateSubqueryAliases(this) /** * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala index e01f69f81359e..9dfdf4da78ff6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -25,6 +25,11 @@ import org.apache.spark.sql.catalyst.InternalRow */ trait BroadcastMode { def transform(rows: Array[InternalRow]): Any + + /** + * Returns true iff this [[BroadcastMode]] generates the same result as `other`. + */ + def compatibleWith(other: BroadcastMode): Boolean } /** @@ -33,4 +38,8 @@ trait BroadcastMode { case object IdentityBroadcastMode extends BroadcastMode { // TODO: pack the UnsafeRows into single bytes array. override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows + + override def compatibleWith(other: BroadcastMode): Boolean = { + this eq other + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 9019e5dfd66c6..e7275c91b8c46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.exchange.ReusedExchange import org.apache.spark.sql.execution.metric.SQLMetricInfo import org.apache.spark.util.Utils @@ -31,13 +32,26 @@ class SparkPlanInfo( val simpleString: String, val children: Seq[SparkPlanInfo], val metadata: Map[String, String], - val metrics: Seq[SQLMetricInfo]) + val metrics: Seq[SQLMetricInfo]) { + + override def hashCode(): Int = { + simpleString.hashCode + } + + override def equals(other: Any): Boolean = other match { + case o: SparkPlanInfo => + nodeName == o.nodeName && simpleString == o.simpleString && children == o.children + case _ => false + } +} private[sql] object SparkPlanInfo { def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { - - val children = plan.children ++ plan.subqueries + val children = plan match { + case ReusedExchange(_, child) => child :: Nil + case _ => plan.children ++ plan.subqueries + } val metrics = plan.metrics.toSeq.map { case (key, metric) => new SQLMetricInfo(metric.name.getOrElse(key), metric.id, Utils.getFormattedClassName(metric.param)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index a46722963a6e1..ff70a8a3a8516 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -46,6 +46,10 @@ case class TungstenAggregate( require(TungstenAggregate.supportsAggregate(aggregateBufferAttributes)) + override lazy val allAttributes: Seq[Attribute] = + child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"), "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index b2f443c0e9ae6..2134a1fee9f06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -156,6 +156,9 @@ case class Range( private[sql] override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + // output attributes should not affect the results + override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements) + override def upstreams(): Seq[RDD[InternalRow]] = { sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) .map(i => InternalRow(i)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala index 40cad4b1a7645..1a5c6a66c484e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala @@ -34,12 +34,16 @@ import org.apache.spark.util.ThreadUtils */ case class BroadcastExchange( mode: BroadcastMode, - child: SparkPlan) extends UnaryNode { - - override def output: Seq[Attribute] = child.output + child: SparkPlan) extends Exchange { override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) + override def sameResult(plan: SparkPlan): Boolean = plan match { + case p: BroadcastExchange => + mode.compatibleWith(p.mode) && child.sameResult(p.child) + case _ => false + } + @transient private val timeout: Duration = { val timeoutValue = sqlContext.conf.broadcastTimeout diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala new file mode 100644 index 0000000000000..54c2d4e04b88a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} +import org.apache.spark.sql.types.StructType + +/** + * An interface for exchanges. + */ +abstract class Exchange extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +/** + * A wrapper for reused exchange to have different output, which is required to resolve the + * attributes in following plans. + */ +case class ReusedExchange(override val output: Seq[Attribute], child: Exchange) extends LeafNode { + + override def sameResult(plan: SparkPlan): Boolean = { + // Ignore this wrapper. `plan` could also be a ReusedExchange, so we reverse the order here. + plan.sameResult(child) + } + + def doExecute(): RDD[InternalRow] = { + child.execute() + } + + override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + child.executeBroadcast() + } + + // Do not repeat the same tree in explain. + override def treeChildren: Seq[SparkPlan] = Nil +} + +/** + * Find out duplicated exchanges in the spark plan, then use the same exchange for all the + * references. + */ +private[sql] case class ReuseExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { + + def apply(plan: SparkPlan): SparkPlan = { + if (!sqlContext.conf.exchangeReuseEnabled) { + return plan + } + // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. + val exchanges = mutable.HashMap[StructType, ArrayBuffer[Exchange]]() + plan.transformUp { + case exchange: Exchange => + val sameSchema = exchanges.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]()) + val samePlan = sameSchema.find { e => + exchange.sameResult(e) + } + if (samePlan.isDefined) { + // Keep the output of this exchange, the following plans require that to resolve + // attributes. + val reused = ReusedExchange(exchange.output, samePlan.get) + reused + } else { + sameSchema += exchange + exchange + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index de21d7705e137..481b362c4f062 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -38,7 +38,7 @@ import org.apache.spark.util.MutablePair case class ShuffleExchange( var newPartitioning: Partitioning, child: SparkPlan, - @transient coordinator: Option[ExchangeCoordinator]) extends UnaryNode { + @transient coordinator: Option[ExchangeCoordinator]) extends Exchange { override def nodeName: String = { val extraInfo = coordinator match { @@ -55,8 +55,6 @@ case class ShuffleExchange( override def outputPartitioning: Partitioning = newPartitioning - override def output: Seq[Attribute] = child.output - private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) override protected def doPrepare(): Unit = { @@ -103,16 +101,25 @@ case class ShuffleExchange( new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) } + /** + * Caches the created ShuffleRowRDD so we can reuse that. + */ + private var shuffleRDD: ShuffledRowRDD = null + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - coordinator match { - case Some(exchangeCoordinator) => - val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) - assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) - shuffleRDD - case None => - val shuffleDependency = prepareShuffleDependency() - preparePostShuffleRDD(shuffleDependency) + // Returns the same ShuffleRowRDD if this plan is used by multiple plans. + if (shuffleRDD == null) { + shuffleRDD = coordinator match { + case Some(exchangeCoordinator) => + val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) + assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) + shuffleRDD + case None => + val shuffleDependency = prepareShuffleDependency() + preparePostShuffleRDD(shuffleDependency) + } } + shuffleRDD } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 9a3cdaf697e2d..99f8841c8737b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -681,7 +681,7 @@ private[execution] case class HashedRelationBroadcastMode( keys: Seq[Expression], attributes: Seq[Attribute]) extends BroadcastMode { - def transform(rows: Array[InternalRow]): HashedRelation = { + override def transform(rows: Array[InternalRow]): HashedRelation = { val generator = UnsafeProjection.create(keys, attributes) if (canJoinKeyFitWithinLong) { LongHashedRelation(rows.iterator, generator, rows.length) @@ -689,5 +689,18 @@ private[execution] case class HashedRelationBroadcastMode( HashedRelation(rows.iterator, generator, rows.length) } } + + private lazy val canonicalizedKeys: Seq[Expression] = { + keys.map { e => + BindReferences.bindReference(e.canonicalized, attributes) + } + } + + override def compatibleWith(other: BroadcastMode): Boolean = other match { + case m: HashedRelationBroadcastMode => + canJoinKeyFitWithinLong == m.canJoinKeyFitWithinLong && + canonicalizedKeys == m.canonicalizedKeys + case _ => false + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 12e586ada5976..1b211b8cbc3cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -62,7 +62,8 @@ private[sql] object SparkPlanGraph { val nodeIdGenerator = new AtomicLong(0) val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() - buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, null, null) + val exchanges = mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]() + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, null, null, exchanges) new SparkPlanGraph(nodes, edges) } @@ -72,7 +73,8 @@ private[sql] object SparkPlanGraph { nodes: mutable.ArrayBuffer[SparkPlanGraphNode], edges: mutable.ArrayBuffer[SparkPlanGraphEdge], parent: SparkPlanGraphNode, - subgraph: SparkPlanGraphCluster): Unit = { + subgraph: SparkPlanGraphCluster, + exchanges: mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]): Unit = { planInfo.nodeName match { case "WholeStageCodegen" => val cluster = new SparkPlanGraphCluster( @@ -82,13 +84,14 @@ private[sql] object SparkPlanGraph { mutable.ArrayBuffer[SparkPlanGraphNode]()) nodes += cluster buildSparkPlanGraphNode( - planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster) + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster, exchanges) case "InputAdapter" => - buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null) + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) case "Subquery" if subgraph != null => // Subquery should not be included in WholeStageCodegen - buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null) - case _ => + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges) + case name => val metrics = planInfo.metrics.map { metric => SQLPlanMetric(metric.name, metric.accumulatorId, SQLMetrics.getMetricParam(metric.metricParam)) @@ -101,12 +104,16 @@ private[sql] object SparkPlanGraph { } else { subgraph.nodes += node } + // ShuffleExchange or BroadcastExchange + if (name.endsWith("Exchange")) { + exchanges += planInfo -> node + } if (parent != null) { edges += SparkPlanGraphEdge(node.id, parent.id) } planInfo.children.foreach( - buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph)) + buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph, exchanges)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1d1e2884414d8..384102e5eaa5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -504,6 +504,10 @@ object SQLConf { " method", isPublic = false) + val EXCHANGE_REUSE_ENABLED = booleanConf("spark.sql.exchange.reuse", + defaultValue = Some(true), + doc = "When true, the planner will try to find out duplicated exchanges and re-use them", + isPublic = false) object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -564,6 +568,8 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) + def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) + def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index f93a405f77fc7..19b979e4e32ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.{PreInsertCastAndRename, ResolveDataSource} -import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.execution.exchange.{ReuseExchange, EnsureRequirements} import org.apache.spark.sql.util.ExecutionListenerManager @@ -93,7 +93,8 @@ private[sql] class SessionState(ctx: SQLContext) { override val batches: Seq[Batch] = Seq( Batch("Subquery", Once, PlanSubqueries(ctx)), Batch("Add exchange", Once, EnsureRequirements(ctx)), - Batch("Whole stage codegen", Once, CollapseCodegenStages(ctx)) + Batch("Whole stage codegen", Once, CollapseCodegenStages(ctx)), + Batch("Reuse duplicated exchanges", Once, ReuseExchange(ctx)) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a824759cb8955..be6733827c82d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -25,9 +25,9 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} +import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, OneRowRelation, Union} import org.apache.spark.sql.execution.aggregate.TungstenAggregate -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.{BroadcastExchange, ReusedExchange, ShuffleExchange} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} @@ -1318,6 +1318,34 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } + test("reuse exchange") { + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "2") { + val df = sqlContext.range(100) + val join = df.join(df, "id") + val plan = join.queryExecution.executedPlan + checkAnswer(join, df) + assert( + join.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1) + assert(join.queryExecution.executedPlan.collect { case e: ReusedExchange => true }.size === 1) + val broadcasted = broadcast(join) + val join2 = join.join(broadcasted, "id").join(broadcasted, "id") + checkAnswer(join2, df) + assert( + join2.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1) + assert( + join2.queryExecution.executedPlan.collect { case e: BroadcastExchange => true }.size === 1) + assert( + join2.queryExecution.executedPlan.collect { case e: ReusedExchange => true }.size === 4) + } + } + + test("same result on aggregate") { + val df = sqlContext.range(100) + val agg1 = df.groupBy().count() + val agg2 = df.groupBy().count() + agg1.queryExecution.executedPlan.sameResult(agg2.queryExecution.executedPlan) + } + test("SPARK-12512: support `.` in column name for withColumn()") { val df = Seq("a" -> "b").toDF("col.a", "col.b") checkAnswer(df.select(df("*")), Row("a", "b")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index d4f22de90c523..9f159d1e1e8a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.plans.physical.SinglePartition -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} +import org.apache.spark.sql.execution.exchange.{BroadcastExchange, ReusedExchange, ShuffleExchange} +import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { @@ -33,4 +35,70 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { input.map(Row.fromTuple) ) } + + test("compatible BroadcastMode") { + val mode1 = IdentityBroadcastMode + val mode2 = HashedRelationBroadcastMode(true, Literal(1) :: Nil, Seq()) + val mode3 = HashedRelationBroadcastMode(false, Literal("s") :: Nil, Seq()) + + assert(mode1.compatibleWith(mode1)) + assert(!mode1.compatibleWith(mode2)) + assert(!mode2.compatibleWith(mode1)) + assert(mode2.compatibleWith(mode2)) + assert(!mode2.compatibleWith(mode3)) + assert(mode3.compatibleWith(mode3)) + } + + test("BroadcastExchange same result") { + val df = sqlContext.range(10) + val plan = df.queryExecution.executedPlan + val output = plan.output + assert(plan sameResult plan) + + val exchange1 = BroadcastExchange(IdentityBroadcastMode, plan) + val hashMode = HashedRelationBroadcastMode(true, output, plan.output) + val exchange2 = BroadcastExchange(hashMode, plan) + val hashMode2 = + HashedRelationBroadcastMode(true, Alias(output.head, "id2")() :: Nil, plan.output) + val exchange3 = BroadcastExchange(hashMode2, plan) + val exchange4 = ReusedExchange(output, exchange3) + + assert(exchange1 sameResult exchange1) + assert(exchange2 sameResult exchange2) + assert(exchange3 sameResult exchange3) + assert(exchange4 sameResult exchange4) + + assert(!exchange1.sameResult(exchange2)) + assert(!exchange2.sameResult(exchange3)) + assert(!exchange3.sameResult(exchange4)) + assert(exchange4 sameResult exchange3) + } + + test("ShuffleExchange same result") { + val df = sqlContext.range(10) + val plan = df.queryExecution.executedPlan + val output = plan.output + assert(plan sameResult plan) + + val part1 = HashPartitioning(output, 1) + val exchange1 = ShuffleExchange(part1, plan) + val exchange2 = ShuffleExchange(part1, plan) + val part2 = HashPartitioning(output, 2) + val exchange3 = ShuffleExchange(part2, plan) + val part3 = HashPartitioning(output ++ output, 2) + val exchange4 = ShuffleExchange(part3, plan) + val exchange5 = ReusedExchange(output, exchange4) + + assert(exchange1 sameResult exchange1) + assert(exchange2 sameResult exchange2) + assert(exchange3 sameResult exchange3) + assert(exchange4 sameResult exchange4) + assert(exchange5 sameResult exchange5) + + assert(exchange1 sameResult exchange2) + assert(!exchange2.sameResult(exchange3)) + assert(!exchange3.sameResult(exchange4)) + assert(!exchange4.sameResult(exchange5)) + assert(exchange5 sameResult exchange4) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f66e08e6ca5c8..b09b30bbaaacd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,20 +18,20 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchange} +import org.apache.spark.sql.execution +import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.exchange.{ReuseExchange, EnsureRequirements, ReusedExchange, ShuffleExchange} import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ - class PlannerSuite extends SharedSQLContext { import testImplicits._ @@ -472,6 +472,50 @@ class PlannerSuite extends SharedSQLContext { } // --------------------------------------------------------------------------------------------- + + test("Reuse exchanges") { + val distribution = ClusteredDistribution(Literal(1) :: Nil) + val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) + val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) + assert(!childPartitioning.satisfies(distribution)) + val shuffle = ShuffleExchange(finalPartitioning, + DummySparkPlan( + children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, + requiredChildDistribution = Seq(distribution), + requiredChildOrdering = Seq(Seq.empty)), + None) + + val inputPlan = SortMergeJoin( + Literal(1) :: Nil, + Literal(1) :: Nil, + None, + shuffle, + shuffle) + + val outputPlan = ReuseExchange(sqlContext).apply(inputPlan) + if (outputPlan.collect { case e: ReusedExchange => true }.size != 1) { + fail(s"Should re-use the shuffle:\n$outputPlan") + } + if (outputPlan.collect { case e: ShuffleExchange => true }.size != 1) { + fail(s"Should have only one shuffle:\n$outputPlan") + } + + // nested exchanges + val inputPlan2 = SortMergeJoin( + Literal(1) :: Nil, + Literal(1) :: Nil, + None, + ShuffleExchange(finalPartitioning, inputPlan), + ShuffleExchange(finalPartitioning, inputPlan)) + + val outputPlan2 = ReuseExchange(sqlContext).apply(inputPlan2) + if (outputPlan2.collect { case e: ReusedExchange => true }.size != 2) { + fail(s"Should re-use the two shuffles:\n$outputPlan2") + } + if (outputPlan2.collect { case e: ShuffleExchange => true }.size != 2) { + fail(s"Should have only two shuffles:\n$outputPlan") + } + } } // Used for unit-testing EnsureRequirements From f6a7f5c40d4961569ffda11f6084ac58bb630c6c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 4 Mar 2016 11:45:05 -0800 Subject: [PATCH 2/5] fix style --- .../scala/org/apache/spark/sql/internal/SessionState.scala | 2 +- .../scala/org/apache/spark/sql/execution/PlannerSuite.scala | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 19b979e4e32ef..ce69722799bd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.{PreInsertCastAndRename, ResolveDataSource} -import org.apache.spark.sql.execution.exchange.{ReuseExchange, EnsureRequirements} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.util.ExecutionListenerManager diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index b09b30bbaaacd..b59a2c0a7f1c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,14 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row +import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.exchange.{ReuseExchange, EnsureRequirements, ReusedExchange, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchange, ReuseExchange, ShuffleExchange} import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf From 3cc23f3d94e744a96a2f5c1de9d3ebc67ec917e8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 4 Mar 2016 14:44:03 -0800 Subject: [PATCH 3/5] fix flakyness of sameResult --- .../org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index e5888d5864702..c98f747406cc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -276,15 +276,15 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT // to erase that for equality testing. val cleanedExprId = Alias(a.child, a.name)(ExprId(-1), a.qualifiers, isGenerated = a.isGenerated) - BindReferences.bindReference(cleanedExprId.canonicalized, allAttributes, allowFailures = true) + BindReferences.bindReference(cleanedExprId, allAttributes, allowFailures = true) case other => - BindReferences.bindReference(other.canonicalized, allAttributes, allowFailures = true) + BindReferences.bindReference(other, allAttributes, allowFailures = true) } /** Args that have cleaned such that differences in expression id should not affect equality */ protected lazy val cleanArgs: Seq[Any] = { def cleanArg(arg: Any): Any = arg match { - case e: Expression => cleanExpression(e) + case e: Expression => cleanExpression(e).canonicalized case other => other } From 7df43ca78846966b0af8045b924d646d97505925 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 8 Mar 2016 15:35:25 -0800 Subject: [PATCH 4/5] address comments --- .../org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 8 ++++---- .../spark/sql/catalyst/plans/logical/LogicalPlan.scala | 2 +- .../org/apache/spark/sql/execution/SparkPlanInfo.scala | 2 ++ .../apache/spark/sql/execution/exchange/Exchange.scala | 9 +++++---- .../spark/sql/execution/exchange/ShuffleExchange.scala | 8 ++++---- .../apache/spark/sql/execution/ui/SparkPlanGraph.scala | 3 +-- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 1 + 7 files changed, 18 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index a8107a12e8179..638a223acc784 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -239,9 +239,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT override def innerChildren: Seq[PlanType] = subqueries /** - * Cleaned copy of this query plan. + * Canonicalized copy of this query plan. */ - protected lazy val cleaned: PlanType = this + protected lazy val canonicalized: PlanType = this /** * Returns true when the given query plan will return the same results as this query plan. @@ -257,8 +257,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * can do better should override this function. */ def sameResult(plan: PlanType): Boolean = { - val cleanLeft = this.cleaned - val cleanRight = plan.cleaned + val cleanLeft = this.canonicalized + val cleanRight = plan.canonicalized cleanLeft.getClass == cleanRight.getClass && cleanLeft.children.size == cleanRight.children.size && cleanLeft.cleanArgs == cleanRight.cleanArgs && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 3f851cf6ada73..b32c7d0fcbaa4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -114,7 +114,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def childrenResolved: Boolean = children.forall(_.resolved) - override lazy val cleaned: LogicalPlan = EliminateSubqueryAliases(this) + override lazy val canonicalized: LogicalPlan = EliminateSubqueryAliases(this) /** * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index e7275c91b8c46..247f55da1d2a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -35,6 +35,8 @@ class SparkPlanInfo( val metrics: Seq[SQLMetricInfo]) { override def hashCode(): Int = { + // hashCode of simpleString should be good enough to distinguish the plans from each other + // within a plan simpleString.hashCode } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 54c2d4e04b88a..12513e9106707 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -37,8 +37,9 @@ abstract class Exchange extends UnaryNode { } /** - * A wrapper for reused exchange to have different output, which is required to resolve the - * attributes in following plans. + * A wrapper for reused exchange to have different output, because two exchanges which produce + * logically identical output will have distinct sets of output attribute ids, so we need to + * preserve the original ids because they're what downstream operators are expecting. */ case class ReusedExchange(override val output: Seq[Attribute], child: Exchange) extends LeafNode { @@ -73,6 +74,7 @@ private[sql] case class ReuseExchange(sqlContext: SQLContext) extends Rule[Spark val exchanges = mutable.HashMap[StructType, ArrayBuffer[Exchange]]() plan.transformUp { case exchange: Exchange => + // the exchanges that have same results usually also have same schemas (same column names). val sameSchema = exchanges.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]()) val samePlan = sameSchema.find { e => exchange.sameResult(e) @@ -80,8 +82,7 @@ private[sql] case class ReuseExchange(sqlContext: SQLContext) extends Rule[Spark if (samePlan.isDefined) { // Keep the output of this exchange, the following plans require that to resolve // attributes. - val reused = ReusedExchange(exchange.output, samePlan.get) - reused + ReusedExchange(exchange.output, samePlan.get) } else { sameSchema += exchange exchange diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index 481b362c4f062..4eb4d9adbddc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -104,12 +104,12 @@ case class ShuffleExchange( /** * Caches the created ShuffleRowRDD so we can reuse that. */ - private var shuffleRDD: ShuffledRowRDD = null + private var cachedShuffleRDD: ShuffledRowRDD = null protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { // Returns the same ShuffleRowRDD if this plan is used by multiple plans. - if (shuffleRDD == null) { - shuffleRDD = coordinator match { + if (cachedShuffleRDD == null) { + cachedShuffleRDD = coordinator match { case Some(exchangeCoordinator) => val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) @@ -119,7 +119,7 @@ case class ShuffleExchange( preparePostShuffleRDD(shuffleDependency) } } - shuffleRDD + cachedShuffleRDD } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 1b211b8cbc3cb..bbf0c5a68cfbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -104,8 +104,7 @@ private[sql] object SparkPlanGraph { } else { subgraph.nodes += node } - // ShuffleExchange or BroadcastExchange - if (name.endsWith("Exchange")) { + if (name == "ShuffleExchange" || name == "BroadcastExchange") { exchanges += planInfo -> node } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 45462ca1d38d9..40ff92b93e210 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1341,6 +1341,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = sqlContext.range(100) val agg1 = df.groupBy().count() val agg2 = df.groupBy().count() + // two aggregates with different ExprId within them should have same result agg1.queryExecution.executedPlan.sameResult(agg2.queryExecution.executedPlan) } From 7cd6844bb405bfe66285a60b58cd583f86f1a2cb Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 8 Mar 2016 16:41:26 -0800 Subject: [PATCH 5/5] address comments --- .../spark/sql/catalyst/plans/QueryPlan.scala | 14 +++++++------- .../org/apache/spark/sql/DataFrameSuite.scala | 9 +++++++-- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 638a223acc784..371d72ef5af08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -257,12 +257,12 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * can do better should override this function. */ def sameResult(plan: PlanType): Boolean = { - val cleanLeft = this.canonicalized - val cleanRight = plan.canonicalized - cleanLeft.getClass == cleanRight.getClass && - cleanLeft.children.size == cleanRight.children.size && - cleanLeft.cleanArgs == cleanRight.cleanArgs && - (cleanLeft.children, cleanRight.children).zipped.forall(_ sameResult _) + val canonicalizedLeft = this.canonicalized + val canonicalizedRight = plan.canonicalized + canonicalizedLeft.getClass == canonicalizedRight.getClass && + canonicalizedLeft.children.size == canonicalizedRight.children.size && + canonicalizedLeft.cleanArgs == canonicalizedRight.cleanArgs && + (canonicalizedLeft.children, canonicalizedRight.children).zipped.forall(_ sameResult _) } /** @@ -291,7 +291,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT productIterator.map { // Children are checked using sameResult above. case tn: TreeNode[_] if containsChild(tn) => null - case e: Expression => cleanExpression(e) + case e: Expression => cleanArg(e) case s: Option[_] => s.map(cleanArg) case s: Seq[_] => s.map(cleanArg) case m: Map[_, _] => m.mapValues(cleanArg) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 40ff92b93e210..26775c3700e23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1337,12 +1337,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - test("same result on aggregate") { + test("sameResult() on aggregate") { val df = sqlContext.range(100) val agg1 = df.groupBy().count() val agg2 = df.groupBy().count() // two aggregates with different ExprId within them should have same result - agg1.queryExecution.executedPlan.sameResult(agg2.queryExecution.executedPlan) + assert(agg1.queryExecution.executedPlan.sameResult(agg2.queryExecution.executedPlan)) + val agg3 = df.groupBy().sum() + assert(!agg1.queryExecution.executedPlan.sameResult(agg3.queryExecution.executedPlan)) + val df2 = sqlContext.range(101) + val agg4 = df2.groupBy().count() + assert(!agg1.queryExecution.executedPlan.sameResult(agg4.queryExecution.executedPlan)) } test("SPARK-12512: support `.` in column name for withColumn()") {