Skip to content

Commit

Permalink
[SPARK-21977] SinglePartition optimizations break certain Streaming S…
Browse files Browse the repository at this point in the history
…tateful Aggregation requirements

## What changes were proposed in this pull request?

This is a bit hard to explain as there are several issues here, I'll try my best. Here are the requirements:
  1. A StructuredStreaming Source that can generate empty RDDs with 0 partitions
  2. A StructuredStreaming query that uses the above source, performs a stateful aggregation
     (mapGroupsWithState, groupBy.count, ...), and coalesce's by 1

The crux of the problem is that when a dataset has a `coalesce(1)` call, it receives a `SinglePartition` partitioning scheme. This scheme satisfies most required distributions used for aggregations such as HashAggregateExec. This causes a world of problems:
  Symptom 1. If the input RDD has 0 partitions, the whole lineage will receive 0 partitions, nothing will be executed, the state store will not create any delta files. When this happens, the next trigger fails, because the StateStore fails to load the delta file for the previous trigger
  Symptom 2. Let's say that there was data. Then in this case, if you stop your stream, and change `coalesce(1)` with `coalesce(2)`, then restart your stream, your stream will fail, because `spark.sql.shuffle.partitions - 1` number of StateStores will fail to find its delta files.

To fix the issues above, we must check that the partitioning of the child of a `StatefulOperator` satisfies:
If the grouping expressions are empty:
  a) AllTuple distribution
  b) Single physical partition
If the grouping expressions are non empty:
  a) Clustered distribution
  b) spark.sql.shuffle.partition # of partitions
whether or not `coalesce(1)` exists in the plan, and whether or not the input RDD for the trigger has any data.

Once you fix the above problem by adding an Exchange to the plan, you come across the following bug:
If you call `coalesce(1).groupBy().count()` on a Streaming DataFrame, and if you have a trigger with no data, `StateStoreRestoreExec` doesn't return the prior state. However, for this specific aggregation, `HashAggregateExec` after the restore returns a (0, 0) row, since we're performing a count, and there is no data. Then this data gets stored in `StateStoreSaveExec` causing the previous counts to be overwritten and lost.

## How was this patch tested?

Regression tests

Author: Burak Yavuz <brkyvz@gmail.com>

Closes #19196 from brkyvz/sa-0.
  • Loading branch information
brkyvz committed Sep 20, 2017
1 parent c6ff59a commit 280ff52
Show file tree
Hide file tree
Showing 6 changed files with 395 additions and 21 deletions.
Expand Up @@ -21,11 +21,13 @@ import java.util.UUID
import java.util.concurrent.atomic.AtomicInteger

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, Strategy}
import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode}
import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.streaming.OutputMode

/**
Expand Down Expand Up @@ -89,7 +91,7 @@ class IncrementalExecution(
override def apply(plan: SparkPlan): SparkPlan = plan transform {
case StateStoreSaveExec(keys, None, None, None,
UnaryExecNode(agg,
StateStoreRestoreExec(keys2, None, child))) =>
StateStoreRestoreExec(_, None, child))) =>
val aggStateInfo = nextStatefulOperationStateInfo
StateStoreSaveExec(
keys,
Expand Down Expand Up @@ -117,8 +119,34 @@ class IncrementalExecution(
}
}

override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations
override def preparations: Seq[Rule[SparkPlan]] =
Seq(state, EnsureStatefulOpPartitioning) ++ super.preparations

/** No need assert supported, as this check has already been done */
override def assertSupported(): Unit = { }
}

