Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-46979][SS] Add support for specifying key and value encoder separately and also for each col family in RocksDB state store provider #45038

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ class StatefulProcessorHandleImpl(
override def getValueState[T](stateName: String): ValueState[T] = {
verify(currState == CREATED, s"Cannot create state variable with name=$stateName after " +
"initialization is complete")
store.createColFamilyIfAbsent(stateName)
val resultState = new ValueStateImpl[T](store, stateName, keyEncoder)
resultState
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,20 @@ import org.apache.spark.sql.types._
* @param store - reference to the StateStore instance to be used for storing state
* @param stateName - name of logical state partition
* @param keyEnc - Spark SQL encoder for key
* @tparam K - data type of key
* @tparam S - data type of object that will be stored
*/
class ValueStateImpl[S](
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any]) extends ValueState[S] with Logging {

private val schemaForKeyRow: StructType = new StructType().add("key", BinaryType)

private val schemaForValueRow: StructType = new StructType().add("value", BinaryType)

store.createColFamilyIfAbsent(stateName, schemaForKeyRow, numColsPrefixKey = 0,
schemaForValueRow)

// TODO: validate places that are trying to encode the key and check if we can eliminate/
// add caching for some of these calls.
private def encodeKey(): UnsafeRow = {
Expand All @@ -52,14 +58,12 @@ class ValueStateImpl[S](
val keyByteArr = toRow
.apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()

val schemaForKeyRow: StructType = new StructType().add("key", BinaryType)
val keyEncoder = UnsafeProjection.create(schemaForKeyRow)
anishshri-db marked this conversation as resolved.
Show resolved Hide resolved
val keyRow = keyEncoder(InternalRow(keyByteArr))
keyRow
}

private def encodeValue(value: S): UnsafeRow = {
val schemaForValueRow: StructType = new StructType().add("value", BinaryType)
val valueByteArr = SerializationUtils.serialize(value.asInstanceOf[Serializable])
val valueEncoder = UnsafeProjection.create(schemaForValueRow)
val valueRow = valueEncoder(InternalRow(valueByteArr))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with

override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId

override def createColFamilyIfAbsent(colFamilyName: String): Unit = {
override def createColFamilyIfAbsent(
colFamilyName: String,
keySchema: StructType,
numColsPrefixKey: Int,
valueSchema: StructType): Unit = {
throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3193")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,35 @@ import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.unsafe.Platform

sealed trait RocksDBStateEncoder {
sealed trait RocksDBKeyStateEncoder {
def supportPrefixKeyScan: Boolean
def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte]
def extractPrefixKey(key: UnsafeRow): UnsafeRow

def encodeKey(row: UnsafeRow): Array[Byte]
def encodeValue(row: UnsafeRow): Array[Byte]

def decodeKey(keyBytes: Array[Byte]): UnsafeRow
}

sealed trait RocksDBValueStateEncoder {
def encodeValue(row: UnsafeRow): Array[Byte]
def decodeValue(valueBytes: Array[Byte]): UnsafeRow
def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair
}

object RocksDBStateEncoder {
def getEncoder(
def getKeyEncoder(
keySchema: StructType,
valueSchema: StructType,
numColsPrefixKey: Int): RocksDBStateEncoder = {
numColsPrefixKey: Int): RocksDBKeyStateEncoder = {
if (numColsPrefixKey > 0) {
new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey)
new PrefixKeyScanStateEncoder(keySchema, numColsPrefixKey)
} else {
new NoPrefixKeyStateEncoder(keySchema, valueSchema)
new NoPrefixKeyStateEncoder(keySchema)
}
}

def getValueEncoder(valueSchema: StructType): RocksDBValueStateEncoder = {
new SingleValueStateEncoder(valueSchema)
}

/**
* Encode the UnsafeRow of N bytes as a N+1 byte array.
* @note This creates a new byte array and memcopies the UnsafeRow to the new array.
Expand Down Expand Up @@ -86,10 +90,15 @@ object RocksDBStateEncoder {
}
}

/**
* RocksDB Key Encoder for UnsafeRow that supports prefix scan
*
* @param keySchema - schema of the key to be encoded
* @param numColsPrefixKey - number of columns to be used for prefix key
*/
class PrefixKeyScanStateEncoder(
keySchema: StructType,
valueSchema: StructType,
numColsPrefixKey: Int) extends RocksDBStateEncoder {
numColsPrefixKey: Int) extends RocksDBKeyStateEncoder {

import RocksDBStateEncoder._

Expand Down Expand Up @@ -120,8 +129,6 @@ class PrefixKeyScanStateEncoder(

// Reusable objects
private val joinedRowOnKey = new JoinedRow()
private val valueRow = new UnsafeRow(valueSchema.size)
private val rowTuple = new UnsafeRowPair()

override def encodeKey(row: UnsafeRow): Array[Byte] = {
val prefixKeyEncoded = encodeUnsafeRow(extractPrefixKey(row))
Expand All @@ -140,8 +147,6 @@ class PrefixKeyScanStateEncoder(
encodedBytes
}

override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)

override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
val prefixKeyEncodedLen = Platform.getInt(keyBytes, Platform.BYTE_ARRAY_OFFSET)
val prefixKeyEncoded = new Array[Byte](prefixKeyEncodedLen)
Expand All @@ -163,10 +168,6 @@ class PrefixKeyScanStateEncoder(
restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded))
}

override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
decodeToUnsafeRow(valueBytes, valueRow)
}

override def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
prefixKeyProjection(key)
}
Expand All @@ -180,14 +181,12 @@ class PrefixKeyScanStateEncoder(
prefix
}

override def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair = {
rowTuple.withRows(decodeKey(byteArrayTuple.key), decodeValue(byteArrayTuple.value))
}

override def supportPrefixKeyScan: Boolean = true
}

/**
* RocksDB Key Encoder for UnsafeRow that does not support prefix key scan.
*
* Encodes/decodes UnsafeRows to versioned byte arrays.
* It uses the first byte of the generated byte array to store the version the describes how the
* row is encoded in the rest of the byte array. Currently, the default version is 0,
Expand All @@ -197,20 +196,16 @@ class PrefixKeyScanStateEncoder(
* (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes,
* then the generated array byte will be N+1 bytes.
*/
class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType)
extends RocksDBStateEncoder {
class NoPrefixKeyStateEncoder(keySchema: StructType)
extends RocksDBKeyStateEncoder {

import RocksDBStateEncoder._

// Reusable objects
private val keyRow = new UnsafeRow(keySchema.size)
private val valueRow = new UnsafeRow(valueSchema.size)
private val rowTuple = new UnsafeRowPair()

override def encodeKey(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)

override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)

/**
* Decode byte array for a key to a UnsafeRow.
* @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
Expand All @@ -220,26 +215,6 @@ class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType)
decodeToUnsafeRow(keyBytes, keyRow)
}

/**
* Decode byte array for a value to a UnsafeRow.
*
* @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
* the given byte array.
*/
override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
decodeToUnsafeRow(valueBytes, valueRow)
}

/**
* Decode pair of key-value byte arrays in a pair of key-value UnsafeRows.
*
* @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
* the given byte array.
*/
override def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair = {
rowTuple.withRows(decodeKey(byteArrayTuple.key), decodeValue(byteArrayTuple.value))
}

override def supportPrefixKeyScan: Boolean = false

override def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
Expand All @@ -250,3 +225,36 @@ class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType)
throw new IllegalStateException("This encoder doesn't support prefix key!")
}
}

