From 961ade120f7a179751e5ec45b24e159259de0bae Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 10 Oct 2017 15:02:48 -0700 Subject: [PATCH 01/10] Fix plan resolution bug caused by EnsureStatefulOpPartitioning --- .../execution/basicPhysicalOperators.scala | 24 +++++++++++++++++-- .../streaming/IncrementalExecution.scala | 10 ++++---- .../EnsureStatefulOpPartitioningSuite.scala | 2 +- .../streaming/StreamingAggregationSuite.scala | 4 ++-- .../sql/streaming/StreamingQuerySuite.scala | 13 ++++++++++ 5 files changed, 43 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 63cd1691f4cd7..d276d2461cb2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration -import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -590,10 +590,30 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN } protected override def doExecute(): RDD[InternalRow] = { - child.execute().coalesce(numPartitions, shuffle = false) + if (child.execute().getNumPartitions < 1) { + new CoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions) + } else { + child.execute().coalesce(numPartitions, shuffle = false) + } } } +object CoalesceExec { + class EmptyRDDWithPartitions( + sc: SparkContext, + numPartitions: Int) extends RDD[InternalRow](sc, Nil) { + + override def getPartitions: Array[Partition] = + Array.tabulate(numPartitions)(i => SimplePartition(i)) + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + Iterator.empty + } + } + + case class SimplePartition(index: Int) extends Partition +} + /** * Physical plan for a subquery. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 82f879c763c2b..faff2b3b99346 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode /** @@ -131,17 +132,17 @@ class IncrementalExecution( } override def preparations: Seq[Rule[SparkPlan]] = - Seq(state, EnsureStatefulOpPartitioning) ++ super.preparations + Seq(state, EnsureStatefulOpPartitioning(sparkSession.sessionState.conf)) ++ super.preparations /** No need assert supported, as this check has already been done */ override def assertSupported(): Unit = { } } -object EnsureStatefulOpPartitioning extends Rule[SparkPlan] { +case class EnsureStatefulOpPartitioning(conf: SQLConf) extends Rule[SparkPlan] { // Needs to be transformUp to avoid extra shuffles override def apply(plan: SparkPlan): SparkPlan = plan transformUp { case so: StatefulOperator => - val numPartitions = plan.sqlContext.sessionState.conf.numShufflePartitions + val numPartitions = conf.numShufflePartitions val distributions = so.requiredChildDistribution val children = so.children.zip(distributions).map { case (child, reqDistribution) => val expectedPartitioning = reqDistribution match { @@ -151,8 +152,7 @@ object EnsureStatefulOpPartitioning extends Rule[SparkPlan] { s"Stateful Operator: $so. Expect AllTuples or ClusteredDistribution but got " + s"$reqDistribution.") } - if (child.outputPartitioning.guarantees(expectedPartitioning) && - child.execute().getNumPartitions == expectedPartitioning.numPartitions) { + if (child.outputPartitioning.guarantees(expectedPartitioning)) { child } else { ShuffleExchangeExec(expectedPartitioning, child) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala index ed9823fbddfda..69ee7822f8103 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} -import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.test.SharedSQLContext class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index fe7efa69f7e31..c049ae031b7b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -455,8 +455,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest }, AddBlockData(inputSource), // create an empty trigger CheckLastBatch(1), - AssertOnQuery("Verify addition of exchange operator") { se => - checkAggregationChain(se, expectShuffling = true, 1) + AssertOnQuery("Verify that no exchange is required") { se => + checkAggregationChain(se, expectShuffling = false, 1) }, AddBlockData(inputSource, Seq(2, 3)), CheckLastBatch(3), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index ab35079dca23f..c5c8422efe59a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -652,6 +652,19 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + test("don't check for RDD partitions during streaming aggregation preparation") { + val stream = MemoryStream[(Int, Int)] + val baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char").where("char = 'A'") + val otherDf = stream.toDF().toDF("num", "numSq") + .join(broadcast(baseDf), "num") + .groupBy('char) + .agg(sum('numSq)) + + testStream(otherDf, OutputMode.Complete())( + AddData(stream, (1, 1), (2, 4)), + CheckLastBatch(("A", 1))) + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) From a8db9ade19d5921ef3f965222bf95f65b9d57265 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 10 Oct 2017 16:07:39 -0700 Subject: [PATCH 02/10] minor additions --- .../spark/sql/execution/basicPhysicalOperators.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index d276d2461cb2c..075668dafdfb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -590,7 +590,9 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN } protected override def doExecute(): RDD[InternalRow] = { - if (child.execute().getNumPartitions < 1) { + if (numPartitions == 1 && child.execute().getNumPartitions < 1) { + // Make sure we don't output an RDD with 0 partitions, when claiming that we have a + // `SinglePartition`. new CoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions) } else { child.execute().coalesce(numPartitions, shuffle = false) @@ -599,8 +601,9 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN } object CoalesceExec { + /** A simple RDD with no data, but with the given number of partitions. */ class EmptyRDDWithPartitions( - sc: SparkContext, + @transient private val sc: SparkContext, numPartitions: Int) extends RDD[InternalRow](sc, Nil) { override def getPartitions: Array[Partition] = From 549b88237e727bbf9f00a0bd59d19ca072478dc5 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 10 Oct 2017 16:08:17 -0700 Subject: [PATCH 03/10] add jira ticket to test --- .../org/apache/spark/sql/streaming/StreamingQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index c5c8422efe59a..c53889bb8566c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -652,7 +652,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } - test("don't check for RDD partitions during streaming aggregation preparation") { + test("SPARK-22238: don't check for RDD partitions during streaming aggregation preparation") { val stream = MemoryStream[(Int, Int)] val baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char").where("char = 'A'") val otherDf = stream.toDF().toDF("num", "numSq") From 70211ca3d3bb83a831634dfd9a9f5534eafdcdad Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 12 Oct 2017 09:01:30 -0700 Subject: [PATCH 04/10] address comments --- .../plans/physical/partitioning.scala | 15 +++++--- .../exchange/EnsureRequirements.scala | 3 +- .../FlatMapGroupsWithStateExec.scala | 2 +- .../streaming/IncrementalExecution.scala | 37 +++++-------------- .../streaming/statefulOperators.scala | 11 +++--- .../streaming/state/StateStoreRDDSuite.scala | 2 +- .../SymmetricHashJoinStateManagerSuite.scala | 2 +- .../streaming/StreamingAggregationSuite.scala | 2 +- .../sql/streaming/StreamingJoinSuite.scala | 2 +- 9 files changed, 33 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 51d78dd1233fe..87e7835dfcf97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -49,7 +49,9 @@ case object AllTuples extends Distribution * can mean such tuples are either co-located in the same partition or they will be contiguous * within a single partition. */ -case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution { +case class ClusteredDistribution( + clustering: Seq[Expression], + numPartitions: Option[Int] = None) extends Distribution { require( clustering != Nil, "The clustering expressions of a ClusteredDistribution should not be Nil. " + @@ -221,6 +223,7 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false + case ClusteredDistribution(_, desiredPartitions) => desiredPartitions.exists(_ == 1) case _ => true } @@ -243,8 +246,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true - case ClusteredDistribution(requiredClustering) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + case ClusteredDistribution(requiredClustering, desiredPartitions) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && + desiredPartitions.forall(_ == numPartitions) // if desiredPartition = true, returns true case _ => false } @@ -289,8 +293,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) + case ClusteredDistribution(requiredClustering, desiredPartitions) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && + desiredPartitions.forall(_ == numPartitions) // if desiredPartition = true, returns true case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index d28ce60e276d5..eca4a1ca2be9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -50,7 +50,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { numPartitions: Int): Partitioning = { requiredDistribution match { case AllTuples => SinglePartition - case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) + case ClusteredDistribution(clustering, desiredPartitions) => + HashPartitioning(clustering, desiredPartitions.getOrElse(numPartitions)) case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) case dist => sys.error(s"Do not know how to satisfy distribution $dist") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index aab06d611a5ea..c81f1a8142784 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -64,7 +64,7 @@ case class FlatMapGroupsWithStateExec( /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes) :: Nil + ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: Nil /** Ordering needed for using GroupingIterator */ override def requiredChildOrdering: Seq[Seq[SortOrder]] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index faff2b3b99346..2e378637727fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -62,6 +62,10 @@ class IncrementalExecution( StreamingDeduplicationStrategy :: Nil } + private val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) + .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter) + .getOrElse(sparkSession.sessionState.conf.numShufflePartitions) + /** * See [SPARK-18339] * Walk the optimized logical plan and replace CurrentBatchTimestamp @@ -84,7 +88,11 @@ class IncrementalExecution( /** Get the state info of the next stateful operator */ private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = { StatefulOperatorStateInfo( - checkpointLocation, runId, statefulOperatorId.getAndIncrement(), currentBatchId) + checkpointLocation, + runId, + statefulOperatorId.getAndIncrement(), + currentBatchId, + numStateStores) } /** Locates save/restore pairs surrounding aggregation. */ @@ -131,33 +139,8 @@ class IncrementalExecution( } } - override def preparations: Seq[Rule[SparkPlan]] = - Seq(state, EnsureStatefulOpPartitioning(sparkSession.sessionState.conf)) ++ super.preparations + override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations /** No need assert supported, as this check has already been done */ override def assertSupported(): Unit = { } } - -case class EnsureStatefulOpPartitioning(conf: SQLConf) extends Rule[SparkPlan] { - // Needs to be transformUp to avoid extra shuffles - override def apply(plan: SparkPlan): SparkPlan = plan transformUp { - case so: StatefulOperator => - val numPartitions = conf.numShufflePartitions - val distributions = so.requiredChildDistribution - val children = so.children.zip(distributions).map { case (child, reqDistribution) => - val expectedPartitioning = reqDistribution match { - case AllTuples => SinglePartition - case ClusteredDistribution(keys) => HashPartitioning(keys, numPartitions) - case _ => throw new AnalysisException("Unexpected distribution expected for " + - s"Stateful Operator: $so. Expect AllTuples or ClusteredDistribution but got " + - s"$reqDistribution.") - } - if (child.outputPartitioning.guarantees(expectedPartitioning)) { - child - } else { - ShuffleExchangeExec(expectedPartitioning, child) - } - } - so.withNewChildren(children) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 0d85542928ee6..b9b07a2e688f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -43,10 +43,11 @@ case class StatefulOperatorStateInfo( checkpointLocation: String, queryRunId: UUID, operatorId: Long, - storeVersion: Long) { + storeVersion: Long, + numPartitions: Int) { override def toString(): String = { s"state info [ checkpoint = $checkpointLocation, runId = $queryRunId, " + - s"opId = $operatorId, ver = $storeVersion]" + s"opId = $operatorId, ver = $storeVersion, numPartitions = $numPartitions]" } } @@ -239,7 +240,7 @@ case class StateStoreRestoreExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } } @@ -386,7 +387,7 @@ case class StateStoreSaveExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } } @@ -401,7 +402,7 @@ case class StreamingDeduplicateExec( /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(keyExpressions) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index defb9ed63a881..fa99a8dd0a493 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -214,7 +214,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn path: String, queryRunId: UUID = UUID.randomUUID, version: Int = 0): StatefulOperatorStateInfo = { - StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version) + StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version, 5) } private val increment = (store: StateStore, iter: Iterator[String]) => { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index d44af1d14c27a..c0216a2ef3e61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -160,7 +160,7 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter withTempDir { file => val storeConf = new StateStoreConf() - val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0) + val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) val manager = new SymmetricHashJoinStateManager( LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration) try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index c049ae031b7b2..4b82da676a678 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -495,7 +495,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AssertOnQuery("Verify addition of exchange operator") { se => checkAggregationChain( se, - expectShuffling = true, + expectShuffling = false, spark.sessionState.conf.numShufflePartitions) }, StopStream diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index a6593b71e51de..d32617275aadc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -330,7 +330,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with val queryId = UUID.randomUUID val opId = 0 val path = Utils.createDirectory(tempDir.getAbsolutePath, Random.nextString(10)).toString - val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L) + val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L, 5) implicit val sqlContext = spark.sqlContext val coordinatorRef = sqlContext.streams.stateStoreCoordinator From 48d1f25ccae8c8456a9a5d9c373208e573da1027 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 12 Oct 2017 09:04:36 -0700 Subject: [PATCH 05/10] add coalesce test --- .../apache/spark/sql/execution/basicPhysicalOperators.scala | 4 ++-- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 075668dafdfb5..d15ece304cac4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -607,14 +607,14 @@ object CoalesceExec { numPartitions: Int) extends RDD[InternalRow](sc, Nil) { override def getPartitions: Array[Partition] = - Array.tabulate(numPartitions)(i => SimplePartition(i)) + Array.tabulate(numPartitions)(i => EmptyPartition(i)) override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { Iterator.empty } } - case class SimplePartition(index: Int) extends Partition + case class EmptyPartition(index: Int) extends Partition } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index dd8f54b690f64..1268207565206 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -368,6 +368,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( testData.select('key).coalesce(1).select('key), testData.select('key).collect().toSeq) + + assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 1) } test("convert $\"attribute name\" into unresolved attribute") { From 3f51c5ce7592e50d922864eb4d9643b984f73cba Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 12 Oct 2017 11:21:49 -0700 Subject: [PATCH 06/10] Fix test --- .../apache/spark/sql/catalyst/plans/physical/partitioning.scala | 2 +- .../apache/spark/sql/streaming/StreamingAggregationSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 87e7835dfcf97..cf7ff2d05750b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -223,7 +223,7 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false - case ClusteredDistribution(_, desiredPartitions) => desiredPartitions.exists(_ == 1) + case ClusteredDistribution(_, desiredPartitions) => desiredPartitions.forall(_ == 1) case _ => true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 4b82da676a678..c049ae031b7b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -495,7 +495,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AssertOnQuery("Verify addition of exchange operator") { se => checkAggregationChain( se, - expectShuffling = false, + expectShuffling = true, spark.sessionState.conf.numShufflePartitions) }, StopStream From 407d76c57aea09232c2ca92df2eab4fc0d2d8838 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 12 Oct 2017 13:27:07 -0700 Subject: [PATCH 07/10] savE --- .../sql/streaming/EnsureStatefulOpPartitioningSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala index 69ee7822f8103..84123b2a6448f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala @@ -44,7 +44,7 @@ class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLCont test("ClusteredDistribution generates Exchange with HashPartitioning") { testEnsureStatefulOpPartitioning( baseDf.queryExecution.sparkPlan, - requiredDistribution = keys => ClusteredDistribution(keys), + requiredDistribution = keys => ClusteredDistribution(keys, Some(5)), expectedPartitioning = keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), expectShuffle = true) @@ -53,7 +53,7 @@ class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLCont test("ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning") { testEnsureStatefulOpPartitioning( baseDf.coalesce(1).queryExecution.sparkPlan, - requiredDistribution = keys => ClusteredDistribution(keys), + requiredDistribution = keys => ClusteredDistribution(keys, Some(5)), expectedPartitioning = keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), expectShuffle = true) From 51221175458f28788352a0c315fa14577b5f0c87 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 12 Oct 2017 14:39:32 -0700 Subject: [PATCH 08/10] refactor tests --- .../sql/streaming/DeduplicateSuite.scala | 10 +- .../EnsureStatefulOpPartitioningSuite.scala | 138 ------------------ .../FlatMapGroupsWithStateSuite.scala | 6 +- .../sql/streaming/StatefulOperatorTest.scala | 49 +++++++ .../streaming/StreamingAggregationSuite.scala | 3 +- 5 files changed, 64 insertions(+), 142 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala index e858b7d9998a8..8f50b0ea60dab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.streaming import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplicateExec} import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ -class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class DeduplicateSuite extends StateStoreMetricsTest + with BeforeAndAfterAll + with StatefulOperatorTest { import testImplicits._ @@ -41,6 +44,8 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { AddData(inputData, "a"), CheckLastBatch("a"), assertNumStateRows(total = 1, updated = 1), + AssertOnQuery(sq => + checkChildOutputPartitioning[StreamingDeduplicateExec](sq, SinglePartition)), AddData(inputData, "a"), CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), @@ -58,6 +63,7 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { AddData(inputData, "a" -> 1), CheckLastBatch("a" -> 1), assertNumStateRows(total = 1, updated = 1), + AssertOnQuery(sq => checkChildOutputPartitioning[StreamingDeduplicateExec](sq, Seq("_1"))), AddData(inputData, "a" -> 2), // Dropped CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala deleted file mode 100644 index 84123b2a6448f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import java.util.UUID - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} -import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.test.SharedSQLContext - -class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext { - - import testImplicits._ - - private var baseDf: DataFrame = null - - override def beforeAll(): Unit = { - super.beforeAll() - baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char") - } - - test("ClusteredDistribution generates Exchange with HashPartitioning") { - testEnsureStatefulOpPartitioning( - baseDf.queryExecution.sparkPlan, - requiredDistribution = keys => ClusteredDistribution(keys, Some(5)), - expectedPartitioning = - keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), - expectShuffle = true) - } - - test("ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning") { - testEnsureStatefulOpPartitioning( - baseDf.coalesce(1).queryExecution.sparkPlan, - requiredDistribution = keys => ClusteredDistribution(keys, Some(5)), - expectedPartitioning = - keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), - expectShuffle = true) - } - - test("AllTuples generates Exchange with SinglePartition") { - testEnsureStatefulOpPartitioning( - baseDf.queryExecution.sparkPlan, - requiredDistribution = _ => AllTuples, - expectedPartitioning = _ => SinglePartition, - expectShuffle = true) - } - - test("AllTuples with coalesce(1) doesn't need Exchange") { - testEnsureStatefulOpPartitioning( - baseDf.coalesce(1).queryExecution.sparkPlan, - requiredDistribution = _ => AllTuples, - expectedPartitioning = _ => SinglePartition, - expectShuffle = false) - } - - /** - * For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan - * `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to - * ensure the expected partitioning. - */ - private def testEnsureStatefulOpPartitioning( - inputPlan: SparkPlan, - requiredDistribution: Seq[Attribute] => Distribution, - expectedPartitioning: Seq[Attribute] => Partitioning, - expectShuffle: Boolean): Unit = { - val operator = TestStatefulOperator(inputPlan, requiredDistribution(inputPlan.output.take(1))) - val executed = executePlan(operator, OutputMode.Complete()) - if (expectShuffle) { - val exchange = executed.children.find(_.isInstanceOf[Exchange]) - if (exchange.isEmpty) { - fail(s"Was expecting an exchange but didn't get one in:\n$executed") - } - assert(exchange.get === - ShuffleExchangeExec(expectedPartitioning(inputPlan.output.take(1)), inputPlan), - s"Exchange didn't have expected properties:\n${exchange.get}") - } else { - assert(!executed.children.exists(_.isInstanceOf[Exchange]), - s"Unexpected exchange found in:\n$executed") - } - } - - /** Executes a SparkPlan using the IncrementalPlanner used for Structured Streaming. */ - private def executePlan( - p: SparkPlan, - outputMode: OutputMode = OutputMode.Append()): SparkPlan = { - val execution = new IncrementalExecution( - spark, - null, - OutputMode.Complete(), - "chk", - UUID.randomUUID(), - 0L, - OffsetSeqMetadata()) { - override lazy val sparkPlan: SparkPlan = p transform { - case plan: SparkPlan => - val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap - plan transformExpressions { - case UnresolvedAttribute(Seq(u)) => - inputMap.getOrElse(u, - sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) - } - } - } - execution.executedPlan - } -} - -/** Used to emulate a `StatefulOperator` with the given requiredDistribution. */ -case class TestStatefulOperator( - child: SparkPlan, - requiredDist: Distribution) extends UnaryExecNode with StatefulOperator { - override def output: Seq[Attribute] = child.output - override def doExecute(): RDD[InternalRow] = child.execute() - override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil - override def stateInfo: Option[StatefulOperatorStateInfo] = None -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index d2e8beb2f5290..2226eff5eb848 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -41,7 +41,9 @@ case class RunningCount(count: Long) case class Result(key: Long, count: Int) -class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest + with BeforeAndAfterAll + with StatefulOperatorTest { import testImplicits._ import GroupStateImpl._ @@ -544,6 +546,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf AddData(inputData, "a"), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), + AssertOnQuery(sq => checkChildOutputPartitioning[FlatMapGroupsWithStateExec]( + sq, Seq("value"))), AddData(inputData, "a", "b"), CheckLastBatch(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala new file mode 100644 index 0000000000000..e720f5f1697cc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.streaming._ + +trait StatefulOperatorTest { + /** + * Check that the output partitioning of a child operator of a Stateful operator satisfies the + * distribution that we expect for our Stateful operator. + */ + protected def checkChildOutputPartitioning[T <: StatefulOperator]( + sq: StreamingQuery, + colNames: Seq[String], + numPartitions: Option[Int] = None): Boolean = { + val attr = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.analyzed.output + val partitions = numPartitions.getOrElse(sq.sparkSession.sessionState.conf.numShufflePartitions) + val groupingAttr = attr.filter(a => colNames.contains(a.name)) + checkChildOutputPartitioning(sq, HashPartitioning(groupingAttr, partitions)) + } + + /** + * Check that the output partitioning of a child operator of a Stateful operator satisfies the + * distribution that we expect for our Stateful operator. + */ + protected def checkChildOutputPartitioning[T <: StatefulOperator]( + sq: StreamingQuery, + expectedPartitioning: Partitioning): Boolean = { + val operator = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution + .executedPlan.collect { case p: T => p } + operator.head.children.forall(_.outputPartitioning == expectedPartitioning) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index c049ae031b7b2..89c171581ed77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -44,7 +44,7 @@ object FailureSingleton { } class StreamingAggregationSuite extends StateStoreMetricsTest - with BeforeAndAfterAll with Assertions { + with BeforeAndAfterAll with Assertions with StatefulOperatorTest { override def afterAll(): Unit = { super.afterAll() @@ -248,6 +248,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest testStream(aggregated, Update)( StartStream(), AddData(inputData, 1, 2, 3, 4), + AssertOnQuery(sq => checkChildOutputPartitioning[StateStoreRestoreExec](sq, Seq("value"))), ExpectFailure[SparkException](), StartStream(), CheckLastBatch((1, 1), (2, 1), (3, 1), (4, 1)) From 84ac2d84fff94a54a01d92f0ecf00c1f9ace4203 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 13 Oct 2017 13:58:27 -0700 Subject: [PATCH 09/10] address --- .../catalyst/plans/physical/partitioning.scala | 4 ++-- .../execution/exchange/EnsureRequirements.scala | 2 ++ .../spark/sql/execution/PlannerSuite.scala | 17 +++++++++++++++++ .../streaming/state/StateStoreRDDSuite.scala | 2 +- .../spark/sql/streaming/DeduplicateSuite.scala | 3 ++- .../streaming/FlatMapGroupsWithStateSuite.scala | 2 +- .../sql/streaming/StatefulOperatorTest.scala | 11 +++++------ .../streaming/StreamingAggregationSuite.scala | 3 ++- 8 files changed, 32 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index cf7ff2d05750b..e57c842ce2a36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -248,7 +248,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case UnspecifiedDistribution => true case ClusteredDistribution(requiredClustering, desiredPartitions) => expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && - desiredPartitions.forall(_ == numPartitions) // if desiredPartition = true, returns true + desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true case _ => false } @@ -295,7 +295,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) requiredOrdering.take(minSize) == ordering.take(minSize) case ClusteredDistribution(requiredClustering, desiredPartitions) => ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && - desiredPartitions.forall(_ == numPartitions) // if desiredPartition = true, returns true + desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index eca4a1ca2be9f..4e2ca37bc1a59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -44,6 +44,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { /** * Given a required distribution, returns a partitioning that satisfies that distribution. + * @param requiredDistribution The distribution that is required by the operator + * @param numPartitions Used when the distribution doesn't require a specific number of partitions */ private def createPartitioning( requiredDistribution: Distribution, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 86066362da9dd..c25c90d0c70e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -425,6 +425,23 @@ class PlannerSuite extends SharedSQLContext { } } + test("EnsureRequirements should respect ClusteredDistribution's num partitioning") { + val distribution = ClusteredDistribution(Literal(1) :: Nil, Some(13)) + // Number of partitions differ + val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 13) + val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5) + assert(!childPartitioning.satisfies(distribution)) + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, + requiredChildDistribution = Seq(distribution), + requiredChildOrdering = Seq(Seq.empty)) + + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) + val shuffle = outputPlan.collect { case e: ShuffleExchangeExec => e } + assert(shuffle.size === 1) + assert(shuffle.head.newPartitioning === finalPartitioning) + } + test("Reuse exchanges") { val distribution = ClusteredDistribution(Literal(1) :: Nil) val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index fa99a8dd0a493..65b39f0fbd73d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -214,7 +214,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn path: String, queryRunId: UUID = UUID.randomUUID, version: Int = 0): StatefulOperatorStateInfo = { - StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version, 5) + StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version, numPartitions = 5) } private val increment = (store: StateStore, iter: Iterator[String]) => { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala index 8f50b0ea60dab..2788e309ef28c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -63,7 +63,8 @@ class DeduplicateSuite extends StateStoreMetricsTest AddData(inputData, "a" -> 1), CheckLastBatch("a" -> 1), assertNumStateRows(total = 1, updated = 1), - AssertOnQuery(sq => checkChildOutputPartitioning[StreamingDeduplicateExec](sq, Seq("_1"))), + AssertOnQuery(sq => + checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("_1"))), AddData(inputData, "a" -> 2), // Dropped CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 2226eff5eb848..aeb83835f981a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -546,7 +546,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest AddData(inputData, "a"), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), - AssertOnQuery(sq => checkChildOutputPartitioning[FlatMapGroupsWithStateExec]( + AssertOnQuery(sq => checkChildOutputHashPartitioning[FlatMapGroupsWithStateExec]( sq, Seq("value"))), AddData(inputData, "a", "b"), CheckLastBatch(("a", "2"), ("b", "1")), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala index e720f5f1697cc..936fdb7366daf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala @@ -25,12 +25,11 @@ trait StatefulOperatorTest { * Check that the output partitioning of a child operator of a Stateful operator satisfies the * distribution that we expect for our Stateful operator. */ - protected def checkChildOutputPartitioning[T <: StatefulOperator]( + protected def checkChildOutputHashPartitioning[T <: StatefulOperator]( sq: StreamingQuery, - colNames: Seq[String], - numPartitions: Option[Int] = None): Boolean = { - val attr = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.analyzed.output - val partitions = numPartitions.getOrElse(sq.sparkSession.sessionState.conf.numShufflePartitions) + colNames: Seq[String]): Boolean = { + val attr = sq.asInstanceOf[StreamExecution].lastExecution.analyzed.output + val partitions = sq.sparkSession.sessionState.conf.numShufflePartitions val groupingAttr = attr.filter(a => colNames.contains(a.name)) checkChildOutputPartitioning(sq, HashPartitioning(groupingAttr, partitions)) } @@ -42,7 +41,7 @@ trait StatefulOperatorTest { protected def checkChildOutputPartitioning[T <: StatefulOperator]( sq: StreamingQuery, expectedPartitioning: Partitioning): Boolean = { - val operator = sq.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution + val operator = sq.asInstanceOf[StreamExecution].lastExecution .executedPlan.collect { case p: T => p } operator.head.children.forall(_.outputPartitioning == expectedPartitioning) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 89c171581ed77..15085bbfb684b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -248,7 +248,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest testStream(aggregated, Update)( StartStream(), AddData(inputData, 1, 2, 3, 4), - AssertOnQuery(sq => checkChildOutputPartitioning[StateStoreRestoreExec](sq, Seq("value"))), + AssertOnQuery(sq => + checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, Seq("value"))), ExpectFailure[SparkException](), StartStream(), CheckLastBatch((1, 1), (2, 1), (3, 1), (4, 1)) From 971f57963a3f8f7f5b0481441fef387c80920048 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 13 Oct 2017 16:41:42 -0700 Subject: [PATCH 10/10] fix tests --- .../org/apache/spark/sql/streaming/DeduplicateSuite.scala | 2 +- .../org/apache/spark/sql/streaming/StatefulOperatorTest.scala | 3 ++- .../spark/sql/streaming/StreamingAggregationSuite.scala | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala index 2788e309ef28c..caf2bab8a5859 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -45,7 +45,7 @@ class DeduplicateSuite extends StateStoreMetricsTest CheckLastBatch("a"), assertNumStateRows(total = 1, updated = 1), AssertOnQuery(sq => - checkChildOutputPartitioning[StreamingDeduplicateExec](sq, SinglePartition)), + checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("value"))), AddData(inputData, "a"), CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala index 936fdb7366daf..45142278993bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala @@ -43,6 +43,7 @@ trait StatefulOperatorTest { expectedPartitioning: Partitioning): Boolean = { val operator = sq.asInstanceOf[StreamExecution].lastExecution .executedPlan.collect { case p: T => p } - operator.head.children.forall(_.outputPartitioning == expectedPartitioning) + operator.head.children.forall( + _.outputPartitioning.numPartitions == expectedPartitioning.numPartitions) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 15085bbfb684b..1b4d8556f6ae5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -248,8 +248,6 @@ class StreamingAggregationSuite extends StateStoreMetricsTest testStream(aggregated, Update)( StartStream(), AddData(inputData, 1, 2, 3, 4), - AssertOnQuery(sq => - checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, Seq("value"))), ExpectFailure[SparkException](), StartStream(), CheckLastBatch((1, 1), (2, 1), (3, 1), (4, 1)) @@ -283,6 +281,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AddData(inputData, 0L, 5L, 5L, 10L), AdvanceManualClock(10 * 1000), CheckLastBatch((0L, 1), (5L, 2), (10L, 1)), + AssertOnQuery(sq => + checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, Seq("value"))), // advance clock to 20 seconds, should retain keys >= 10 AddData(inputData, 15L, 15L, 20L),