diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index aa789af6f812f..12f8cffb6774a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -311,7 +311,7 @@ object AggUtils { val saved = StateStoreSaveExec( groupingAttributes, - stateId = None, + stateInfo = None, outputMode = None, eventTimeWatermark = None, partialMerged2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 2d82fcf4da6e9..81bc93e7ebcf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.util.UUID + import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -117,7 +119,8 @@ case class ExplainCommand( // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the // output mode does not matter since there is no `Sink`. new IncrementalExecution( - sparkSession, logicalPlan, OutputMode.Append(), "", 0, OffsetSeqMetadata(0, 0)) + sparkSession, logicalPlan, OutputMode.Append(), "", + UUID.randomUUID, 0, OffsetSeqMetadata(0, 0)) } else { sparkSession.sessionState.executePlan(logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 2aad8701a4eca..9dcac33b4107c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -50,7 +50,7 @@ case class FlatMapGroupsWithStateExec( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, - stateId: Option[OperatorStateId], + stateInfo: Option[StatefulOperatorStateInfo], stateEncoder: ExpressionEncoder[Any], outputMode: OutputMode, timeoutConf: GroupStateTimeout, @@ -107,10 +107,7 @@ case class FlatMapGroupsWithStateExec( } child.execute().mapPartitionsWithStateStore[InternalRow]( - getStateId.checkpointLocation, - getStateId.operatorId, - storeName = "default", - getStateId.batchId, + getStateInfo, groupingAttributes.toStructType, stateAttributes.toStructType, indexOrdinal = None, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 622e049630db2..ab89dc6b705d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.internal.Logging @@ -36,6 +37,7 @@ class IncrementalExecution( logicalPlan: LogicalPlan, val outputMode: OutputMode, val checkpointLocation: String, + val runId: UUID, val currentBatchId: Long, offsetSeqMetadata: OffsetSeqMetadata) extends QueryExecution(sparkSession, logicalPlan) with Logging { @@ -69,7 +71,13 @@ class IncrementalExecution( * Records the current id for a given stateful operator in the query plan as the `state` * preparation walks the query plan. */ - private val operatorId = new AtomicInteger(0) + private val statefulOperatorId = new AtomicInteger(0) + + /** Get the state info of the next stateful operator */ + private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = { + StatefulOperatorStateInfo( + checkpointLocation, runId, statefulOperatorId.getAndIncrement(), currentBatchId) + } /** Locates save/restore pairs surrounding aggregation. */ val state = new Rule[SparkPlan] { @@ -78,35 +86,28 @@ class IncrementalExecution( case StateStoreSaveExec(keys, None, None, None, UnaryExecNode(agg, StateStoreRestoreExec(keys2, None, child))) => - val stateId = - OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - + val aggStateInfo = nextStatefulOperationStateInfo StateStoreSaveExec( keys, - Some(stateId), + Some(aggStateInfo), Some(outputMode), Some(offsetSeqMetadata.batchWatermarkMs), agg.withNewChildren( StateStoreRestoreExec( keys, - Some(stateId), + Some(aggStateInfo), child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None) => - val stateId = - OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - StreamingDeduplicateExec( keys, child, - Some(stateId), + Some(nextStatefulOperationStateInfo), Some(offsetSeqMetadata.batchWatermarkMs)) case m: FlatMapGroupsWithStateExec => - val stateId = - OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) m.copy( - stateId = Some(stateId), + stateInfo = Some(nextStatefulOperationStateInfo), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 74f0f509bbf85..06bdec8b06407 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -652,6 +652,7 @@ class StreamExecution( triggerLogicalPlan, outputMode, checkpointFile("state"), + runId, currentBatchId, offsetSeqMetadata) lastExecution.executedPlan // Force the lazy generation of execution plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 67d86daf10812..bae7a15165e43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -92,7 +92,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit @volatile private var state: STATE = UPDATING @volatile private var finalDeltaFile: Path = null - override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id + override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId override def get(key: UnsafeRow): UnsafeRow = { mapToUpdate.get(key) @@ -177,7 +177,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** * Whether all updates have been committed */ - override private[streaming] def hasCommitted: Boolean = { + override def hasCommitted: Boolean = { state == COMMITTED } @@ -205,7 +205,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit indexOrdinal: Option[Int], // for sorting the data storeConf: StateStoreConf, hadoopConf: Configuration): Unit = { - this.stateStoreId = stateStoreId + this.stateStoreId_ = stateStoreId this.keySchema = keySchema this.valueSchema = valueSchema this.storeConf = storeConf @@ -213,7 +213,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit fs.mkdirs(baseDir) } - override def id: StateStoreId = stateStoreId + override def stateStoreId: StateStoreId = stateStoreId_ /** Do maintenance backing data files, including creating snapshots and cleaning up old files */ override def doMaintenance(): Unit = { @@ -231,20 +231,20 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } override def toString(): String = { - s"HDFSStateStoreProvider[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" + s"HDFSStateStoreProvider[" + + s"id = (op=${stateStoreId.operatorId},part=${stateStoreId.partitionId}),dir = $baseDir]" } /* Internal fields and methods */ - @volatile private var stateStoreId: StateStoreId = _ + @volatile private var stateStoreId_ : StateStoreId = _ @volatile private var keySchema: StructType = _ @volatile private var valueSchema: StructType = _ @volatile private var storeConf: StateStoreConf = _ @volatile private var hadoopConf: Configuration = _ private lazy val loadedMaps = new mutable.HashMap[Long, MapType] - private lazy val baseDir = - new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}") + private lazy val baseDir = stateStoreId.storeCheckpointLocation() private lazy val fs = baseDir.getFileSystem(hadoopConf) private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 29c456f86e1ed..a94ff8a7ebd1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID import java.util.concurrent.{ScheduledFuture, TimeUnit} import javax.annotation.concurrent.GuardedBy @@ -24,14 +25,14 @@ import scala.collection.mutable import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ThreadUtils, Utils} - /** * Base trait for a versioned key-value store. Each instance of a `StateStore` represents a specific * version of state data, and such instances are created through a [[StateStoreProvider]]. @@ -99,7 +100,7 @@ trait StateStore { /** * Whether all updates have been committed */ - private[streaming] def hasCommitted: Boolean + def hasCommitted: Boolean } @@ -147,7 +148,7 @@ trait StateStoreProvider { * Return the id of the StateStores this provider will generate. * Should be the same as the one passed in init(). */ - def id: StateStoreId + def stateStoreId: StateStoreId /** Called when the provider instance is unloaded from the executor */ def close(): Unit @@ -179,13 +180,46 @@ object StateStoreProvider { } } +/** + * Unique identifier for a provider, used to identify when providers can be reused. + * Note that `queryRunId` is used uniquely identify a provider, so that the same provider + * instance is not reused across query restarts. + */ +case class StateStoreProviderId(storeId: StateStoreId, queryRunId: UUID) -/** Unique identifier for a bunch of keyed state data. */ +/** + * Unique identifier for a bunch of keyed state data. + * @param checkpointRootLocation Root directory where all the state data of a query is stored + * @param operatorId Unique id of a stateful operator + * @param partitionId Index of the partition of an operators state data + * @param storeName Optional, name of the store. Each partition can optionally use multiple state + * stores, but they have to be identified by distinct names. + */ case class StateStoreId( - checkpointLocation: String, + checkpointRootLocation: String, operatorId: Long, partitionId: Int, - name: String = "") + storeName: String = StateStoreId.DEFAULT_STORE_NAME) { + + /** + * Checkpoint directory to be used by a single state store, identified uniquely by the tuple + * (operatorId, partitionId, storeName). All implementations of [[StateStoreProvider]] should + * use this path for saving state data, as this ensures that distinct stores will write to + * different locations. + */ + def storeCheckpointLocation(): Path = { + if (storeName == StateStoreId.DEFAULT_STORE_NAME) { + // For reading state store data that was generated before store names were used (Spark <= 2.2) + new Path(checkpointRootLocation, s"$operatorId/$partitionId") + } else { + new Path(checkpointRootLocation, s"$operatorId/$partitionId/$storeName") + } + } +} + +object StateStoreId { + val DEFAULT_STORE_NAME = "default" +} /** Mutable, and reusable class for representing a pair of UnsafeRows. */ class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) { @@ -211,7 +245,7 @@ object StateStore extends Logging { val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 @GuardedBy("loadedProviders") - private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() + private val loadedProviders = new mutable.HashMap[StateStoreProviderId, StateStoreProvider]() /** * Runs the `task` periodically and automatically cancels it if there is an exception. `onError` @@ -253,7 +287,7 @@ object StateStore extends Logging { /** Get or create a store associated with the id. */ def get( - storeId: StateStoreId, + storeProviderId: StateStoreProviderId, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int], @@ -264,24 +298,24 @@ object StateStore extends Logging { val storeProvider = loadedProviders.synchronized { startMaintenanceIfNeeded() val provider = loadedProviders.getOrElseUpdate( - storeId, + storeProviderId, StateStoreProvider.instantiate( - storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) + storeProviderId.storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) ) - reportActiveStoreInstance(storeId) + reportActiveStoreInstance(storeProviderId) provider } storeProvider.getStore(version) } /** Unload a state store provider */ - def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized { - loadedProviders.remove(storeId).foreach(_.close()) + def unload(storeProviderId: StateStoreProviderId): Unit = loadedProviders.synchronized { + loadedProviders.remove(storeProviderId).foreach(_.close()) } /** Whether a state store provider is loaded or not */ - def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized { - loadedProviders.contains(storeId) + def isLoaded(storeProviderId: StateStoreProviderId): Boolean = loadedProviders.synchronized { + loadedProviders.contains(storeProviderId) } def isMaintenanceRunning: Boolean = loadedProviders.synchronized { @@ -340,21 +374,21 @@ object StateStore extends Logging { } } - private def reportActiveStoreInstance(storeId: StateStoreId): Unit = { + private def reportActiveStoreInstance(storeProviderId: StateStoreProviderId): Unit = { if (SparkEnv.get != null) { val host = SparkEnv.get.blockManager.blockManagerId.host val executorId = SparkEnv.get.blockManager.blockManagerId.executorId - coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId)) - logDebug(s"Reported that the loaded instance $storeId is active") + coordinatorRef.foreach(_.reportActiveInstance(storeProviderId, host, executorId)) + logInfo(s"Reported that the loaded instance $storeProviderId is active") } } - private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = { + private def verifyIfStoreInstanceActive(storeProviderId: StateStoreProviderId): Boolean = { if (SparkEnv.get != null) { val executorId = SparkEnv.get.blockManager.blockManagerId.executorId val verified = - coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false) - logDebug(s"Verified whether the loaded instance $storeId is active: $verified") + coordinatorRef.map(_.verifyIfInstanceActive(storeProviderId, executorId)).getOrElse(false) + logDebug(s"Verified whether the loaded instance $storeProviderId is active: $verified") verified } else { false @@ -364,12 +398,21 @@ object StateStore extends Logging { private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { - if (_coordRef == null) { + logInfo("Env is not null") + val isDriver = + env.executorId == SparkContext.DRIVER_IDENTIFIER || + env.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER + // If running locally, then the coordinator reference in _coordRef may be have become inactive + // as SparkContext + SparkEnv may have been restarted. Hence, when running in driver, + // always recreate the reference. + if (isDriver || _coordRef == null) { + logInfo("Getting StateStoreCoordinatorRef") _coordRef = StateStoreCoordinatorRef.forExecutor(env) } - logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") + logInfo(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") Some(_coordRef) } else { + logInfo("Env is null") _coordRef = null None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index d0f81887e62d1..3884f5e6ce766 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID + import scala.collection.mutable import org.apache.spark.SparkEnv @@ -29,16 +31,19 @@ import org.apache.spark.util.RpcUtils private sealed trait StateStoreCoordinatorMessage extends Serializable /** Classes representing messages */ -private case class ReportActiveInstance(storeId: StateStoreId, host: String, executorId: String) +private case class ReportActiveInstance( + storeId: StateStoreProviderId, + host: String, + executorId: String) extends StateStoreCoordinatorMessage -private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: String) +private case class VerifyIfInstanceActive(storeId: StateStoreProviderId, executorId: String) extends StateStoreCoordinatorMessage -private case class GetLocation(storeId: StateStoreId) +private case class GetLocation(storeId: StateStoreProviderId) extends StateStoreCoordinatorMessage -private case class DeactivateInstances(checkpointLocation: String) +private case class DeactivateInstances(runId: UUID) extends StateStoreCoordinatorMessage private object StopCoordinator @@ -80,25 +85,27 @@ object StateStoreCoordinatorRef extends Logging { class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { private[state] def reportActiveInstance( - storeId: StateStoreId, + stateStoreProviderId: StateStoreProviderId, host: String, executorId: String): Unit = { - rpcEndpointRef.send(ReportActiveInstance(storeId, host, executorId)) + rpcEndpointRef.send(ReportActiveInstance(stateStoreProviderId, host, executorId)) } /** Verify whether the given executor has the active instance of a state store */ - private[state] def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = { - rpcEndpointRef.askSync[Boolean](VerifyIfInstanceActive(storeId, executorId)) + private[state] def verifyIfInstanceActive( + stateStoreProviderId: StateStoreProviderId, + executorId: String): Boolean = { + rpcEndpointRef.askSync[Boolean](VerifyIfInstanceActive(stateStoreProviderId, executorId)) } /** Get the location of the state store */ - private[state] def getLocation(storeId: StateStoreId): Option[String] = { - rpcEndpointRef.askSync[Option[String]](GetLocation(storeId)) + private[state] def getLocation(stateStoreProviderId: StateStoreProviderId): Option[String] = { + rpcEndpointRef.askSync[Option[String]](GetLocation(stateStoreProviderId)) } - /** Deactivate instances related to a set of operator */ - private[state] def deactivateInstances(storeRootLocation: String): Unit = { - rpcEndpointRef.askSync[Boolean](DeactivateInstances(storeRootLocation)) + /** Deactivate instances related to a query */ + private[sql] def deactivateInstances(runId: UUID): Unit = { + rpcEndpointRef.askSync[Boolean](DeactivateInstances(runId)) } private[state] def stop(): Unit = { @@ -113,7 +120,7 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { */ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { - private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] + private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation] override def receive: PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId) => @@ -135,11 +142,11 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) logDebug(s"Got location of the state store $id: $executorId") context.reply(executorId) - case DeactivateInstances(checkpointLocation) => + case DeactivateInstances(runId) => val storeIdsToRemove = - instances.keys.filter(_.checkpointLocation == checkpointLocation).toSeq + instances.keys.filter(_.queryRunId == runId).toSeq instances --= storeIdsToRemove - logDebug(s"Deactivating instances related to checkpoint location $checkpointLocation: " + + logDebug(s"Deactivating instances related to checkpoint location $runId: " + storeIdsToRemove.mkString(", ")) context.reply(true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index b744c25dc97a8..01d8e75980993 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID + import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} @@ -34,8 +36,8 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( dataRDD: RDD[T], storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], checkpointLocation: String, + queryRunId: UUID, operatorId: Long, - storeName: String, storeVersion: Long, keySchema: StructType, valueSchema: StructType, @@ -52,16 +54,25 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( override protected def getPartitions: Array[Partition] = dataRDD.partitions + /** + * Set the preferred location of each partition using the executor that has the related + * [[StateStoreProvider]] already loaded. + */ override def getPreferredLocations(partition: Partition): Seq[String] = { - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName) - storeCoordinator.flatMap(_.getLocation(storeId)).toSeq + val stateStoreProviderId = StateStoreProviderId( + StateStoreId(checkpointLocation, operatorId, partition.index), + queryRunId) + storeCoordinator.flatMap(_.getLocation(stateStoreProviderId)).toSeq } override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { var store: StateStore = null - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName) + val storeProviderId = StateStoreProviderId( + StateStoreId(checkpointLocation, operatorId, partition.index), + queryRunId) + store = StateStore.get( - storeId, keySchema, valueSchema, indexOrdinal, storeVersion, + storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion, storeConf, hadoopConfBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) storeUpdateFunction(store, inputIter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 228fe86d59940..a0086e251f9c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID + import scala.reflect.ClassTag import org.apache.spark.TaskContext @@ -32,20 +34,14 @@ package object state { /** Map each partition of an RDD along with data in a [[StateStore]]. */ def mapPartitionsWithStateStore[U: ClassTag]( sqlContext: SQLContext, - checkpointLocation: String, - operatorId: Long, - storeName: String, - storeVersion: Long, + stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int])( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { mapPartitionsWithStateStore( - checkpointLocation, - operatorId, - storeName, - storeVersion, + stateInfo, keySchema, valueSchema, indexOrdinal, @@ -56,10 +52,7 @@ package object state { /** Map each partition of an RDD along with data in a [[StateStore]]. */ private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( - checkpointLocation: String, - operatorId: Long, - storeName: String, - storeVersion: Long, + stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int], @@ -79,10 +72,10 @@ package object state { new StateStoreRDD( dataRDD, wrappedF, - checkpointLocation, - operatorId, - storeName, - storeVersion, + stateInfo.checkpointLocation, + stateInfo.queryRunId, + stateInfo.operatorId, + stateInfo.storeVersion, keySchema, valueSchema, indexOrdinal, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 3e57f3fbada32..c5722466a33af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID import java.util.concurrent.TimeUnit._ import org.apache.spark.rdd.RDD @@ -36,20 +37,22 @@ import org.apache.spark.util.{CompletionIterator, NextIterator} /** Used to identify the state store for a given operator. */ -case class OperatorStateId( +case class StatefulOperatorStateInfo( checkpointLocation: String, + queryRunId: UUID, operatorId: Long, - batchId: Long) + storeVersion: Long) /** - * An operator that reads or writes state from the [[StateStore]]. The [[OperatorStateId]] should - * be filled in by `prepareForExecution` in [[IncrementalExecution]]. + * An operator that reads or writes state from the [[StateStore]]. + * The [[StatefulOperatorStateInfo]] should be filled in by `prepareForExecution` in + * [[IncrementalExecution]]. */ trait StatefulOperator extends SparkPlan { - def stateId: Option[OperatorStateId] + def stateInfo: Option[StatefulOperatorStateInfo] - protected def getStateId: OperatorStateId = attachTree(this) { - stateId.getOrElse { + protected def getStateInfo: StatefulOperatorStateInfo = attachTree(this) { + stateInfo.getOrElse { throw new IllegalStateException("State location not present for execution") } } @@ -140,7 +143,7 @@ trait WatermarkSupport extends UnaryExecNode { */ case class StateStoreRestoreExec( keyExpressions: Seq[Attribute], - stateId: Option[OperatorStateId], + stateInfo: Option[StatefulOperatorStateInfo], child: SparkPlan) extends UnaryExecNode with StateStoreReader { @@ -148,10 +151,7 @@ case class StateStoreRestoreExec( val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsWithStateStore( - getStateId.checkpointLocation, - operatorId = getStateId.operatorId, - storeName = "default", - storeVersion = getStateId.batchId, + getStateInfo, keyExpressions.toStructType, child.output.toStructType, indexOrdinal = None, @@ -177,7 +177,7 @@ case class StateStoreRestoreExec( */ case class StateStoreSaveExec( keyExpressions: Seq[Attribute], - stateId: Option[OperatorStateId] = None, + stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, child: SparkPlan) @@ -189,10 +189,7 @@ case class StateStoreSaveExec( "Incorrect planning in IncrementalExecution, outputMode has not been set") child.execute().mapPartitionsWithStateStore( - getStateId.checkpointLocation, - getStateId.operatorId, - storeName = "default", - getStateId.batchId, + getStateInfo, keyExpressions.toStructType, child.output.toStructType, indexOrdinal = None, @@ -319,7 +316,7 @@ case class StateStoreSaveExec( case class StreamingDeduplicateExec( keyExpressions: Seq[Attribute], child: SparkPlan, - stateId: Option[OperatorStateId] = None, + stateInfo: Option[StatefulOperatorStateInfo] = None, eventTimeWatermark: Option[Long] = None) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { @@ -331,10 +328,7 @@ case class StreamingDeduplicateExec( metrics // force lazy init at driver child.execute().mapPartitionsWithStateStore( - getStateId.checkpointLocation, - getStateId.operatorId, - storeName = "default", - getStateId.batchId, + getStateInfo, keyExpressions.toStructType, child.output.toStructType, indexOrdinal = None, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 002c45413b4c2..48b0ea20e5da1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -332,5 +332,6 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } awaitTerminationLock.notifyAll() } + stateStoreCoordinator.deactivateInstances(terminatedQuery.runId) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index a7e32626264cc..9a7595eee7bd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -17,11 +17,17 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID + import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.functions.count +import org.apache.spark.util.Utils class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { @@ -29,7 +35,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("report, verify, getLocation") { withCoordinatorRef(sc) { coordinatorRef => - val id = StateStoreId("x", 0, 0) + val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID) assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) assert(coordinatorRef.getLocation(id) === None) @@ -57,9 +63,11 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("make inactive") { withCoordinatorRef(sc) { coordinatorRef => - val id1 = StateStoreId("x", 0, 0) - val id2 = StateStoreId("y", 1, 0) - val id3 = StateStoreId("x", 0, 1) + val runId1 = UUID.randomUUID + val runId2 = UUID.randomUUID + val id1 = StateStoreProviderId(StateStoreId("x", 0, 0), runId1) + val id2 = StateStoreProviderId(StateStoreId("y", 1, 0), runId2) + val id3 = StateStoreProviderId(StateStoreId("x", 0, 1), runId1) val host = "hostX" val exec = "exec1" @@ -73,7 +81,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true) } - coordinatorRef.deactivateInstances("x") + coordinatorRef.deactivateInstances(runId1) assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === false) assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) @@ -85,7 +93,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { Some(ExecutorCacheTaskLocation(host, exec).toString)) assert(coordinatorRef.getLocation(id3) === None) - coordinatorRef.deactivateInstances("y") + coordinatorRef.deactivateInstances(runId2) assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === false) assert(coordinatorRef.getLocation(id2) === None) } @@ -95,7 +103,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { withCoordinatorRef(sc) { coordRef1 => val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env) - val id = StateStoreId("x", 0, 0) + val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID) coordRef1.reportActiveInstance(id, "hostX", "exec1") @@ -107,6 +115,45 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } } } + + test("query stop deactivates related store providers") { + var coordRef: StateStoreCoordinatorRef = null + try { + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + SparkSession.setActiveSession(spark) + import spark.implicits._ + coordRef = spark.streams.stateStoreCoordinator + implicit val sqlContext = spark.sqlContext + spark.conf.set("spark.sql.shuffle.partitions", "1") + + // Start a query and run a batch to load state stores + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().groupBy("value").agg(count("*")) // stateful query + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + inputData.addData(1, 2, 3) + query.processAllAvailable() + + // Verify state store has been loaded + val stateCheckpointDir = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation + val providerId = StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, 0), query.runId) + assert(coordRef.getLocation(providerId).nonEmpty) + + // Stop and verify whether the stores are deactivated in the coordinator + query.stop() + assert(coordRef.getLocation(providerId).isEmpty) + } finally { + SparkSession.getActiveSession.foreach(_.streams.active.foreach(_.stop())) + if (coordRef != null) coordRef.stop() + StateStore.stop() + } + } } object StateStoreCoordinatorSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 4a1a089af54c2..defb9ed63a881 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -19,20 +19,19 @@ package org.apache.spark.sql.execution.streaming.state import java.io.File import java.nio.file.Files +import java.util.UUID import scala.util.Random import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.sql.LocalSparkSession._ -import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.LocalSparkSession._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{CompletionIterator, Utils} @@ -57,16 +56,14 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("versioning and immutability") { withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - val opId = 0 - val rdd1 = - makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( + spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)( + spark.sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)( increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) @@ -76,7 +73,6 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("recovering from files") { - val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString def makeStoreRDD( @@ -85,7 +81,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn storeVersion: Int): RDD[(String, Int)] = { implicit val sqlContext = spark.sqlContext makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion, keySchema, valueSchema, None)(increment) + sqlContext, operatorStateInfo(path, version = storeVersion), + keySchema, valueSchema, None)(increment) } // Generate RDDs and state store data @@ -132,17 +129,17 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( iteratorOfGets) assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( iteratorOfPuts) assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)( + sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)( iteratorOfGets) assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) } @@ -150,22 +147,25 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("preferred locations using StateStoreCoordinator") { quietly { + val queryRunId = UUID.randomUUID val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext val coordinatorRef = sqlContext.streams.stateStoreCoordinator - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0, "name"), "host1", "exec1") - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1, "name"), "host2", "exec2") + val storeProviderId1 = StateStoreProviderId(StateStoreId(path, opId, 0), queryRunId) + val storeProviderId2 = StateStoreProviderId(StateStoreId(path, opId, 1), queryRunId) + coordinatorRef.reportActiveInstance(storeProviderId1, "host1", "exec1") + coordinatorRef.reportActiveInstance(storeProviderId2, "host2", "exec2") - assert( - coordinatorRef.getLocation(StateStoreId(path, opId, 0, "name")) === + require( + coordinatorRef.getLocation(storeProviderId1) === Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( - increment) + sqlContext, operatorStateInfo(path, queryRunId = queryRunId), + keySchema, valueSchema, None)(increment) require(rdd.partitions.length === 2) assert( @@ -192,12 +192,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(increment) + sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)(increment) + sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -210,6 +210,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2) } + private def operatorStateInfo( + path: String, + queryRunId: UUID = UUID.randomUUID, + version: Int = 0): StatefulOperatorStateInfo = { + StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version) + } + private val increment = (store: StateStore, iter: Iterator[String]) => { iter.foreach { s => val key = stringToRow(s) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index af2b9f1c11fb6..c2087ec219e57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI +import java.util.UUID import scala.collection.JavaConverters._ import scala.collection.mutable @@ -33,8 +34,11 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.LocalSparkContext._ +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.functions.count import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -143,7 +147,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] provider.getStore(0).commit() // Verify we don't leak temp files - val tempFiles = FileUtils.listFiles(new File(provider.id.checkpointLocation), + val tempFiles = FileUtils.listFiles(new File(provider.stateStoreId.checkpointRootLocation), null, true).asScala.filter(_.getName.startsWith("temp-")) assert(tempFiles.isEmpty) } @@ -183,7 +187,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] test("StateStore.get") { quietly { val dir = newDir() - val storeId = StateStoreId(dir, 0, 0) + val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), UUID.randomUUID) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() @@ -243,18 +247,18 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] .set("spark.rpc.numRetries", "1") val opId = 0 val dir = newDir() - val storeId = StateStoreId(dir, opId, 0) + val storeProviderId = StateStoreProviderId(StateStoreId(dir, opId, 0), UUID.randomUUID) val sqlConf = new SQLConf() sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) val storeConf = StateStoreConf(sqlConf) val hadoopConf = new Configuration() - val provider = newStoreProvider(storeId) + val provider = newStoreProvider(storeProviderId.storeId) var latestStoreVersion = 0 def generateStoreVersions() { for (i <- 1 to 20) { - val store = StateStore.get(storeId, keySchema, valueSchema, None, + val store = StateStore.get(storeProviderId, keySchema, valueSchema, None, latestStoreVersion, storeConf, hadoopConf) put(store, "a", i) store.commit() @@ -274,7 +278,8 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] eventually(timeout(timeoutDuration)) { // Store should have been reported to the coordinator - assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported") + assert(coordinatorRef.getLocation(storeProviderId).nonEmpty, + "active instance was not reported") // Background maintenance should clean up and generate snapshots assert(StateStore.isMaintenanceRunning, "Maintenance task is not running") @@ -295,35 +300,35 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") } - // If driver decides to deactivate all instances of the store, then this instance - // should be unloaded - coordinatorRef.deactivateInstances(dir) + // If driver decides to deactivate all stores related to a query run, + // then this instance should be unloaded + coordinatorRef.deactivateInstances(storeProviderId.queryRunId) eventually(timeout(timeoutDuration)) { - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None, + StateStore.get(storeProviderId, keySchema, valueSchema, indexOrdinal = None, latestStoreVersion, storeConf, hadoopConf) - assert(StateStore.isLoaded(storeId)) + assert(StateStore.isLoaded(storeProviderId)) // If some other executor loads the store, then this instance should be unloaded - coordinatorRef.reportActiveInstance(storeId, "other-host", "other-exec") + coordinatorRef.reportActiveInstance(storeProviderId, "other-host", "other-exec") eventually(timeout(timeoutDuration)) { - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None, + StateStore.get(storeProviderId, keySchema, valueSchema, indexOrdinal = None, latestStoreVersion, storeConf, hadoopConf) - assert(StateStore.isLoaded(storeId)) + assert(StateStore.isLoaded(storeProviderId)) } } // Verify if instance is unloaded if SparkContext is stopped eventually(timeout(timeoutDuration)) { require(SparkEnv.get === null) - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) assert(!StateStore.isMaintenanceRunning) } } @@ -344,7 +349,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] test("SPARK-18416: do not create temp delta file until the store is updated") { val dir = newDir() - val storeId = StateStoreId(dir, 0, 0) + val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), UUID.randomUUID) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() val deltaFileDir = new File(s"$dir/0/0/") @@ -408,12 +413,60 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(numDeltaFiles === 3) } + test("SPARK-21145: Restarted queries create new provider instances") { + try { + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val spark = SparkSession.builder().master("local[2]").getOrCreate() + SparkSession.setActiveSession(spark) + implicit val sqlContext = spark.sqlContext + spark.conf.set("spark.sql.shuffle.partitions", "1") + import spark.implicits._ + val inputData = MemoryStream[Int] + + def runQueryAndGetLoadedProviders(): Seq[StateStoreProvider] = { + val aggregated = inputData.toDF().groupBy("value").agg(count("*")) + // stateful query + val query = aggregated.writeStream + .format("memory") + .outputMode("complete") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + inputData.addData(1, 2, 3) + query.processAllAvailable() + require(query.lastProgress != null) // at least one batch processed after start + val loadedProvidersMethod = + PrivateMethod[mutable.HashMap[StateStoreProviderId, StateStoreProvider]]('loadedProviders) + val loadedProvidersMap = StateStore invokePrivate loadedProvidersMethod() + val loadedProviders = loadedProvidersMap.synchronized { loadedProvidersMap.values.toSeq } + query.stop() + loadedProviders + } + + val loadedProvidersAfterRun1 = runQueryAndGetLoadedProviders() + require(loadedProvidersAfterRun1.length === 1) + + val loadedProvidersAfterRun2 = runQueryAndGetLoadedProviders() + assert(loadedProvidersAfterRun2.length === 2) // two providers loaded for 2 runs + + // Both providers should have the same StateStoreId, but the should be different objects + assert(loadedProvidersAfterRun2(0).stateStoreId === loadedProvidersAfterRun2(1).stateStoreId) + assert(loadedProvidersAfterRun2(0) ne loadedProvidersAfterRun2(1)) + + } finally { + SparkSession.getActiveSession.foreach { spark => + spark.streams.active.foreach(_.stop()) + spark.stop() + } + } + } + override def newStoreProvider(): HDFSBackedStateStoreProvider = { newStoreProvider(opId = Random.nextInt(), partition = 0) } override def newStoreProvider(storeId: StateStoreId): HDFSBackedStateStoreProvider = { - newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointLocation) + newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointRootLocation) } override def getLatestData(storeProvider: HDFSBackedStateStoreProvider): Set[(String, Int)] = { @@ -423,7 +476,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] override def getData( provider: HDFSBackedStateStoreProvider, version: Int = -1): Set[(String, Int)] = { - val reloadedProvider = newStoreProvider(provider.id) + val reloadedProvider = newStoreProvider(provider.stateStoreId) if (version < 0) { reloadedProvider.latestIterator().map(rowsToStringInt).toSet } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 4ede4fd9a035e..86c3a35a59c13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -777,7 +777,7 @@ class TestStateStoreProvider extends StateStoreProvider { throw new Exception("Successfully instantiated") } - override def id: StateStoreId = null + override def stateStoreId: StateStoreId = null override def close(): Unit = { } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 2a4039cc5831a..b2c42eef88f6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -26,9 +26,8 @@ import scala.reflect.ClassTag import scala.util.Random import scala.util.control.NonFatal -import org.scalatest.Assertions +import org.scalatest.{Assertions, BeforeAndAfterAll} import org.scalatest.concurrent.{Eventually, Timeouts} -import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.Span @@ -39,9 +38,10 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ro import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * A framework for implementing tests for streaming queries and sources. @@ -67,7 +67,12 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} * avoid hanging forever in the case of failures. However, individual suites can change this * by overriding `streamingTimeout`. */ -trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { +trait StreamTest extends QueryTest with SharedSQLContext with Timeouts with BeforeAndAfterAll { + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() // stop the state store maintenance thread and unload store providers + } /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds