Skip to content

Commit

Permalink
[SPARK-21145][SS] Added StateStoreProviderId with queryRunId to reloa…
Browse files Browse the repository at this point in the history
…d StateStoreProviders when query is restarted

## What changes were proposed in this pull request?
StateStoreProvider instances are loaded on-demand in a executor when a query is started. When a query is restarted, the loaded provider instance will get reused. Now, there is a non-trivial chance, that the task of the previous query run is still running, while the tasks of the restarted run has started. So for a stateful partition, there may be two concurrent tasks related to the same stateful partition, and there for using the same provider instance. This can lead to inconsistent results and possibly random failures, as state store implementations are not designed to be thread-safe.

To fix this, I have introduced a `StateStoreProviderId`, that unique identifies a provider loaded in an executor. It has the query run id in it, thus making sure that restarted queries will force the executor to load a new provider instance, thus avoiding two concurrent tasks (from two different runs) from reusing the same provider instance.

Additional minor bug fixes
- All state stores related to query run is marked as deactivated in the `StateStoreCoordinator` so that the executors can unload them and clear resources.
- Moved the code that determined the checkpoint directory of a state store from implementation-specific code (`HDFSBackedStateStoreProvider`) to non-specific code (StateStoreId), so that implementation do not accidentally get it wrong.
  - Also added store name to the path, to support multiple stores per sql operator partition.

*Note:* This change does not address the scenario where two tasks of the same run (e.g. speculative tasks) are concurrently running in the same executor. The chance of this very small, because ideally speculative tasks should never run in the same executor.

## How was this patch tested?
Existing unit tests + new unit test.

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #18355 from tdas/SPARK-21145.
  • Loading branch information
