diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index bb566ba925bf7..2dd91decfa99f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -530,12 +530,13 @@ case class SessionWindowStateStoreRestoreExec( override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions + assert(keyExpressions.nonEmpty, "Grouping key must be specified when using sessionWindow") + private val stateManager = StreamingSessionWindowStateManager.createStateManager( keyWithoutSessionExpressions, sessionExpression, child.output, stateFormatVersion) override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - assert(keyExpressions.nonEmpty, "Grouping key must be specified when using sessionWindow") child.execute().mapPartitionsWithReadStateStore( getStateInfo, @@ -558,8 +559,8 @@ case class SessionWindowStateStoreRestoreExec( keyWithoutSessionExpressions, sessionExpression, child.output).map { row => - numOutputRows += 1 - row + numOutputRows += 1 + row } } } @@ -573,11 +574,7 @@ case class SessionWindowStateStoreRestoreExec( } override def requiredChildDistribution: Seq[Distribution] = { - if (keyWithoutSessionExpressions.isEmpty) { - AllTuples :: Nil - } else { - ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil - } + ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil } override def requiredChildOrdering: Seq[Seq[SortOrder]] = { @@ -592,7 +589,7 @@ case class SessionWindowStateStoreRestoreExec( * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]]. */ case class SessionWindowStateStoreSaveExec( - keyExpressions: Seq[Attribute], + keyWithoutSessionExpressions: Seq[Attribute], sessionExpression: Attribute, stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, @@ -601,9 +598,7 @@ case class SessionWindowStateStoreSaveExec( child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { - private val keyWithoutSessionExpressions = keyExpressions.filterNot { p => - p.semanticEquals(sessionExpression) - } + override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions private val stateManager = StreamingSessionWindowStateManager.createStateManager( keyWithoutSessionExpressions, sessionExpression, child.output, stateFormatVersion) @@ -624,6 +619,7 @@ case class SessionWindowStateStoreSaveExec( Some(session.streams.stateStoreCoordinator)) { case (store, iter) => val numOutputRows = longMetric("numOutputRows") + val numRemovedStateRows = longMetric("numRemovedStateRows") val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") val commitTimeMs = longMetric("commitTimeMs") @@ -632,9 +628,8 @@ case class SessionWindowStateStoreSaveExec( // Update and output all rows in the StateStore. case Some(Complete) => allUpdatesTimeMs += timeTakenMs { - putToStore(iter, store, false) + putToStore(iter, store) } - allRemovalsTimeMs += 0 commitTimeMs += timeTakenMs { stateManager.commit(store) } @@ -648,7 +643,9 @@ case class SessionWindowStateStoreSaveExec( // Assumption: watermark predicates must be non-empty if append mode is allowed case Some(Append) => allUpdatesTimeMs += timeTakenMs { - putToStore(iter, store, true) + val filteredIter = applyRemovingRowsOlderThanWatermark(iter, + watermarkPredicateForData.get) + putToStore(filteredIter, store) } val removalStartTimeNs = System.nanoTime @@ -661,6 +658,7 @@ case class SessionWindowStateStoreSaveExec( finished = true null } else { + numRemovedStateRows += 1 numOutputRows += 1 removedIter.next() } @@ -670,17 +668,25 @@ case class SessionWindowStateStoreSaveExec( allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) commitTimeMs += timeTakenMs { store.commit() } setStoreMetrics(store) + setOperatorMetrics() } } case Some(Update) => - val iterPutToStore = iteratorPutToStore(iter, store, true, true) + val baseIterator = watermarkPredicateForData match { + case Some(predicate) => applyRemovingRowsOlderThanWatermark(iter, predicate) + case None => iter + } + val iterPutToStore = iteratorPutToStore(baseIterator, store, + returnOnlyUpdatedRows = true) new NextIterator[InternalRow] { private val updatesStartTimeNs = System.nanoTime override protected def getNext(): InternalRow = { if (iterPutToStore.hasNext) { - iterPutToStore.next() + val row = iterPutToStore.next() + numOutputRows += 1 + row } else { finished = true null @@ -695,16 +701,18 @@ case class SessionWindowStateStoreSaveExec( val removedIter = stateManager.removeByValueCondition( store, watermarkPredicateForData.get.eval) while (removedIter.hasNext) { + numRemovedStateRows += 1 removedIter.next() } } } commitTimeMs += timeTakenMs { store.commit() } setStoreMetrics(store) + setOperatorMetrics() } } - case _ => throw new UnsupportedOperationException(s"Invalid output mode: $outputMode") + case _ => throw QueryExecutionErrors.invalidStreamingOutputModeError(outputMode) } } } @@ -714,11 +722,7 @@ case class SessionWindowStateStoreSaveExec( override def outputPartitioning: Partitioning = child.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = { - if (keyExpressions.isEmpty) { - AllTuples :: Nil - } else { - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil - } + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { @@ -728,17 +732,11 @@ case class SessionWindowStateStoreSaveExec( } private def iteratorPutToStore( - baseIter: Iterator[InternalRow], + iter: Iterator[InternalRow], store: StateStore, - needFilter: Boolean, returnOnlyUpdatedRows: Boolean): Iterator[InternalRow] = { val numUpdatedStateRows = longMetric("numUpdatedStateRows") val numRemovedStateRows = longMetric("numRemovedStateRows") - val iter = if (needFilter) { - baseIter.filter(row => !watermarkPredicateForData.get.eval(row)) - } else { - baseIter - } new NextIterator[InternalRow] { var curKey: UnsafeRow = null @@ -790,11 +788,8 @@ case class SessionWindowStateStoreSaveExec( } } - private def putToStore( - baseIter: Iterator[InternalRow], - store: StateStore, - needFilter: Boolean): Unit = { - val iterPutToStore = iteratorPutToStore(baseIter, store, needFilter, false) + private def putToStore(baseIter: Iterator[InternalRow], store: StateStore): Unit = { + val iterPutToStore = iteratorPutToStore(baseIter, store, returnOnlyUpdatedRows = false) while (iterPutToStore.hasNext) { iterPutToStore.next() }