Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-11074] [table][tests] Enable harness tests with RocksdbStateBackend and add harness tests for CollectAggFunction #7253

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,12 @@ class CollectAggFunction[E](valueTypeInfo: TypeInformation[_])
def retract(acc: CollectAccumulator[E], value: E): Unit = {
if (value != null) {
val count = acc.map.get(value)
if (count == 1) {
acc.map.remove(value)
} else {
acc.map.put(value, count - 1)
if (count != null) {
if (count == 1) {
acc.map.remove(value)
} else {
acc.map.put(value, count - 1)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.runtime.harness

import java.lang.{Integer => JInt}
import java.util.concurrent.ConcurrentLinkedQueue

import org.apache.flink.api.common.time.Time
import org.apache.flink.api.scala._
import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.TableEnvironment
import org.apache.flink.table.api.dataview.MapView
import org.apache.flink.table.dataview.StateMapView
import org.apache.flink.table.runtime.aggregate.GroupAggProcessFunction
import org.apache.flink.table.runtime.harness.HarnessTestBase.TestStreamQueryConfig
import org.apache.flink.table.runtime.types.CRow
import org.apache.flink.types.Row
import org.junit.Assert.assertTrue
import org.junit.Test

import scala.collection.JavaConverters._
import scala.collection.mutable

class AggFunctionHarnessTest extends HarnessTestBase {
private val queryConfig = new TestStreamQueryConfig(Time.seconds(0), Time.seconds(0))

@Test
def testCollectAggregate(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)

val data = new mutable.MutableList[(JInt, String)]
val t = env.fromCollection(data).toTable(tEnv, 'a, 'b)
tEnv.registerTable("T", t)
val sqlQuery = tEnv.sqlQuery(
s"""
|SELECT
| b, collect(a)
|FROM (
| SELECT a, b
| FROM T
| GROUP BY a, b
|) GROUP BY b
|""".stripMargin)

val testHarness = createHarnessTester[String, CRow, CRow](
sqlQuery.toRetractStream[Row](queryConfig), "groupBy")

testHarness.setStateBackend(getStateBackend)
testHarness.open()

val operator = getOperator(testHarness)
val state = getState(
operator,
"function",
classOf[GroupAggProcessFunction],
"acc0_map_dataview").asInstanceOf[MapView[JInt, JInt]]
assertTrue(state.isInstanceOf[StateMapView[_, _]])
assertTrue(operator.getKeyedStateBackend.isInstanceOf[RocksDBKeyedStateBackend[_]])

val expectedOutput = new ConcurrentLinkedQueue[Object]()

testHarness.processElement(new StreamRecord(CRow(1: JInt, "aaa"), 1))
expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 1).asJava), 1))

testHarness.processElement(new StreamRecord(CRow(1: JInt, "bbb"), 1))
expectedOutput.add(new StreamRecord(CRow("bbb", Map(1 -> 1).asJava), 1))

Copy link
Contributor

@walterddr walterddr Dec 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add something like:

// do a snapshot & close
State snapshot = testHarness.snapshot(0L, 0L);
testHarness.close();
// reopen and restore
testHarness = createTestHarness(operator);
testHarness.setup();
testHarness.initializeState(snapshot);
testHarness.open();

this will catch some of the weird serialization/deserialization problem as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. What about adding this kind of tests for the operator tests, such as GroupAggregateHarnessTest, JoinHarnessTest, etc as I think it's more useful for operator test.

testHarness.processElement(new StreamRecord(CRow(1: JInt, "aaa"), 1))
expectedOutput.add(new StreamRecord(CRow(false, "aaa", Map(1 -> 1).asJava), 1))
expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 2).asJava), 1))

testHarness.processElement(new StreamRecord(CRow(2: JInt, "aaa"), 1))
expectedOutput.add(new StreamRecord(CRow(false, "aaa", Map(1 -> 2).asJava), 1))
expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 2, 2 -> 1).asJava), 1))

// remove some state: state may be cleaned up by the state backend
// if not accessed beyond ttl time
operator.setCurrentKey(Row.of("aaa"))
state.remove(2)

// retract after state has been cleaned up
testHarness.processElement(new StreamRecord(CRow(false, 2: JInt, "aaa"), 1))

val result = testHarness.getOutput

verify(expectedOutput, result)

testHarness.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,27 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{LONG_TYPE_INFO, STRIN
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.streaming.api.operators.OneInputStreamOperator
import org.apache.flink.streaming.api.operators.{AbstractUdfStreamOperator, OneInputStreamOperator}
import org.apache.flink.streaming.api.scala.DataStream
import org.apache.flink.streaming.api.transformations._
import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, TestHarnessUtil}
import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, OneInputStreamOperatorTestHarness, TestHarnessUtil}
import org.apache.flink.table.api.dataview.DataView
import org.apache.flink.table.api.{StreamQueryConfig, Types}
import org.apache.flink.table.codegen.GeneratedAggregationsFunction
import org.apache.flink.table.functions.aggfunctions.{CountAggFunction, IntSumWithRetractAggFunction, LongMaxWithRetractAggFunction, LongMinWithRetractAggFunction}
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.getAccumulatorTypeOfAggregateFunction
import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
import org.apache.flink.table.runtime.aggregate.GeneratedAggregations
import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, RowResultSortComparatorWithWatermarks}
import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase
import org.apache.flink.table.utils.EncodingUtils
import org.junit.Rule
import org.junit.rules.ExpectedException

