From d60ef3e67736c6cf366a486d0c57b13874c381bd Mon Sep 17 00:00:00 2001 From: shaoxuan-wang Date: Wed, 8 Mar 2017 23:10:28 +0800 Subject: [PATCH] [FLINK-5984] [table] add resetAccumulator method for AggregateFunction --- .../table/functions/AggregateFunction.scala | 7 +++++++ .../aggfunctions/AvgAggFunction.scala | 20 +++++++++++++++++++ .../aggfunctions/CountAggFunction.scala | 4 ++++ .../aggfunctions/MaxAggFunction.scala | 5 +++++ .../MaxAggFunctionWithRetract.scala | 5 +++++ .../aggfunctions/MinAggFunction.scala | 5 +++++ .../MinAggFunctionWithRetract.scala | 5 +++++ .../aggfunctions/SumAggFunction.scala | 14 +++++++++++-- .../SumWithRetractAggFunction.scala | 14 +++++++++++-- .../AggregateReduceCombineFunction.scala | 3 +-- .../AggregateReduceGroupFunction.scala | 3 +-- ...etSessionWindowAggregatePreProcessor.scala | 6 ++---- ...onWindowAggregateReduceGroupFunction.scala | 6 ++---- ...bleCountWindowAggReduceGroupFunction.scala | 3 +-- ...leTimeWindowAggReduceCombineFunction.scala | 3 +-- ...mbleTimeWindowAggReduceGroupFunction.scala | 3 +-- .../aggfunctions/AggFunctionTestBase.scala | 16 +++++++++++++++ 17 files changed, 100 insertions(+), 22 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala index 967d2ea1c2e2a..7d01cb76e638d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala @@ -81,6 +81,13 @@ abstract class AggregateFunction[T] extends UserDefinedFunction { */ def merge(accumulators: JList[Accumulator]): Accumulator + /** + * Reset the Accumulator for this [[AggregateFunction]] + * + * @param accumulator the accumulator which needs to be reset + */ + def resetAccumulator(accumulator: Accumulator): Unit + /** * Returns the [[TypeInformation]] of the accumulator. * This function is optional and can be implemented if the accumulator type cannot automatically diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala index dad4d7f12c13e..996bef0598a2b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala @@ -81,6 +81,11 @@ abstract class IntegralAvgAggFunction[T] extends AggregateFunction[T] { ret } + override def resetAccumulator(accumulator: Accumulator): Unit = { + accumulator.asInstanceOf[IntegralAvgAccumulator].f0 = 0L + accumulator.asInstanceOf[IntegralAvgAccumulator].f1 = 0L + } + override def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( new IntegralAvgAccumulator().getClass, @@ -176,6 +181,11 @@ abstract class BigIntegralAvgAggFunction[T] extends AggregateFunction[T] { ret } + override def resetAccumulator(accumulator: Accumulator): Unit = { + accumulator.asInstanceOf[BigIntegralAvgAccumulator].f0 = BigInteger.ZERO + accumulator.asInstanceOf[BigIntegralAvgAccumulator].f1 = 0 + } + override def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( new BigIntegralAvgAccumulator().getClass, @@ -257,6 +267,11 @@ abstract class FloatingAvgAggFunction[T] extends AggregateFunction[T] { ret } + override def resetAccumulator(accumulator: Accumulator): Unit = { + accumulator.asInstanceOf[FloatingAvgAccumulator].f0 = 0 + accumulator.asInstanceOf[FloatingAvgAccumulator].f1 = 0L + } + override def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( new FloatingAvgAccumulator().getClass, @@ -343,6 +358,11 @@ class DecimalAvgAggFunction extends AggregateFunction[BigDecimal] { ret } + override def resetAccumulator(accumulator: Accumulator): Unit = { + accumulator.asInstanceOf[DecimalAvgAccumulator].f0 = BigDecimal.ZERO + accumulator.asInstanceOf[DecimalAvgAccumulator].f1 = 0L + } + override def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( new DecimalAvgAccumulator().getClass, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala index 8191a2f23b6a6..231337a117a51 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala @@ -64,6 +64,10 @@ class CountAggFunction extends AggregateFunction[Long] { new CountAccumulator } + override def resetAccumulator(accumulator: Accumulator): Unit = { + accumulator.asInstanceOf[CountAccumulator].f0 = 0L + } + override def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo((new CountAccumulator).getClass, BasicTypeInfo.LONG_TYPE_INFO) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala index 1a0a80b692b03..55e3e5f02a048 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala @@ -75,6 +75,11 @@ abstract class MaxAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun ret } + override def resetAccumulator(accumulator: Accumulator): Unit = { + accumulator.asInstanceOf[MaxAccumulator[T]].f0 = getInitValue + accumulator.asInstanceOf[MaxAccumulator[T]].f1 = false + } + override def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( new MaxAccumulator[T].getClass, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala index 3d83121c9f3cd..eb6e7dca7e6b5 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala @@ -132,6 +132,11 @@ abstract class MaxWithRetractAggFunction[T](implicit ord: Ordering[T]) ret } + override def resetAccumulator(accumulator: Accumulator): Unit = { + accumulator.asInstanceOf[MaxWithRetractAccumulator[T]].f0 = getInitValue + accumulator.asInstanceOf[MaxWithRetractAccumulator[T]].f1.clear() + } + override def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( new MaxWithRetractAccumulator[T].getClass, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala index 58a3c24496245..647388adcf098 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala @@ -75,6 +75,11 @@ abstract class MinAggFunction[T](implicit ord: Ordering[T]) extends AggregateFun ret } + override def resetAccumulator(accumulator: Accumulator): Unit = { + accumulator.asInstanceOf[MinAccumulator[T]].f0 = getInitValue + accumulator.asInstanceOf[MinAccumulator[T]].f1 = false + } + override def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( new MinAccumulator[T].getClass, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala index a08dd256599b6..c9532863f2e6a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala @@ -132,6 +132,11 @@ abstract class MinWithRetractAggFunction[T](implicit ord: Ordering[T]) ret } + override def resetAccumulator(accumulator: Accumulator): Unit = { + accumulator.asInstanceOf[MinWithRetractAccumulator[T]].f0 = getInitValue + accumulator.asInstanceOf[MinWithRetractAccumulator[T]].f1.clear() + } + override def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( new MinWithRetractAccumulator[T].getClass, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala index 6c4aba528588c..8ee986217e230 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala @@ -63,8 +63,8 @@ abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T] { } override def merge(accumulators: JList[Accumulator]): Accumulator = { - val ret = createAccumulator().asInstanceOf[SumAccumulator[T]] - var i: Int = 0 + val ret = accumulators.get(0).asInstanceOf[SumAccumulator[T]] + var i: Int = 1 while (i < accumulators.size()) { val a = accumulators.get(i).asInstanceOf[SumAccumulator[T]] if (a.f1) { @@ -76,6 +76,11 @@ abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T] { ret } + override def resetAccumulator(accumulator: Accumulator): Unit = { + accumulator.asInstanceOf[SumAccumulator[T]].f0 = numeric.zero + accumulator.asInstanceOf[SumAccumulator[T]].f1 = false + } + override def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( (new SumAccumulator).getClass, @@ -174,6 +179,11 @@ class DecimalSumAggFunction extends AggregateFunction[BigDecimal] { ret } + override def resetAccumulator(accumulator: Accumulator): Unit = { + accumulator.asInstanceOf[DecimalSumAccumulator].f0 = BigDecimal.ZERO + accumulator.asInstanceOf[DecimalSumAccumulator].f1 = false + } + override def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( (new DecimalSumAccumulator).getClass, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala index ebcf184039307..928be1154ca9d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala @@ -72,8 +72,8 @@ abstract class SumWithRetractAggFunction[T: Numeric] extends AggregateFunction[T } override def merge(accumulators: JList[Accumulator]): Accumulator = { - val ret = createAccumulator().asInstanceOf[SumWithRetractAccumulator[T]] - var i: Int = 0 + val ret = accumulators.get(0).asInstanceOf[SumWithRetractAccumulator[T]] + var i: Int = 1 while (i < accumulators.size()) { val a = accumulators.get(i).asInstanceOf[SumWithRetractAccumulator[T]] ret.f0 = numeric.plus(ret.f0, a.f0) @@ -83,6 +83,11 @@ abstract class SumWithRetractAggFunction[T: Numeric] extends AggregateFunction[T ret } + override def resetAccumulator(accumulator: Accumulator): Unit = { + accumulator.asInstanceOf[SumWithRetractAccumulator[T]].f0 = numeric.zero + accumulator.asInstanceOf[SumWithRetractAccumulator[T]].f1 = 0L + } + override def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( (new SumWithRetractAccumulator).getClass, @@ -188,6 +193,11 @@ class DecimalSumWithRetractAggFunction extends AggregateFunction[BigDecimal] { ret } + override def resetAccumulator(accumulator: Accumulator): Unit = { + accumulator.asInstanceOf[DecimalSumWithRetractAccumulator].f0 = BigDecimal.ZERO + accumulator.asInstanceOf[DecimalSumWithRetractAccumulator].f1 = 0L + } + override def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( (new DecimalSumWithRetractAccumulator).getClass, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala index 6b95cb8e72718..376518a0736ab 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala @@ -74,8 +74,7 @@ class AggregateReduceCombineFunction( // reset first accumulator in merge list for (i <- aggregates.indices) { - val accumulator = aggregates(i).createAccumulator() - accumulatorList(i).set(0, accumulator) + aggregates(i).resetAccumulator(accumulatorList(i).get(0)) } while (iterator.hasNext) { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala index 2f75cd76c9ae0..cb69a36d14919 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala @@ -87,8 +87,7 @@ class AggregateReduceGroupFunction( // reset first accumulator in merge list for (i <- aggregates.indices) { - val accumulator = aggregates(i).createAccumulator() - accumulatorList(i).set(0, accumulator) + aggregates(i).resetAccumulator(accumulatorList(i).get(0)) } while (iterator.hasNext) { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala index a299c40e91da7..b006a267c5a82 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala @@ -111,8 +111,7 @@ class DataSetSessionWindowAggregatePreProcessor( // reset first accumulator in merge list for (i <- aggregates.indices) { - val accumulator = aggregates(i).createAccumulator() - accumulatorList(i).set(0, accumulator) + aggregates(i).resetAccumulator(accumulatorList(i).get(0)) } val iterator = records.iterator() @@ -130,8 +129,7 @@ class DataSetSessionWindowAggregatePreProcessor( // reset first value of accumulator list for (i <- aggregates.indices) { - val accumulator = aggregates(i).createAccumulator() - accumulatorList(i).set(0, accumulator) + aggregates(i).resetAccumulator(accumulatorList(i).get(0)) } } else { // set group keys to aggregateBuffer. diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala index ebef211efb362..10cbc92d3f3e9 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregateReduceGroupFunction.scala @@ -106,8 +106,7 @@ class DataSetSessionWindowAggregateReduceGroupFunction( // reset first accumulator in merge list for (i <- aggregates.indices) { - val accumulator = aggregates(i).createAccumulator() - accumulatorList(i).set(0, accumulator) + aggregates(i).resetAccumulator(accumulatorList(i).get(0)) } val iterator = records.iterator() @@ -126,8 +125,7 @@ class DataSetSessionWindowAggregateReduceGroupFunction( // reset first accumulator in list for (i <- aggregates.indices) { - val accumulator = aggregates(i).createAccumulator() - accumulatorList(i).set(0, accumulator) + aggregates(i).resetAccumulator(accumulatorList(i).get(0)) } } else { // set group keys value to final output. diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala index 85df1d8865de6..96a73e406032a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala @@ -80,8 +80,7 @@ class DataSetTumbleCountWindowAggReduceGroupFunction( if (count == 0) { // reset first accumulator for (i <- aggregates.indices) { - val accumulator = aggregates(i).createAccumulator() - accumulatorList(i).set(0, accumulator) + aggregates(i).resetAccumulator(accumulatorList(i).get(0)) } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala index df8bed9ac91dd..29cd2283e584f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala @@ -71,8 +71,7 @@ class DataSetTumbleTimeWindowAggReduceCombineFunction( // reset first accumulator in merge list for (i <- aggregates.indices) { - val accumulator = aggregates(i).createAccumulator() - accumulatorList(i).set(0, accumulator) + aggregates(i).resetAccumulator(accumulatorList(i).get(0)) } while (iterator.hasNext) { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala index 7ce0bf149e703..8a63c40e17201 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala @@ -85,8 +85,7 @@ class DataSetTumbleTimeWindowAggReduceGroupFunction( // reset first accumulator in merge list for (i <- aggregates.indices) { - val accumulator = aggregates(i).createAccumulator() - accumulatorList(i).set(0, accumulator) + aggregates(i).resetAccumulator(accumulatorList(i).get(0)) } while (iterator.hasNext) { diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala index 5c6f7c4caa08a..80fc9477ff951 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala @@ -98,6 +98,22 @@ abstract class AggFunctionTestBase[T] { } } + @Test + // test aggregate functions with resetAccumulator + def testResetAccumulator(): Unit = { + + if (ifMethodExistInFunction("resetAccumulator", aggregator)) { + // iterate over input sets + for ((vals, expected) <- inputValueSets.zip(expectedResults)) { + val accumulator = accumulateVals(vals) + aggregator.resetAccumulator(accumulator) + val expectedAccum = aggregator.createAccumulator() + //The accumulator after reset should be exactly same as the new accumulator + validateResult[Accumulator](expectedAccum, accumulator) + } + } + } + private def validateResult[T](expected: T, result: T): Unit = { (expected, result) match { case (e: DecimalSumWithRetractAccumulator, r: DecimalSumWithRetractAccumulator) =>