Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Shuyi Chen committed Sep 29, 2017
1 parent f07216a commit 1741f10
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 85 deletions.
4 changes: 2 additions & 2 deletions docs/dev/table/sql.md
Expand Up @@ -746,7 +746,7 @@ The SQL runtime is built on top of Flink's DataSet and DataStream APIs. Internal
| `Types.PRIMITIVE_ARRAY`| `ARRAY` | e.g. `int[]` |
| `Types.OBJECT_ARRAY` | `ARRAY` | e.g. `java.lang.Byte[]`|
| `Types.MAP` | `MAP` | `java.util.HashMap` |
| `Types.MULTISET` | `MULTISET` | `java.util.HashMap` |
| `Types.MULTISET` | `MULTISET` | e.g. `java.util.HashMap<String, Integer>` for a multiset of `String` |


Advanced types such as generic types, composite types (e.g. POJOs or Tuples), and array types (object or primitive arrays) can be fields of a row.
Expand Down Expand Up @@ -2116,7 +2116,7 @@ VAR_SAMP(value)
{% endhighlight %}
</td>
<td>
<p>Returns a multiset of the <i>value</i>s.</p>
<p>Returns a multiset of the <i>value</i>s. null input <i>value</i> will be ignored. Return a empty multiset if only null values are added. </p>
</td>
</tr>
</tbody>
Expand Down
Expand Up @@ -93,7 +93,7 @@ public int getArity() {

@Override
public int getTotalFields() {
return 2;
return 1;
}

@SuppressWarnings("unchecked")
Expand Down
Expand Up @@ -22,8 +22,6 @@
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;

import java.util.Map;

import static org.apache.flink.util.Preconditions.checkNotNull;

/**
Expand All @@ -36,7 +34,6 @@ public final class MultisetTypeInfo<T> extends MapTypeInfo<T, Integer> {

private static final long serialVersionUID = 1L;


public MultisetTypeInfo(Class<T> elementTypeClass) {
super(elementTypeClass, Integer.class);
}
Expand All @@ -56,45 +53,6 @@ public TypeInformation<T> getElementTypeInfo() {
return getKeyTypeInfo();
}

// ------------------------------------------------------------------------
// TypeInformation implementation
// ------------------------------------------------------------------------

@Override
public boolean isBasicType() {
return false;
}

@Override
public boolean isTupleType() {
return false;
}

@Override
public int getArity() {
return 0;
}

@Override
public int getTotalFields() {
// similar as arrays, the multiset are "opaque" to the direct field addressing logic
// since the multiset's elements are not addressable, we do not expose them
return 1;
}

@SuppressWarnings("unchecked")
@Override
public Class<Map<T, Integer>> getTypeClass() {
return (Class<Map<T, Integer>>)(Class<?>)Map.class;
}

@Override
public boolean isKeyType() {
return false;
}

// ------------------------------------------------------------------------

@Override
public String toString() {
return "Multiset<" + getKeyTypeInfo() + '>';
Expand Down
Expand Up @@ -18,7 +18,7 @@
package org.apache.flink.table.api

import org.apache.flink.api.common.typeinfo.{PrimitiveArrayTypeInfo, TypeInformation, Types => JTypes}
import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo}
import org.apache.flink.api.java.typeutils.{MapTypeInfo, MultisetTypeInfo, ObjectArrayTypeInfo}
import org.apache.flink.table.typeutils.TimeIntervalTypeInfo
import org.apache.flink.types.Row

Expand Down Expand Up @@ -110,4 +110,13 @@ object Types {
def MAP(keyType: TypeInformation[_], valueType: TypeInformation[_]): TypeInformation[_] = {
new MapTypeInfo(keyType, valueType)
}

/**
* Generates type information for a Multiset.
*
* @param elementType type of the elements of the multiset e.g. Types.STRING
*/
def MULTISET(elementType: TypeInformation[_]): TypeInformation[_] = {
new MultisetTypeInfo(elementType)
}
}
Expand Up @@ -20,19 +20,19 @@ package org.apache.flink.table.functions.aggfunctions

import java.lang.{Iterable => JIterable}
import java.util
import java.util.function.BiFunction

import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.typeutils._
import org.apache.flink.table.api.dataview.MapView
import org.apache.flink.table.dataview.MapViewTypeInfo
import org.apache.flink.table.functions.AggregateFunction

import scala.collection.JavaConverters._

