Skip to content

Commit

Permalink
[SPARK-23004][SS] Ensure StateStore.commit is called only once in a s…
Browse files Browse the repository at this point in the history
…treaming aggregation task

## What changes were proposed in this pull request?

A structured streaming query with a streaming aggregation can throw the following error in rare cases. 

```
java.lang.IllegalStateException: Cannot commit after already committed or aborted
	at org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider.org$apache$spark$sql$execution$streaming$state$HDFSBackedStateStoreProvider$$verify(HDFSBackedStateStoreProvider.scala:643)
	at org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider$HDFSBackedStateStore.commit(HDFSBackedStateStoreProvider.scala:135)
	at org.apache.spark.sql.execution.streaming.StateStoreSaveExec$$anonfun$doExecute$3$$anon$2$$anonfun$hasNext$2.apply$mcV$sp(statefulOperators.scala:359)
	at org.apache.spark.sql.execution.streaming.StateStoreWriter$class.timeTakenMs(statefulOperators.scala:102)
	at org.apache.spark.sql.execution.streaming.StateStoreSaveExec.timeTakenMs(statefulOperators.scala:251)
	at org.apache.spark.sql.execution.streaming.StateStoreSaveExec$$anonfun$doExecute$3$$anon$2.hasNext(statefulOperators.scala:359)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:188)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.<init>(ObjectAggregationIterator.scala:78)
	at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:114)
	at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:105)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:830)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:830)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:42)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:336)
```

This can happen when the following conditions are accidentally hit. 
 - Streaming aggregation with aggregation function that is a subset of [`TypedImperativeAggregation`](https://github.com/apache/spark/blob/76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala#L473) (for example, `collect_set`, `collect_list`, `percentile`, etc.). 
 - Query running in `update}` mode
 - After the shuffle, a partition has exactly 128 records. 

This causes StateStore.commit to be called twice. See the [JIRA](https://issues.apache.org/jira/browse/SPARK-23004) for a more detailed explanation. The solution is to use `NextIterator` or `CompletionIterator`, each of which has a flag to prevent the "onCompletion" task from being called more than once. In this PR, I chose to implement using `NextIterator`.

## How was this patch tested?

Added unit test that I have confirm will fail without the fix.

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #21124 from tdas/SPARK-23004.

(cherry picked from commit 770add8)
Signed-off-by: Tathagata Das <tathagata.das1565@gmail.com>
  • Loading branch information
tdas committed Apr 23, 2018
1 parent c2f4ee7 commit 8eb9a41
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -340,37 +340,35 @@ case class StateStoreSaveExec(
// Update and output modified rows from the StateStore.
case Some(Update) =>

val updatesStartTimeNs = System.nanoTime

new Iterator[InternalRow] {

new NextIterator[InternalRow] {
// Filter late date using watermark if specified
private[this] val baseIterator = watermarkPredicateForData match {
case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
case None => iter
}
private val updatesStartTimeNs = System.nanoTime

override def hasNext: Boolean = {
if (!baseIterator.hasNext) {
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)

// Remove old aggregates if watermark specified
allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
commitTimeMs += timeTakenMs { store.commit() }
setStoreMetrics(store)
false
override protected def getNext(): InternalRow = {
if (baseIterator.hasNext) {
val row = baseIterator.next().asInstanceOf[UnsafeRow]
val key = getKey(row)
store.put(key, row)
numOutputRows += 1
numUpdatedStateRows += 1
row
} else {
true
finished = true
null
}
}

override def next(): InternalRow = {
val row = baseIterator.next().asInstanceOf[UnsafeRow]
val key = getKey(row)
store.put(key, row)
numOutputRows += 1
numUpdatedStateRows += 1
row
override protected def close(): Unit = {
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)

// Remove old aggregates if watermark specified
allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
commitTimeMs += timeTakenMs { store.commit() }
setStoreMetrics(store)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,31 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
)
}

test("SPARK-23004: Ensure that TypedImperativeAggregate functions do not throw errors") {
// See the JIRA SPARK-23004 for more details. In short, this test reproduces the error
// by ensuring the following.
// - A streaming query with a streaming aggregation.
// - Aggregation function 'collect_list' that is a subclass of TypedImperativeAggregate.
// - Post shuffle partition has exactly 128 records (i.e. the threshold at which
// ObjectHashAggregateExec falls back to sort-based aggregation). This is done by having a
// micro-batch with 128 records that shuffle to a single partition.
// This test throws the exact error reported in SPARK-23004 without the corresponding fix.
withSQLConf("spark.sql.shuffle.partitions" -> "1") {
val input = MemoryStream[Int]
val df = input.toDF().toDF("value")
.selectExpr("value as group", "value")
.groupBy("group")
.agg(collect_list("value"))
testStream(df, outputMode = OutputMode.Update)(
AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*),
AssertOnQuery { q =>
q.processAllAvailable()
true
}
)
}
}

/** Add blocks of data to the `BlockRDDBackedSource`. */
case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData {
override def addData(query: Option[StreamExecution]): (Source, Offset) = {
Expand Down

0 comments on commit 8eb9a41

Please sign in to comment.