Skip to content

Commit

Permalink
Add tests for StatefulOperatorsHelper as well
Browse files Browse the repository at this point in the history
  • Loading branch information
HeartSaVioR committed Jul 20, 2018
1 parent 63dfb5d commit e844636
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 23 deletions.
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.state

import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.sql.catalyst.expressions.UnsafeRow

class MemoryStateStore extends StateStore() {
import scala.collection.JavaConverters._
private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow]

override def iterator(): Iterator[UnsafeRowPair] = {
map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) }
}

override def get(key: UnsafeRow): UnsafeRow = map.get(key)

override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = {
map.put(key.copy(), newValue.copy())
}

override def remove(key: UnsafeRow): Unit = {
map.remove(key)
}

override def commit(): Long = version + 1

override def abort(): Unit = {}

override def id: StateStoreId = null

override def version: Long = 0

override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty)

override def hasCommitted: Boolean = true
}
@@ -0,0 +1,121 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.state

import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.execution.streaming.StatefulOperatorsHelper.StreamingAggregationStateManager
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class StatefulOperatorsHelperSuite extends StreamTest {
import TestMaterial._

test("StateManager v1 - get, put, iter") {
val stateManager = newStateManager(KEYS_ATTRIBUTES, OUTPUT_ATTRIBUTES, 1)

// in V1, input row is stored as value
testGetPutIterOnStateManager(stateManager, OUTPUT_ATTRIBUTES, TEST_ROW, TEST_KEY_ROW, TEST_ROW)
}

// ============================ StateManagerImplV2 ============================
test("StateManager v2 - get, put, iter") {
val stateManager = newStateManager(KEYS_ATTRIBUTES, OUTPUT_ATTRIBUTES, 2)

// in V2, row for values itself (excluding keys from input row) is stored as value
// so that stored value doesn't have key part, but state manager V2 will provide same output
// as V1 when getting row for key
testGetPutIterOnStateManager(stateManager, VALUES_ATTRIBUTES, TEST_ROW, TEST_KEY_ROW,
TEST_VALUE_ROW)
}

private def newStateManager(
keysAttributes: Seq[Attribute],
outputAttributes: Seq[Attribute],
version: Int): StreamingAggregationStateManager = {
StreamingAggregationStateManager.createStateManager(keysAttributes, outputAttributes, version)
}

private def testGetPutIterOnStateManager(
stateManager: StreamingAggregationStateManager,
expectedValueExpressions: Seq[Attribute],
inputRow: UnsafeRow,
expectedStateKey: UnsafeRow,
expectedStateValue: UnsafeRow): Unit = {

assert(stateManager.getValueExpressions === expectedValueExpressions)

val memoryStateStore = new MemoryStateStore()
stateManager.put(memoryStateStore, inputRow)

assert(memoryStateStore.iterator().size === 1)

val keyRow = stateManager.extractKey(inputRow)
assert(keyRow === expectedStateKey)

// iterate state store and verify whether expected format of key and value are stored
val pair = memoryStateStore.iterator().next()
assert(pair.key === keyRow)
assert(pair.value === expectedStateValue)
assert(stateManager.restoreOriginRow(pair) === inputRow)

// verify the stored value once again via get
assert(memoryStateStore.get(keyRow) === expectedStateValue)

// state manager should return row which is same as input row regardless of format version
assert(inputRow === stateManager.get(memoryStateStore, keyRow))
}

}

object TestMaterial {
val KEYS: Seq[String] = Seq("key1", "key2")
val VALUES: Seq[String] = Seq("sum(key1)", "sum(key2)")

val OUTPUT_SCHEMA: StructType = StructType(
KEYS.map(createIntegerField) ++ VALUES.map(createIntegerField))

val OUTPUT_ATTRIBUTES: Seq[Attribute] = OUTPUT_SCHEMA.toAttributes
val KEYS_ATTRIBUTES: Seq[Attribute] = OUTPUT_ATTRIBUTES.filter { p =>
KEYS.contains(p.name)
}
val VALUES_ATTRIBUTES: Seq[Attribute] = OUTPUT_ATTRIBUTES.filter { p =>
VALUES.contains(p.name)
}

val TEST_ROW: UnsafeRow = {
val unsafeRowProjection = UnsafeProjection.create(OUTPUT_SCHEMA)
val row = unsafeRowProjection(new SpecificInternalRow(OUTPUT_SCHEMA))
(KEYS ++ VALUES).zipWithIndex.foreach { case (_, index) => row.setInt(index, index) }
row
}

val TEST_KEY_ROW: UnsafeRow = {
val keyProjector = GenerateUnsafeProjection.generate(KEYS_ATTRIBUTES, OUTPUT_ATTRIBUTES)
keyProjector(TEST_ROW)
}

val TEST_VALUE_ROW: UnsafeRow = {
val valueProjector = GenerateUnsafeProjection.generate(VALUES_ATTRIBUTES, OUTPUT_ATTRIBUTES)
valueProjector(TEST_ROW)
}

private def createIntegerField(name: String): StructField = {
StructField(name, IntegerType, nullable = false)
}
}
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.streaming

import java.io.File
import java.sql.Date
import java.util.concurrent.ConcurrentHashMap

import org.apache.commons.io.FileUtils
import org.scalatest.BeforeAndAfterAll
Expand All @@ -34,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.RDDScanExec
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair}
import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, MemoryStateStore, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.types.{DataType, IntegerType}
Expand Down Expand Up @@ -1286,27 +1285,6 @@ object FlatMapGroupsWithStateSuite {

var failInTask = true

class MemoryStateStore extends StateStore() {
import scala.collection.JavaConverters._
private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow]

override def iterator(): Iterator[UnsafeRowPair] = {
map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) }
}

override def get(key: UnsafeRow): UnsafeRow = map.get(key)
override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = {
map.put(key.copy(), newValue.copy())
}
override def remove(key: UnsafeRow): Unit = { map.remove(key) }
override def commit(): Long = version + 1
override def abort(): Unit = { }
override def id: StateStoreId = null
override def version: Long = 0
override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty)
override def hasCommitted: Boolean = true
}

def assertCanGetProcessingTime(predicate: => Boolean): Unit = {
if (!predicate) throw new TestFailedException("Could not get processing time", 20)
}
Expand Down

0 comments on commit e844636

Please sign in to comment.