Skip to content

Commit

Permalink
[FLINK-11074] [table][tests] Enable harness tests with RocksdbStateBa…
Browse files Browse the repository at this point in the history
…ckend and add harness tests for CollectAggFunction
  • Loading branch information
dianfu committed Dec 6, 2018
1 parent b3a378a commit 7c7eb80
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,12 @@ class CollectAggFunction[E](valueTypeInfo: TypeInformation[_])
def retract(acc: CollectAccumulator[E], value: E): Unit = {
if (value != null) {
val count = acc.map.get(value)
if (count == 1) {
acc.map.remove(value)
} else {
acc.map.put(value, count - 1)
if (count != null) {
if (count == 1) {
acc.map.remove(value)
} else {
acc.map.put(value, count - 1)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.runtime.harness

import java.lang.{Integer => JInt, Long => JLong}
import java.util.concurrent.ConcurrentLinkedQueue

import org.apache.flink.api.common.time.Time
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.streaming.api.operators.{AbstractUdfStreamOperator, LegacyKeyedProcessOperator}
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
import org.apache.flink.table.api.dataview.{DataView, MapView}
import org.apache.flink.table.runtime.aggregate.{GeneratedAggregations, GroupAggProcessFunction}
import org.apache.flink.table.runtime.harness.HarnessTestBase.{TestStreamQueryConfig, TupleRowKeySelector}
import org.apache.flink.table.runtime.types.CRow
import org.junit.Test

import scala.collection.JavaConverters._

class AggFunctionHarnessTest extends HarnessTestBase {
protected var queryConfig =
new TestStreamQueryConfig(Time.seconds(0), Time.seconds(0))

@Test
def testCollectAggregate(): Unit = {
val processFunction = new LegacyKeyedProcessOperator[String, CRow, CRow](
new GroupAggProcessFunction(
genCollectAggFunction,
collectAggregationStateType,
false,
queryConfig))

val testHarness = createHarnessTester(
processFunction,
new TupleRowKeySelector[String](2),
BasicTypeInfo.STRING_TYPE_INFO)
testHarness.setStateBackend(getStateBackend)

testHarness.open()

val state = getState(processFunction, "mapView").asInstanceOf[MapView[JInt, JInt]]

val expectedOutput = new ConcurrentLinkedQueue[Object]()

testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1: JInt, "aaa"), 1))
expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 1).asJava), 1))
testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1: JInt, "bbb"), 1))
expectedOutput.add(new StreamRecord(CRow("bbb", Map(1 -> 1).asJava), 1))

testHarness.processElement(new StreamRecord(CRow(3L: JLong, 1: JInt, "aaa"), 1))
expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 2).asJava), 1))
testHarness.processElement(new StreamRecord(CRow(4L: JLong, 2: JInt, "aaa"), 1))
expectedOutput.add(new StreamRecord(CRow("aaa", Map(1 -> 2, 2 -> 1).asJava), 1))

// remove some state: state may be cleaned up by the state backend if not accessed more than ttl
processFunction.setCurrentKey("aaa")
state.remove(2)

// retract after state has been cleaned up
testHarness.processElement(new StreamRecord(CRow(false, 5L: JLong, 2: JInt, "aaa"), 1))

val result = testHarness.getOutput

verify(expectedOutput, result)

testHarness.close()
}

private def getState(
operator: AbstractUdfStreamOperator[_, _],
stateFieldName: String): DataView = {
val field = classOf[GroupAggProcessFunction].getDeclaredField("function")
field.setAccessible(true)
val generatedAggregation =
field.get(operator.getUserFunction).asInstanceOf[GeneratedAggregations]
generatedAggregation.getClass.getDeclaredField(stateFieldName)
.get(generatedAggregation).asInstanceOf[DataView]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,23 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{LONG_TYPE_INFO, STRIN
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.contrib.streaming.state.RocksDBStateBackend
import org.apache.flink.runtime.state.StateBackend
import org.apache.flink.streaming.api.operators.OneInputStreamOperator
import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, TestHarnessUtil}
import org.apache.flink.table.api.{StreamQueryConfig, Types}
import org.apache.flink.table.codegen.GeneratedAggregationsFunction
import org.apache.flink.table.functions.aggfunctions.{CountAggFunction, IntSumWithRetractAggFunction, LongMaxWithRetractAggFunction, LongMinWithRetractAggFunction}
import org.apache.flink.table.dataview.StateMapView
import org.apache.flink.table.functions.aggfunctions._
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.getAccumulatorTypeOfAggregateFunction
import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, RowResultSortComparatorWithWatermarks}
import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
import org.apache.flink.table.utils.EncodingUtils
import org.junit.Rule
import org.junit.rules.ExpectedException
import org.junit.rules.{ExpectedException, TemporaryFolder}

