diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala index 5cbf48a9e9b816..fbf16427f6f2a5 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala @@ -22,35 +22,41 @@ import java.lang.{Iterable => JIterable} import java.util import java.util.function.BiFunction -import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1} -import org.apache.flink.api.java.typeutils.{GenericTypeInfo, TupleTypeInfo} +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils._ +import org.apache.flink.table.api.dataview.MapView +import org.apache.flink.table.dataview.MapViewTypeInfo import org.apache.flink.table.functions.AggregateFunction -import scala.collection.JavaConverters._ /** The initial accumulator for Collect aggregate function */ -class CollectAccumulator[E] extends JTuple1[util.Map[E, Integer]] +class CollectAccumulator[E](var f0:MapView[E, Integer]) { + def this() { + this(null) + } + + def canEqual(a: Any) = a.isInstanceOf[CollectAccumulator[E]] + + override def equals(that: Any): Boolean = + that match { + case that: CollectAccumulator[E] => that.canEqual(this) && this.f0 == that.f0 + case _ => false + } +} abstract class CollectAggFunction[E] extends AggregateFunction[util.Map[E, Integer], CollectAccumulator[E]] { - @transient - private lazy val addFunction = new BiFunction[Integer, Integer, Integer] { - override def apply(t: Integer, u: Integer): Integer = t + u - } - override def createAccumulator(): CollectAccumulator[E] = { - val acc = new CollectAccumulator[E]() - acc.f0 = new util.HashMap[E, Integer]() + val acc = new CollectAccumulator[E](new MapView[E, Integer]( + getValueTypeInfo.asInstanceOf[TypeInformation[E]], BasicTypeInfo.INT_TYPE_INFO)) acc } def accumulate(accumulator: CollectAccumulator[E], value: E): Unit = { if (value != null) { - if (accumulator.f0.containsKey(value)) { - val add = (x: Integer, y: Integer) => x + y - accumulator.f0.merge(value, 1, addFunction) + if (accumulator.f0.contains(value)) { + accumulator.f0.put(value, accumulator.f0.get(value) + 1) } else { accumulator.f0.put(value, 1) } @@ -58,8 +64,14 @@ abstract class CollectAggFunction[E] } override def getValue(accumulator: CollectAccumulator[E]): util.Map[E, Integer] = { - if (accumulator.f0.size() > 0) { - new util.HashMap(accumulator.f0) + val iterator = accumulator.f0.iterator + if (iterator.hasNext) { + val map = new util.HashMap[E, Integer]() + while (iterator.hasNext) { + val entry = iterator.next() + map.put(entry.getKey, entry.getValue) + } + map } else { null.asInstanceOf[util.Map[E, Integer]] } @@ -70,52 +82,77 @@ abstract class CollectAggFunction[E] } override def getAccumulatorType: TypeInformation[CollectAccumulator[E]] = { - new TupleTypeInfo( - classOf[CollectAccumulator[E]], - new GenericTypeInfo[util.Map[E, Integer]](classOf[util.Map[E, Integer]])) + val clazz = classOf[CollectAccumulator[E]] + val pojoFields = new util.ArrayList[PojoField] + pojoFields.add(new PojoField(clazz.getDeclaredField("f0"), + new MapViewTypeInfo[E, Integer]( + getValueTypeInfo.asInstanceOf[TypeInformation[E]], BasicTypeInfo.INT_TYPE_INFO))) + new PojoTypeInfo[CollectAccumulator[E]](clazz, pojoFields) } def merge(acc: CollectAccumulator[E], its: JIterable[CollectAccumulator[E]]): Unit = { val iter = its.iterator() while (iter.hasNext) { - for ((k: E, v: Integer) <- iter.next().f0.asScala) { - acc.f0.merge(k, v, addFunction) + val mapViewIterator = iter.next().f0.iterator + while (mapViewIterator.hasNext) { + val entry = mapViewIterator.next() + val k = entry.getKey + val oldValue = acc.f0.get(k) + if (oldValue == null) { + acc.f0.put(k, entry.getValue) + } else { + acc.f0.put(k, entry.getValue + oldValue) + } } } } def retract(acc: CollectAccumulator[E], value: E): Unit = { if (value != null) { - if (0 == acc.f0.merge(value, -1, addFunction)) { + val count = acc.f0.get(value) + if (count == 1) { acc.f0.remove(value) + } else { + acc.f0.put(value, count - 1) } } } + + def getValueTypeInfo: TypeInformation[_] } class IntCollectAggFunction extends CollectAggFunction[Int] { + override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.INT_TYPE_INFO } class LongCollectAggFunction extends CollectAggFunction[Long] { + override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.LONG_TYPE_INFO } class StringCollectAggFunction extends CollectAggFunction[String] { + override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.STRING_TYPE_INFO } class ByteCollectAggFunction extends CollectAggFunction[Byte] { + override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.BYTE_TYPE_INFO } class ShortCollectAggFunction extends CollectAggFunction[Short] { + override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.SHORT_TYPE_INFO } class FloatCollectAggFunction extends CollectAggFunction[Float] { + override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.FLOAT_TYPE_INFO } class DoubleCollectAggFunction extends CollectAggFunction[Double] { + override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.DOUBLE_TYPE_INFO } class BooleanCollectAggFunction extends CollectAggFunction[Boolean] { + override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.BOOLEAN_TYPE_INFO } class ObjectCollectAggFunction extends CollectAggFunction[Object] { + override def getValueTypeInfo: TypeInformation[_] = new GenericTypeInfo[Object](classOf[Object]) } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala index 9e6a54a1bbb52f..157c7cb0e5ef69 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala @@ -179,7 +179,7 @@ class DoubleAggFunctionTest extends AggFunctionTestBase[util.Map[Double, Integer], CollectAccumulator[Double]] { override def inputValueSets: Seq[Seq[_]] = Seq( - Seq(1d, 1d, 2d, null, 3.2d, null, 4d, 5d, null), + Seq(1d, 1d, 2d, null, 3.2d, null, 4d, 5d), Seq(null, null, null, null, null, null) ) @@ -200,7 +200,7 @@ class DoubleAggFunctionTest override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) } -class ObjectAggFunctionTest +class ObjectCollectAggFunctionTest extends AggFunctionTestBase[util.Map[Object, Integer], CollectAccumulator[Object]] { override def inputValueSets: Seq[Seq[_]] = Seq(