From 7bcb0db2e0c4a8286def1d34e29e60a5539a5205 Mon Sep 17 00:00:00 2001 From: Shuyi Chen Date: Wed, 23 Aug 2017 17:54:10 -0700 Subject: [PATCH] add MultiSetTypeInfo; add Collect SQL feature --- flink-core/pom.xml | 7 + .../typeutils/base/MultisetSerializer.java | 195 ++++++++++++++++++ .../api/java/typeutils/MultisetTypeInfo.java | 142 +++++++++++++ .../base/MultisetSerializerTest.java | 71 +++++++ .../java/typeutils/MultisetTypeInfoTest.java | 38 ++++ .../table/calcite/FlinkTypeFactory.scala | 24 ++- .../aggfunctions/CollectAggFunction.scala | 101 +++++++++ .../plan/schema/MultisetRelDataType.scala | 50 +++++ .../runtime/aggregate/AggregateUtil.scala | 26 ++- .../table/validate/FunctionCatalog.scala | 1 + .../aggfunctions/CollectAggFunctionTest.scala | 169 +++++++++++++++ .../runtime/batch/sql/AggregateITCase.scala | 29 +++ .../table/runtime/stream/sql/SqlITCase.scala | 26 +++ 13 files changed, 876 insertions(+), 3 deletions(-) create mode 100644 flink-core/src/main/java/org/apache/flink/api/common/typeutils/base/MultisetSerializer.java create mode 100644 flink-core/src/main/java/org/apache/flink/api/java/typeutils/MultisetTypeInfo.java create mode 100644 flink-core/src/test/java/org/apache/flink/api/common/typeutils/base/MultisetSerializerTest.java create mode 100644 flink-core/src/test/java/org/apache/flink/api/java/typeutils/MultisetTypeInfoTest.java create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/MultisetRelDataType.scala create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala diff --git a/flink-core/pom.xml b/flink-core/pom.xml index 7039e48bf8bc0..1a68cb4f7580b 100644 --- a/flink-core/pom.xml +++ b/flink-core/pom.xml @@ -80,6 +80,13 @@ under the License. + + + org.apache.commons + commons-collections4 + 4.1 + + org.apache.avro diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/base/MultisetSerializer.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/base/MultisetSerializer.java new file mode 100644 index 0000000000000..cd10b17b09830 --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/base/MultisetSerializer.java @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.typeutils.base; + +import org.apache.commons.collections4.multiset.AbstractMultiSet; +import org.apache.commons.collections4.multiset.HashMultiSet; +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.*; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; + +import java.io.IOException; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * A serializer for Multisets. The serializer relies on an element serializer + * for the serialization of the Multiset's elements. + * + *