/** The initial accumulator for Collect aggregate function */
class CollectAccumulator[E](var f0:MapView[E, Integer]) {
def this() {
this(null)
this(null)
}

def canEqual(a: Any) = a.isInstanceOf[CollectAccumulator[E]]
Expand All @@ -55,8 +55,9 @@ abstract class CollectAggFunction[E]

def accumulate(accumulator: CollectAccumulator[E], value: E): Unit = {
if (value != null) {
if (accumulator.f0.contains(value)) {
accumulator.f0.put(value, accumulator.f0.get(value) + 1)
val currVal = accumulator.f0.get(value)
if (currVal != null) {
accumulator.f0.put(value, currVal + 1)
} else {
accumulator.f0.put(value, 1)
}
Expand All @@ -73,7 +74,7 @@ abstract class CollectAggFunction[E]
}
map
} else {
null.asInstanceOf[util.Map[E, Integer]]
Map[E, Integer]().asJava
}
}

Expand Down
Expand Up @@ -1410,33 +1410,32 @@ object AggregateUtil {
case _: SqlCountAggFunction =>
aggregates(index) = new CountAggFunction

case collect: SqlAggFunction if collect.getKind == SqlKind.COLLECT =>
aggregates(index) = sqlTypeName match {
case TINYINT =>
new ByteCollectAggFunction
case SMALLINT =>
new ShortCollectAggFunction
case INTEGER =>
new IntCollectAggFunction
case BIGINT =>
new LongCollectAggFunction
case VARCHAR | CHAR =>
new StringCollectAggFunction
case FLOAT =>
new FloatCollectAggFunction
case DOUBLE =>
new DoubleCollectAggFunction
case _ =>
new ObjectCollectAggFunction
}

case udagg: AggSqlFunction =>
aggregates(index) = udagg.getFunction
accTypes(index) = udagg.accType

case other: SqlAggFunction =>
if (other.getKind == SqlKind.COLLECT) {
aggregates(index) = sqlTypeName match {
case TINYINT =>
new ByteCollectAggFunction
case SMALLINT =>
new ShortCollectAggFunction
case INTEGER =>
new IntCollectAggFunction
case BIGINT =>
new LongCollectAggFunction
case VARCHAR | CHAR =>
new StringCollectAggFunction
case FLOAT =>
new FloatCollectAggFunction
case DOUBLE =>
new DoubleCollectAggFunction
case _ =>
new ObjectCollectAggFunction
}
} else {
throw new TableException(s"unsupported Function: '${other.getName}'")
}
case unSupported: SqlAggFunction =>
throw new TableException(s"unsupported Function: '${unSupported.getName}'")
}
}

Expand Down
Expand Up @@ -20,19 +20,19 @@ package org.apache.flink.table.runtime.aggfunctions

import java.util

import com.google.common.collect.ImmutableMap
import org.apache.curator
import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.functions.aggfunctions._

import scala.collection.JavaConverters._

/**
* Test case for built-in collect aggregate functions
*/
class StringCollectAggFunctionTest
extends AggFunctionTestBase[util.Map[String, Integer], CollectAccumulator[String]] {

override def inputValueSets: Seq[Seq[_]] = Seq(
Seq("a", "a", "b", null, "c", null, "d", "e", null, "f", null),
Seq("a", "a", "b", null, "c", null, "d", "e", null, "f"),
Seq(null, null, null, null, null, null)
)

Expand All @@ -44,7 +44,7 @@ class StringCollectAggFunctionTest
map.put("d", 1)
map.put("e", 1)
map.put("f", 1)
Seq(map, null)
Seq(map, Map[String, Integer]().asJava)
}

override def aggregator: AggregateFunction[
Expand All @@ -69,7 +69,7 @@ class IntCollectAggFunctionTest
map.put(3, 1)
map.put(4, 1)
map.put(5, 1)
Seq(map, null)
Seq(map, Map[Int, Integer]().asJava)
}

override def aggregator: AggregateFunction[util.Map[Int, Integer], CollectAccumulator[Int]] =
Expand All @@ -93,7 +93,7 @@ class ByteCollectAggFunctionTest
map.put(3, 1)
map.put(4, 1)
map.put(5, 1)
Seq(map, null)
Seq(map, Map[Byte, Integer]().asJava)
}

override def aggregator: AggregateFunction[util.Map[Byte, Integer], CollectAccumulator[Byte]] =
Expand All @@ -118,7 +118,7 @@ class ShortCollectAggFunctionTest
map.put(3, 1)
map.put(4, 1)
map.put(5, 1)
Seq(map, null)
Seq(map, Map[Short, Integer]().asJava)
}

override def aggregator: AggregateFunction[util.Map[Short, Integer], CollectAccumulator[Short]] =
Expand All @@ -142,7 +142,7 @@ class LongCollectAggFunctionTest
map.put(3, 1)
map.put(4, 1)
map.put(5, 1)
Seq(map, null)
Seq(map, Map[Long, Integer]().asJava)
}

override def aggregator: AggregateFunction[util.Map[Long, Integer], CollectAccumulator[Long]] =
Expand All @@ -166,7 +166,7 @@ class FloatAggFunctionTest
map.put(3.2f, 1)
map.put(4, 1)
map.put(5, 1)
Seq(map, null)
Seq(map, Map[Float, Integer]().asJava)
}

override def aggregator: AggregateFunction[util.Map[Float, Integer], CollectAccumulator[Float]] =
Expand All @@ -190,7 +190,7 @@ class DoubleAggFunctionTest
map.put(3.2d, 1)
map.put(4, 1)
map.put(5, 1)
Seq(map, null)
Seq(map, Map[Double, Integer]().asJava)
}

override def aggregator: AggregateFunction[
Expand All @@ -212,7 +212,7 @@ class ObjectCollectAggFunctionTest
val map = new util.HashMap[Object, Integer]()
map.put(Tuple2(1, "a"), 2)
map.put(Tuple2(2, "b"), 1)
Seq(map, null)
Seq(map, Map[Object, Integer]().asJava)
}

override def aggregator: AggregateFunction[
Expand Down

0 comments on commit 1741f10

Please sign in to comment.