Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Shuyi Chen committed Oct 4, 2017
1 parent 03a609a commit 48c7b4c
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 68 deletions.
Expand Up @@ -44,12 +44,12 @@ class CollectAccumulator[E](var f0:MapView[E, Integer]) {
}
}

abstract class CollectAggFunction[E]
class CollectAggFunction[E](valueTypeInfo: TypeInformation[_])
extends AggregateFunction[util.Map[E, Integer], CollectAccumulator[E]] {

override def createAccumulator(): CollectAccumulator[E] = {
val acc = new CollectAccumulator[E](new MapView[E, Integer](
getValueTypeInfo.asInstanceOf[TypeInformation[E]], BasicTypeInfo.INT_TYPE_INFO))
valueTypeInfo.asInstanceOf[TypeInformation[E]], BasicTypeInfo.INT_TYPE_INFO))
acc
}

Expand Down Expand Up @@ -87,7 +87,7 @@ abstract class CollectAggFunction[E]
val pojoFields = new util.ArrayList[PojoField]
pojoFields.add(new PojoField(clazz.getDeclaredField("f0"),
new MapViewTypeInfo[E, Integer](
getValueTypeInfo.asInstanceOf[TypeInformation[E]], BasicTypeInfo.INT_TYPE_INFO)))
valueTypeInfo.asInstanceOf[TypeInformation[E]], BasicTypeInfo.INT_TYPE_INFO)))
new PojoTypeInfo[CollectAccumulator[E]](clazz, pojoFields)
}

Expand Down Expand Up @@ -118,42 +118,4 @@ abstract class CollectAggFunction[E]
}
}
}

def getValueTypeInfo: TypeInformation[_]
}

class IntCollectAggFunction extends CollectAggFunction[Int] {
override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.INT_TYPE_INFO
}

class LongCollectAggFunction extends CollectAggFunction[Long] {
override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.LONG_TYPE_INFO
}

class StringCollectAggFunction extends CollectAggFunction[String] {
override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.STRING_TYPE_INFO
}

class ByteCollectAggFunction extends CollectAggFunction[Byte] {
override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.BYTE_TYPE_INFO
}

class ShortCollectAggFunction extends CollectAggFunction[Short] {
override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.SHORT_TYPE_INFO
}

class FloatCollectAggFunction extends CollectAggFunction[Float] {
override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.FLOAT_TYPE_INFO
}

class DoubleCollectAggFunction extends CollectAggFunction[Double] {
override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.DOUBLE_TYPE_INFO
}

class BooleanCollectAggFunction extends CollectAggFunction[Boolean] {
override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.BOOLEAN_TYPE_INFO
}

class ObjectCollectAggFunction extends CollectAggFunction[Object] {
override def getValueTypeInfo: TypeInformation[_] = new GenericTypeInfo[Object](classOf[Object])
}
}
Expand Up @@ -28,7 +28,7 @@ import org.apache.calcite.sql.{SqlAggFunction, SqlKind}
import org.apache.flink.api.common.functions.{MapFunction, RichGroupReduceFunction, AggregateFunction => DataStreamAggFunction, _}
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.Tuple
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.api.java.typeutils.{GenericTypeInfo, RowTypeInfo}
import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, WindowFunction}
import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow}
Expand Down Expand Up @@ -1200,8 +1200,8 @@ object AggregateUtil {
} else {
aggFieldIndexes(index) = argList.asScala.map(i => i.intValue).toArray
}
val sqlTypeName = inputType.getFieldList.get(aggFieldIndexes(index)(0)).getType
.getSqlTypeName
val relDataType = inputType.getFieldList.get(aggFieldIndexes(index)(0)).getType
val sqlTypeName = relDataType.getSqlTypeName
aggregateCall.getAggregation match {

case _: SqlSumAggFunction =>
Expand Down Expand Up @@ -1412,23 +1412,12 @@ object AggregateUtil {

case collect: SqlAggFunction if collect.getKind == SqlKind.COLLECT =>
aggregates(index) = sqlTypeName match {
case TINYINT =>
new ByteCollectAggFunction
case SMALLINT =>
new ShortCollectAggFunction
case INTEGER =>
new IntCollectAggFunction
case BIGINT =>
new LongCollectAggFunction
case VARCHAR | CHAR =>
new StringCollectAggFunction
case FLOAT =>
new FloatCollectAggFunction
case DOUBLE =>
new DoubleCollectAggFunction
case TINYINT | SMALLINT | INTEGER | BIGINT | VARCHAR | CHAR | FLOAT | DOUBLE =>
new CollectAggFunction(FlinkTypeFactory.toTypeInfo(relDataType))
case _ =>
new ObjectCollectAggFunction
new CollectAggFunction(new GenericTypeInfo[Object](classOf[Object]))
}
accTypes(index) = aggregates(index).getAccumulatorType

case udagg: AggSqlFunction =>
aggregates(index) = udagg.getFunction
Expand Down
Expand Up @@ -20,6 +20,8 @@ package org.apache.flink.table.runtime.aggfunctions

import java.util

import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.java.typeutils.GenericTypeInfo
import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.functions.aggfunctions._

Expand Down Expand Up @@ -49,7 +51,7 @@ class StringCollectAggFunctionTest

override def aggregator: AggregateFunction[
util.Map[String, Integer], CollectAccumulator[String]] =
new StringCollectAggFunction()
new CollectAggFunction(BasicTypeInfo.STRING_TYPE_INFO)

override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}
Expand All @@ -73,7 +75,7 @@ class IntCollectAggFunctionTest
}