The serialization format for the Multiset is as follows: four bytes for the length of the lost, + * followed by the serialized representation of each element. + * + * @param The type of element in the Multiset. + */ +@Internal +public final class MultisetSerializer extends TypeSerializer> { + + private static final long serialVersionUID = 12L; + + /** The serializer for the elements of the Multiset */ + private final TypeSerializer elementSerializer; + + /** + * Creates a Multiset serializer that uses the given serializer to serialize the Multiset's elements. + * + * @param elementSerializer The serializer for the elements of the Multiset + */ + public MultisetSerializer(TypeSerializer elementSerializer) { + this.elementSerializer = checkNotNull(elementSerializer); + } + + // ------------------------------------------------------------------------ + // MultisetSerializer specific properties + // ------------------------------------------------------------------------ + + /** + * Gets the serializer for the elements of the Multiset. + * @return The serializer for the elements of the Multiset + */ + public TypeSerializer getElementSerializer() { + return elementSerializer; + } + + // ------------------------------------------------------------------------ + // Type Serializer implementation + // ------------------------------------------------------------------------ + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer> duplicate() { + TypeSerializer duplicateElement = elementSerializer.duplicate(); + return duplicateElement == elementSerializer ? + this : new MultisetSerializer<>(duplicateElement); + } + + @Override + public AbstractMultiSet createInstance() { + return new HashMultiSet<>(); + } + + @Override + public AbstractMultiSet copy(AbstractMultiSet from) { + return new HashMultiSet<>(from); + } + + @Override + public AbstractMultiSet copy(AbstractMultiSet from, AbstractMultiSet reuse) { + return copy(from); + } + + @Override + public int getLength() { + return -1; // var length + } + + @Override + public void serialize(AbstractMultiSet multiSet, DataOutputView target) throws IOException { + final int size = multiSet.size(); + target.writeInt(size); + + for (T element : multiSet) { + elementSerializer.serialize(element, target); + } + } + + @Override + public AbstractMultiSet deserialize(DataInputView source) throws IOException { + final int size = source.readInt(); + final AbstractMultiSet multiSet = new HashMultiSet<>(); + for (int i = 0; i < size; i++) { + multiSet.add(elementSerializer.deserialize(source)); + } + return multiSet; + } + + @Override + public AbstractMultiSet deserialize(AbstractMultiSet reuse, DataInputView source) + throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + // copy number of elements + final int num = source.readInt(); + target.writeInt(num); + for (int i = 0; i < num; i++) { + elementSerializer.copy(source, target); + } + } + + // -------------------------------------------------------------------- + + @Override + public boolean equals(Object obj) { + return obj == this || + (obj != null && obj.getClass() == getClass() && + elementSerializer.equals(((MultisetSerializer) obj).elementSerializer)); + } + + @Override + public boolean canEqual(Object obj) { + return true; + } + + @Override + public int hashCode() { + return elementSerializer.hashCode(); + } + + // -------------------------------------------------------------------------------------------- + // Serializer configuration snapshotting & compatibility + // -------------------------------------------------------------------------------------------- + + @Override + public CollectionSerializerConfigSnapshot snapshotConfiguration() { + return new CollectionSerializerConfigSnapshot<>(elementSerializer); + } + + @Override + public CompatibilityResult> ensureCompatibility( + TypeSerializerConfigSnapshot configSnapshot) { + if (configSnapshot instanceof CollectionSerializerConfigSnapshot) { + Tuple2, TypeSerializerConfigSnapshot> previousElemSerializerAndConfig = + ((CollectionSerializerConfigSnapshot) configSnapshot).getSingleNestedSerializerAndConfig(); + + CompatibilityResult compatResult = CompatibilityUtil.resolveCompatibilityResult( + previousElemSerializerAndConfig.f0, + UnloadableDummyTypeSerializer.class, + previousElemSerializerAndConfig.f1, + elementSerializer); + + if (!compatResult.isRequiresMigration()) { + return CompatibilityResult.compatible(); + } else if (compatResult.getConvertDeserializer() != null) { + return CompatibilityResult.requiresMigration( + new MultisetSerializer<>( + new TypeDeserializerAdapter<>(compatResult.getConvertDeserializer()))); + } + } + + return CompatibilityResult.requiresMigration(); + } +} diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MultisetTypeInfo.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MultisetTypeInfo.java new file mode 100644 index 0000000000000..f861372b520cd --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MultisetTypeInfo.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.java.typeutils; + +import org.apache.commons.collections4.multiset.AbstractMultiSet; +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.MultisetSerializer; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * A {@link TypeInformation} for the Multiset types of the Java API. + * + * @param The type of the elements in the Multiset. + */ +@PublicEvolving +public final class MultisetTypeInfo extends TypeInformation> { + + private static final long serialVersionUID = 1L; + + private final TypeInformation elementTypeInfo; + + + public MultisetTypeInfo(Class elementTypeClass) { + this.elementTypeInfo = of(checkNotNull(elementTypeClass, "elementTypeClass")); + } + + public MultisetTypeInfo(TypeInformation elementTypeInfo) { + this.elementTypeInfo = checkNotNull(elementTypeInfo, "elementTypeInfo"); + } + + // ------------------------------------------------------------------------ + // MultisetTypeInfo specific properties + // ------------------------------------------------------------------------ + + /** + * Gets the type information for the elements contained in the Multiset + */ + public TypeInformation getElementTypeInfo() { + return elementTypeInfo; + } + + // ------------------------------------------------------------------------ + // TypeInformation implementation + // ------------------------------------------------------------------------ + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 0; + } + + @Override + public int getTotalFields() { + // similar as arrays, the multiset are "opaque" to the direct field addressing logic + // since the multiset's elements are not addressable, we do not expose them + return 1; + } + + @SuppressWarnings("unchecked") + @Override + public Class> getTypeClass() { + return (Class>)(Class)AbstractMultiSet.class; + } + + @Override + public boolean isKeyType() { + return false; + } + + @Override + public TypeSerializer> createSerializer(ExecutionConfig config) { + TypeSerializer elementTypeSerializer = elementTypeInfo.createSerializer(config); + return new MultisetSerializer<>(elementTypeSerializer); + } + + // ------------------------------------------------------------------------ + + @Override + public String toString() { + return "Multiset<" + elementTypeInfo + '>'; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + else if (obj instanceof MultisetTypeInfo) { + final MultisetTypeInfo other = (MultisetTypeInfo) obj; + return other.canEqual(this) && elementTypeInfo.equals(other.elementTypeInfo); + } else { + return false; + } + } + + @Override + public int hashCode() { + return 31 * elementTypeInfo.hashCode() + 1; + } + + @Override + public boolean canEqual(Object obj) { + return obj != null && obj.getClass() == getClass(); + } + + @SuppressWarnings("unchecked") + @PublicEvolving + public static MultisetTypeInfo getInfoFor(TypeInformation componentInfo) { + checkNotNull(componentInfo); + + return new MultisetTypeInfo<>(componentInfo); + } +} diff --git a/flink-core/src/test/java/org/apache/flink/api/common/typeutils/base/MultisetSerializerTest.java b/flink-core/src/test/java/org/apache/flink/api/common/typeutils/base/MultisetSerializerTest.java new file mode 100644 index 0000000000000..6f2646a44573f --- /dev/null +++ b/flink-core/src/test/java/org/apache/flink/api/common/typeutils/base/MultisetSerializerTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.typeutils.base; + +import org.apache.commons.collections4.multiset.AbstractMultiSet; +import org.apache.commons.collections4.multiset.HashMultiSet; +import org.apache.flink.api.common.typeutils.SerializerTestBase; +import org.apache.flink.api.common.typeutils.TypeSerializer; + +import java.util.Random; + +/** + * Test for {@link MultisetSerializer}. + */ +public class MultisetSerializerTest extends SerializerTestBase> { + + @Override + protected TypeSerializer> createSerializer() { + return new MultisetSerializer(LongSerializer.INSTANCE); + } + + @Override + protected int getLength() { + return -1; + } + + @SuppressWarnings("unchecked") + @Override + protected Class> getTypeClass() { + return (Class>) (Class) AbstractMultiSet.class; + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + @Override + protected AbstractMultiSet[] getTestData() { + final Random rnd = new Random(123654789); + + // empty Multisets + final AbstractMultiSet set1 = new HashMultiSet<>(); + + // single element Multisets + final AbstractMultiSet set2 = new HashMultiSet<>(); + set2.add(12345L); + + // larger Multisets + final AbstractMultiSet set3 = new HashMultiSet<>(); + for (int i = 0; i < rnd.nextInt(200); i++) { + set3.add(rnd.nextLong()); + } + + return (AbstractMultiSet[]) new AbstractMultiSet[]{ + set1, set2, set3 + }; + } +} diff --git a/flink-core/src/test/java/org/apache/flink/api/java/typeutils/MultisetTypeInfoTest.java b/flink-core/src/test/java/org/apache/flink/api/java/typeutils/MultisetTypeInfoTest.java new file mode 100644 index 0000000000000..395f4cef47a49 --- /dev/null +++ b/flink-core/src/test/java/org/apache/flink/api/java/typeutils/MultisetTypeInfoTest.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.java.typeutils; + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeutils.TypeInformationTestBase; + +/** + * Test for {@link MultisetTypeInfo}. + */ +public class MultisetTypeInfoTest extends TypeInformationTestBase> { + + @Override + protected MultisetTypeInfo[] getTestData() { + return new MultisetTypeInfo[] { + new MultisetTypeInfo<>(BasicTypeInfo.STRING_TYPE_INFO), + new MultisetTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO), + new MultisetTypeInfo<>(Long.class) + }; + } +} + diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala index dbefe203e9601..ae46b984b2f25 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala @@ -29,7 +29,7 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo._ import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.java.typeutils.ValueTypeInfo._ -import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, RowTypeInfo} +import org.apache.flink.api.java.typeutils.{MapTypeInfo, MultisetTypeInfo, ObjectArrayTypeInfo, RowTypeInfo} import org.apache.flink.table.api.TableException import org.apache.flink.table.calcite.FlinkTypeFactory.typeInfoToSqlTypeName import org.apache.flink.table.plan.schema._ @@ -154,6 +154,13 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp createTypeFromTypeInfo(mp.getValueTypeInfo, isNullable = true), isNullable) + case mts: MultisetTypeInfo[_] => + new MultisetRelDataType( + mts, + createTypeFromTypeInfo(mts.getElementTypeInfo, isNullable = true), + isNullable + ) + case ti: TypeInformation[_] => new GenericRelDataType( ti, @@ -236,6 +243,14 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp canonize(relType) } + override def createMultisetType(elementType: RelDataType, maxCardinality: Long): RelDataType = { + val relType = new MultisetRelDataType( + MultisetTypeInfo.getInfoFor(FlinkTypeFactory.toTypeInfo(elementType)), + elementType, + isNullable = false) + canonize(relType) + } + override def createTypeWithNullability( relDataType: RelDataType, isNullable: Boolean): RelDataType = { @@ -257,6 +272,9 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp case map: MapRelDataType => new MapRelDataType(map.typeInfo, map.keyType, map.valueType, isNullable) + case multiSet: MultisetRelDataType => + new MultisetRelDataType(multiSet.typeInfo, multiSet.getComponentType, isNullable) + case generic: GenericRelDataType => new GenericRelDataType(generic.typeInfo, isNullable, typeSystem) @@ -396,6 +414,10 @@ object FlinkTypeFactory { val mapRelDataType = relDataType.asInstanceOf[MapRelDataType] mapRelDataType.typeInfo + case MULTISET if relDataType.isInstanceOf[MultisetRelDataType] => + val multisetRelDataType = relDataType.asInstanceOf[MultisetRelDataType] + multisetRelDataType.typeInfo + case _@t => throw TableException(s"Type is not supported: $t") } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala new file mode 100644 index 0000000000000..8efe4a94890df --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CollectAggFunction.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.aggfunctions + +import java.lang.{Iterable => JIterable} + +import org.apache.commons.collections4.multiset.{AbstractMultiSet, HashMultiSet} +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1} +import org.apache.flink.api.java.typeutils.{GenericTypeInfo, TupleTypeInfo} +import org.apache.flink.table.functions.AggregateFunction + +/** The initial accumulator for Collect aggregate function */ +class CollectAccumulator[E] extends JTuple1[AbstractMultiSet[E]] + +abstract class CollectAggFunction[E] + extends AggregateFunction[AbstractMultiSet[E], CollectAccumulator[E]] { + + override def createAccumulator(): CollectAccumulator[E] = { + val acc = new CollectAccumulator[E]() + acc.f0 = new HashMultiSet() + acc + } + + def accumulate(accumulator: CollectAccumulator[E], value: E): Unit = { + if (value != null) { + accumulator.f0.add(value) + } + } + + override def getValue(accumulator: CollectAccumulator[E]): AbstractMultiSet[E] = { + if (accumulator.f0.size() > 0) { + new HashMultiSet(accumulator.f0) + } else { + null.asInstanceOf[AbstractMultiSet[E]] + } + } + + def resetAccumulator(acc: CollectAccumulator[E]): Unit = { + acc.f0.clear() + } + + override def getAccumulatorType: TypeInformation[CollectAccumulator[E]] = { + new TupleTypeInfo( + classOf[CollectAccumulator[E]], + new GenericTypeInfo[AbstractMultiSet[E]](classOf[AbstractMultiSet[E]])) + } + + def merge(acc: CollectAccumulator[E], its: JIterable[CollectAccumulator[E]]): Unit = { + val iter = its.iterator() + while (iter.hasNext) { + acc.f0.addAll(iter.next().f0) + } + } + + def retract(acc: CollectAccumulator[E], value: Any): Unit = { + if (value != null) { + acc.f0.remove(value) + } + } +} + +class IntCollectAggFunction extends CollectAggFunction[Int] { +} + +class LongCollectAggFunction extends CollectAggFunction[Long] { +} + +class StringCollectAggFunction extends CollectAggFunction[String] { +} + +class ByteCollectAggFunction extends CollectAggFunction[Byte] { +} + +class ShortCollectAggFunction extends CollectAggFunction[Short] { +} + +class FloatCollectAggFunction extends CollectAggFunction[Float] { +} + +class DoubleCollectAggFunction extends CollectAggFunction[Double] { +} + +class BooleanCollectAggFunction extends CollectAggFunction[Boolean] { +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/MultisetRelDataType.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/MultisetRelDataType.scala new file mode 100644 index 0000000000000..3153c75141d06 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/MultisetRelDataType.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.schema + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.sql.`type`.MultisetSqlType +import org.apache.flink.api.common.typeinfo.TypeInformation + +class MultisetRelDataType( + val typeInfo: TypeInformation[_], + elementType: RelDataType, + isNullable: Boolean) + extends MultisetSqlType( + elementType, + isNullable) { + + override def toString = s"MULTISET($typeInfo)" + + def canEqual(other: Any): Boolean = other.isInstanceOf[MultisetRelDataType] + + override def equals(other: Any): Boolean = other match { + case that: MultisetRelDataType => + super.equals(that) && + (that canEqual this) && + typeInfo == that.typeInfo && + isNullable == that.isNullable + case _ => false + } + + override def hashCode(): Int = { + typeInfo.hashCode() + } + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index c9f98e31bd3be..4ab400e6d4632 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -1394,8 +1394,30 @@ object AggregateUtil { aggregates(index) = udagg.getFunction accTypes(index) = udagg.accType - case unSupported: SqlAggFunction => - throw new TableException(s"unsupported Function: '${unSupported.getName}'") + case other: SqlAggFunction => + if (other.getKind == SqlKind.COLLECT) { + aggregates(index) = sqlTypeName match { + case TINYINT => + new ByteCollectAggFunction + case SMALLINT => + new ShortCollectAggFunction + case INTEGER => + new IntCollectAggFunction + case BIGINT => + new LongCollectAggFunction + case VARCHAR | CHAR => + new StringCollectAggFunction + case FLOAT => + new FloatCollectAggFunction + case DOUBLE => + new DoubleCollectAggFunction + case _ => + throw new TableException( + s"unsupported Sql type for Collect AggFunction: '${sqlTypeName.getName}'") + } + } else { + throw new TableException(s"unsupported Function: '${other.getName}'") + } } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala index 55dbe4c55552d..9c0d6e936824d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala @@ -335,6 +335,7 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable { SqlStdOperatorTable.SUM, SqlStdOperatorTable.SUM0, SqlStdOperatorTable.COUNT, + SqlStdOperatorTable.COLLECT, SqlStdOperatorTable.MIN, SqlStdOperatorTable.MAX, SqlStdOperatorTable.AVG, diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala new file mode 100644 index 0000000000000..4eae2478bc5a1 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CollectAggFunctionTest.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.aggfunctions + +import com.google.common.collect.ImmutableList +import org.apache.commons.collections4.multiset.{AbstractMultiSet, HashMultiSet} +import org.apache.flink.table.functions.AggregateFunction +import org.apache.flink.table.functions.aggfunctions._ + +/** + * Test case for built-in collect aggregate functions + */ +class StringCollectAggFunctionTest + extends AggFunctionTestBase[AbstractMultiSet[String], CollectAccumulator[String]] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq("a", "a", "b", null, "c", null, "d", "e", null, "f"), + Seq(null, null, null, null, null, null) + ) + + override def expectedResults: Seq[AbstractMultiSet[String]] = { + val set1 = new HashMultiSet[String]() + set1.addAll(ImmutableList.of("a", "a", "b", "c", "d", "e", "f")) + Seq(set1, null) + } + + override def aggregator: AggregateFunction[AbstractMultiSet[String], CollectAccumulator[String]] = + new StringCollectAggFunction() + + override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) +} + +class IntCollectAggFunctionTest + extends AggFunctionTestBase[AbstractMultiSet[Int], CollectAccumulator[Int]] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq(1, 1, 2, null, 3, null, 4, 5, null, 6), + Seq(null, null, null, null, null, null) + ) + + override def expectedResults: Seq[AbstractMultiSet[Int]] = { + val set1 = new HashMultiSet[Int]() + set1.addAll(ImmutableList.of(1, 1, 2, 3, 4, 5, 6)) + Seq(set1, null) + } + + override def aggregator: AggregateFunction[AbstractMultiSet[Int], CollectAccumulator[Int]] = + new IntCollectAggFunction() + + override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) +} + +class ByteCollectAggFunctionTest + extends AggFunctionTestBase[AbstractMultiSet[Byte], CollectAccumulator[Byte]] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq(1.toByte, 1.toByte, 2.toByte, null, 3.toByte, null, 4.toByte, 5.toByte, null, 6.toByte), + Seq(null, null, null, null, null, null) + ) + + override def expectedResults: Seq[AbstractMultiSet[Byte]] = { + val set1 = new HashMultiSet[Byte]() + set1.addAll(ImmutableList.of(1, 1, 2, 3, 4, 5, 6)) + Seq(set1, null) + } + + override def aggregator: AggregateFunction[AbstractMultiSet[Byte], CollectAccumulator[Byte]] = + new ByteCollectAggFunction() + + override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) +} + +class ShortCollectAggFunctionTest + extends AggFunctionTestBase[AbstractMultiSet[Short], CollectAccumulator[Short]] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq(1.toShort, 1.toShort, 2.toShort, null, + 3.toShort, null, 4.toShort, 5.toShort, null, 6.toShort), + Seq(null, null, null, null, null, null) + ) + + override def expectedResults: Seq[AbstractMultiSet[Short]] = { + val set1 = new HashMultiSet[Short]() + set1.addAll(ImmutableList.of(1, 1, 2, 3, 4, 5, 6)) + Seq(set1, null) + } + + override def aggregator: AggregateFunction[AbstractMultiSet[Short], CollectAccumulator[Short]] = + new ShortCollectAggFunction() + + override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) +} + +class LongCollectAggFunctionTest + extends AggFunctionTestBase[AbstractMultiSet[Long], CollectAccumulator[Long]] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq(1L, 1L, 2L, null, 3L, null, 4L, 5L, null, 6L), + Seq(null, null, null, null, null, null) + ) + + override def expectedResults: Seq[AbstractMultiSet[Long]] = { + val set1 = new HashMultiSet[Long]() + set1.addAll(ImmutableList.of(1, 1, 2, 3, 4, 5, 6)) + Seq(set1, null) + } + + override def aggregator: AggregateFunction[AbstractMultiSet[Long], CollectAccumulator[Long]] = + new LongCollectAggFunction() + + override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) +} + +class FloatAggFunctionTest + extends AggFunctionTestBase[AbstractMultiSet[Float], CollectAccumulator[Float]] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq(1f, 1f, 2f, null, 3.2f, null, 4f, 5f, null, 6f), + Seq(null, null, null, null, null, null) + ) + + override def expectedResults: Seq[AbstractMultiSet[Float]] = { + val set1 = new HashMultiSet[Float]() + set1.addAll(ImmutableList.of(1f, 1f, 2f, 3.2f, 4f, 5f, 6f)) + Seq(set1, null) + } + + override def aggregator: AggregateFunction[AbstractMultiSet[Float], CollectAccumulator[Float]] = + new FloatCollectAggFunction() + + override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) +} + +class DoubleAggFunctionTest + extends AggFunctionTestBase[AbstractMultiSet[Double], CollectAccumulator[Double]] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq(1d, 1d, 2d, null, 3.2d, null, 4d, 5d, null, 6d), + Seq(null, null, null, null, null, null) + ) + + override def expectedResults: Seq[AbstractMultiSet[Double]] = { + val set1 = new HashMultiSet[Double]() + set1.addAll(ImmutableList.of(1, 1, 2, 3.2, 4, 5, 6)) + Seq(set1, null) + } + + override def aggregator: AggregateFunction[AbstractMultiSet[Double], CollectAccumulator[Double]] = + new DoubleCollectAggFunction() + + override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) +} + diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala index 39b83710c5d37..46b5f481997ce 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala @@ -328,6 +328,35 @@ class AggregateITCase( TestBaseUtils.compareResultAsText(result.asJava, expected) } + @Test + def testTumbleWindowAggregateWithCollect(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val sqlQuery = + "SELECT b, COLLECT(b)" + + "FROM T " + + "GROUP BY b, TUMBLE(ts, INTERVAL '3' SECOND)" + + val ds = CollectionDataSets.get3TupleDataSet(env) + // create timestamps + .map(x => (x._1, x._2, x._3, new Timestamp(x._1 * 1000))) + tEnv.registerDataSet("T", ds, 'a, 'b, 'c, 'ts) + + val result = tEnv.sql(sqlQuery).toDataSet[Row].collect() + val expected = Seq( + "1,[1:1]", + "2,[2:2]", + "3,[3:3]", + "4,[4:1]", "4,[4:3]", + "5,[5:2]", "5,[5:3]", + "6,[6:3]", "6,[6:3]" + ).mkString("\n") + + TestBaseUtils.compareResultAsText(result.asJava, expected) + } + @Test def testHopWindowAggregate(): Unit = { diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala index d2f9a9a73cb15..30baa17bb3413 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala @@ -85,6 +85,32 @@ class SqlITCase extends StreamingWithStateTestBase { assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) } + @Test + def testUnboundedGroupByCollect(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val sqlQuery = "SELECT b, COLLECT(a) FROM MyTable GROUP BY b" + + val t = StreamTestData.get3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c) + tEnv.registerTable("MyTable", t) + + val result = tEnv.sql(sqlQuery).toRetractStream[Row] + result.addSink(new StreamITCase.RetractingSink).setParallelism(1) + env.execute() + + val expected = List( + "1,[1:1]", + "2,[2:1, 3:1]", + "3,[4:1, 5:1, 6:1]", + "4,[7:1, 8:1, 9:1, 10:1]", + "5,[11:1, 12:1, 13:1, 14:1, 15:1]", + "6,[16:1, 17:1, 18:1, 19:1, 20:1, 21:1]") + assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) + } + /** test selection **/ @Test def testSelectExpressionFromTable(): Unit = {