Skip to content

Commit

Permalink
rebase on mapstate
Browse files Browse the repository at this point in the history
  • Loading branch information
jingz-db committed Mar 13, 2024
1 parent 7b0c7c7 commit 8fbd501
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 9 deletions.
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming
import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider}
import org.apache.spark.sql.internal.SQLConf

case class InputMapRow(key: String, action: String, value: (String, String))
Expand Down Expand Up @@ -79,7 +79,8 @@ class TestMapStateProcessor
* Class that adds integration tests for MapState types used in arbitrary stateful
* operators such as transformWithState.
*/
class TransformWithMapStateSuite extends StreamTest {
class TransformWithMapStateSuite extends StreamTest
with AlsoTestWithChangelogCheckpointingEnabled {
import testImplicits._

private def testMapStateWithNullUserKey(inputMapRow: InputMapRow): Unit = {
Expand Down
Expand Up @@ -28,18 +28,23 @@ class StatefulProcessorWithInitialStateTestClass extends StatefulProcessorWithIn
String, InitInputRow, (String, String, Double), (String, Double)] {
@transient var _valState: ValueState[Double] = _
@transient var _listState: ListState[Double] = _
// TODO will add mapstate test after rebase
@transient var _mapState: MapState[Double, Int] = _

override def handleInitialState(
key: String,
initialState: (String, Double)): Unit = {
_valState.update(initialState._2)
_listState.appendValue(initialState._2)
val initStateVal = initialState._2
_valState.update(initStateVal)
_listState.appendValue(initStateVal)
// mapState acts as an occurrence counter
_mapState.updateValue(initStateVal, 1)
}

override def init(operatorOutputMode: OutputMode): Unit = {
_valState = getHandle.getValueState[Double]("testValueInit", Encoders.scalaDouble)
_listState = getHandle.getListState[Double]("testListInit", Encoders.scalaDouble)
_mapState = getHandle.getMapState[Double, Int](
"testMapInit", Encoders.scalaDouble, Encoders.scalaInt)
}

override def close(): Unit = {}
Expand All @@ -64,6 +69,18 @@ class StatefulProcessorWithInitialStateTestClass extends StatefulProcessorWithIn
_listState.appendValue(row.value)
} else if (row.action == "clearList") {
_listState.clear()
} else if (row.action == "getCount") {
val count =
if (!_mapState.containsKey(row.value)) 0
else _mapState.getValue(row.value)
output = (key, row.action, count.toDouble) :: output
} else if (row.action == "incCount") {
val count =
if (!_mapState.containsKey(row.value)) 0
else _mapState.getValue(row.value)
_mapState.updateValue(row.value, count + 1)
} else if (row.action == "clearCount") {
_mapState.removeKey(row.value)
}
}
output.iterator
Expand Down Expand Up @@ -130,10 +147,13 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest
AddData(inputData, InitInputRow("non-exist", "getList", -1.0)),
CheckNewAnswer(),

// test remove
AddData(inputData, InitInputRow("k1", "remove", -1.0)),
AddData(inputData, InitInputRow("k1", "getOption", -1.0)),
CheckNewAnswer(("k1", "getOption", -1.0)),
AddData(inputData, InitInputRow("k1", "incCount", 37.0)),
AddData(inputData, InitInputRow("k2", "incCount", 40.0)),
AddData(inputData, InitInputRow("non-exist", "getCount", -1.0)),
CheckNewAnswer(("non-exist", "getCount", 0.0)),
AddData(inputData, InitInputRow("k2", "incCount", 40.0)),
AddData(inputData, InitInputRow("k2", "getCount", 40.0)),
CheckNewAnswer(("k2", "getCount", 2.0)),

// test every row in initial State is processed
AddData(inputData, InitInputRow("init_1", "getOption", -1.0)),
Expand All @@ -146,6 +166,11 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest
AddData(inputData, InitInputRow("init_2", "getList", -1.0)),
CheckNewAnswer(("init_2", "getList", 100.0)),

AddData(inputData, InitInputRow("init_1", "getCount", 40.0)),
CheckNewAnswer(("init_1", "getCount", 1.0)),
AddData(inputData, InitInputRow("init_2", "getCount", 100.0)),
CheckNewAnswer(("init_2", "getCount", 1.0)),

// Update row with key in initial row will work
AddData(inputData, InitInputRow("init_1", "update", 50.0)),
AddData(inputData, InitInputRow("init_1", "getOption", -1.0)),
Expand All @@ -157,6 +182,20 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest
AddData(inputData, InitInputRow("init_1", "appendList", 50.0)),
AddData(inputData, InitInputRow("init_1", "getList", -1.0)),
CheckNewAnswer(("init_1", "getList", 50.0), ("init_1", "getList", 40.0)),

AddData(inputData, InitInputRow("init_1", "incCount", 40.0)),
AddData(inputData, InitInputRow("init_1", "getCount", 40.0)),
CheckNewAnswer(("init_1", "getCount", 2.0)),

// test remove
AddData(inputData, InitInputRow("k1", "remove", -1.0)),
AddData(inputData, InitInputRow("k1", "getOption", -1.0)),
CheckNewAnswer(("k1", "getOption", -1.0)),

AddData(inputData, InitInputRow("init_1", "clearCount", -1.0)),
AddData(inputData, InitInputRow("init_1", "getCount", -1.0)),
CheckNewAnswer(("init_1", "getCount", 0.0)),

AddData(inputData, InitInputRow("init_1", "clearList", -1.0)),
AddData(inputData, InitInputRow("init_1", "getList", -1.0)),
CheckNewAnswer()
Expand Down

0 comments on commit 8fbd501

Please sign in to comment.