Skip to content

Commit

Permalink
use MapView for the accumulator
Browse files Browse the repository at this point in the history
  • Loading branch information
Shuyi Chen committed Sep 21, 2017
1 parent 1e3dadc commit f07216a
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 25 deletions.
Expand Up @@ -22,44 +22,56 @@ import java.lang.{Iterable => JIterable}
import java.util
import java.util.function.BiFunction

import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1}
import org.apache.flink.api.java.typeutils.{GenericTypeInfo, TupleTypeInfo}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.typeutils._
import org.apache.flink.table.api.dataview.MapView
import org.apache.flink.table.dataview.MapViewTypeInfo
import org.apache.flink.table.functions.AggregateFunction

import scala.collection.JavaConverters._

/** The initial accumulator for Collect aggregate function */
class CollectAccumulator[E] extends JTuple1[util.Map[E, Integer]]
class CollectAccumulator[E](var f0:MapView[E, Integer]) {
def this() {
this(null)
}

def canEqual(a: Any) = a.isInstanceOf[CollectAccumulator[E]]

override def equals(that: Any): Boolean =
that match {
case that: CollectAccumulator[E] => that.canEqual(this) && this.f0 == that.f0
case _ => false
}
}

abstract class CollectAggFunction[E]
extends AggregateFunction[util.Map[E, Integer], CollectAccumulator[E]] {

@transient
private lazy val addFunction = new BiFunction[Integer, Integer, Integer] {
override def apply(t: Integer, u: Integer): Integer = t + u
}

override def createAccumulator(): CollectAccumulator[E] = {
val acc = new CollectAccumulator[E]()
acc.f0 = new util.HashMap[E, Integer]()
val acc = new CollectAccumulator[E](new MapView[E, Integer](
getValueTypeInfo.asInstanceOf[TypeInformation[E]], BasicTypeInfo.INT_TYPE_INFO))
acc
}

def accumulate(accumulator: CollectAccumulator[E], value: E): Unit = {
if (value != null) {
if (accumulator.f0.containsKey(value)) {
val add = (x: Integer, y: Integer) => x + y
accumulator.f0.merge(value, 1, addFunction)
if (accumulator.f0.contains(value)) {
accumulator.f0.put(value, accumulator.f0.get(value) + 1)
} else {
accumulator.f0.put(value, 1)
}
}
}

override def getValue(accumulator: CollectAccumulator[E]): util.Map[E, Integer] = {
if (accumulator.f0.size() > 0) {
new util.HashMap(accumulator.f0)
val iterator = accumulator.f0.iterator
if (iterator.hasNext) {
val map = new util.HashMap[E, Integer]()
while (iterator.hasNext) {
val entry = iterator.next()
map.put(entry.getKey, entry.getValue)
}
map
} else {
null.asInstanceOf[util.Map[E, Integer]]
}
Expand All @@ -70,52 +82,77 @@ abstract class CollectAggFunction[E]
}

override def getAccumulatorType: TypeInformation[CollectAccumulator[E]] = {
new TupleTypeInfo(
classOf[CollectAccumulator[E]],
new GenericTypeInfo[util.Map[E, Integer]](classOf[util.Map[E, Integer]]))
val clazz = classOf[CollectAccumulator[E]]
val pojoFields = new util.ArrayList[PojoField]
pojoFields.add(new PojoField(clazz.getDeclaredField("f0"),
new MapViewTypeInfo[E, Integer](
getValueTypeInfo.asInstanceOf[TypeInformation[E]], BasicTypeInfo.INT_TYPE_INFO)))
new PojoTypeInfo[CollectAccumulator[E]](clazz, pojoFields)
}

def merge(acc: CollectAccumulator[E], its: JIterable[CollectAccumulator[E]]): Unit = {
val iter = its.iterator()
while (iter.hasNext) {
for ((k: E, v: Integer) <- iter.next().f0.asScala) {
acc.f0.merge(k, v, addFunction)
val mapViewIterator = iter.next().f0.iterator
while (mapViewIterator.hasNext) {
val entry = mapViewIterator.next()
val k = entry.getKey
val oldValue = acc.f0.get(k)
if (oldValue == null) {
acc.f0.put(k, entry.getValue)
} else {
acc.f0.put(k, entry.getValue + oldValue)
}
}
}
}

def retract(acc: CollectAccumulator[E], value: E): Unit = {
if (value != null) {
if (0 == acc.f0.merge(value, -1, addFunction)) {
val count = acc.f0.get(value)
if (count == 1) {
acc.f0.remove(value)
} else {
acc.f0.put(value, count - 1)
}
}
}

def getValueTypeInfo: TypeInformation[_]
}

class IntCollectAggFunction extends CollectAggFunction[Int] {
override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.INT_TYPE_INFO
}

class LongCollectAggFunction extends CollectAggFunction[Long] {
override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.LONG_TYPE_INFO
}

class StringCollectAggFunction extends CollectAggFunction[String] {
override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.STRING_TYPE_INFO
}

class ByteCollectAggFunction extends CollectAggFunction[Byte] {
override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.BYTE_TYPE_INFO
}

class ShortCollectAggFunction extends CollectAggFunction[Short] {
override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.SHORT_TYPE_INFO
}

class FloatCollectAggFunction extends CollectAggFunction[Float] {
override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.FLOAT_TYPE_INFO
}

class DoubleCollectAggFunction extends CollectAggFunction[Double] {
override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.DOUBLE_TYPE_INFO
}

class BooleanCollectAggFunction extends CollectAggFunction[Boolean] {
override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.BOOLEAN_TYPE_INFO
}

class ObjectCollectAggFunction extends CollectAggFunction[Object] {
override def getValueTypeInfo: TypeInformation[_] = new GenericTypeInfo[Object](classOf[Object])
}
Expand Up @@ -179,7 +179,7 @@ class DoubleAggFunctionTest
extends AggFunctionTestBase[util.Map[Double, Integer], CollectAccumulator[Double]] {

override def inputValueSets: Seq[Seq[_]] = Seq(
Seq(1d, 1d, 2d, null, 3.2d, null, 4d, 5d, null),
Seq(1d, 1d, 2d, null, 3.2d, null, 4d, 5d),
Seq(null, null, null, null, null, null)
)

Expand All @@ -200,7 +200,7 @@ class DoubleAggFunctionTest
override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}

class ObjectAggFunctionTest
class ObjectCollectAggFunctionTest
extends AggFunctionTestBase[util.Map[Object, Integer], CollectAccumulator[Object]] {

override def inputValueSets: Seq[Seq[_]] = Seq(
Expand Down

0 comments on commit f07216a

Please sign in to comment.