From d597ab9c4bb7b6e8e2bab3f8b7be311541cef628 Mon Sep 17 00:00:00 2001 From: Shuyi Chen Date: Wed, 27 Sep 2017 17:01:40 -0700 Subject: [PATCH] fix comments --- docs/dev/table/sql.md | 4 +- .../flink/api/java/typeutils/MapTypeInfo.java | 2 +- .../api/java/typeutils/MultisetTypeInfo.java | 42 ----------------- .../org/apache/flink/table/api/Types.scala | 11 ++++- .../aggfunctions/CollectAggFunction.scala | 11 ++--- .../runtime/aggregate/AggregateUtil.scala | 45 +++++++++---------- .../aggfunctions/CollectAggFunctionTest.scala | 22 ++++----- 7 files changed, 52 insertions(+), 85 deletions(-) diff --git a/docs/dev/table/sql.md b/docs/dev/table/sql.md index 1fd16f817dc0f2..c3a16a64a15493 100644 --- a/docs/dev/table/sql.md +++ b/docs/dev/table/sql.md @@ -803,7 +803,7 @@ The SQL runtime is built on top of Flink's DataSet and DataStream APIs. Internal | `Types.PRIMITIVE_ARRAY`| `ARRAY` | e.g. `int[]` | | `Types.OBJECT_ARRAY` | `ARRAY` | e.g. `java.lang.Byte[]`| | `Types.MAP` | `MAP` | `java.util.HashMap` | -| `Types.MULTISET` | `MULTISET` | `java.util.HashMap` | +| `Types.MULTISET` | `MULTISET` | e.g. `java.util.HashMap` for a multiset of `String` | Advanced types such as generic types, composite types (e.g. POJOs or Tuples), and array types (object or primitive arrays) can be fields of a row. @@ -2173,7 +2173,7 @@ VAR_SAMP(value) {% endhighlight %} -

Returns a multiset of the values.

+

Returns a multiset of the values. null input value will be ignored. Return a empty multiset if only null values are added.

diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MapTypeInfo.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MapTypeInfo.java index ca04e0cbc72d75..e9cd09dc217bb8 100644 --- a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MapTypeInfo.java +++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MapTypeInfo.java @@ -93,7 +93,7 @@ public int getArity() { @Override public int getTotalFields() { - return 2; + return 1; } @SuppressWarnings("unchecked") diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MultisetTypeInfo.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MultisetTypeInfo.java index 5aadc006bca5f8..27fe70903edfba 100644 --- a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MultisetTypeInfo.java +++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MultisetTypeInfo.java @@ -22,8 +22,6 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; -import java.util.Map; - import static org.apache.flink.util.Preconditions.checkNotNull; /** @@ -36,7 +34,6 @@ public final class MultisetTypeInfo extends MapTypeInfo { private static final long serialVersionUID = 1L; - public MultisetTypeInfo(Class elementTypeClass) { super(elementTypeClass, Integer.class); } @@ -56,45 +53,6 @@ public TypeInformation getElementTypeInfo() { return getKeyTypeInfo(); } - // ------------------------------------------------------------------------ - // TypeInformation implementation - // ------------------------------------------------------------------------ - - @Override - public boolean isBasicType() { - return false; - } - - @Override - public boolean isTupleType() { - return false; - } - - @Override - public int getArity() { - return 0; - } - - @Override - public int getTotalFields() { - // similar as arrays, the multiset are "opaque" to the direct field addressing logic - // since the multiset's elements are not addressable, we do not expose them - return 1; - } - - @SuppressWarnings("unchecked") - @Override - public Class> getTypeClass() { - return (Class>)(Class)Map.class; - } - - @Override - public boolean isKeyType() { - return false; - } - - // ------------------------------------------------------------------------ - @Override public String toString() { return "Multiset<" + getKeyTypeInfo() + '>'; diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala index 2152b727fff5fa..100c22b368e3d2 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala @@ -18,7 +18,7 @@ package org.apache.flink.table.api import org.apache.flink.api.common.typeinfo.{PrimitiveArrayTypeInfo, TypeInformation, Types => JTypes} -import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo} +import org.apache.flink.api.java.typeutils.{MapTypeInfo, MultisetTypeInfo, ObjectArrayTypeInfo} import org.apache.flink.table.typeutils.TimeIntervalTypeInfo import org.apache.flink.types.Row @@ -110,4 +110,13 @@ object Types { def MAP(keyType: TypeInformation[_], valueType: TypeInformation[_]): TypeInformation[_] = { new MapTypeInfo(keyType, valueType) } + + /** + * Generates type information for a Multiset. + * + * @param elementType type of the elements of the multiset e.g. Types.STRING + */ + def MULTISET(elementType: TypeInformation[_]): TypeInformation[_] = { + new MultisetTypeInfo(elementType) + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala index fbf16427f6f2a5..75ae58c3a665d8 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala @@ -20,7 +20,6 @@ package org.apache.flink.table.functions.aggfunctions import java.lang.{Iterable => JIterable} import java.util -import java.util.function.BiFunction import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.typeutils._ @@ -28,11 +27,12 @@ import org.apache.flink.table.api.dataview.MapView import org.apache.flink.table.dataview.MapViewTypeInfo import org.apache.flink.table.functions.AggregateFunction +import scala.collection.JavaConverters._ /** The initial accumulator for Collect aggregate function */ class CollectAccumulator[E](var f0:MapView[E, Integer]) { def this() { - this(null) + this(null) } def canEqual(a: Any) = a.isInstanceOf[CollectAccumulator[E]] @@ -55,8 +55,9 @@ abstract class CollectAggFunction[E] def accumulate(accumulator: CollectAccumulator[E], value: E): Unit = { if (value != null) { - if (accumulator.f0.contains(value)) { - accumulator.f0.put(value, accumulator.f0.get(value) + 1) + val currVal = accumulator.f0.get(value) + if (currVal != null) { + accumulator.f0.put(value, currVal + 1) } else { accumulator.f0.put(value, 1) } @@ -73,7 +74,7 @@ abstract class CollectAggFunction[E] } map } else { - null.asInstanceOf[util.Map[E, Integer]] + Map[E, Integer]().asJava } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index 27ee27e18ba9ac..f38985e7ab82aa 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -1410,33 +1410,32 @@ object AggregateUtil { case _: SqlCountAggFunction => aggregates(index) = new CountAggFunction + case collect: SqlAggFunction if collect.getKind == SqlKind.COLLECT => + aggregates(index) = sqlTypeName match { + case TINYINT => + new ByteCollectAggFunction + case SMALLINT => + new ShortCollectAggFunction + case INTEGER => + new IntCollectAggFunction + case BIGINT => + new LongCollectAggFunction + case VARCHAR | CHAR => + new StringCollectAggFunction + case FLOAT => + new FloatCollectAggFunction + case DOUBLE => + new DoubleCollectAggFunction + case _ => + new ObjectCollectAggFunction + } + case udagg: AggSqlFunction => aggregates(index) = udagg.getFunction accTypes(index) = udagg.accType - case other: SqlAggFunction => - if (other.getKind == SqlKind.COLLECT) { - aggregates(index) = sqlTypeName match { - case TINYINT => - new ByteCollectAggFunction - case SMALLINT => - new ShortCollectAggFunction - case INTEGER => - new IntCollectAggFunction - case BIGINT => - new LongCollectAggFunction - case VARCHAR | CHAR => - new StringCollectAggFunction - case FLOAT => - new FloatCollectAggFunction - case DOUBLE => - new DoubleCollectAggFunction - case _ => - new ObjectCollectAggFunction - } - } else { - throw new TableException(s"unsupported Function: '${other.getName}'") - } + case unSupported: SqlAggFunction => + throw new TableException(s"unsupported Function: '${unSupported.getName}'") } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala index 157c7cb0e5ef69..ea3f763e6e63a7 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala @@ -20,11 +20,11 @@ package org.apache.flink.table.runtime.aggfunctions import java.util -import com.google.common.collect.ImmutableMap -import org.apache.curator import org.apache.flink.table.functions.AggregateFunction import org.apache.flink.table.functions.aggfunctions._ +import scala.collection.JavaConverters._ + /** * Test case for built-in collect aggregate functions */ @@ -32,7 +32,7 @@ class StringCollectAggFunctionTest extends AggFunctionTestBase[util.Map[String, Integer], CollectAccumulator[String]] { override def inputValueSets: Seq[Seq[_]] = Seq( - Seq("a", "a", "b", null, "c", null, "d", "e", null, "f", null), + Seq("a", "a", "b", null, "c", null, "d", "e", null, "f"), Seq(null, null, null, null, null, null) ) @@ -44,7 +44,7 @@ class StringCollectAggFunctionTest map.put("d", 1) map.put("e", 1) map.put("f", 1) - Seq(map, null) + Seq(map, Map[String, Integer]().asJava) } override def aggregator: AggregateFunction[ @@ -69,7 +69,7 @@ class IntCollectAggFunctionTest map.put(3, 1) map.put(4, 1) map.put(5, 1) - Seq(map, null) + Seq(map, Map[Int, Integer]().asJava) } override def aggregator: AggregateFunction[util.Map[Int, Integer], CollectAccumulator[Int]] = @@ -93,7 +93,7 @@ class ByteCollectAggFunctionTest map.put(3, 1) map.put(4, 1) map.put(5, 1) - Seq(map, null) + Seq(map, Map[Byte, Integer]().asJava) } override def aggregator: AggregateFunction[util.Map[Byte, Integer], CollectAccumulator[Byte]] = @@ -118,7 +118,7 @@ class ShortCollectAggFunctionTest map.put(3, 1) map.put(4, 1) map.put(5, 1) - Seq(map, null) + Seq(map, Map[Short, Integer]().asJava) } override def aggregator: AggregateFunction[util.Map[Short, Integer], CollectAccumulator[Short]] = @@ -142,7 +142,7 @@ class LongCollectAggFunctionTest map.put(3, 1) map.put(4, 1) map.put(5, 1) - Seq(map, null) + Seq(map, Map[Long, Integer]().asJava) } override def aggregator: AggregateFunction[util.Map[Long, Integer], CollectAccumulator[Long]] = @@ -166,7 +166,7 @@ class FloatAggFunctionTest map.put(3.2f, 1) map.put(4, 1) map.put(5, 1) - Seq(map, null) + Seq(map, Map[Float, Integer]().asJava) } override def aggregator: AggregateFunction[util.Map[Float, Integer], CollectAccumulator[Float]] = @@ -190,7 +190,7 @@ class DoubleAggFunctionTest map.put(3.2d, 1) map.put(4, 1) map.put(5, 1) - Seq(map, null) + Seq(map, Map[Double, Integer]().asJava) } override def aggregator: AggregateFunction[ @@ -212,7 +212,7 @@ class ObjectCollectAggFunctionTest val map = new util.HashMap[Object, Integer]() map.put(Tuple2(1, "a"), 2) map.put(Tuple2(2, "b"), 1) - Seq(map, null) + Seq(map, Map[Object, Integer]().asJava) } override def aggregator: AggregateFunction[