object EnsureStatefulOpPartitioning extends Rule[SparkPlan] {
// Needs to be transformUp to avoid extra shuffles
override def apply(plan: SparkPlan): SparkPlan = plan transformUp {
case so: StatefulOperator =>
val numPartitions = plan.sqlContext.sessionState.conf.numShufflePartitions
val distributions = so.requiredChildDistribution
val children = so.children.zip(distributions).map { case (child, reqDistribution) =>
val expectedPartitioning = reqDistribution match {
case AllTuples => SinglePartition
case ClusteredDistribution(keys) => HashPartitioning(keys, numPartitions)
case _ => throw new AnalysisException("Unexpected distribution expected for " +
s"Stateful Operator: $so. Expect AllTuples or ClusteredDistribution but got " +
s"$reqDistribution.")
}
if (child.outputPartitioning.guarantees(expectedPartitioning) &&
child.execute().getNumPartitions == expectedPartitioning.numPartitions) {
child
} else {
ShuffleExchange(expectedPartitioning, child)
}
}
so.withNewChildren(children)
}
}
Expand Up @@ -829,6 +829,7 @@ class StreamExecution(
if (streamDeathCause != null) {
throw streamDeathCause
}
if (!isActive) return
awaitBatchLock.lock()
try {
noNewData = false
Expand Down
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
Expand Down Expand Up @@ -200,18 +200,35 @@ case class StateStoreRestoreExec(
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
iter.flatMap { row =>
val key = getKey(row)
val savedState = store.get(key)
numOutputRows += 1
row +: Option(savedState).toSeq
val hasInput = iter.hasNext
if (!hasInput && keyExpressions.isEmpty) {
// If our `keyExpressions` are empty, we're getting a global aggregation. In that case
// the `HashAggregateExec` will output a 0 value for the partial merge. We need to
// restore the value, so that we don't overwrite our state with a 0 value, but rather
// merge the 0 with existing state.
store.iterator().map(_.value)
} else {
iter.flatMap { row =>
val key = getKey(row)
val savedState = store.get(key)
numOutputRows += 1
row +: Option(savedState).toSeq
}
}
}
}

override def output: Seq[Attribute] = child.output

override def outputPartitioning: Partitioning = child.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] = {
if (keyExpressions.isEmpty) {
AllTuples :: Nil
} else {
ClusteredDistribution(keyExpressions) :: Nil
}
}
}

/**
Expand Down Expand Up @@ -351,6 +368,14 @@ case class StateStoreSaveExec(
override def output: Seq[Attribute] = child.output

override def outputPartitioning: Partitioning = child.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] = {
if (keyExpressions.isEmpty) {
AllTuples :: Nil
} else {
ClusteredDistribution(keyExpressions) :: Nil
}
}
}

/** Physical operator for executing streaming Deduplicate. */
Expand Down
@@ -0,0 +1,132 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.streaming

import java.util.UUID

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode}
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange}
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo}
import org.apache.spark.sql.test.SharedSQLContext

class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext {

import testImplicits._
super.beforeAll()

This comment has been minimized.

Copy link
@srowen

srowen Sep 21, 2017

Member

@brkyvz this test is actually failing consistently in master -- it's actually manually calling beforeAll and tests in the constructor. I have a fix I can submit


private val baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char")

testEnsureStatefulOpPartitioning(
"ClusteredDistribution generates Exchange with HashPartitioning",
baseDf.queryExecution.sparkPlan,
requiredDistribution = keys => ClusteredDistribution(keys),
expectedPartitioning =
keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions),
expectShuffle = true)

testEnsureStatefulOpPartitioning(
"ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning",
baseDf.coalesce(1).queryExecution.sparkPlan,
requiredDistribution = keys => ClusteredDistribution(keys),
expectedPartitioning =
keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions),
expectShuffle = true)

testEnsureStatefulOpPartitioning(
"AllTuples generates Exchange with SinglePartition",
baseDf.queryExecution.sparkPlan,
requiredDistribution = _ => AllTuples,
expectedPartitioning = _ => SinglePartition,
expectShuffle = true)

testEnsureStatefulOpPartitioning(
"AllTuples with coalesce(1) doesn't need Exchange",
baseDf.coalesce(1).queryExecution.sparkPlan,
requiredDistribution = _ => AllTuples,
expectedPartitioning = _ => SinglePartition,
expectShuffle = false)

