Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dianfu committed Dec 11, 2018
1 parent 0dfe579 commit 712c88a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@ import java.util.concurrent.ConcurrentLinkedQueue
import org.apache.flink.api.common.time.Time
import org.apache.flink.api.scala._
import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend
import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.TableEnvironment
import org.apache.flink.table.api.dataview.{DataView, MapView}
import org.apache.flink.table.api.dataview.MapView
import org.apache.flink.table.dataview.StateMapView
import org.apache.flink.table.runtime.aggregate.{GeneratedAggregations, GroupAggProcessFunction}
import org.apache.flink.table.runtime.aggregate.GroupAggProcessFunction
import org.apache.flink.table.runtime.harness.HarnessTestBase.TestStreamQueryConfig
import org.apache.flink.table.runtime.types.CRow
import org.apache.flink.types.Row
Expand Down Expand Up @@ -70,7 +69,11 @@ class AggFunctionHarnessTest extends HarnessTestBase {
testHarness.open()

val operator = getOperator(testHarness)
val state = getState(operator, "acc0_map_dataview").asInstanceOf[MapView[JInt, JInt]]
val state = getState(
operator,
"function",
classOf[GroupAggProcessFunction],
"acc0_map_dataview").asInstanceOf[MapView[JInt, JInt]]
assertTrue(state.isInstanceOf[StateMapView[_, _]])
assertTrue(operator.getKeyedStateBackend.isInstanceOf[RocksDBKeyedStateBackend[_]])

Expand Down Expand Up @@ -103,17 +106,4 @@ class AggFunctionHarnessTest extends HarnessTestBase {

testHarness.close()
}

private def getState(
operator: AbstractUdfStreamOperator[_, _],
stateFieldName: String): DataView = {
val function = classOf[GroupAggProcessFunction].getDeclaredField("function")
function.setAccessible(true)
val generatedAggregation =
function.get(operator.getUserFunction).asInstanceOf[GeneratedAggregations]
val cls = generatedAggregation.getClass
val stateField = cls.getDeclaredField(stateFieldName)
stateField.setAccessible(true)
stateField.get(generatedAggregation).asInstanceOf[DataView]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ import org.apache.flink.streaming.api.transformations._
import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, OneInputStreamOperatorTestHarness, TestHarnessUtil}
import org.apache.flink.table.api.dataview.DataView
import org.apache.flink.table.api.{StreamQueryConfig, Types}
import org.apache.flink.table.codegen.GeneratedAggregationsFunction
import org.apache.flink.table.functions.aggfunctions.{CountAggFunction, IntSumWithRetractAggFunction, LongMaxWithRetractAggFunction, LongMinWithRetractAggFunction}
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.getAccumulatorTypeOfAggregateFunction
import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
import org.apache.flink.table.runtime.aggregate.GeneratedAggregations
import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, RowResultSortComparatorWithWatermarks}
import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase
Expand Down Expand Up @@ -85,7 +87,7 @@ class HarnessTestBase extends StreamingWithStateTestBase {
new RowTypeInfo(distinctCountAggregates.map(getAccumulatorTypeOfAggregateFunction(_)): _*)

protected val distinctCountDescriptor: String = EncodingUtils.encodeObjectToString(
new MapStateDescriptor("distinctAgg0", new RowTypeInfo(Types.INT), Types.LONG))
new MapStateDescriptor("distinctAgg0", distinctCountAggregationStateType, Types.LONG))

protected val minMaxFuncName = "MinMaxAggregateHelper"
protected val sumFuncName = "SumAggregationHelper"
Expand Down Expand Up @@ -529,19 +531,28 @@ class HarnessTestBase extends StreamingWithStateTestBase {
} else {
extractExpectedTransformation(one.getInput, prefixOperatorName)
}
case two: TwoInputTransformation[_, _, _] =>
if (two.getName.startsWith(prefixOperatorName)) {
two
} else {
extractFromInputs(two.getInput1, two.getInput2)
}
case union: UnionTransformation[_] => extractFromInputs(union.getInputs.toSeq: _*)
case p: PartitionTransformation[_] => extractFromInputs(p.getInput)
case _: SourceTransformation[_] => null
case _ => throw new Exception("This should not happen.")
case _ => throw new UnsupportedOperationException("This should not happen.")
}
}

def getState(
operator: AbstractUdfStreamOperator[_, _],
funcName: String,
funcClass: Class[_],
stateFieldName: String): DataView = {
val function = funcClass.getDeclaredField(funcName)
function.setAccessible(true)
val generatedAggregation =
function.get(operator.getUserFunction).asInstanceOf[GeneratedAggregations]
val cls = generatedAggregation.getClass
val stateField = cls.getDeclaredField(stateFieldName)
stateField.setAccessible(true)
stateField.get(generatedAggregation).asInstanceOf[DataView]
}

def createHarnessTester[IN, OUT, KEY](
operator: OneInputStreamOperator[IN, OUT],
keySelector: KeySelector[IN, KEY],
Expand All @@ -550,12 +561,11 @@ class HarnessTestBase extends StreamingWithStateTestBase {
}

def getOperator(testHarness: OneInputStreamOperatorTestHarness[_, _])
: AbstractUdfStreamOperator[_, _] = {
: AbstractUdfStreamOperator[_, _] = {
val operatorField = classOf[OneInputStreamOperatorTestHarness[_, _]]
.getDeclaredField("oneInputOperator")
operatorField.setAccessible(true)
operatorField.get(testHarness)
.asInstanceOf[AbstractUdfStreamOperator[_, _]]
operatorField.get(testHarness).asInstanceOf[AbstractUdfStreamOperator[_, _]]
}

def verify(expected: JQueue[Object], actual: JQueue[Object]): Unit = {
Expand Down

0 comments on commit 712c88a

Please sign in to comment.