Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-46979][SS] Add support for specifying key and value encoder separately and also for each col family in RocksDB state store provider #45038

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ class StatefulProcessorHandleImpl(
override def getValueState[T](stateName: String): ValueState[T] = {
verify(currState == CREATED, s"Cannot create state variable with name=$stateName after " +
"initialization is complete")
store.createColFamilyIfAbsent(stateName)
val resultState = new ValueStateImpl[T](store, stateName, keyEncoder)
resultState
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,20 @@ import org.apache.spark.sql.types._
* @param store - reference to the StateStore instance to be used for storing state
* @param stateName - name of logical state partition
* @param keyEnc - Spark SQL encoder for key
* @tparam K - data type of key
* @tparam S - data type of object that will be stored
*/
class ValueStateImpl[S](
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any]) extends ValueState[S] with Logging {

private val schemaForKeyRow: StructType = new StructType().add("key", BinaryType)

private val schemaForValueRow: StructType = new StructType().add("value", BinaryType)

store.createColFamilyIfAbsent(stateName, schemaForKeyRow, numColsPrefixKey = 0,
schemaForValueRow)

// TODO: validate places that are trying to encode the key and check if we can eliminate/
// add caching for some of these calls.
private def encodeKey(): UnsafeRow = {
Expand All @@ -52,14 +58,12 @@ class ValueStateImpl[S](
val keyByteArr = toRow
.apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()

val schemaForKeyRow: StructType = new StructType().add("key", BinaryType)
val keyEncoder = UnsafeProjection.create(schemaForKeyRow)
anishshri-db marked this conversation as resolved.
Show resolved Hide resolved
val keyRow = keyEncoder(InternalRow(keyByteArr))
keyRow
}

private def encodeValue(value: S): UnsafeRow = {
val schemaForValueRow: StructType = new StructType().add("value", BinaryType)
val valueByteArr = SerializationUtils.serialize(value.asInstanceOf[Serializable])
val valueEncoder = UnsafeProjection.create(schemaForValueRow)
val valueRow = valueEncoder(InternalRow(valueByteArr))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with

override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId

override def createColFamilyIfAbsent(colFamilyName: String): Unit = {
override def createColFamilyIfAbsent(
colFamilyName: String,
keySchema: StructType,
numColsPrefixKey: Int,
valueSchema: StructType): Unit = {
throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3193")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,35 @@ import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.unsafe.Platform

sealed trait RocksDBStateEncoder {
sealed trait RocksDBKeyStateEncoder {
def supportPrefixKeyScan: Boolean
def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte]
def extractPrefixKey(key: UnsafeRow): UnsafeRow

def encodeKey(row: UnsafeRow): Array[Byte]
def encodeValue(row: UnsafeRow): Array[Byte]

def decodeKey(keyBytes: Array[Byte]): UnsafeRow
}

sealed trait RocksDBValueStateEncoder {
def encodeValue(row: UnsafeRow): Array[Byte]
def decodeValue(valueBytes: Array[Byte]): UnsafeRow
def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair
}

object RocksDBStateEncoder {
def getEncoder(
def getKeyEncoder(
keySchema: StructType,
valueSchema: StructType,
numColsPrefixKey: Int): RocksDBStateEncoder = {
numColsPrefixKey: Int): RocksDBKeyStateEncoder = {
if (numColsPrefixKey > 0) {
new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey)
new PrefixKeyScanStateEncoder(keySchema, numColsPrefixKey)
} else {
new NoPrefixKeyStateEncoder(keySchema, valueSchema)
new NoPrefixKeyStateEncoder(keySchema)
}
}

def getValueEncoder(valueSchema: StructType): RocksDBValueStateEncoder = {
new SingleValueStateEncoder(valueSchema)
}

/**
* Encode the UnsafeRow of N bytes as a N+1 byte array.
* @note This creates a new byte array and memcopies the UnsafeRow to the new array.
Expand Down Expand Up @@ -86,10 +90,15 @@ object RocksDBStateEncoder {
}
}

/**
* RocksDB Key Encoder for UnsafeRow that supports prefix scan
*
* @param keySchema - schema of the key to be encoded
* @param numColsPrefixKey - number of columns to be used for prefix key
*/
class PrefixKeyScanStateEncoder(
keySchema: StructType,
valueSchema: StructType,
numColsPrefixKey: Int) extends RocksDBStateEncoder {
numColsPrefixKey: Int) extends RocksDBKeyStateEncoder {

import RocksDBStateEncoder._

Expand Down Expand Up @@ -120,8 +129,6 @@ class PrefixKeyScanStateEncoder(

// Reusable objects
private val joinedRowOnKey = new JoinedRow()
private val valueRow = new UnsafeRow(valueSchema.size)
private val rowTuple = new UnsafeRowPair()

override def encodeKey(row: UnsafeRow): Array[Byte] = {
val prefixKeyEncoded = encodeUnsafeRow(extractPrefixKey(row))
Expand All @@ -140,8 +147,6 @@ class PrefixKeyScanStateEncoder(
encodedBytes
}

override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)

override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
val prefixKeyEncodedLen = Platform.getInt(keyBytes, Platform.BYTE_ARRAY_OFFSET)
val prefixKeyEncoded = new Array[Byte](prefixKeyEncodedLen)
Expand All @@ -163,10 +168,6 @@ class PrefixKeyScanStateEncoder(
restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded))
}

override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
decodeToUnsafeRow(valueBytes, valueRow)
}

override def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
prefixKeyProjection(key)
}
Expand All @@ -180,14 +181,12 @@ class PrefixKeyScanStateEncoder(
prefix
}

override def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair = {
rowTuple.withRows(decodeKey(byteArrayTuple.key), decodeValue(byteArrayTuple.value))
}

override def supportPrefixKeyScan: Boolean = true
}

/**
* RocksDB Key Encoder for UnsafeRow that does not support prefix key scan.
*
* Encodes/decodes UnsafeRows to versioned byte arrays.
* It uses the first byte of the generated byte array to store the version the describes how the
* row is encoded in the rest of the byte array. Currently, the default version is 0,
Expand All @@ -197,20 +196,16 @@ class PrefixKeyScanStateEncoder(
* (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes,
* then the generated array byte will be N+1 bytes.
*/
class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType)
extends RocksDBStateEncoder {
class NoPrefixKeyStateEncoder(keySchema: StructType)
extends RocksDBKeyStateEncoder {

import RocksDBStateEncoder._

// Reusable objects
private val keyRow = new UnsafeRow(keySchema.size)
private val valueRow = new UnsafeRow(valueSchema.size)
private val rowTuple = new UnsafeRowPair()

override def encodeKey(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)

override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)

/**
* Decode byte array for a key to a UnsafeRow.
* @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
Expand All @@ -220,26 +215,6 @@ class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType)
decodeToUnsafeRow(keyBytes, keyRow)
}

/**
* Decode byte array for a value to a UnsafeRow.
*
* @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
* the given byte array.
*/
override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
decodeToUnsafeRow(valueBytes, valueRow)
}

/**
* Decode pair of key-value byte arrays in a pair of key-value UnsafeRows.
*
* @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
* the given byte array.
*/
override def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair = {
rowTuple.withRows(decodeKey(byteArrayTuple.key), decodeValue(byteArrayTuple.value))
}

override def supportPrefixKeyScan: Boolean = false

override def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
Expand All @@ -250,3 +225,36 @@ class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType)
throw new IllegalStateException("This encoder doesn't support prefix key!")
}
}

/**
* RocksDB Value Encoder for UnsafeRow that only supports single value.
*
* Encodes/decodes UnsafeRows to versioned byte arrays.
* It uses the first byte of the generated byte array to store the version the describes how the
* row is encoded in the rest of the byte array. Currently, the default version is 0,
*
* VERSION 0: [ VERSION (1 byte) | ROW (N bytes) ]
* The bytes of a UnsafeRow is written unmodified to starting from offset 1
* (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes,
* then the generated array byte will be N+1 bytes.
*/
class SingleValueStateEncoder(valueSchema: StructType)
extends RocksDBValueStateEncoder {

import RocksDBStateEncoder._

// Reusable objects
private val valueRow = new UnsafeRow(valueSchema.size)

override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)

/**
* Decode byte array for a value to a UnsafeRow.
*
* @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
* the given byte array.
*/
override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
decodeToUnsafeRow(valueBytes, valueRow)
}
}