Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ case class StateMetadataTableEntry(
minBatchId: Long,
maxBatchId: Long,
operatorPropertiesJson: String,
numColsPrefixKey: Int) {
numColsPrefixKey: Int,
stateSchemaFilePath: Option[String]) {
def toRow(): InternalRow = {
new GenericInternalRow(
Array[Any](operatorId,
Expand Down Expand Up @@ -215,6 +216,8 @@ class StateMetadataPartitionReader(
}
}

// From v2, we also need to populate the operatorProperties and stateSchemaFilePath fields
// for use with the state data source reader
private[sql] lazy val stateMetadata: Iterator[StateMetadataTableEntry] = {
allOperatorStateMetadata.flatMap { operatorStateMetadata =>
require(operatorStateMetadata.version == 1 || operatorStateMetadata.version == 2)
Expand All @@ -228,7 +231,8 @@ class StateMetadataPartitionReader(
if (batchIds.nonEmpty) batchIds.head else -1,
if (batchIds.nonEmpty) batchIds.last else -1,
null,
stateStoreMetadata.numColsPrefixKey
stateStoreMetadata.numColsPrefixKey,
None
)
}
case v2: OperatorStateMetadataV2 =>
Expand All @@ -240,7 +244,8 @@ class StateMetadataPartitionReader(
if (batchIds.nonEmpty) batchIds.head else -1,
if (batchIds.nonEmpty) batchIds.last else -1,
v2.operatorPropertiesJson,
-1 // numColsPrefixKey is not available in OperatorStateMetadataV2
-1, // numColsPrefixKey is not available in OperatorStateMetadataV2
Some(stateStoreMetadata.stateSchemaFilePath)
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,17 @@ object OperatorStateMetadataUtils extends Logging {

private implicit val formats: Formats = Serialization.formats(NoTypeHints)

def readMetadata(inputStream: FSDataInputStream): Option[OperatorStateMetadata] = {
def readMetadata(
inputStream: FSDataInputStream,
expectedVersion: Int): Option[OperatorStateMetadata] = {
val inputReader =
new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))
try {
val versionStr = inputReader.readLine()
val version = MetadataVersionUtil.validateVersion(versionStr, 2)
if (version != expectedVersion) {
throw new IllegalArgumentException(s"Expected version $expectedVersion, but found $version")
}
Some(deserialize(version, inputReader))
} finally {
inputStream.close()
Expand Down Expand Up @@ -214,7 +219,7 @@ object OperatorStateMetadataV2 {
.classType[OperatorStateMetadataV2](implicitly[ClassTag[OperatorStateMetadataV2]].runtimeClass)

def metadataDirPath(stateCheckpointPath: Path): Path =
new Path(new Path(new Path(stateCheckpointPath, "_metadata"), "metadata"), "v2")
new Path(new Path(stateCheckpointPath, "_metadata"), "v2")

def metadataFilePath(stateCheckpointPath: Path, currentBatchId: Long): Path =
new Path(metadataDirPath(stateCheckpointPath), currentBatchId.toString)
Expand Down Expand Up @@ -259,7 +264,7 @@ class OperatorStateMetadataV1Reader(

def read(): Option[OperatorStateMetadata] = {
val inputStream = fm.open(metadataFilePath)
OperatorStateMetadataUtils.readMetadata(inputStream)
OperatorStateMetadataUtils.readMetadata(inputStream, version)
}
}

Expand Down Expand Up @@ -310,6 +315,6 @@ class OperatorStateMetadataV2Reader(
val metadataFilePath = OperatorStateMetadataV2.metadataFilePath(
stateCheckpointPath, lastBatchId)
val inputStream = fm.open(metadataFilePath)
OperatorStateMetadataUtils.readMetadata(inputStream)
OperatorStateMetadataUtils.readMetadata(inputStream, version)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceUnspecifiedRequiredOption, StateSourceOptions}
import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MemoryStream}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming.{OutputMode, StreamTest}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{OutputMode, RunningCountStatefulProcessor, StreamTest, TimeMode}
import org.apache.spark.sql.streaming.OutputMode.{Complete, Update}
import org.apache.spark.sql.test.SharedSparkSession

Expand All @@ -38,13 +39,30 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession {
private def checkOperatorStateMetadata(
checkpointDir: String,
operatorId: Int,
expectedMetadata: OperatorStateMetadataV1): Unit = {
expectedMetadata: OperatorStateMetadata,
expectedVersion: Int = 1): Unit = {
val statePath = new Path(checkpointDir, s"state/$operatorId")
val operatorMetadata = new OperatorStateMetadataV1Reader(statePath, hadoopConf).read()
.asInstanceOf[Option[OperatorStateMetadataV1]]
val operatorMetadata = OperatorStateMetadataReader.createReader(statePath,
hadoopConf, expectedVersion).read()
assert(operatorMetadata.isDefined)
assert(operatorMetadata.get.operatorInfo == expectedMetadata.operatorInfo &&
operatorMetadata.get.stateStoreInfo.sameElements(expectedMetadata.stateStoreInfo))
assert(operatorMetadata.get.version == expectedVersion)

if (expectedVersion == 1) {
val operatorMetadataV1 = operatorMetadata.get.asInstanceOf[OperatorStateMetadataV1]
val expectedMetadataV1 = expectedMetadata.asInstanceOf[OperatorStateMetadataV1]
assert(operatorMetadataV1.operatorInfo == expectedMetadata.operatorInfo &&
operatorMetadataV1.stateStoreInfo.sameElements(expectedMetadataV1.stateStoreInfo))
} else {
val operatorMetadataV2 = operatorMetadata.get.asInstanceOf[OperatorStateMetadataV2]
val expectedMetadataV2 = expectedMetadata.asInstanceOf[OperatorStateMetadataV2]
assert(operatorMetadataV2.operatorInfo == expectedMetadataV2.operatorInfo)
assert(operatorMetadataV2.operatorPropertiesJson.nonEmpty)
val stateStoreInfo = operatorMetadataV2.stateStoreInfo.head
val expectedStateStoreInfo = expectedMetadataV2.stateStoreInfo.head
assert(stateStoreInfo.stateSchemaFilePath.nonEmpty)
assert(stateStoreInfo.storeName == expectedStateStoreInfo.storeName)
assert(stateStoreInfo.numPartitions == expectedStateStoreInfo.numPartitions)
}
}

test("Serialize and deserialize stateful operator metadata") {
Expand Down Expand Up @@ -89,6 +107,35 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession {
}
}

test("Stateful operator metadata for streaming transformWithState") {
withTempDir { checkpointDir =>
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key -> numShufflePartitions.toString) {
val inputData = MemoryStream[String]
val result = inputData.toDS()
.groupByKey(x => x)
.transformWithState(new RunningCountStatefulProcessor(),
TimeMode.None(),
OutputMode.Update())

testStream(result, OutputMode.Update())(
StartStream(checkpointLocation = checkpointDir.toString),
AddData(inputData, "a"),
CheckNewAnswer(("a", "1")),
StopStream
)
}

// Assign some placeholder values to the state store metadata since they are generated
// dynamically by the operator.
val expectedMetadata = OperatorStateMetadataV2(OperatorInfoV1(0, "transformWithStateExec"),
Array(StateStoreMetadataV2("default", 0, numShufflePartitions, checkpointDir.toString)),
"")
checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata, 2)
}
}

test("Stateful operator metadata for streaming join") {
withTempDir { checkpointDir =>
val input1 = MemoryStream[Int]
Expand Down