class HarnessTestBase {
// used for accurate exception information checking.
Expand All @@ -47,6 +50,19 @@ class HarnessTestBase {
@Rule
def thrown = expectedException

val _tempFolder = new TemporaryFolder

@Rule
def tempFolder: TemporaryFolder = _tempFolder

def getStateBackend: StateBackend = {
val dbPath = tempFolder.newFolder().getAbsolutePath
val checkpointPath = tempFolder.newFolder().toURI.toString
val backend = new RocksDBStateBackend(checkpointPath)
backend.setDbStoragePath(dbPath)
backend
}

val longMinWithRetractAggFunction: String =
EncodingUtils.encodeObjectToString(new LongMinWithRetractAggFunction)

Expand All @@ -56,6 +72,9 @@ class HarnessTestBase {
val intSumWithRetractAggFunction: String =
EncodingUtils.encodeObjectToString(new IntSumWithRetractAggFunction)

val intCollectAggFunction: String =
EncodingUtils.encodeObjectToString(new CollectAggFunction(Types.INT))

val distinctCountAggFunction: String =
EncodingUtils.encodeObjectToString(new CountAggFunction())

Expand All @@ -74,6 +93,9 @@ class HarnessTestBase {
protected val sumAggregates: Array[AggregateFunction[_, _]] =
Array(new IntSumWithRetractAggFunction).asInstanceOf[Array[AggregateFunction[_, _]]]

protected val collectAggregates: Array[AggregateFunction[_, _]] =
Array(new CollectAggFunction(Types.INT)).asInstanceOf[Array[AggregateFunction[_, _]]]

protected val distinctCountAggregates: Array[AggregateFunction[_, _]] =
Array(new CountAggFunction).asInstanceOf[Array[AggregateFunction[_, _]]]

Expand All @@ -83,15 +105,19 @@ class HarnessTestBase {
protected val sumAggregationStateType: RowTypeInfo =
new RowTypeInfo(sumAggregates.map(getAccumulatorTypeOfAggregateFunction(_)): _*)

protected val collectAggregationStateType: RowTypeInfo =
new RowTypeInfo(collectAggregates.map(getAccumulatorTypeOfAggregateFunction(_)): _*)

protected val distinctCountAggregationStateType: RowTypeInfo =
new RowTypeInfo(distinctCountAggregates.map(getAccumulatorTypeOfAggregateFunction(_)): _*)

protected val distinctCountDescriptor: String = EncodingUtils.encodeObjectToString(
new MapStateDescriptor("distinctAgg0", distinctCountAggregationStateType, Types.LONG))
new MapStateDescriptor("distinctAgg0", new RowTypeInfo(Types.INT), Types.LONG))

protected val minMaxFuncName = "MinMaxAggregateHelper"
protected val sumFuncName = "SumAggregationHelper"
protected val distinctCountFuncName = "DistinctCountAggregationHelper"
protected val collectFuncName = "CollectAggregationHelper"

val minMaxCode: String =
s"""
Expand Down Expand Up @@ -326,6 +352,105 @@ class HarnessTestBase {
|}
|""".stripMargin

val collectAggCode: String =
s"""
|public final class $collectFuncName
| extends org.apache.flink.table.runtime.aggregate.GeneratedAggregations {
|
| transient org.apache.flink.table.functions.aggfunctions.CollectAggFunction collect = null;
| public transient org.apache.flink.table.dataview.StateMapView mapView = null;
| private java.lang.reflect.Field mapViewField = null;
|
| public $collectFuncName() throws Exception {
| collect = (org.apache.flink.table.functions.aggfunctions.CollectAggFunction)
| ${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject(
| "$intCollectAggFunction",
| ${classOf[UserDefinedFunction].getCanonicalName}.class);
|
| mapViewField = org.apache.flink.table.functions.aggfunctions.CollectAccumulator.class
| .getDeclaredField("map");
| mapViewField.setAccessible(true);
| }
|
| public final void setAggregationResults(
| org.apache.flink.types.Row accs,
| org.apache.flink.types.Row output) throws Exception {
|
| org.apache.flink.table.functions.AggregateFunction baseClass0 =
| (org.apache.flink.table.functions.AggregateFunction) collect;
|
| org.apache.flink.table.functions.aggfunctions.CollectAccumulator acc =
| (org.apache.flink.table.functions.aggfunctions.CollectAccumulator) accs.getField(0);
| mapViewField.set(acc, mapView);
|
| output.setField(1, baseClass0.getValue(acc));
| }
|
| public final void accumulate(
| org.apache.flink.types.Row accs,
| org.apache.flink.types.Row input) throws Exception {
| org.apache.flink.table.functions.aggfunctions.CollectAccumulator acc =
| (org.apache.flink.table.functions.aggfunctions.CollectAccumulator) accs.getField(0);
| mapViewField.set(acc, mapView);
|
| collect.accumulate(acc, (java.lang.Integer) input.getField(1));
| }
|
| public final void retract(
| org.apache.flink.types.Row accs,
| org.apache.flink.types.Row input) throws Exception {
| org.apache.flink.table.functions.aggfunctions.CollectAccumulator acc =
| (org.apache.flink.table.functions.aggfunctions.CollectAccumulator) accs.getField(0);
| mapViewField.set(acc, mapView);
|
| collect.retract(acc, (java.lang.Integer) input.getField(1));
| }
|
| public final org.apache.flink.types.Row createAccumulators() throws Exception {
| org.apache.flink.types.Row accs = new org.apache.flink.types.Row(1);
| accs.setField(0, collect.createAccumulator());
| return accs;
| }
|
| public final void setForwardedFields(
| org.apache.flink.types.Row input,
| org.apache.flink.types.Row output) {
| output.setField(0, input.getField(2));
| }
|
| public final org.apache.flink.types.Row createOutputRow() {
| return new org.apache.flink.types.Row(2);
| }
|
| public final org.apache.flink.types.Row mergeAccumulatorsPair(
| org.apache.flink.types.Row a, org.apache.flink.types.Row b) {
| return a;
| }
|
| public final void resetAccumulator(
| org.apache.flink.types.Row accs) {
| }
|
| public void open(org.apache.flink.api.common.functions.RuntimeContext ctx) {
| mapView = new org.apache.flink.table.dataview.StateMapView(
| ctx.getMapState(
| new org.apache.flink.api.common.state.MapStateDescriptor(
| "collect",
| org.apache.flink.api.common.typeinfo.BasicTypeInfo.INT_TYPE_INFO,
| org.apache.flink.api.common.typeinfo.BasicTypeInfo.INT_TYPE_INFO
| )
| )
| );
| }
|
| public void cleanup() {
| }
|
| public void close() {
| }
|}
|""".stripMargin

val distinctCountAggCode: String =
s"""
|public final class $distinctCountFuncName
Expand Down Expand Up @@ -487,9 +612,10 @@ class HarnessTestBase {

protected val genMinMaxAggFunction = GeneratedAggregationsFunction(minMaxFuncName, minMaxCode)
protected val genSumAggFunction = GeneratedAggregationsFunction(sumFuncName, sumAggCode)
protected val genCollectAggFunction = GeneratedAggregationsFunction(
collectFuncName, collectAggCode)
protected val genDistinctCountAggFunction = GeneratedAggregationsFunction(
distinctCountFuncName,
distinctCountAggCode)
distinctCountFuncName, distinctCountAggCode)

def createHarnessTester[IN, OUT, KEY](
operator: OneInputStreamOperator[IN, OUT],
Expand Down

0 comments on commit 7c7eb80

Please sign in to comment.