From 48c7b4c1dc2f67bb62f1828a351de740d4dd5d95 Mon Sep 17 00:00:00 2001 From: Shuyi Chen Date: Wed, 4 Oct 2017 15:28:42 -0700 Subject: [PATCH] fix comments --- .../aggfunctions/CollectAggFunction.scala | 46 ++----------------- .../runtime/aggregate/AggregateUtil.scala | 25 +++------- .../aggfunctions/CollectAggFunctionTest.scala | 18 ++++---- .../table/runtime/stream/sql/SqlITCase.scala | 2 + 4 files changed, 23 insertions(+), 68 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala index 75ae58c3a665d8..364145ead85a8f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala @@ -44,12 +44,12 @@ class CollectAccumulator[E](var f0:MapView[E, Integer]) { } } -abstract class CollectAggFunction[E] +class CollectAggFunction[E](valueTypeInfo: TypeInformation[_]) extends AggregateFunction[util.Map[E, Integer], CollectAccumulator[E]] { override def createAccumulator(): CollectAccumulator[E] = { val acc = new CollectAccumulator[E](new MapView[E, Integer]( - getValueTypeInfo.asInstanceOf[TypeInformation[E]], BasicTypeInfo.INT_TYPE_INFO)) + valueTypeInfo.asInstanceOf[TypeInformation[E]], BasicTypeInfo.INT_TYPE_INFO)) acc } @@ -87,7 +87,7 @@ abstract class CollectAggFunction[E] val pojoFields = new util.ArrayList[PojoField] pojoFields.add(new PojoField(clazz.getDeclaredField("f0"), new MapViewTypeInfo[E, Integer]( - getValueTypeInfo.asInstanceOf[TypeInformation[E]], BasicTypeInfo.INT_TYPE_INFO))) + valueTypeInfo.asInstanceOf[TypeInformation[E]], BasicTypeInfo.INT_TYPE_INFO))) new PojoTypeInfo[CollectAccumulator[E]](clazz, pojoFields) } @@ -118,42 +118,4 @@ abstract class CollectAggFunction[E] } } } - - def getValueTypeInfo: TypeInformation[_] -} - -class IntCollectAggFunction extends CollectAggFunction[Int] { - override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.INT_TYPE_INFO -} - -class LongCollectAggFunction extends CollectAggFunction[Long] { - override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.LONG_TYPE_INFO -} - -class StringCollectAggFunction extends CollectAggFunction[String] { - override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.STRING_TYPE_INFO -} - -class ByteCollectAggFunction extends CollectAggFunction[Byte] { - override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.BYTE_TYPE_INFO -} - -class ShortCollectAggFunction extends CollectAggFunction[Short] { - override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.SHORT_TYPE_INFO -} - -class FloatCollectAggFunction extends CollectAggFunction[Float] { - override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.FLOAT_TYPE_INFO -} - -class DoubleCollectAggFunction extends CollectAggFunction[Double] { - override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.DOUBLE_TYPE_INFO -} - -class BooleanCollectAggFunction extends CollectAggFunction[Boolean] { - override def getValueTypeInfo: TypeInformation[_] = BasicTypeInfo.BOOLEAN_TYPE_INFO -} - -class ObjectCollectAggFunction extends CollectAggFunction[Object] { - override def getValueTypeInfo: TypeInformation[_] = new GenericTypeInfo[Object](classOf[Object]) -} +} \ No newline at end of file diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index f38985e7ab82aa..40808e39d8f74e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -28,7 +28,7 @@ import org.apache.calcite.sql.{SqlAggFunction, SqlKind} import org.apache.flink.api.common.functions.{MapFunction, RichGroupReduceFunction, AggregateFunction => DataStreamAggFunction, _} import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.Tuple -import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.java.typeutils.{GenericTypeInfo, RowTypeInfo} import org.apache.flink.streaming.api.functions.ProcessFunction import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, WindowFunction} import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} @@ -1200,8 +1200,8 @@ object AggregateUtil { } else { aggFieldIndexes(index) = argList.asScala.map(i => i.intValue).toArray } - val sqlTypeName = inputType.getFieldList.get(aggFieldIndexes(index)(0)).getType - .getSqlTypeName + val relDataType = inputType.getFieldList.get(aggFieldIndexes(index)(0)).getType + val sqlTypeName = relDataType.getSqlTypeName aggregateCall.getAggregation match { case _: SqlSumAggFunction => @@ -1412,23 +1412,12 @@ object AggregateUtil { case collect: SqlAggFunction if collect.getKind == SqlKind.COLLECT => aggregates(index) = sqlTypeName match { - case TINYINT => - new ByteCollectAggFunction - case SMALLINT => - new ShortCollectAggFunction - case INTEGER => - new IntCollectAggFunction - case BIGINT => - new LongCollectAggFunction - case VARCHAR | CHAR => - new StringCollectAggFunction - case FLOAT => - new FloatCollectAggFunction - case DOUBLE => - new DoubleCollectAggFunction + case TINYINT | SMALLINT | INTEGER | BIGINT | VARCHAR | CHAR | FLOAT | DOUBLE => + new CollectAggFunction(FlinkTypeFactory.toTypeInfo(relDataType)) case _ => - new ObjectCollectAggFunction + new CollectAggFunction(new GenericTypeInfo[Object](classOf[Object])) } + accTypes(index) = aggregates(index).getAccumulatorType case udagg: AggSqlFunction => aggregates(index) = udagg.getFunction diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala index ea3f763e6e63a7..f85cb70a56257a 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala @@ -20,6 +20,8 @@ package org.apache.flink.table.runtime.aggfunctions import java.util +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.api.java.typeutils.GenericTypeInfo import org.apache.flink.table.functions.AggregateFunction import org.apache.flink.table.functions.aggfunctions._ @@ -49,7 +51,7 @@ class StringCollectAggFunctionTest override def aggregator: AggregateFunction[ util.Map[String, Integer], CollectAccumulator[String]] = - new StringCollectAggFunction() + new CollectAggFunction(BasicTypeInfo.STRING_TYPE_INFO) override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) } @@ -73,7 +75,7 @@ class IntCollectAggFunctionTest } override def aggregator: AggregateFunction[util.Map[Int, Integer], CollectAccumulator[Int]] = - new IntCollectAggFunction() + new CollectAggFunction(BasicTypeInfo.INT_TYPE_INFO) override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) } @@ -97,7 +99,7 @@ class ByteCollectAggFunctionTest } override def aggregator: AggregateFunction[util.Map[Byte, Integer], CollectAccumulator[Byte]] = - new ByteCollectAggFunction() + new CollectAggFunction(BasicTypeInfo.BYTE_TYPE_INFO) override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) } @@ -122,7 +124,7 @@ class ShortCollectAggFunctionTest } override def aggregator: AggregateFunction[util.Map[Short, Integer], CollectAccumulator[Short]] = - new ShortCollectAggFunction() + new CollectAggFunction(BasicTypeInfo.SHORT_TYPE_INFO) override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) } @@ -146,7 +148,7 @@ class LongCollectAggFunctionTest } override def aggregator: AggregateFunction[util.Map[Long, Integer], CollectAccumulator[Long]] = - new LongCollectAggFunction() + new CollectAggFunction(BasicTypeInfo.LONG_TYPE_INFO) override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) } @@ -170,7 +172,7 @@ class FloatAggFunctionTest } override def aggregator: AggregateFunction[util.Map[Float, Integer], CollectAccumulator[Float]] = - new FloatCollectAggFunction() + new CollectAggFunction(BasicTypeInfo.FLOAT_TYPE_INFO) override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) } @@ -195,7 +197,7 @@ class DoubleAggFunctionTest override def aggregator: AggregateFunction[ util.Map[Double, Integer], CollectAccumulator[Double]] = - new DoubleCollectAggFunction() + new CollectAggFunction(BasicTypeInfo.DOUBLE_TYPE_INFO) override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) } @@ -217,7 +219,7 @@ class ObjectCollectAggFunctionTest override def aggregator: AggregateFunction[ util.Map[Object, Integer], CollectAccumulator[Object]] = - new ObjectCollectAggFunction() + new CollectAggFunction(new GenericTypeInfo[Object](classOf[Object])) override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala index 4128ee8198e385..32e37243a9fca8 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala @@ -97,6 +97,7 @@ class SqlITCase extends StreamingWithStateTestBase { val env = StreamExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env) + env.setStateBackend(getStateBackend) StreamITCase.clear val sqlQuery = "SELECT b, COLLECT(a) FROM MyTable GROUP BY b" @@ -123,6 +124,7 @@ class SqlITCase extends StreamingWithStateTestBase { val env = StreamExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env) + env.setStateBackend(getStateBackend) StreamITCase.clear val sqlQuery = "SELECT b, COLLECT(c) FROM MyTable GROUP BY b"