tdas committed Jun 23, 2017
1 parent b8a743b commit fe24634
Show file tree
Hide file tree
Showing 17 changed files with 329 additions and 166 deletions.
Expand Up @@ -311,7 +311,7 @@ object AggUtils {
val saved =
StateStoreSaveExec(
groupingAttributes,
stateId = None,
stateInfo = None,
outputMode = None,
eventTimeWatermark = None,
partialMerged2)
Expand Down
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.command

import java.util.UUID

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
Expand Down Expand Up @@ -117,7 +119,8 @@ case class ExplainCommand(
// This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the
// output mode does not matter since there is no `Sink`.
new IncrementalExecution(
sparkSession, logicalPlan, OutputMode.Append(), "<unknown>", 0, OffsetSeqMetadata(0, 0))
sparkSession, logicalPlan, OutputMode.Append(), "<unknown>",
UUID.randomUUID, 0, OffsetSeqMetadata(0, 0))
} else {
sparkSession.sessionState.executePlan(logicalPlan)
}
Expand Down
Expand Up @@ -50,7 +50,7 @@ case class FlatMapGroupsWithStateExec(
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
outputObjAttr: Attribute,
stateId: Option[OperatorStateId],
stateInfo: Option[StatefulOperatorStateInfo],
stateEncoder: ExpressionEncoder[Any],
outputMode: OutputMode,
timeoutConf: GroupStateTimeout,
Expand Down Expand Up @@ -107,10 +107,7 @@ case class FlatMapGroupsWithStateExec(
}

child.execute().mapPartitionsWithStateStore[InternalRow](
getStateId.checkpointLocation,
getStateId.operatorId,
storeName = "default",
getStateId.batchId,
getStateInfo,
groupingAttributes.toStructType,
stateAttributes.toStructType,
indexOrdinal = None,
Expand Down
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution.streaming

import java.util.UUID
import java.util.concurrent.atomic.AtomicInteger

import org.apache.spark.internal.Logging
Expand All @@ -36,6 +37,7 @@ class IncrementalExecution(
logicalPlan: LogicalPlan,
val outputMode: OutputMode,
val checkpointLocation: String,
val runId: UUID,
val currentBatchId: Long,
offsetSeqMetadata: OffsetSeqMetadata)
extends QueryExecution(sparkSession, logicalPlan) with Logging {
Expand Down Expand Up @@ -69,7 +71,13 @@ class IncrementalExecution(
* Records the current id for a given stateful operator in the query plan as the `state`
* preparation walks the query plan.
*/
private val operatorId = new AtomicInteger(0)
private val statefulOperatorId = new AtomicInteger(0)

/** Get the state info of the next stateful operator */
private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = {
StatefulOperatorStateInfo(
checkpointLocation, runId, statefulOperatorId.getAndIncrement(), currentBatchId)
}

/** Locates save/restore pairs surrounding aggregation. */
val state = new Rule[SparkPlan] {
Expand All @@ -78,35 +86,28 @@ class IncrementalExecution(
case StateStoreSaveExec(keys, None, None, None,
UnaryExecNode(agg,
StateStoreRestoreExec(keys2, None, child))) =>
val stateId =
OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)

val aggStateInfo = nextStatefulOperationStateInfo
StateStoreSaveExec(
keys,
Some(stateId),
Some(aggStateInfo),
Some(outputMode),
Some(offsetSeqMetadata.batchWatermarkMs),
agg.withNewChildren(
StateStoreRestoreExec(
keys,
Some(stateId),
Some(aggStateInfo),
child) :: Nil))

case StreamingDeduplicateExec(keys, child, None, None) =>
val stateId =
OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)

StreamingDeduplicateExec(
keys,
child,
Some(stateId),
Some(nextStatefulOperationStateInfo),
Some(offsetSeqMetadata.batchWatermarkMs))

case m: FlatMapGroupsWithStateExec =>
val stateId =
OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
m.copy(
stateId = Some(stateId),
stateInfo = Some(nextStatefulOperationStateInfo),
batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs))
}
Expand Down
Expand Up @@ -652,6 +652,7 @@ class StreamExecution(
triggerLogicalPlan,
outputMode,
checkpointFile("state"),
runId,
currentBatchId,
offsetSeqMetadata)
lastExecution.executedPlan // Force the lazy generation of execution plan
Expand Down
Expand Up @@ -92,7 +92,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
@volatile private var state: STATE = UPDATING
@volatile private var finalDeltaFile: Path = null

override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id
override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId

override def get(key: UnsafeRow): UnsafeRow = {
mapToUpdate.get(key)
Expand Down Expand Up @@ -177,7 +177,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
/**
* Whether all updates have been committed
*/
override private[streaming] def hasCommitted: Boolean = {
override def hasCommitted: Boolean = {
state == COMMITTED
}

Expand Down Expand Up @@ -205,15 +205,15 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
indexOrdinal: Option[Int], // for sorting the data
storeConf: StateStoreConf,
hadoopConf: Configuration): Unit = {
this.stateStoreId = stateStoreId
this.stateStoreId_ = stateStoreId
this.keySchema = keySchema
this.valueSchema = valueSchema
this.storeConf = storeConf
this.hadoopConf = hadoopConf
fs.mkdirs(baseDir)
}

override def id: StateStoreId = stateStoreId
override def stateStoreId: StateStoreId = stateStoreId_

/** Do maintenance backing data files, including creating snapshots and cleaning up old files */
override def doMaintenance(): Unit = {
Expand All @@ -231,20 +231,20 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
}

override def toString(): String = {
s"HDFSStateStoreProvider[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]"
s"HDFSStateStoreProvider[" +
s"id = (op=${stateStoreId.operatorId},part=${stateStoreId.partitionId}),dir = $baseDir]"
}

/* Internal fields and methods */

@volatile private var stateStoreId: StateStoreId = _
@volatile private var stateStoreId_ : StateStoreId = _
@volatile private var keySchema: StructType = _
@volatile private var valueSchema: StructType = _
@volatile private var storeConf: StateStoreConf = _
@volatile private var hadoopConf: Configuration = _

private lazy val loadedMaps = new mutable.HashMap[Long, MapType]
private lazy val baseDir =
new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}")
private lazy val baseDir = stateStoreId.storeCheckpointLocation()
private lazy val fs = baseDir.getFileSystem(hadoopConf)
private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)

Expand Down
Expand Up @@ -17,21 +17,22 @@

package org.apache.spark.sql.execution.streaming.state

import java.util.UUID
import java.util.concurrent.{ScheduledFuture, TimeUnit}
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
import scala.util.control.NonFatal

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.SparkEnv
import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{ThreadUtils, Utils}


/**
* Base trait for a versioned key-value store. Each instance of a `StateStore` represents a specific
* version of state data, and such instances are created through a [[StateStoreProvider]].
Expand Down Expand Up @@ -99,7 +100,7 @@ trait StateStore {
/**
* Whether all updates have been committed
*/
private[streaming] def hasCommitted: Boolean
def hasCommitted: Boolean
}


Expand Down Expand Up @@ -147,7 +148,7 @@ trait StateStoreProvider {
* Return the id of the StateStores this provider will generate.
* Should be the same as the one passed in init().
*/
def id: StateStoreId
def stateStoreId: StateStoreId

/** Called when the provider instance is unloaded from the executor */
def close(): Unit
Expand Down Expand Up @@ -179,13 +180,46 @@ object StateStoreProvider {
}
}

/**
* Unique identifier for a provider, used to identify when providers can be reused.
* Note that `queryRunId` is used uniquely identify a provider, so that the same provider
* instance is not reused across query restarts.
*/
case class StateStoreProviderId(storeId: StateStoreId, queryRunId: UUID)

