diff --git a/docs/data/sql_functions.yml b/docs/data/sql_functions.yml index b871bd58ac54e..9692db6994b97 100644 --- a/docs/data/sql_functions.yml +++ b/docs/data/sql_functions.yml @@ -1059,6 +1059,13 @@ aggregate: Divides the rows for each window partition into `n` buckets ranging from 1 to at most `n`. If the number of rows in the window partition doesn't divide evenly into the number of buckets, then the remainder values are distributed one per bucket, starting with the first bucket. For example, with 6 rows and 4 buckets, the bucket values would be as follows: 1 1 2 2 3 4 + - sql: ARRAY_AGG([ ALL | DISTINCT ] expression [ RESPECT NULLS | IGNORE NULLS ]) + table: FIELD.arrayAgg + description: | + By default or with keyword `ALL` and, return an array that concatenates the input rows + and returns `NULL` if there are no input rows. Use `DISTINCT` for one unique instance of each value. + By default `NULL` values are respected, use `IGNORE NULLS` to skip `NULL` values. + The `ORDER BY` clause is currently not supported. - sql: JSON_OBJECTAGG([KEY] key VALUE value [ { NULL | ABSENT } ON NULL ]) table: jsonObjectAgg(JsonOnNull, keyExpression, valueExpression) description: | diff --git a/docs/data/sql_functions_zh.yml b/docs/data/sql_functions_zh.yml index 13bdaec40e654..b0ba79f12b3e1 100644 --- a/docs/data/sql_functions_zh.yml +++ b/docs/data/sql_functions_zh.yml @@ -1181,6 +1181,12 @@ aggregate: 将窗口分区中的所有数据按照顺序划分为 n 个分组,返回分配给各行数据的分组编号(从 1 开始,最大为 n)。 如果不能均匀划分为 n 个分组,则剩余值从第 1 个分组开始,为每一分组分配一个。 比如某个窗口分区有 6 行数据,划分为 4 个分组,则各行的分组编号为:1,1,2,2,3,4。 + - sql: ARRAY_AGG([ ALL | DISTINCT ] expression [ RESPECT NULLS | IGNORE NULLS ]) + table: FIELD.arrayAgg + description: | + 默认情况下或使用关键字ALL,返回输入行中表达式所组成的数组,并且如果没有输入行,则返回 `NULL`。使用 `DISTINCT` 则对所有值去重后计算。 + 默认情况下`NULL` 值不会被忽略,使用 `IGNORE NULLS` 忽略 `NULL` 值。 + 目前尚不支持 `ORDER BY` 子句。 - sql: JSON_OBJECTAGG([KEY] key VALUE value [ { NULL | ABSENT } ON NULL ]) table: jsonObjectAgg(JsonOnNull, keyExpression, valueExpression) description: | diff --git a/flink-python/docs/reference/pyflink.table/expressions.rst b/flink-python/docs/reference/pyflink.table/expressions.rst index 908a6ceda5aa5..5c7ea97df032c 100644 --- a/flink-python/docs/reference/pyflink.table/expressions.rst +++ b/flink-python/docs/reference/pyflink.table/expressions.rst @@ -138,6 +138,7 @@ arithmetic functions Expression.var_pop Expression.var_samp Expression.collect + Expression.array_agg Expression.alias Expression.cast Expression.try_cast diff --git a/flink-python/pyflink/table/expression.py b/flink-python/pyflink/table/expression.py index cb72ba40b21bc..4272f1724cb3a 100644 --- a/flink-python/pyflink/table/expression.py +++ b/flink-python/pyflink/table/expression.py @@ -832,6 +832,10 @@ def var_samp(self) -> 'Expression': def collect(self) -> 'Expression': return _unary_op("collect")(self) + @property + def array_agg(self) -> 'Expression': + return _unary_op("arrayAgg")(self) + def alias(self, name: str, *extra_names: str) -> 'Expression[T]': """ Specifies a name for an expression i.e. a field. diff --git a/flink-python/pyflink/table/tests/test_expression.py b/flink-python/pyflink/table/tests/test_expression.py index 589d089496199..d187ef8347ee3 100644 --- a/flink-python/pyflink/table/tests/test_expression.py +++ b/flink-python/pyflink/table/tests/test_expression.py @@ -114,6 +114,7 @@ def test_expression(self): self.assertEqual('varPop(a)', str(expr1.var_pop)) self.assertEqual('varSamp(a)', str(expr1.var_samp)) self.assertEqual('collect(a)', str(expr1.collect)) + self.assertEqual('ARRAY_AGG(a)', str(expr1.array_agg)) self.assertEqual("as(a, 'a', 'b', 'c')", str(expr1.alias('a', 'b', 'c'))) self.assertEqual('cast(a, INT)', str(expr1.cast(DataTypes.INT()))) self.assertEqual('asc(a)', str(expr1.asc)) diff --git a/flink-table/flink-sql-parser/src/main/codegen/data/Parser.tdd b/flink-table/flink-sql-parser/src/main/codegen/data/Parser.tdd index 9b82c2d18f02b..9153be2a08b00 100644 --- a/flink-table/flink-sql-parser/src/main/codegen/data/Parser.tdd +++ b/flink-table/flink-sql-parser/src/main/codegen/data/Parser.tdd @@ -201,6 +201,7 @@ "AFTER" "ALWAYS" "APPLY" + "ARRAY_AGG" "ASC" "ASSERTION" "ASSIGNMENT" diff --git a/flink-table/flink-sql-parser/src/test/java/org/apache/flink/sql/parser/FlinkSqlParserImplTest.java b/flink-table/flink-sql-parser/src/test/java/org/apache/flink/sql/parser/FlinkSqlParserImplTest.java index 87ffcfdd3cbc8..ef1750a89fb9c 100644 --- a/flink-table/flink-sql-parser/src/test/java/org/apache/flink/sql/parser/FlinkSqlParserImplTest.java +++ b/flink-table/flink-sql-parser/src/test/java/org/apache/flink/sql/parser/FlinkSqlParserImplTest.java @@ -63,12 +63,23 @@ void testDescribeCatalog() { sql("desc catalog a").ok("DESCRIBE CATALOG `A`"); } - // ignore test methods that we don't support - // BEGIN - // ARRAY_AGG - @Disabled @Test - void testArrayAgg() {} + void testArrayAgg() { + sql("select\n" + + " array_agg(ename respect nulls order by deptno, ename) as c1,\n" + + " array_agg(ename order by deptno, ename desc) as c2,\n" + + " array_agg(distinct ename) as c3,\n" + + " array_agg(ename) as c4\n" + + "from emp group by gender") + .ok( + "SELECT" + + " ARRAY_AGG(`ENAME` ORDER BY `DEPTNO`, `ENAME`) RESPECT NULLS AS `C1`," + + " ARRAY_AGG(`ENAME` ORDER BY `DEPTNO`, `ENAME` DESC) AS `C2`," + + " ARRAY_AGG(DISTINCT `ENAME`) AS `C3`," + + " ARRAY_AGG(`ENAME`) AS `C4`\n" + + "FROM `EMP`\n" + + "GROUP BY `GENDER`"); + } // DESCRIBE SCHEMA @Disabled diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/BaseExpressions.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/BaseExpressions.java index cdc108d36726c..36da6d54969c2 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/BaseExpressions.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/BaseExpressions.java @@ -53,6 +53,7 @@ import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ABS; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ACOS; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AND; +import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_AGG; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_CONCAT; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_CONTAINS; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_DISTINCT; @@ -527,6 +528,11 @@ public OutType collect() { return toApiSpecificExpression(unresolvedCall(COLLECT, toExpr())); } + /** Returns array aggregate of a given expression. */ + public OutType arrayAgg() { + return toApiSpecificExpression(unresolvedCall(ARRAY_AGG, toExpr())); + } + /** * Returns a new value being cast to {@code toType}. A cast error throws an exception and fails * the job. When performing a cast operation that may fail, like {@link DataTypes#STRING()} to diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java index 93bb9d3690c9b..704e8efac3559 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java @@ -745,6 +745,13 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL) .outputTypeStrategy(argument(0)) .build(); + public static final BuiltInFunctionDefinition ARRAY_AGG = + BuiltInFunctionDefinition.newBuilder() + .name("ARRAY_AGG") + .kind(AGGREGATE) + .outputTypeStrategy(nullableIfArgs(SpecificTypeStrategies.ARRAY)) + .build(); + // -------------------------------------------------------------------------------------------- // String functions // -------------------------------------------------------------------------------------------- diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java index 69844be1c1505..24cde572c7eb7 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java @@ -84,6 +84,8 @@ public class SqlAggFunctionVisitor extends ExpressionDefaultVisitor getAuxiliaryFunctions() { public static final SqlAggFunction APPROX_COUNT_DISTINCT = SqlStdOperatorTable.APPROX_COUNT_DISTINCT; + /** + * Use the definitions in Flink instead of {@link SqlLibraryOperators#ARRAY_AGG}, because we + * return nullable ARRAY type. Order by clause like ARRAY_AGG(x ORDER BY x, y) for + * aggregate function is not supported yet, because the row data cannot be obtained inside the + * aggregate function. + */ + public static final SqlAggFunction ARRAY_AGG = + SqlBasicAggFunction.create( + SqlKind.ARRAY_AGG, + ReturnTypes.cascade( + ReturnTypes.TO_ARRAY, SqlTypeTransforms.TO_NULLABLE), + OperandTypes.ANY) + .withFunctionType(SqlFunctionCategory.SYSTEM) + .withSyntax(SqlSyntax.FUNCTION) + .withAllowsNullTreatment(true); + // ARRAY OPERATORS public static final SqlOperator ARRAY_VALUE_CONSTRUCTOR = new SqlArrayConstructor(); public static final SqlOperator ELEMENT = SqlStdOperatorTable.ELEMENT; @@ -1154,6 +1172,8 @@ public List getAuxiliaryFunctions() { // SPECIAL OPERATORS public static final SqlOperator MULTISET_VALUE = SqlStdOperatorTable.MULTISET_VALUE; public static final SqlOperator ROW = SqlStdOperatorTable.ROW; + public static final SqlOperator IGNORE_NULLS = SqlStdOperatorTable.IGNORE_NULLS; + public static final SqlOperator RESPECT_NULLS = SqlStdOperatorTable.RESPECT_NULLS; public static final SqlOperator OVERLAPS = SqlStdOperatorTable.OVERLAPS; public static final SqlOperator LITERAL_CHAIN = SqlStdOperatorTable.LITERAL_CHAIN; public static final SqlOperator BETWEEN = SqlStdOperatorTable.BETWEEN; diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala index 861f537f6c24c..4ecd43638638f 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala @@ -146,6 +146,9 @@ class AggFunctionFactory( case a: SqlAggFunction if a.getKind == SqlKind.COLLECT => createCollectAggFunction(argTypes) + case a: SqlAggFunction if a.getKind == SqlKind.ARRAY_AGG => + createArrayAggFunction(argTypes, call.ignoreNulls) + case fn: SqlAggFunction if fn.getKind == SqlKind.JSON_OBJECTAGG => val onNull = fn.asInstanceOf[SqlJsonObjectAggAggFunction].getNullClause new JsonObjectAggFunction(argTypes, onNull == SqlJsonConstructorNullClause.ABSENT_ON_NULL) @@ -620,4 +623,10 @@ class AggFunctionFactory( private def createCollectAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = { new CollectAggFunction(argTypes(0)) } + + private def createArrayAggFunction( + types: Array[LogicalType], + ignoreNulls: Boolean): UserDefinedFunction = { + new ArrayAggFunction(types(0), ignoreNulls) + } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala index aa16c4128d32d..66850f7f9d04d 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala @@ -849,7 +849,7 @@ object AggregateUtil extends Enumeration { call.getAggregation, false, false, - false, + call.ignoreNulls, call.getArgList, -1, // remove filterArg null, diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java new file mode 100644 index 0000000000000..bc849f4438913 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions; + +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.types.Row; + +import java.util.Arrays; +import java.util.Collections; +import java.util.stream.Stream; + +import static org.apache.flink.table.api.DataTypes.ARRAY; +import static org.apache.flink.table.api.DataTypes.INT; +import static org.apache.flink.table.api.DataTypes.ROW; +import static org.apache.flink.table.api.DataTypes.STRING; +import static org.apache.flink.table.api.Expressions.$; +import static org.apache.flink.types.RowKind.DELETE; +import static org.apache.flink.types.RowKind.INSERT; +import static org.apache.flink.types.RowKind.UPDATE_AFTER; +import static org.apache.flink.types.RowKind.UPDATE_BEFORE; + +/** Tests for built-in ARRAY_AGG aggregation functions. */ +class ArrayAggFunctionITCase extends BuiltInAggregateFunctionTestBase { + + @Override + Stream getTestCaseSpecs() { + return Stream.of( + TestSpec.forFunction(BuiltInFunctionDefinitions.ARRAY_AGG) + .withDescription("ARRAY changelog stream aggregation") + .withSource( + ROW(STRING(), INT()), + Arrays.asList( + Row.ofKind(INSERT, "A", 1), + Row.ofKind(INSERT, "A", 2), + Row.ofKind(INSERT, "B", 2), + Row.ofKind(INSERT, "B", 2), + Row.ofKind(INSERT, "B", 3), + Row.ofKind(INSERT, "C", 3), + Row.ofKind(INSERT, "C", null), + Row.ofKind(DELETE, "C", null), + Row.ofKind(INSERT, "D", null), + Row.ofKind(INSERT, "E", 4), + Row.ofKind(INSERT, "E", 5), + Row.ofKind(DELETE, "E", 5), + Row.ofKind(UPDATE_BEFORE, "E", 4), + Row.ofKind(UPDATE_AFTER, "E", 6))) + .testResult( + source -> + "SELECT f0, array_agg(f1) FROM " + source + " GROUP BY f0", + TableApiAggSpec.groupBySelect( + Collections.singletonList($("f0")), + $("f0"), + $("f1").arrayAgg()), + ROW(STRING(), ARRAY(INT())), + ROW(STRING(), ARRAY(INT())), + Arrays.asList( + Row.of("A", new Integer[] {1, 2}), + Row.of("B", new Integer[] {2, 2, 3}), + Row.of("C", new Integer[] {3}), + Row.of("D", new Integer[] {null}), + Row.of("E", new Integer[] {6}))) + .testSqlResult( + source -> + "SELECT f0, array_agg(DISTINCT f1 IGNORE NULLS) FROM " + + source + + " GROUP BY f0", + ROW(STRING(), ARRAY(INT())), + Arrays.asList( + Row.of("A", new Integer[] {1, 2}), + Row.of("B", new Integer[] {2, 3}), + Row.of("C", new Integer[] {3}), + Row.of("D", null), + Row.of("E", new Integer[] {6})))); + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java new file mode 100644 index 0000000000000..0fc10ee48d7c1 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java @@ -0,0 +1,411 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions.aggfunctions; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.DecimalData; +import org.apache.flink.table.data.DecimalDataUtils; +import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.functions.AggregateFunction; +import org.apache.flink.table.runtime.functions.aggregate.ArrayAggFunction; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.BigIntType; +import org.apache.flink.table.types.logical.BooleanType; +import org.apache.flink.table.types.logical.DecimalType; +import org.apache.flink.table.types.logical.DoubleType; +import org.apache.flink.table.types.logical.FloatType; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.TinyIntType; +import org.apache.flink.table.types.logical.VarCharType; +import org.apache.flink.testutils.serialization.types.ShortType; +import org.apache.flink.types.RowKind; + +import org.junit.jupiter.api.Nested; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.List; + +/** Test case for built-in ARRAY_AGG with retraction aggregate function. */ +final class ArrayAggFunctionTest { + + /** Test for {@link TinyIntType}. */ + @Nested + final class ByteArrayAggFunctionTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Byte getValue(String v) { + return Byte.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.TINYINT().getLogicalType(), true); + } + } + + /** Test for {@link ShortType}. */ + @Nested + final class ShortArrayAggTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Short getValue(String v) { + return Short.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.SMALLINT().getLogicalType(), true); + } + } + + /** Test for {@link IntType}. */ + @Nested + final class IntArrayAggTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Integer getValue(String v) { + return Integer.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.INT().getLogicalType(), true); + } + } + + /** Test for {@link BigIntType}. */ + @Nested + final class LongArrayAggTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Long getValue(String v) { + return Long.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.BIGINT().getLogicalType(), true); + } + } + + /** Test for {@link FloatType}. */ + @Nested + final class FloatArrayAggTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Float getValue(String v) { + return Float.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.FLOAT().getLogicalType(), true); + } + } + + /** Test for {@link DoubleType}. */ + @Nested + final class DoubleArrayAggTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Double getValue(String v) { + return Double.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.DOUBLE().getLogicalType(), true); + } + } + + /** Test for {@link BooleanType}. */ + @Nested + final class BooleanArrayAggTest extends ArrayAggFunctionTestBase { + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList(false, false, false), + Arrays.asList(true, true, true), + Arrays.asList(true, false, null, true, false, true, null), + Arrays.asList(null, null, null), + Arrays.asList(null, true)); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData(new Object[] {false, false, false}), + new GenericArrayData(new Object[] {true, true, true}), + new GenericArrayData(new Object[] {true, false, true, false, true}), + null, + new GenericArrayData(new Object[] {true})); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.BOOLEAN().getLogicalType(), true); + } + } + + /** Test for {@link DecimalType}. */ + @Nested + final class DecimalArrayAggFunctionTest extends NumberArrayAggFunctionTestBase { + + private final int precision = 20; + private final int scale = 6; + + @Override + protected DecimalData getValue(String v) { + return DecimalDataUtils.castFrom(v, precision, scale); + } + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList( + getValue("1"), + getValue("1000.000001"), + getValue("-1"), + getValue("-999.998999"), + null, + getValue("0"), + getValue("-999.999"), + null, + getValue("999.999")), + Arrays.asList(null, null, null, null, null), + Arrays.asList(null, getValue("0"))); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData( + new Object[] { + getValue("1"), + getValue("1000.000001"), + getValue("-1"), + getValue("-999.998999"), + getValue("0"), + getValue("-999.999"), + getValue("999.999") + }), + null, + new GenericArrayData(new Object[] {getValue("0")})); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>( + DataTypes.DECIMAL(precision, scale).getLogicalType(), true); + } + } + + /** Test for {@link VarCharType}. */ + @Nested + final class StringArrayAggFunctionTest extends ArrayAggFunctionTestBase { + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList( + StringData.fromString("abc"), + StringData.fromString("def"), + StringData.fromString("ghi"), + null, + StringData.fromString("jkl"), + null, + StringData.fromString("zzz")), + Arrays.asList(null, null), + Arrays.asList(null, StringData.fromString("a")), + Arrays.asList(StringData.fromString("x"), null, StringData.fromString("e"))); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData( + new Object[] { + StringData.fromString("abc"), + StringData.fromString("def"), + StringData.fromString("ghi"), + StringData.fromString("jkl"), + StringData.fromString("zzz") + }), + null, + new GenericArrayData(new Object[] {StringData.fromString("a")}), + new GenericArrayData( + new Object[] {StringData.fromString("x"), StringData.fromString("e")})); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.STRING().getLogicalType(), true); + } + } + + /** Test for {@link RowType}. */ + @Nested + final class RowDArrayAggFunctionTest extends ArrayAggFunctionTestBase { + + private RowData getValue(Integer f0, String f1) { + GenericRowData rowData = new GenericRowData(RowKind.INSERT, 2); + rowData.setField(0, f0); + rowData.setField(1, f1 == null ? null : StringData.fromString(f1)); + return rowData; + } + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList( + getValue(0, "abc"), + getValue(1, "def"), + getValue(2, "ghi"), + null, + getValue(3, "jkl"), + null, + getValue(4, "zzz")), + Arrays.asList(null, null), + Arrays.asList(null, getValue(null, "a")), + Arrays.asList(getValue(5, null), null, getValue(null, "e"))); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData( + new Object[] { + getValue(0, "abc"), + getValue(1, "def"), + getValue(2, "ghi"), + getValue(3, "jkl"), + getValue(4, "zzz") + }), + null, + new GenericArrayData(new Object[] {getValue(null, "a")}), + new GenericArrayData(new Object[] {getValue(5, null), getValue(null, "e")})); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>( + DataTypes.ROW(DataTypes.INT(), DataTypes.STRING()).getLogicalType(), true); + } + } + + /** Test for {@link ArrayType}. */ + @Nested + final class ArrayArrayAggFunctionTest extends ArrayAggFunctionTestBase { + + private ArrayData getValue(Integer... elements) { + return new GenericArrayData(elements); + } + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList( + getValue(0, 1, 2), + getValue(1, null), + getValue(5, 3, 4, 5), + null, + getValue(6, null, 7)), + Arrays.asList(null, null)); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData( + new Object[] { + getValue(0, 1, 2), + getValue(1, null), + getValue(5, 3, 4, 5), + getValue(6, null, 7) + }), + null); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.ARRAY(DataTypes.INT()).getLogicalType(), true); + } + } + + /** Test base for {@link ArrayAggFunction}. */ + abstract static class ArrayAggFunctionTestBase + extends AggFunctionTestBase> { + + @Override + protected Class getAccClass() { + return ArrayAggFunction.ArrayAggAccumulator.class; + } + + @Override + protected Method getAccumulateFunc() throws NoSuchMethodException { + return getAggregator().getClass().getMethod("accumulate", getAccClass(), Object.class); + } + + @Override + protected Method getRetractFunc() throws NoSuchMethodException { + return getAggregator().getClass().getMethod("retract", getAccClass(), Object.class); + } + } + + /** Test base for {@link ArrayAggFunction} with number types. */ + abstract static class NumberArrayAggFunctionTestBase extends ArrayAggFunctionTestBase { + + protected abstract T getValue(String v); + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList(getValue("1"), null, getValue("-99"), getValue("3"), null), + Arrays.asList(null, null, null, null), + Arrays.asList(null, getValue("10"), null, getValue("3"))); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData( + new Object[] {getValue("1"), getValue("-99"), getValue("3")}), + null, + new GenericArrayData(new Object[] {getValue("10"), getValue("3")})); + } + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java new file mode 100644 index 0000000000000..126b483ea11ee --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions.aggregate; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.runtime.typeutils.InternalSerializers; +import org.apache.flink.table.runtime.typeutils.LinkedListSerializer; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.util.FlinkRuntimeException; + +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; + +import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataType; + +/** Built-in ARRAY_AGG aggregate function. */ +@Internal +public final class ArrayAggFunction + extends BuiltInAggregateFunction> { + + private static final long serialVersionUID = -5860934997657147836L; + + private final transient DataType elementDataType; + + private final boolean ignoreNulls; + + public ArrayAggFunction(LogicalType elementType, boolean ignoreNulls) { + this.elementDataType = toInternalDataType(elementType); + this.ignoreNulls = ignoreNulls; + } + + // -------------------------------------------------------------------------------------------- + // Planning + // -------------------------------------------------------------------------------------------- + + @Override + public List getArgumentDataTypes() { + return Collections.singletonList(elementDataType); + } + + @Override + public DataType getAccumulatorDataType() { + DataType linkedListType = getLinkedListType(); + return DataTypes.STRUCTURED( + ArrayAggAccumulator.class, + DataTypes.FIELD("list", linkedListType), + DataTypes.FIELD("retractList", linkedListType)); + } + + @Override + public DataType getOutputDataType() { + return DataTypes.ARRAY(elementDataType).bridgedTo(ArrayData.class); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private DataType getLinkedListType() { + TypeSerializer serializer = InternalSerializers.create(elementDataType.getLogicalType()); + return DataTypes.RAW( + LinkedList.class, (TypeSerializer) new LinkedListSerializer<>(serializer)); + } + + // -------------------------------------------------------------------------------------------- + // Runtime + // -------------------------------------------------------------------------------------------- + + /** Accumulator for ARRAY_AGG with retraction. */ + public static class ArrayAggAccumulator { + public LinkedList list; + public LinkedList retractList; + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrayAggAccumulator that = (ArrayAggAccumulator) o; + return Objects.equals(list, that.list) && Objects.equals(retractList, that.retractList); + } + + @Override + public int hashCode() { + return Objects.hash(list, retractList); + } + } + + @Override + public ArrayAggAccumulator createAccumulator() { + final ArrayAggAccumulator acc = new ArrayAggAccumulator<>(); + acc.list = new LinkedList<>(); + acc.retractList = new LinkedList<>(); + return acc; + } + + public void accumulate(ArrayAggAccumulator acc, T value) throws Exception { + if (value != null || !ignoreNulls) { + acc.list.add(value); + } + } + + public void retract(ArrayAggAccumulator acc, T value) throws Exception { + if (value != null || !ignoreNulls) { + if (!acc.list.remove(value)) { + acc.retractList.add(value); + } + } + } + + public void merge(ArrayAggAccumulator acc, Iterable> its) + throws Exception { + List newRetractBuffer = new LinkedList<>(); + for (ArrayAggAccumulator otherAcc : its) { + if (!otherAcc.list.iterator().hasNext() && !otherAcc.retractList.iterator().hasNext()) { + // otherAcc is empty, skip it + continue; + } + acc.list.addAll(otherAcc.list); + acc.retractList.addAll(otherAcc.retractList); + } + + for (T element : acc.retractList) { + if (!acc.list.remove(element)) { + newRetractBuffer.add(element); + } + } + + acc.retractList.clear(); + acc.retractList.addAll(newRetractBuffer); + } + + @Override + public ArrayData getValue(ArrayAggAccumulator acc) { + try { + List accList = acc.list; + if (accList == null || accList.isEmpty()) { + // array_agg returns null rather than an empty array when there are no input rows. + return null; + } else { + return new GenericArrayData(accList.toArray()); + } + } catch (Exception e) { + throw new FlinkRuntimeException(e); + } + } + + public void resetAccumulator(ArrayAggAccumulator acc) { + acc.list.clear(); + acc.retractList.clear(); + } +}