override def aggregator: AggregateFunction[util.Map[Int, Integer], CollectAccumulator[Int]] =
new IntCollectAggFunction()
new CollectAggFunction(BasicTypeInfo.INT_TYPE_INFO)

override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}
Expand All @@ -97,7 +99,7 @@ class ByteCollectAggFunctionTest
}

override def aggregator: AggregateFunction[util.Map[Byte, Integer], CollectAccumulator[Byte]] =
new ByteCollectAggFunction()
new CollectAggFunction(BasicTypeInfo.BYTE_TYPE_INFO)

override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}
Expand All @@ -122,7 +124,7 @@ class ShortCollectAggFunctionTest
}

override def aggregator: AggregateFunction[util.Map[Short, Integer], CollectAccumulator[Short]] =
new ShortCollectAggFunction()
new CollectAggFunction(BasicTypeInfo.SHORT_TYPE_INFO)

override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}
Expand All @@ -146,7 +148,7 @@ class LongCollectAggFunctionTest
}

override def aggregator: AggregateFunction[util.Map[Long, Integer], CollectAccumulator[Long]] =
new LongCollectAggFunction()
new CollectAggFunction(BasicTypeInfo.LONG_TYPE_INFO)

override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}
Expand All @@ -170,7 +172,7 @@ class FloatAggFunctionTest
}

override def aggregator: AggregateFunction[util.Map[Float, Integer], CollectAccumulator[Float]] =
new FloatCollectAggFunction()
new CollectAggFunction(BasicTypeInfo.FLOAT_TYPE_INFO)

override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}
Expand All @@ -195,7 +197,7 @@ class DoubleAggFunctionTest

override def aggregator: AggregateFunction[
util.Map[Double, Integer], CollectAccumulator[Double]] =
new DoubleCollectAggFunction()
new CollectAggFunction(BasicTypeInfo.DOUBLE_TYPE_INFO)

override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}
Expand All @@ -217,7 +219,7 @@ class ObjectCollectAggFunctionTest

override def aggregator: AggregateFunction[
util.Map[Object, Integer], CollectAccumulator[Object]] =
new ObjectCollectAggFunction()
new CollectAggFunction(new GenericTypeInfo[Object](classOf[Object]))

override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
}
Expand Down
Expand Up @@ -97,6 +97,7 @@ class SqlITCase extends StreamingWithStateTestBase {

val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)
env.setStateBackend(getStateBackend)
StreamITCase.clear

val sqlQuery = "SELECT b, COLLECT(a) FROM MyTable GROUP BY b"
Expand All @@ -123,6 +124,7 @@ class SqlITCase extends StreamingWithStateTestBase {

val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)
env.setStateBackend(getStateBackend)
StreamITCase.clear

val sqlQuery = "SELECT b, COLLECT(c) FROM MyTable GROUP BY b"
Expand Down

0 comments on commit 48c7b4c

Please sign in to comment.