class HarnessTestBase {
// used for accurate exception information checking.
val expectedException = ExpectedException.none()
import _root_.scala.collection.JavaConversions._

@Rule
def thrown = expectedException
class HarnessTestBase extends StreamingWithStateTestBase {

val longMinWithRetractAggFunction: String =
EncodingUtils.encodeObjectToString(new LongMinWithRetractAggFunction)
Expand Down Expand Up @@ -491,13 +491,83 @@ class HarnessTestBase {
distinctCountFuncName,
distinctCountAggCode)

def createHarnessTester[KEY, IN, OUT](
dataStream: DataStream[_],
prefixOperatorName: String)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How a bout add aggFieldAlias: String = "", for resolve scenarios where GroupBy is included in multiple UNION clauses. e.g:

      (SELECT b, max(a) as maxA FROM T GROUP BY b)
       UNION (
         SELECT b, min(a) as minA FROM (
          SELECT a, b FROM T GROUP BY a, b
         ) GROUP BY b
       )

And we using this method as follows: createHarnessTester(xx, "groupBy", "minA")
I didn't find a case where I had to test it in this way, it was just an enhanced suggestion.
What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

createHarnessTester will be used not only in agg related tests but also in other harness tests, such as stream join tests, temporal join tests, sort tests, etc. So field aggFieldAlias seems a little wired from my point of view. What about adding it when we really need it? At that time we may have a better idea on how such a field will look like. Thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, If this tester will using in join test, we should add TwoInputTransformation check in the extractExpectedTransformation logic, then we also need add a xxName(may be not named aggFieldAlias), e.g.: In multiple join scenes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're right. Will do that when updating the join related harness tests.

: KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT] = {

val transformation = extractExpectedTransformation(
dataStream.javaStream.getTransformation,
prefixOperatorName).asInstanceOf[OneInputTransformation[_, _]]
if (transformation == null) {
throw new Exception("Can not find the expected transformation")
}

val processOperator = transformation.getOperator.asInstanceOf[OneInputStreamOperator[IN, OUT]]
val keySelector = transformation.getStateKeySelector.asInstanceOf[KeySelector[IN, KEY]]
val keyType = transformation.getStateKeyType.asInstanceOf[TypeInformation[KEY]]

createHarnessTester(processOperator, keySelector, keyType)
.asInstanceOf[KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT]]
}

private def extractExpectedTransformation(
transformation: StreamTransformation[_],
prefixOperatorName: String): StreamTransformation[_] = {
def extractFromInputs(inputs: StreamTransformation[_]*): StreamTransformation[_] = {
for (input <- inputs) {
val t = extractExpectedTransformation(input, prefixOperatorName)
if (t != null) {
return t
}
}
null
}

transformation match {
case one: OneInputTransformation[_, _] =>
if (one.getName.startsWith(prefixOperatorName)) {
one
} else {
extractExpectedTransformation(one.getInput, prefixOperatorName)
}
case union: UnionTransformation[_] => extractFromInputs(union.getInputs.toSeq: _*)
case p: PartitionTransformation[_] => extractFromInputs(p.getInput)
case _: SourceTransformation[_] => null
case _ => throw new UnsupportedOperationException("This should not happen.")
}
}

def getState(
operator: AbstractUdfStreamOperator[_, _],
funcName: String,
funcClass: Class[_],
stateFieldName: String): DataView = {
val function = funcClass.getDeclaredField(funcName)
function.setAccessible(true)
val generatedAggregation =
function.get(operator.getUserFunction).asInstanceOf[GeneratedAggregations]
val cls = generatedAggregation.getClass
val stateField = cls.getDeclaredField(stateFieldName)
stateField.setAccessible(true)
stateField.get(generatedAggregation).asInstanceOf[DataView]
}

def createHarnessTester[IN, OUT, KEY](
operator: OneInputStreamOperator[IN, OUT],
keySelector: KeySelector[IN, KEY],
keyType: TypeInformation[KEY]): KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT] = {
new KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT](operator, keySelector, keyType)
}

def getOperator(testHarness: OneInputStreamOperatorTestHarness[_, _])
: AbstractUdfStreamOperator[_, _] = {
val operatorField = classOf[OneInputStreamOperatorTestHarness[_, _]]
.getDeclaredField("oneInputOperator")
operatorField.setAccessible(true)
operatorField.get(testHarness).asInstanceOf[AbstractUdfStreamOperator[_, _]]
}

def verify(expected: JQueue[Object], actual: JQueue[Object]): Unit = {
verify(expected, actual, new RowResultSortComparator)
}
Expand Down