diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala index 0024ef1a5cae8..4972ec152b584 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -47,7 +47,8 @@ case class StateMetadataTableEntry( minBatchId: Long, maxBatchId: Long, operatorPropertiesJson: String, - numColsPrefixKey: Int) { + numColsPrefixKey: Int, + stateSchemaFilePath: Option[String]) { def toRow(): InternalRow = { new GenericInternalRow( Array[Any](operatorId, @@ -215,6 +216,8 @@ class StateMetadataPartitionReader( } } + // From v2, we also need to populate the operatorProperties and stateSchemaFilePath fields + // for use with the state data source reader private[sql] lazy val stateMetadata: Iterator[StateMetadataTableEntry] = { allOperatorStateMetadata.flatMap { operatorStateMetadata => require(operatorStateMetadata.version == 1 || operatorStateMetadata.version == 2) @@ -228,7 +231,8 @@ class StateMetadataPartitionReader( if (batchIds.nonEmpty) batchIds.head else -1, if (batchIds.nonEmpty) batchIds.last else -1, null, - stateStoreMetadata.numColsPrefixKey + stateStoreMetadata.numColsPrefixKey, + None ) } case v2: OperatorStateMetadataV2 => @@ -240,7 +244,8 @@ class StateMetadataPartitionReader( if (batchIds.nonEmpty) batchIds.head else -1, if (batchIds.nonEmpty) batchIds.last else -1, v2.operatorPropertiesJson, - -1 // numColsPrefixKey is not available in OperatorStateMetadataV2 + -1, // numColsPrefixKey is not available in OperatorStateMetadataV2 + Some(stateStoreMetadata.stateSchemaFilePath) ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index df3de5d9ceab6..e7b63272f38d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -104,12 +104,17 @@ object OperatorStateMetadataUtils extends Logging { private implicit val formats: Formats = Serialization.formats(NoTypeHints) - def readMetadata(inputStream: FSDataInputStream): Option[OperatorStateMetadata] = { + def readMetadata( + inputStream: FSDataInputStream, + expectedVersion: Int): Option[OperatorStateMetadata] = { val inputReader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) try { val versionStr = inputReader.readLine() val version = MetadataVersionUtil.validateVersion(versionStr, 2) + if (version != expectedVersion) { + throw new IllegalArgumentException(s"Expected version $expectedVersion, but found $version") + } Some(deserialize(version, inputReader)) } finally { inputStream.close() @@ -214,7 +219,7 @@ object OperatorStateMetadataV2 { .classType[OperatorStateMetadataV2](implicitly[ClassTag[OperatorStateMetadataV2]].runtimeClass) def metadataDirPath(stateCheckpointPath: Path): Path = - new Path(new Path(new Path(stateCheckpointPath, "_metadata"), "metadata"), "v2") + new Path(new Path(stateCheckpointPath, "_metadata"), "v2") def metadataFilePath(stateCheckpointPath: Path, currentBatchId: Long): Path = new Path(metadataDirPath(stateCheckpointPath), currentBatchId.toString) @@ -259,7 +264,7 @@ class OperatorStateMetadataV1Reader( def read(): Option[OperatorStateMetadata] = { val inputStream = fm.open(metadataFilePath) - OperatorStateMetadataUtils.readMetadata(inputStream) + OperatorStateMetadataUtils.readMetadata(inputStream, version) } } @@ -310,6 +315,6 @@ class OperatorStateMetadataV2Reader( val metadataFilePath = OperatorStateMetadataV2.metadataFilePath( stateCheckpointPath, lastBatchId) val inputStream = fm.open(metadataFilePath) - OperatorStateMetadataUtils.readMetadata(inputStream) + OperatorStateMetadataUtils.readMetadata(inputStream, version) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala index 65d32b474708a..b1f3a5233752a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceUnspecifiedRequiredOption, StateSourceOptions} import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MemoryStream} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.streaming.{OutputMode, StreamTest} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.{OutputMode, RunningCountStatefulProcessor, StreamTest, TimeMode} import org.apache.spark.sql.streaming.OutputMode.{Complete, Update} import org.apache.spark.sql.test.SharedSparkSession @@ -38,13 +39,30 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { private def checkOperatorStateMetadata( checkpointDir: String, operatorId: Int, - expectedMetadata: OperatorStateMetadataV1): Unit = { + expectedMetadata: OperatorStateMetadata, + expectedVersion: Int = 1): Unit = { val statePath = new Path(checkpointDir, s"state/$operatorId") - val operatorMetadata = new OperatorStateMetadataV1Reader(statePath, hadoopConf).read() - .asInstanceOf[Option[OperatorStateMetadataV1]] + val operatorMetadata = OperatorStateMetadataReader.createReader(statePath, + hadoopConf, expectedVersion).read() assert(operatorMetadata.isDefined) - assert(operatorMetadata.get.operatorInfo == expectedMetadata.operatorInfo && - operatorMetadata.get.stateStoreInfo.sameElements(expectedMetadata.stateStoreInfo)) + assert(operatorMetadata.get.version == expectedVersion) + + if (expectedVersion == 1) { + val operatorMetadataV1 = operatorMetadata.get.asInstanceOf[OperatorStateMetadataV1] + val expectedMetadataV1 = expectedMetadata.asInstanceOf[OperatorStateMetadataV1] + assert(operatorMetadataV1.operatorInfo == expectedMetadata.operatorInfo && + operatorMetadataV1.stateStoreInfo.sameElements(expectedMetadataV1.stateStoreInfo)) + } else { + val operatorMetadataV2 = operatorMetadata.get.asInstanceOf[OperatorStateMetadataV2] + val expectedMetadataV2 = expectedMetadata.asInstanceOf[OperatorStateMetadataV2] + assert(operatorMetadataV2.operatorInfo == expectedMetadataV2.operatorInfo) + assert(operatorMetadataV2.operatorPropertiesJson.nonEmpty) + val stateStoreInfo = operatorMetadataV2.stateStoreInfo.head + val expectedStateStoreInfo = expectedMetadataV2.stateStoreInfo.head + assert(stateStoreInfo.stateSchemaFilePath.nonEmpty) + assert(stateStoreInfo.storeName == expectedStateStoreInfo.storeName) + assert(stateStoreInfo.numPartitions == expectedStateStoreInfo.numPartitions) + } } test("Serialize and deserialize stateful operator metadata") { @@ -89,6 +107,35 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { } } + test("Stateful operator metadata for streaming transformWithState") { + withTempDir { checkpointDir => + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> numShufflePartitions.toString) { + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.toString), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) + } + + // Assign some placeholder values to the state store metadata since they are generated + // dynamically by the operator. + val expectedMetadata = OperatorStateMetadataV2(OperatorInfoV1(0, "transformWithStateExec"), + Array(StateStoreMetadataV2("default", 0, numShufflePartitions, checkpointDir.toString)), + "") + checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata, 2) + } + } + test("Stateful operator metadata for streaming join") { withTempDir { checkpointDir => val input1 = MemoryStream[Int]