Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ case class FlatMapGroupsWithStateExec(

override def keyExpressions: Seq[Attribute] = groupingAttributes

override def shortName: String = "flatMapGroupsWithState"

override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
timeoutConf match {
case ProcessingTimeTimeout =>
Expand Down Expand Up @@ -115,10 +117,13 @@ case class FlatMapGroupsWithStateExec(
Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
val commitTimeMs = longMetric("commitTimeMs")
val updatesStartTimeNs = System.nanoTime

val timeoutLatencyMs = longMetric("allRemovalsTimeMs")
val processor = new InputProcessor(store)

val currentTimeNs = System.nanoTime
val updatesStartTimeNs = currentTimeNs
var timeoutProcessingStartTimeNs = currentTimeNs

// If timeout is based on event time, then filter late data based on watermark
val filteredIter = watermarkPredicateForData match {
case Some(predicate) if timeoutConf == EventTimeTimeout =>
Expand All @@ -127,12 +132,26 @@ case class FlatMapGroupsWithStateExec(
iter
}

val newDataProcessorIter =
CompletionIterator[InternalRow, Iterator[InternalRow]](
processor.processNewData(filteredIter), {
// Once the input is processed, mark the start time for timeout processing to measure
// it separately from the overall processing time.
timeoutProcessingStartTimeNs = System.nanoTime
})

val timeoutProcessorIter =
CompletionIterator[InternalRow, Iterator[InternalRow]](processor.processTimedOutState(), {
// Note: `timeoutLatencyMs` also includes the time the parent operator took for
// processing output returned through iterator.
timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - timeoutProcessingStartTimeNs)
})

// Generate a iterator that returns the rows grouped by the grouping function
// Note that this code ensures that the filtering for timeout occurs only after
// all the data has been processed. This is to ensure that the timeout information of all
// the keys with data is updated before they are processed for timeouts.
val outputIterator = processor.processNewData(filteredIter) ++
processor.processTimedOutState()
val outputIterator = newDataProcessorIter ++ timeoutProcessorIter

// Return an iterator of all the rows generated by all the keys, such that when fully
// consumed, all the state updates will be committed by the state store
Expand All @@ -144,6 +163,7 @@ case class FlatMapGroupsWithStateExec(
store.commit()
}
setStoreMetrics(store)
setOperatorMetrics()
}
)
}
Expand All @@ -162,6 +182,7 @@ case class FlatMapGroupsWithStateExec(
// Metrics
private val numUpdatedStateRows = longMetric("numUpdatedStateRows")
private val numOutputRows = longMetric("numOutputRows")
private val numRemovedStateRows = longMetric("numRemovedStateRows")

/**
* For every group, get the key, values and corresponding state and call the function,
Expand Down Expand Up @@ -231,7 +252,7 @@ case class FlatMapGroupsWithStateExec(
def onIteratorCompletion: Unit = {
if (groupState.isRemoved && !groupState.getTimeoutTimestampMs.isPresent()) {
stateManager.removeState(store, stateData.keyRow)
numUpdatedStateRows += 1
numRemovedStateRows += 1
} else {
val currentTimeoutTimestamp = groupState.getTimeoutTimestampMs.orElse(NO_TIMESTAMP)
val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ case class StreamingSymmetricHashJoinExec(
case _ => throwBadJoinTypeException()
}

override def shortName: String = "symmetricHashJoin"

override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
val watermarkUsedForStateCleanup =
stateWatermarkPredicates.left.nonEmpty || stateWatermarkPredicates.right.nonEmpty
Expand All @@ -221,6 +223,7 @@ case class StreamingSymmetricHashJoinExec(
protected override def doExecute(): RDD[InternalRow] = {
val stateStoreCoord = session.sessionState.streamingQueryManager.stateStoreCoordinator
val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
metrics // initialize metrics
left.execute().stateStoreAwareZipPartitions(
right.execute(), stateInfo.get, stateStoreNames, stateStoreCoord)(processPartitions)
}
Expand All @@ -237,6 +240,7 @@ case class StreamingSymmetricHashJoinExec(
val numUpdatedStateRows = longMetric("numUpdatedStateRows")
val numTotalStateRows = longMetric("numTotalStateRows")
val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
val numRemovedStateRows = longMetric("numRemovedStateRows")
val allRemovalsTimeMs = longMetric("allRemovalsTimeMs")
val commitTimeMs = longMetric("commitTimeMs")
val stateMemory = longMetric("stateMemory")
Expand Down Expand Up @@ -407,6 +411,7 @@ case class StreamingSymmetricHashJoinExec(
}
while (cleanupIter.hasNext) {
cleanupIter.next()
numRemovedStateRows += 1
}
}

Expand All @@ -425,6 +430,9 @@ case class StreamingSymmetricHashJoinExec(
longMetric(metric.name) += value
}
}

val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide);
setOperatorMetrics(numStateStoreInstances = stateStoreNames.length)
}

CompletionIterator[InternalRow, Iterator[InternalRow]](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,13 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
"numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"),
"numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"),
"allUpdatesTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to update"),
"numRemovedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of removed state rows"),
"allRemovalsTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to remove"),
"commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to commit changes"),
"stateMemory" -> SQLMetrics.createSizeMetric(sparkContext, "memory used by state")
"stateMemory" -> SQLMetrics.createSizeMetric(sparkContext, "memory used by state"),
"numShufflePartitions" -> SQLMetrics.createMetric(sparkContext, "number of shuffle partitions"),
"numStateStoreInstances" -> SQLMetrics.createMetric(sparkContext,
"number of state store instances")
) ++ stateStoreCustomMetrics

/**
Expand All @@ -118,17 +122,33 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
new java.util.HashMap(customMetrics.mapValues(long2Long).toMap.asJava)

new StateOperatorProgress(
operatorName = shortName,
numRowsTotal = longMetric("numTotalStateRows").value,
numRowsUpdated = longMetric("numUpdatedStateRows").value,
allUpdatesTimeMs = longMetric("allUpdatesTimeMs").value,
numRowsRemoved = longMetric("numRemovedStateRows").value,
allRemovalsTimeMs = longMetric("allRemovalsTimeMs").value,
commitTimeMs = longMetric("commitTimeMs").value,
memoryUsedBytes = longMetric("stateMemory").value,
numRowsDroppedByWatermark = longMetric("numRowsDroppedByWatermark").value,
numShufflePartitions = longMetric("numShufflePartitions").value,
numStateStoreInstances = longMetric("numStateStoreInstances").value,
javaConvertedCustomMetrics
)
}

/** Records the duration of running `body` for the next query progress update. */
protected def timeTakenMs(body: => Unit): Long = Utils.timeTakenMs(body)._2

/** Set the operator level metrics */
protected def setOperatorMetrics(numStateStoreInstances: Int = 1): Unit = {
assert(numStateStoreInstances >= 1, s"invalid number of stores: $numStateStoreInstances")
// Shuffle partitions capture the number of tasks that have this stateful operator instance.
// For each task instance this number is incremented by one.
longMetric("numShufflePartitions") += 1
longMetric("numStateStoreInstances") += numStateStoreInstances
}

/**
* Set the SQL metrics related to the state store.
* This should be called in that task after the store has been updated.
Expand Down Expand Up @@ -172,6 +192,9 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
}
}

/** Name to output in [[StreamingOperatorProgress]] to identify operator type */
protected def shortName: String = "defaultName"

/**
* Should the MicroBatchExecution run another batch based on this stateful operator and the
* current updated metadata.
Expand Down Expand Up @@ -210,9 +233,11 @@ trait WatermarkSupport extends UnaryExecNode {

protected def removeKeysOlderThanWatermark(store: StateStore): Unit = {
if (watermarkPredicateForKeys.nonEmpty) {
val numRemovedStateRows = longMetric("numRemovedStateRows")
store.getRange(None, None).foreach { rowPair =>
if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
store.remove(rowPair.key)
numRemovedStateRows += 1
}
}
}
Expand All @@ -222,9 +247,11 @@ trait WatermarkSupport extends UnaryExecNode {
storeManager: StreamingAggregationStateManager,
store: StateStore): Unit = {
if (watermarkPredicateForKeys.nonEmpty) {
val numRemovedStateRows = longMetric("numRemovedStateRows")
storeManager.keys(store).foreach { keyRow =>
if (watermarkPredicateForKeys.get.eval(keyRow)) {
storeManager.remove(store, keyRow)
numRemovedStateRows += 1
}
}
}
Expand Down Expand Up @@ -345,6 +372,7 @@ case class StateStoreSaveExec(
val numOutputRows = longMetric("numOutputRows")
val numUpdatedStateRows = longMetric("numUpdatedStateRows")
val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
val numRemovedStateRows = longMetric("numRemovedStateRows")
val allRemovalsTimeMs = longMetric("allRemovalsTimeMs")
val commitTimeMs = longMetric("commitTimeMs")

Expand All @@ -363,6 +391,7 @@ case class StateStoreSaveExec(
stateManager.commit(store)
}
setStoreMetrics(store)
setOperatorMetrics()
stateManager.values(store).map { valueRow =>
numOutputRows += 1
valueRow
Expand Down Expand Up @@ -391,6 +420,7 @@ case class StateStoreSaveExec(
val rowPair = rangeIter.next()
if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
stateManager.remove(store, rowPair.key)
numRemovedStateRows += 1
removedValueRow = rowPair.value
}
}
Expand All @@ -404,9 +434,12 @@ case class StateStoreSaveExec(
}

override protected def close(): Unit = {
// Note: Due to the iterator lazy exec, this metric also captures the time taken
// by the consumer operators in addition to the processing in this operator.
allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs)
commitTimeMs += timeTakenMs { stateManager.commit(store) }
setStoreMetrics(store)
setOperatorMetrics()
}
}

Expand Down Expand Up @@ -443,6 +476,7 @@ case class StateStoreSaveExec(
}
commitTimeMs += timeTakenMs { stateManager.commit(store) }
setStoreMetrics(store)
setOperatorMetrics()
}
}

Expand All @@ -463,6 +497,8 @@ case class StateStoreSaveExec(
}
}

override def shortName: String = "stateStoreSave"

override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
(outputMode.contains(Append) || outputMode.contains(Update)) &&
eventTimeWatermark.isDefined &&
Expand Down Expand Up @@ -534,6 +570,7 @@ case class StreamingDeduplicateExec(
allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
commitTimeMs += timeTakenMs { store.commit() }
setStoreMetrics(store)
setOperatorMetrics()
})
}
}
Expand All @@ -546,6 +583,8 @@ case class StreamingDeduplicateExec(
Seq(StatefulOperatorCustomSumMetric("numDroppedDuplicateRows", "number of duplicates dropped"))
}

override def shortName: String = "dedupe"

override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ case class StreamingGlobalLimitExec(
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
commitTimeMs += timeTakenMs { store.commit() }
setStoreMetrics(store)
setOperatorMetrics()
})
}
}
Expand All @@ -96,6 +97,8 @@ case class StreamingGlobalLimitExec(
UnsafeProjection.create(valueSchema)(new GenericInternalRow(Array[Any](value)))
}

override def shortName: String = "globalLimit"

override protected def withNewChildInternal(newChild: SparkPlan): StreamingGlobalLimitExec =
copy(child = newChild)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,17 @@ import org.apache.spark.sql.streaming.SinkProgress.DEFAULT_NUM_OUTPUT_ROWS
*/
@Evolving
class StateOperatorProgress private[sql](
val operatorName: String,
val numRowsTotal: Long,
val numRowsUpdated: Long,
val allUpdatesTimeMs: Long,
val numRowsRemoved: Long,
val allRemovalsTimeMs: Long,
val commitTimeMs: Long,
val memoryUsedBytes: Long,
val numRowsDroppedByWatermark: Long,
val numShufflePartitions: Long,
val numStateStoreInstances: Long,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is detected as a binary incompatibility. It will be okay because this is Evolving.

val customMetrics: ju.Map[String, JLong] = new ju.HashMap()
) extends Serializable {

Expand All @@ -57,14 +64,26 @@ class StateOperatorProgress private[sql](
private[sql] def copy(
newNumRowsUpdated: Long,
newNumRowsDroppedByWatermark: Long): StateOperatorProgress =
new StateOperatorProgress(numRowsTotal, newNumRowsUpdated, memoryUsedBytes,
newNumRowsDroppedByWatermark, customMetrics)
new StateOperatorProgress(
operatorName = operatorName, numRowsTotal = numRowsTotal, numRowsUpdated = newNumRowsUpdated,
allUpdatesTimeMs = allUpdatesTimeMs, numRowsRemoved = numRowsRemoved,
allRemovalsTimeMs = allRemovalsTimeMs, commitTimeMs = commitTimeMs,
memoryUsedBytes = memoryUsedBytes, numRowsDroppedByWatermark = newNumRowsDroppedByWatermark,
numShufflePartitions = numShufflePartitions, numStateStoreInstances = numStateStoreInstances,
customMetrics = customMetrics)

private[sql] def jsonValue: JValue = {
("operatorName" -> JString(operatorName)) ~
("numRowsTotal" -> JInt(numRowsTotal)) ~
("numRowsUpdated" -> JInt(numRowsUpdated)) ~
("allUpdatesTimeMs" -> JInt(allUpdatesTimeMs)) ~
("numRowsRemoved" -> JInt(numRowsRemoved)) ~
("allRemovalsTimeMs" -> JInt(allRemovalsTimeMs)) ~
("commitTimeMs" -> JInt(commitTimeMs)) ~
("memoryUsedBytes" -> JInt(memoryUsedBytes)) ~
("numRowsDroppedByWatermark" -> JInt(numRowsDroppedByWatermark)) ~
("numShufflePartitions" -> JInt(numShufflePartitions)) ~
("numStateStoreInstances" -> JInt(numStateStoreInstances)) ~
("customMetrics" -> {
if (!customMetrics.isEmpty) {
val keys = customMetrics.keySet.asScala.toSeq.sorted
Expand Down
Loading