/**
* RocksDB Value Encoder for UnsafeRow that only supports single value.
*
* Encodes/decodes UnsafeRows to versioned byte arrays.
* It uses the first byte of the generated byte array to store the version the describes how the
* row is encoded in the rest of the byte array. Currently, the default version is 0,
*
* VERSION 0: [ VERSION (1 byte) | ROW (N bytes) ]
* The bytes of a UnsafeRow is written unmodified to starting from offset 1
* (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes,
* then the generated array byte will be N+1 bytes.
*/
class SingleValueStateEncoder(valueSchema: StructType)
extends RocksDBValueStateEncoder {

import RocksDBStateEncoder._

// Reusable objects
private val valueRow = new UnsafeRow(valueSchema.size)

override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)

/**
* Decode byte array for a value to a UnsafeRow.
*
* @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
* the given byte array.
*/
override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
decodeToUnsafeRow(valueBytes, valueRow)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,26 @@ private[sql] class RocksDBStateStoreProvider

override def version: Long = lastVersion

override def createColFamilyIfAbsent(colFamilyName: String): Unit = {
override def createColFamilyIfAbsent(
colFamilyName: String,
keySchema: StructType,
numColsPrefixKey: Int,
valueSchema: StructType): Unit = {
verify(colFamilyName != StateStore.DEFAULT_COL_FAMILY_NAME,
s"Failed to create column family with reserved_name=$colFamilyName")
verify(useColumnFamilies, "Column families are not supported in this store")
rocksDB.createColFamilyIfAbsent(colFamilyName)
keyValueEncoderMap.putIfAbsent(colFamilyName,
(RocksDBStateEncoder.getKeyEncoder(keySchema, numColsPrefixKey),
RocksDBStateEncoder.getValueEncoder(valueSchema)))
}

override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = {
verify(key != null, "Key cannot be null")
val value = encoder.decodeValue(rocksDB.get(encoder.encodeKey(key), colFamilyName))
if (!isValidated && value != null) {
val kvEncoder = keyValueEncoderMap.get(colFamilyName)
val value = kvEncoder._2.decodeValue(
rocksDB.get(kvEncoder._1.encodeKey(key), colFamilyName))
if (!isValidated && value != null && !useColumnFamilies) {
StateStoreProvider.validateStateRowFormat(
key, keySchema, value, valueSchema, storeConf)
isValidated = true
Expand All @@ -69,19 +79,25 @@ private[sql] class RocksDBStateStoreProvider
verify(state == UPDATING, "Cannot put after already committed or aborted")
verify(key != null, "Key cannot be null")
require(value != null, "Cannot put a null value")
rocksDB.put(encoder.encodeKey(key), encoder.encodeValue(value), colFamilyName)
val kvEncoder = keyValueEncoderMap.get(colFamilyName)
rocksDB.put(kvEncoder._1.encodeKey(key),
kvEncoder._2.encodeValue(value), colFamilyName)
}

override def remove(key: UnsafeRow, colFamilyName: String): Unit = {
verify(state == UPDATING, "Cannot remove after already committed or aborted")
verify(key != null, "Key cannot be null")
rocksDB.remove(encoder.encodeKey(key), colFamilyName)
val kvEncoder = keyValueEncoderMap.get(colFamilyName)
rocksDB.remove(kvEncoder._1.encodeKey(key), colFamilyName)
}

override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = {
val kvEncoder = keyValueEncoderMap.get(colFamilyName)
rocksDB.iterator(colFamilyName).map { kv =>
val rowPair = encoder.decode(kv)
if (!isValidated && rowPair.value != null) {
val rowPair = new UnsafeRowPair()
anishshri-db marked this conversation as resolved.
Show resolved Hide resolved
rowPair.withRows(kvEncoder._1.decodeKey(kv.key),
kvEncoder._2.decodeValue(kv.value))
if (!isValidated && rowPair.value != null && !useColumnFamilies) {
StateStoreProvider.validateStateRowFormat(
rowPair.key, keySchema, rowPair.value, valueSchema, storeConf)
isValidated = true
Expand All @@ -92,10 +108,17 @@ private[sql] class RocksDBStateStoreProvider

override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String):
Iterator[UnsafeRowPair] = {
require(encoder.supportPrefixKeyScan, "Prefix scan requires setting prefix key!")

val prefix = encoder.encodePrefixKey(prefixKey)
rocksDB.prefixScan(prefix, colFamilyName).map(kv => encoder.decode(kv))
val kvEncoder = keyValueEncoderMap.get(colFamilyName)
require(kvEncoder._1.supportPrefixKeyScan,
"Prefix scan requires setting prefix key!")

val prefix = kvEncoder._1.encodePrefixKey(prefixKey)
rocksDB.prefixScan(prefix, colFamilyName).map { kv =>
val rowPair = new UnsafeRowPair()
anishshri-db marked this conversation as resolved.
Show resolved Hide resolved
rowPair.withRows(kvEncoder._1.decodeKey(kv.key),
kvEncoder._2.decodeValue(kv.value))
rowPair
}
}

override def commit(): Long = synchronized {
Expand Down Expand Up @@ -191,8 +214,10 @@ private[sql] class RocksDBStateStoreProvider
def dbInstance(): RocksDB = rocksDB

/** Remove column family if exists */
override def removeColFamilyIfExists(colFamilyName: String): Unit = {
rocksDB.removeColFamilyIfExists(colFamilyName)
override def removeColFamilyIfExists(colFamilyName: String): Unit = {
verify(useColumnFamilies, "Column families are not supported in this store")
rocksDB.removeColFamilyIfExists(colFamilyName)
keyValueEncoderMap.remove(colFamilyName)
}
}

Expand All @@ -215,7 +240,9 @@ private[sql] class RocksDBStateStoreProvider
(keySchema.length > numColsPrefixKey), "The number of columns in the key must be " +
"greater than the number of columns for prefix key!")

this.encoder = RocksDBStateEncoder.getEncoder(keySchema, valueSchema, numColsPrefixKey)
keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Maybe microbenchmark could tell that this could regress for default column family only - map lookup with carefully crafted lock operation in every op, though I'd rather not concern before we see actual regression.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I didn't worry about it too much, given that the provider init likely happens once for long lived queries and where we can retain the use of the same provider on the same executor across m/batch executions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, what I meant is to look up concurrent map per "every op" to figure out encoder, for existing stateful operators - previously it was just a reference to the field. But ops is relatively very cheap compared to commit as of now, so let's see.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok - yea mainly didn't want to maintain 2 data structures for this. But if we find that its more expensive, then we can just split some of the logic for the default col family case

(RocksDBStateEncoder.getKeyEncoder(keySchema, numColsPrefixKey),
RocksDBStateEncoder.getValueEncoder(valueSchema)))

rocksDB // lazy initialization
}
Expand Down Expand Up @@ -287,7 +314,8 @@ private[sql] class RocksDBStateStoreProvider
useColumnFamilies)
}

@volatile private var encoder: RocksDBStateEncoder = _
@volatile private var keyValueEncoderMap = new java.util.concurrent.ConcurrentHashMap[String,
anishshri-db marked this conversation as resolved.
Show resolved Hide resolved
(RocksDBKeyStateEncoder, RocksDBValueStateEncoder)]

private def verify(condition: => Boolean, msg: String): Unit = {
if (!condition) { throw new IllegalStateException(msg) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@ trait StateStore extends ReadStateStore {
/**
* Create column family with given name, if absent.
*/
def createColFamilyIfAbsent(colFamilyName: String): Unit
def createColFamilyIfAbsent(
colFamilyName: String,
keySchema: StructType,
numColsPrefixKey: Int,
valueSchema: StructType): Unit

/**
* Put a new non-null value for a non-null key. Implementations must be aware that the UnsafeRows
Expand Down