-
Notifications
You must be signed in to change notification settings - Fork 13k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FLINK-11074] [table][tests] Enable harness tests with RocksdbStateBackend and add harness tests for CollectAggFunction #7253
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.flink.table.runtime.harness | ||
|
||
import java.lang.{Integer => JInt} | ||
import java.util.concurrent.ConcurrentLinkedQueue | ||
|
||
import org.apache.flink.api.common.time.Time | ||
import org.apache.flink.api.scala._ | ||
import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend | ||
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment | ||
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord | ||
import org.apache.flink.table.api.scala._ | ||
import org.apache.flink.table.api.TableEnvironment | ||
import org.apache.flink.table.api.dataview.MapView | ||
import org.apache.flink.table.dataview.StateMapView | ||
import org.apache.flink.table.runtime.aggregate.GroupAggProcessFunction | ||
import org.apache.flink.table.runtime.harness.HarnessTestBase.TestStreamQueryConfig | ||
import org.apache.flink.table.runtime.types.CRow | ||
import org.apache.flink.types.Row | ||
import org.junit.Assert.assertTrue | ||
import org.junit.Test | ||
|
||
import scala.collection.JavaConverters._ | ||
import scala.collection.mutable | ||
|
||
class AggFunctionHarnessTest extends HarnessTestBase { | ||
private val queryConfig = new TestStreamQueryConfig(Time.seconds(0), Time.seconds(0)) | ||
|
||
@Test | ||
def testCollectAggregate(): Unit = { | ||
val env = StreamExecutionEnvironment.getExecutionEnvironment | ||
val tEnv = TableEnvironment.getTableEnvironment(env) | ||
|
||
val data = new mutable.MutableList[(JInt, String)] | ||
val t = env.fromCollection(data).toTable(tEnv, 'a, 'b) | ||
tEnv.registerTable("T", t) | ||
val sqlQuery = tEnv.sqlQuery( | ||
s""" | ||
|SELECT | ||
| b, collect(a) | ||
|FROM ( | ||
| SELECT a, b | ||
| FROM T | ||
| GROUP BY a, b | ||
|) GROUP BY b | ||
|""".stripMargin) | ||
|
||
val testHarness = createHarnessTester[String, CRow, CRow]( | ||
sqlQuery.toRetractStream[Row](queryConfig), "groupBy") | ||
|
||
testHarness.setStateBackend(getStateBackend) | ||
testHarness.open() | ||
|
||
val operator = getOperator(testHarness) | ||
val state = getState( | ||
operator, | ||
"function", | ||
classOf[GroupAggProcessFunction], | ||
"acc0_map_dataview").asInstanceOf[MapView[JInt, JInt]] | ||
assertTrue(state.isInstanceOf[StateMapView[_, _]]) | ||
assertTrue(operator.getKeyedStateBackend.isInstanceOf[RocksDBKeyedStateBackend[_]]) | ||
|
||
val expectedOutput = new ConcurrentLinkedQueue[Object]() | ||
|
||
testHarness.processElement(new StreamRecord(CRow(1: JInt, "aaa"), 1)) | ||
expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 1).asJava), 1)) | ||
|
||
testHarness.processElement(new StreamRecord(CRow(1: JInt, "bbb"), 1)) | ||
expectedOutput.add(new StreamRecord(CRow("bbb", Map(1 -> 1).asJava), 1)) | ||
|
||
testHarness.processElement(new StreamRecord(CRow(1: JInt, "aaa"), 1)) | ||
expectedOutput.add(new StreamRecord(CRow(false, "aaa", Map(1 -> 1).asJava), 1)) | ||
expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 2).asJava), 1)) | ||
|
||
testHarness.processElement(new StreamRecord(CRow(2: JInt, "aaa"), 1)) | ||
expectedOutput.add(new StreamRecord(CRow(false, "aaa", Map(1 -> 2).asJava), 1)) | ||
expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 2, 2 -> 1).asJava), 1)) | ||
|
||
// remove some state: state may be cleaned up by the state backend | ||
// if not accessed beyond ttl time | ||
operator.setCurrentKey(Row.of("aaa")) | ||
state.remove(2) | ||
|
||
// retract after state has been cleaned up | ||
testHarness.processElement(new StreamRecord(CRow(false, 2: JInt, "aaa"), 1)) | ||
|
||
val result = testHarness.getOutput | ||
|
||
verify(expectedOutput, result) | ||
|
||
testHarness.close() | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,27 +25,27 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{LONG_TYPE_INFO, STRIN | |
import org.apache.flink.api.common.typeinfo.TypeInformation | ||
import org.apache.flink.api.java.functions.KeySelector | ||
import org.apache.flink.api.java.typeutils.RowTypeInfo | ||
import org.apache.flink.streaming.api.operators.OneInputStreamOperator | ||
import org.apache.flink.streaming.api.operators.{AbstractUdfStreamOperator, OneInputStreamOperator} | ||
import org.apache.flink.streaming.api.scala.DataStream | ||
import org.apache.flink.streaming.api.transformations._ | ||
import org.apache.flink.streaming.api.watermark.Watermark | ||
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord | ||
import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, TestHarnessUtil} | ||
import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, OneInputStreamOperatorTestHarness, TestHarnessUtil} | ||
import org.apache.flink.table.api.dataview.DataView | ||
import org.apache.flink.table.api.{StreamQueryConfig, Types} | ||
import org.apache.flink.table.codegen.GeneratedAggregationsFunction | ||
import org.apache.flink.table.functions.aggfunctions.{CountAggFunction, IntSumWithRetractAggFunction, LongMaxWithRetractAggFunction, LongMinWithRetractAggFunction} | ||
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.getAccumulatorTypeOfAggregateFunction | ||
import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction} | ||
import org.apache.flink.table.runtime.aggregate.GeneratedAggregations | ||
import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, RowResultSortComparatorWithWatermarks} | ||
import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} | ||
import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase | ||
import org.apache.flink.table.utils.EncodingUtils | ||
import org.junit.Rule | ||
import org.junit.rules.ExpectedException | ||
|
||
class HarnessTestBase { | ||
// used for accurate exception information checking. | ||
val expectedException = ExpectedException.none() | ||
import _root_.scala.collection.JavaConversions._ | ||
|
||
@Rule | ||
def thrown = expectedException | ||
class HarnessTestBase extends StreamingWithStateTestBase { | ||
|
||
val longMinWithRetractAggFunction: String = | ||
EncodingUtils.encodeObjectToString(new LongMinWithRetractAggFunction) | ||
|
@@ -491,13 +491,83 @@ class HarnessTestBase { | |
distinctCountFuncName, | ||
distinctCountAggCode) | ||
|
||
def createHarnessTester[KEY, IN, OUT]( | ||
dataStream: DataStream[_], | ||
prefixOperatorName: String) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How a bout add
And we using this method as follows: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. createHarnessTester will be used not only in agg related tests but also in other harness tests, such as stream join tests, temporal join tests, sort tests, etc. So field aggFieldAlias seems a little wired from my point of view. What about adding it when we really need it? At that time we may have a better idea on how such a field will look like. Thoughts? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, If this tester will using in join test, we should add TwoInputTransformation check in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, you're right. Will do that when updating the join related harness tests. |
||
: KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT] = { | ||
|
||
val transformation = extractExpectedTransformation( | ||
dataStream.javaStream.getTransformation, | ||
prefixOperatorName).asInstanceOf[OneInputTransformation[_, _]] | ||
if (transformation == null) { | ||
throw new Exception("Can not find the expected transformation") | ||
} | ||
|
||
val processOperator = transformation.getOperator.asInstanceOf[OneInputStreamOperator[IN, OUT]] | ||
val keySelector = transformation.getStateKeySelector.asInstanceOf[KeySelector[IN, KEY]] | ||
val keyType = transformation.getStateKeyType.asInstanceOf[TypeInformation[KEY]] | ||
|
||
createHarnessTester(processOperator, keySelector, keyType) | ||
.asInstanceOf[KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT]] | ||
} | ||
|
||
private def extractExpectedTransformation( | ||
transformation: StreamTransformation[_], | ||
prefixOperatorName: String): StreamTransformation[_] = { | ||
def extractFromInputs(inputs: StreamTransformation[_]*): StreamTransformation[_] = { | ||
for (input <- inputs) { | ||
val t = extractExpectedTransformation(input, prefixOperatorName) | ||
if (t != null) { | ||
return t | ||
} | ||
} | ||
null | ||
} | ||
|
||
transformation match { | ||
case one: OneInputTransformation[_, _] => | ||
if (one.getName.startsWith(prefixOperatorName)) { | ||
one | ||
} else { | ||
extractExpectedTransformation(one.getInput, prefixOperatorName) | ||
} | ||
case union: UnionTransformation[_] => extractFromInputs(union.getInputs.toSeq: _*) | ||
case p: PartitionTransformation[_] => extractFromInputs(p.getInput) | ||
case _: SourceTransformation[_] => null | ||
case _ => throw new UnsupportedOperationException("This should not happen.") | ||
} | ||
} | ||
|
||
def getState( | ||
operator: AbstractUdfStreamOperator[_, _], | ||
funcName: String, | ||
funcClass: Class[_], | ||
stateFieldName: String): DataView = { | ||
val function = funcClass.getDeclaredField(funcName) | ||
function.setAccessible(true) | ||
val generatedAggregation = | ||
function.get(operator.getUserFunction).asInstanceOf[GeneratedAggregations] | ||
val cls = generatedAggregation.getClass | ||
val stateField = cls.getDeclaredField(stateFieldName) | ||
stateField.setAccessible(true) | ||
stateField.get(generatedAggregation).asInstanceOf[DataView] | ||
} | ||
|
||
def createHarnessTester[IN, OUT, KEY]( | ||
operator: OneInputStreamOperator[IN, OUT], | ||
keySelector: KeySelector[IN, KEY], | ||
keyType: TypeInformation[KEY]): KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT] = { | ||
new KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT](operator, keySelector, keyType) | ||
} | ||
|
||
def getOperator(testHarness: OneInputStreamOperatorTestHarness[_, _]) | ||
: AbstractUdfStreamOperator[_, _] = { | ||
val operatorField = classOf[OneInputStreamOperatorTestHarness[_, _]] | ||
.getDeclaredField("oneInputOperator") | ||
operatorField.setAccessible(true) | ||
operatorField.get(testHarness).asInstanceOf[AbstractUdfStreamOperator[_, _]] | ||
} | ||
|
||
def verify(expected: JQueue[Object], actual: JQueue[Object]): Unit = { | ||
verify(expected, actual, new RowResultSortComparator) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add something like:
this will catch some of the weird serialization/deserialization problem as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. What about adding this kind of tests for the operator tests, such as GroupAggregateHarnessTest, JoinHarnessTest, etc as I think it's more useful for operator test.