/** Unique identifier for a bunch of keyed state data. */
/**
* Unique identifier for a bunch of keyed state data.
* @param checkpointRootLocation Root directory where all the state data of a query is stored
* @param operatorId Unique id of a stateful operator
* @param partitionId Index of the partition of an operators state data
* @param storeName Optional, name of the store. Each partition can optionally use multiple state
* stores, but they have to be identified by distinct names.
*/
case class StateStoreId(
checkpointLocation: String,
checkpointRootLocation: String,
operatorId: Long,
partitionId: Int,
name: String = "")
storeName: String = StateStoreId.DEFAULT_STORE_NAME) {

/**
* Checkpoint directory to be used by a single state store, identified uniquely by the tuple
* (operatorId, partitionId, storeName). All implementations of [[StateStoreProvider]] should
* use this path for saving state data, as this ensures that distinct stores will write to
* different locations.
*/
def storeCheckpointLocation(): Path = {
if (storeName == StateStoreId.DEFAULT_STORE_NAME) {
// For reading state store data that was generated before store names were used (Spark <= 2.2)
new Path(checkpointRootLocation, s"$operatorId/$partitionId")
} else {
new Path(checkpointRootLocation, s"$operatorId/$partitionId/$storeName")
}
}
}

object StateStoreId {
val DEFAULT_STORE_NAME = "default"
}

/** Mutable, and reusable class for representing a pair of UnsafeRows. */
class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) {
Expand All @@ -211,7 +245,7 @@ object StateStore extends Logging {
val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60

@GuardedBy("loadedProviders")
private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]()
private val loadedProviders = new mutable.HashMap[StateStoreProviderId, StateStoreProvider]()

/**
* Runs the `task` periodically and automatically cancels it if there is an exception. `onError`
Expand Down Expand Up @@ -253,7 +287,7 @@ object StateStore extends Logging {

/** Get or create a store associated with the id. */
def get(
storeId: StateStoreId,
storeProviderId: StateStoreProviderId,
keySchema: StructType,
valueSchema: StructType,
indexOrdinal: Option[Int],
Expand All @@ -264,24 +298,24 @@ object StateStore extends Logging {
val storeProvider = loadedProviders.synchronized {
startMaintenanceIfNeeded()
val provider = loadedProviders.getOrElseUpdate(
storeId,
storeProviderId,
StateStoreProvider.instantiate(
storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
storeProviderId.storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
)
reportActiveStoreInstance(storeId)
reportActiveStoreInstance(storeProviderId)
provider
}
storeProvider.getStore(version)
}

/** Unload a state store provider */
def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized {
loadedProviders.remove(storeId).foreach(_.close())
def unload(storeProviderId: StateStoreProviderId): Unit = loadedProviders.synchronized {
loadedProviders.remove(storeProviderId).foreach(_.close())
}

/** Whether a state store provider is loaded or not */
def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized {
loadedProviders.contains(storeId)
def isLoaded(storeProviderId: StateStoreProviderId): Boolean = loadedProviders.synchronized {
loadedProviders.contains(storeProviderId)
}

def isMaintenanceRunning: Boolean = loadedProviders.synchronized {
Expand Down Expand Up @@ -340,21 +374,21 @@ object StateStore extends Logging {
}
}

private def reportActiveStoreInstance(storeId: StateStoreId): Unit = {
private def reportActiveStoreInstance(storeProviderId: StateStoreProviderId): Unit = {
if (SparkEnv.get != null) {
val host = SparkEnv.get.blockManager.blockManagerId.host
val executorId = SparkEnv.get.blockManager.blockManagerId.executorId
coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId))
logDebug(s"Reported that the loaded instance $storeId is active")
coordinatorRef.foreach(_.reportActiveInstance(storeProviderId, host, executorId))
logInfo(s"Reported that the loaded instance $storeProviderId is active")
}
}

private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = {
private def verifyIfStoreInstanceActive(storeProviderId: StateStoreProviderId): Boolean = {
if (SparkEnv.get != null) {
val executorId = SparkEnv.get.blockManager.blockManagerId.executorId
val verified =
coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false)
logDebug(s"Verified whether the loaded instance $storeId is active: $verified")
coordinatorRef.map(_.verifyIfInstanceActive(storeProviderId, executorId)).getOrElse(false)
logDebug(s"Verified whether the loaded instance $storeProviderId is active: $verified")
verified
} else {
false
Expand All @@ -364,12 +398,21 @@ object StateStore extends Logging {
private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized {
val env = SparkEnv.get
if (env != null) {
if (_coordRef == null) {
logInfo("Env is not null")
val isDriver =
env.executorId == SparkContext.DRIVER_IDENTIFIER ||
env.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER
// If running locally, then the coordinator reference in _coordRef may be have become inactive
// as SparkContext + SparkEnv may have been restarted. Hence, when running in driver,
// always recreate the reference.
if (isDriver || _coordRef == null) {
logInfo("Getting StateStoreCoordinatorRef")
_coordRef = StateStoreCoordinatorRef.forExecutor(env)
}
logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}")
logInfo(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}")
Some(_coordRef)
} else {
logInfo("Env is null")
_coordRef = null
None
}
Expand Down

0 comments on commit fe24634

Please sign in to comment.