Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
tdas committed Apr 23, 2018
1 parent 0b19122 commit a78ba37
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -628,4 +628,6 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
throw new IllegalStateException(msg)
}
}

override protected def logName: String = s"${super.logName} $stateStoreId"
}
Original file line number Diff line number Diff line change
Expand Up @@ -340,37 +340,35 @@ case class StateStoreSaveExec(
// Update and output modified rows from the StateStore.
case Some(Update) =>

val updatesStartTimeNs = System.nanoTime

new Iterator[InternalRow] {

new NextIterator[InternalRow] {
// Filter late date using watermark if specified
private[this] val baseIterator = watermarkPredicateForData match {
case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
case None => iter
}
private val updatesStartTimeNs = System.nanoTime

override def hasNext: Boolean = {
if (!baseIterator.hasNext) {
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)

// Remove old aggregates if watermark specified
allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
commitTimeMs += timeTakenMs { store.commit() }
setStoreMetrics(store)
false
override protected def getNext(): InternalRow = {
if (baseIterator.hasNext) {
val row = baseIterator.next().asInstanceOf[UnsafeRow]
val key = getKey(row)
store.put(key, row)
numOutputRows += 1
numUpdatedStateRows += 1
row
} else {
true
finished = true
null
}
}

override def next(): InternalRow = {
val row = baseIterator.next().asInstanceOf[UnsafeRow]
val key = getKey(row)
store.put(key, row)
numOutputRows += 1
numUpdatedStateRows += 1
row
override protected def close(): Unit = {
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)

// Remove old aggregates if watermark specified
allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
commitTimeMs += timeTakenMs { store.commit() }
setStoreMetrics(store)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,31 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
)
}

test("SPARK-23004: Ensure that TypedImperativeAggregate functions do not throw errors") {
// See the JIRA SPARK-23004 for more details. In short, this test reproduces the error
// by ensuring the following.
// - A streaming query with a streaming aggregation.
// - Aggregation function 'collect_list' that is a subclass of TypedImperativeAggregate.
// - Post shuffle partition has exactly 128 records (i.e. the threshold at which
// ObjectHashAggregateExec falls back to sort-based aggregation). This is done by having a
// micro-batch with 128 records that shuffle to a single partition.
// This test throws the exact error reported in SPARK-23004 without the corresponding fix.
withSQLConf("spark.sql.shuffle.partitions" -> "1") {
val input = MemoryStream[Int]
val df = input.toDF().toDF("value")
.selectExpr("value as group", "value")
.groupBy("group")
.agg(collect_list("value"))
testStream(df, outputMode = OutputMode.Update)(
AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*),
AssertOnQuery { q =>
q.processAllAvailable()
true
}
)
}
}

/** Add blocks of data to the `BlockRDDBackedSource`. */
case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData {
override def addData(query: Option[StreamExecution]): (Source, Offset) = {
Expand Down

0 comments on commit a78ba37

Please sign in to comment.