Skip to content

Commit

Permalink
[SPARK-45511][SS] State Data Source - Reader
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR proposes to introduce a baseline implementation of state processor - reader.

State processor is a new data source which enables reading and writing the state in the existing checkpoint with the batch query. Since we implement the feature as data source, we are leveraging the UX for DataFrame API which most users are already familiar with.

Functionalities of the baseline implementation are following:

* Specify a state store instance via store name (default: DEFAULT)
* Specify a stateful operator via operator ID (default: 0)
* Specify a batch ID (default: last committed)
* Specify the source option joinSide to construct input rows in the state store for stream-stream join
  * It is still possible that users can read a specific state store instance from 4 instances in stream-stream join, which would be used mostly for debugging Spark itself
  * When this is enabled, the data source hides the internal column from the output.
* Specify a metadata column (_partition_id)so that users can indicate the partition ID for the state row.

### Why are the changes needed?

Please refer to the SPIP doc for rationale: https://docs.google.com/document/d/1_iVf_CIu2RZd3yWWF6KoRNlBiz5NbSIK0yThqG0EvPY/edit?usp=sharing

### Does this PR introduce _any_ user-facing change?

Yes, we are adding a new data source.

### How was this patch tested?

New test suite.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #43425 from HeartSaVioR/SPARK-45511.

Authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
HeartSaVioR committed Nov 15, 2023
1 parent f8ccf20 commit 74a9c6c
Show file tree
Hide file tree
Showing 15 changed files with 2,121 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ object CheckConnectJvmClientCompatibility {

// RuntimeConfig
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RuntimeConfig$"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.RuntimeConfig.sqlConf"),

// DataStreamWriter
ProblemFilters.exclude[MissingClassProblem](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
org.apache.spark.sql.execution.datasources.binaryfile.BinaryFileFormat
org.apache.spark.sql.execution.streaming.sources.RatePerMicroBatchProvider
org.apache.spark.sql.execution.datasources.v2.state.StateMetadataSource
org.apache.spark.sql.execution.datasources.v2.state.StateMetadataSource
org.apache.spark.sql.execution.datasources.v2.state.StateDataSource
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.internal.SQLConf
* @since 2.0.0
*/
@Stable
class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) {
class RuntimeConfig private[sql](val sqlConf: SQLConf = new SQLConf) {

/**
* Sets the given Spark runtime configuration property.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.state

import java.util
import java.util.UUID

import scala.util.control.NonFatal

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{RuntimeConfig, SparkSession}
import org.apache.spark.sql.catalyst.DataSourceOptions
import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues
import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata}
import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE}
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
* An implementation of [[TableProvider]] with [[DataSourceRegister]] for State Store data source.
*/
class StateDataSource extends TableProvider with DataSourceRegister {
private lazy val session: SparkSession = SparkSession.active

private lazy val hadoopConf: Configuration = session.sessionState.newHadoopConf()

override def shortName(): String = "statestore"

override def getTable(
schema: StructType,
partitioning: Array[Transform],
properties: util.Map[String, String]): Table = {
val sourceOptions = StateSourceOptions.apply(session, hadoopConf, properties)
val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId)
new StateTable(session, schema, sourceOptions, stateConf)
}

override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
val sourceOptions = StateSourceOptions.apply(session, hadoopConf, options)

val stateCheckpointLocation = sourceOptions.stateCheckpointLocation
try {
val (keySchema, valueSchema) = sourceOptions.joinSide match {
case JoinSideValues.left =>
StreamStreamJoinStateHelper.readKeyValueSchema(session, stateCheckpointLocation.toString,
sourceOptions.operatorId, LeftSide)

case JoinSideValues.right =>
StreamStreamJoinStateHelper.readKeyValueSchema(session, stateCheckpointLocation.toString,
sourceOptions.operatorId, RightSide)

case JoinSideValues.none =>
val storeId = new StateStoreId(stateCheckpointLocation.toString, sourceOptions.operatorId,
partitionId, sourceOptions.storeName)
val providerId = new StateStoreProviderId(storeId, UUID.randomUUID())
val manager = new StateSchemaCompatibilityChecker(providerId, hadoopConf)
manager.readSchemaFile()
}

new StructType()
.add("key", keySchema)
.add("value", valueSchema)
} catch {
case NonFatal(e) =>
throw new IllegalArgumentException("Failed to read the state schema. Either the file " +
s"does not exist, or the file is corrupted. options: $sourceOptions", e)
}
}

private def buildStateStoreConf(checkpointLocation: String, batchId: Long): StateStoreConf = {
val offsetLog = new OffsetSeqLog(session,
new Path(checkpointLocation, DIR_NAME_OFFSETS).toString)
offsetLog.get(batchId) match {
case Some(value) =>
val metadata = value.metadata.getOrElse(
throw new IllegalStateException(s"Metadata is not available for offset log for " +
s"$batchId, checkpoint location $checkpointLocation")
)

val clonedRuntimeConf = new RuntimeConfig(session.sessionState.conf.clone())
OffsetSeqMetadata.setSessionConf(metadata, clonedRuntimeConf)
StateStoreConf(clonedRuntimeConf.sqlConf)

case _ =>
throw new IllegalStateException(s"The offset log for $batchId does not exist, " +
s"checkpoint location $checkpointLocation")
}
}

override def supportsExternalMetadata(): Boolean = false
}

case class StateSourceOptions(
resolvedCpLocation: String,
batchId: Long,
operatorId: Int,
storeName: String,
joinSide: JoinSideValues) {
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)
}