/**
* For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan
* `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to
* ensure the expected partitioning.
*/
private def testEnsureStatefulOpPartitioning(
testName: String,
inputPlan: SparkPlan,
requiredDistribution: Seq[Attribute] => Distribution,
expectedPartitioning: Seq[Attribute] => Partitioning,
expectShuffle: Boolean): Unit = {
test(testName) {
val operator = TestStatefulOperator(inputPlan, requiredDistribution(inputPlan.output.take(1)))
val executed = executePlan(operator, OutputMode.Complete())
if (expectShuffle) {
val exchange = executed.children.find(_.isInstanceOf[Exchange])
if (exchange.isEmpty) {
fail(s"Was expecting an exchange but didn't get one in:\n$executed")
}
assert(exchange.get ===
ShuffleExchange(expectedPartitioning(inputPlan.output.take(1)), inputPlan),
s"Exchange didn't have expected properties:\n${exchange.get}")
} else {
assert(!executed.children.exists(_.isInstanceOf[Exchange]),
s"Unexpected exchange found in:\n$executed")
}
}
}

/** Executes a SparkPlan using the IncrementalPlanner used for Structured Streaming. */
private def executePlan(
p: SparkPlan,
outputMode: OutputMode = OutputMode.Append()): SparkPlan = {
val execution = new IncrementalExecution(
spark,
null,
OutputMode.Complete(),
"chk",
UUID.randomUUID(),
0L,
OffsetSeqMetadata()) {
override lazy val sparkPlan: SparkPlan = p transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
plan transformExpressions {
case UnresolvedAttribute(Seq(u)) =>
inputMap.getOrElse(u,
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
}
}
}
execution.executedPlan
}
}

/** Used to emulate a `StatefulOperator` with the given requiredDistribution. */
case class TestStatefulOperator(
child: SparkPlan,
requiredDist: Distribution) extends UnaryExecNode with StatefulOperator {
override def output: Seq[Attribute] = child.output
override def doExecute(): RDD[InternalRow] = child.execute()
override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil
override def stateInfo: Option[StatefulOperatorStateInfo] = None
}
Expand Up @@ -167,7 +167,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
case class StartStream(
trigger: Trigger = Trigger.ProcessingTime(0),
triggerClock: Clock = new SystemClock,
additionalConfs: Map[String, String] = Map.empty)
additionalConfs: Map[String, String] = Map.empty,
checkpointLocation: String = null)
extends StreamAction

/** Advance the trigger clock's time manually. */
Expand Down Expand Up @@ -349,20 +350,22 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
""".stripMargin)
}

val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
var manualClockExpectedTime = -1L
val defaultCheckpointLocation =
Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
try {
startedTest.foreach { action =>
logInfo(s"Processing test stream action: $action")
action match {
case StartStream(trigger, triggerClock, additionalConfs) =>
case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) =>
verify(currentStream == null, "stream already running")
verify(triggerClock.isInstanceOf[SystemClock]
|| triggerClock.isInstanceOf[StreamManualClock],
"Use either SystemClock or StreamManualClock to start the stream")
if (triggerClock.isInstanceOf[StreamManualClock]) {
manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis()
}
val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation)

additionalConfs.foreach(pair => {
val value =
Expand Down Expand Up @@ -479,7 +482,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
verify(currentStream != null || lastStream != null,
"cannot assert when no stream has been started")
val streamToAssert = Option(currentStream).getOrElse(lastStream)
verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}")
try {
verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}")
} catch {
case NonFatal(e) =>
failTest(s"Assert on query failed: ${a.message}", e)
}

case a: Assert =>
val streamToAssert = Option(currentStream).getOrElse(lastStream)
Expand Down

0 comments on commit 280ff52

Please sign in to comment.