object StateSourceOptions extends DataSourceOptions {
val PATH = newOption("path")
val BATCH_ID = newOption("batchId")
val OPERATOR_ID = newOption("operatorId")
val STORE_NAME = newOption("storeName")
val JOIN_SIDE = newOption("joinSide")

object JoinSideValues extends Enumeration {
type JoinSideValues = Value
val left, right, none = Value
}

def apply(
sparkSession: SparkSession,
hadoopConf: Configuration,
properties: util.Map[String, String]): StateSourceOptions = {
apply(sparkSession, hadoopConf, new CaseInsensitiveStringMap(properties))
}

def apply(
sparkSession: SparkSession,
hadoopConf: Configuration,
options: CaseInsensitiveStringMap): StateSourceOptions = {
val checkpointLocation = Option(options.get(PATH)).orElse {
throw new IllegalArgumentException(s"'$PATH' must be specified.")
}.get

val resolvedCpLocation = resolvedCheckpointLocation(hadoopConf, checkpointLocation)

val batchId = Option(options.get(BATCH_ID)).map(_.toLong).orElse {
Some(getLastCommittedBatch(sparkSession, resolvedCpLocation))
}.get

if (batchId < 0) {
throw new IllegalArgumentException(s"'$BATCH_ID' cannot be negative.")
}

val operatorId = Option(options.get(OPERATOR_ID)).map(_.toInt)
.orElse(Some(0)).get

if (operatorId < 0) {
throw new IllegalArgumentException(s"'$OPERATOR_ID' cannot be negative.")
}

val storeName = Option(options.get(STORE_NAME))
.map(_.trim)
.getOrElse(StateStoreId.DEFAULT_STORE_NAME)

if (storeName.isEmpty) {
throw new IllegalArgumentException(s"'$STORE_NAME' cannot be an empty string.")
}

val joinSide = try {
Option(options.get(JOIN_SIDE))
.map(JoinSideValues.withName).getOrElse(JoinSideValues.none)
} catch {
case _: NoSuchElementException =>
// convert to IllegalArgumentException
throw new IllegalArgumentException(s"Incorrect value of the option " +
s"'$JOIN_SIDE'. Valid values are ${JoinSideValues.values.mkString(",")}")
}

if (joinSide != JoinSideValues.none && storeName != StateStoreId.DEFAULT_STORE_NAME) {
throw new IllegalArgumentException(s"The options '$JOIN_SIDE' and " +
s"'$STORE_NAME' cannot be specified together. Please specify either one.")
}

StateSourceOptions(resolvedCpLocation, batchId, operatorId, storeName, joinSide)
}

private def resolvedCheckpointLocation(
hadoopConf: Configuration,
checkpointLocation: String): String = {
val checkpointPath = new Path(checkpointLocation)
val fs = checkpointPath.getFileSystem(hadoopConf)
checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString
}

private def getLastCommittedBatch(session: SparkSession, checkpointLocation: String): Long = {
val commitLog = new CommitLog(session,
new Path(checkpointLocation, DIR_NAME_COMMITS).toString)
commitLog.getLatest() match {
case Some((lastId, _)) => lastId
case None => throw new IllegalStateException("No committed batch found, " +
s"checkpoint location: $checkpointLocation")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2.state

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, UnsafeRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

/**
* An implementation of [[PartitionReaderFactory]] for State data source. This is used to support
* general read from a state store instance, rather than specific to the operator.
*/
class StatePartitionReaderFactory(
storeConf: StateStoreConf,
hadoopConf: SerializableConfiguration,
schema: StructType) extends PartitionReaderFactory {

override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
new StatePartitionReader(storeConf, hadoopConf,
partition.asInstanceOf[StateStoreInputPartition], schema)
}
}

/**
* An implementation of [[PartitionReader]] for State data source. This is used to support
* general read from a state store instance, rather than specific to the operator.
*/
class StatePartitionReader(
storeConf: StateStoreConf,
hadoopConf: SerializableConfiguration,
partition: StateStoreInputPartition,
schema: StructType) extends PartitionReader[InternalRow] {

private val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
private val valueSchema = SchemaUtil.getSchemaAsDataType(schema, "value").asInstanceOf[StructType]

private lazy val store: ReadStateStore = {
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId)

// TODO: This does not handle the case of session window aggregation; we don't have an
// information whether the state store uses prefix scan or not. We will have to add such
// information to determine the right encoder/decoder for the data.
StateStore.getReadOnly(stateStoreProviderId, keySchema, valueSchema,
numColsPrefixKey = 0, version = partition.sourceOptions.batchId + 1, storeConf = storeConf,
hadoopConf = hadoopConf.value)
}

private lazy val iter: Iterator[InternalRow] = {
store.iterator().map(pair => unifyStateRowPair((pair.key, pair.value)))
}

private var current: InternalRow = _

override def next(): Boolean = {
if (iter.hasNext) {
current = iter.next()
true
} else {
current = null
false
}
}

private val joinedRow = new JoinedRow()

private def addMetadata(row: InternalRow): InternalRow = {
val metadataRow = new GenericInternalRow(
StateTable.METADATA_COLUMNS.map(_.name()).map {
case "_partition_id" => partition.partition.asInstanceOf[Any]
}.toArray
)
joinedRow.withLeft(row).withRight(metadataRow)
}

override def get(): InternalRow = addMetadata(current)

override def close(): Unit = {
current = null
store.abort()
}

private def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow)): InternalRow = {
val row = new GenericInternalRow(2)
row.update(0, pair._1)
row.update(1, pair._2)
row
}
}

0 comments on commit 74a9c6c

Please sign in to comment.