From c3d2839cb6fad3fcdcd820fb0e78cf9aa22aaa10 Mon Sep 17 00:00:00 2001 From: Jiabao Sun Date: Fri, 8 Sep 2023 21:45:37 +0800 Subject: [PATCH 1/8] [FLINK-21949][table] Support ARRAY_AGG aggregate function --- docs/data/sql_functions.yml | 6 + docs/data/sql_functions_zh.yml | 5 + .../reference/pyflink.table/expressions.rst | 1 + flink-python/pyflink/table/expression.py | 4 + .../pyflink/table/tests/test_expression.py | 1 + .../src/main/codegen/data/Parser.tdd | 1 + .../sql/parser/FlinkSqlParserImplTest.java | 21 +- .../table/api/internal/BaseExpressions.java | 6 + .../functions/BuiltInFunctionDefinitions.java | 7 + .../expressions/SqlAggFunctionVisitor.java | 2 + .../functions/sql/FlinkSqlOperatorTable.java | 16 + .../plan/utils/AggFunctionFactory.scala | 7 + .../functions/ArrayAggFunctionITCase.java | 91 ++++ .../aggfunctions/ArrayAggFunctionTest.java | 410 ++++++++++++++++++ .../functions/aggregate/ArrayAggFunction.java | 177 ++++++++ 15 files changed, 750 insertions(+), 5 deletions(-) create mode 100644 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java create mode 100644 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java create mode 100644 flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java diff --git a/docs/data/sql_functions.yml b/docs/data/sql_functions.yml index b871bd58ac54e..46b97aa2ef77c 100644 --- a/docs/data/sql_functions.yml +++ b/docs/data/sql_functions.yml @@ -1059,6 +1059,12 @@ aggregate: Divides the rows for each window partition into `n` buckets ranging from 1 to at most `n`. If the number of rows in the window partition doesn't divide evenly into the number of buckets, then the remainder values are distributed one per bucket, starting with the first bucket. For example, with 6 rows and 4 buckets, the bucket values would be as follows: 1 1 2 2 3 4 + - sql: ARRAY_AGG([ ALL | DISTINCT ] expression) + table: FIELD.arrayAgg + description: | + By default or with keyword ALL, return an array that concatenates the input rows + and returns NULL if there are no input rows. + NULL values will be ignored. Use DISTINCT for one unique instance of each value. - sql: JSON_OBJECTAGG([KEY] key VALUE value [ { NULL | ABSENT } ON NULL ]) table: jsonObjectAgg(JsonOnNull, keyExpression, valueExpression) description: | diff --git a/docs/data/sql_functions_zh.yml b/docs/data/sql_functions_zh.yml index 13bdaec40e654..987c2287cf398 100644 --- a/docs/data/sql_functions_zh.yml +++ b/docs/data/sql_functions_zh.yml @@ -1181,6 +1181,11 @@ aggregate: 将窗口分区中的所有数据按照顺序划分为 n 个分组,返回分配给各行数据的分组编号(从 1 开始,最大为 n)。 如果不能均匀划分为 n 个分组,则剩余值从第 1 个分组开始,为每一分组分配一个。 比如某个窗口分区有 6 行数据,划分为 4 个分组,则各行的分组编号为:1,1,2,2,3,4。 + - sql: ARRAY_AGG([ ALL | DISTINCT ] expression) + table: FIELD.arrayAgg + description: | + 默认情况下或使用关键字ALL,返回输入行中表达式所组成的数组,并且如果没有输入行,则返回 `NULL`。 + `NULL` 值将被忽略。使用 `DISTINCT` 则对所有值去重后计算。 - sql: JSON_OBJECTAGG([KEY] key VALUE value [ { NULL | ABSENT } ON NULL ]) table: jsonObjectAgg(JsonOnNull, keyExpression, valueExpression) description: | diff --git a/flink-python/docs/reference/pyflink.table/expressions.rst b/flink-python/docs/reference/pyflink.table/expressions.rst index 908a6ceda5aa5..5c7ea97df032c 100644 --- a/flink-python/docs/reference/pyflink.table/expressions.rst +++ b/flink-python/docs/reference/pyflink.table/expressions.rst @@ -138,6 +138,7 @@ arithmetic functions Expression.var_pop Expression.var_samp Expression.collect + Expression.array_agg Expression.alias Expression.cast Expression.try_cast diff --git a/flink-python/pyflink/table/expression.py b/flink-python/pyflink/table/expression.py index cb72ba40b21bc..4272f1724cb3a 100644 --- a/flink-python/pyflink/table/expression.py +++ b/flink-python/pyflink/table/expression.py @@ -832,6 +832,10 @@ def var_samp(self) -> 'Expression': def collect(self) -> 'Expression': return _unary_op("collect")(self) + @property + def array_agg(self) -> 'Expression': + return _unary_op("arrayAgg")(self) + def alias(self, name: str, *extra_names: str) -> 'Expression[T]': """ Specifies a name for an expression i.e. a field. diff --git a/flink-python/pyflink/table/tests/test_expression.py b/flink-python/pyflink/table/tests/test_expression.py index 589d089496199..f8611d577c913 100644 --- a/flink-python/pyflink/table/tests/test_expression.py +++ b/flink-python/pyflink/table/tests/test_expression.py @@ -114,6 +114,7 @@ def test_expression(self): self.assertEqual('varPop(a)', str(expr1.var_pop)) self.assertEqual('varSamp(a)', str(expr1.var_samp)) self.assertEqual('collect(a)', str(expr1.collect)) + self.assertEqual('arrayAgg(a)', str(expr1.array_agg)) self.assertEqual("as(a, 'a', 'b', 'c')", str(expr1.alias('a', 'b', 'c'))) self.assertEqual('cast(a, INT)', str(expr1.cast(DataTypes.INT()))) self.assertEqual('asc(a)', str(expr1.asc)) diff --git a/flink-table/flink-sql-parser/src/main/codegen/data/Parser.tdd b/flink-table/flink-sql-parser/src/main/codegen/data/Parser.tdd index 9b82c2d18f02b..9153be2a08b00 100644 --- a/flink-table/flink-sql-parser/src/main/codegen/data/Parser.tdd +++ b/flink-table/flink-sql-parser/src/main/codegen/data/Parser.tdd @@ -201,6 +201,7 @@ "AFTER" "ALWAYS" "APPLY" + "ARRAY_AGG" "ASC" "ASSERTION" "ASSIGNMENT" diff --git a/flink-table/flink-sql-parser/src/test/java/org/apache/flink/sql/parser/FlinkSqlParserImplTest.java b/flink-table/flink-sql-parser/src/test/java/org/apache/flink/sql/parser/FlinkSqlParserImplTest.java index 87ffcfdd3cbc8..ef1750a89fb9c 100644 --- a/flink-table/flink-sql-parser/src/test/java/org/apache/flink/sql/parser/FlinkSqlParserImplTest.java +++ b/flink-table/flink-sql-parser/src/test/java/org/apache/flink/sql/parser/FlinkSqlParserImplTest.java @@ -63,12 +63,23 @@ void testDescribeCatalog() { sql("desc catalog a").ok("DESCRIBE CATALOG `A`"); } - // ignore test methods that we don't support - // BEGIN - // ARRAY_AGG - @Disabled @Test - void testArrayAgg() {} + void testArrayAgg() { + sql("select\n" + + " array_agg(ename respect nulls order by deptno, ename) as c1,\n" + + " array_agg(ename order by deptno, ename desc) as c2,\n" + + " array_agg(distinct ename) as c3,\n" + + " array_agg(ename) as c4\n" + + "from emp group by gender") + .ok( + "SELECT" + + " ARRAY_AGG(`ENAME` ORDER BY `DEPTNO`, `ENAME`) RESPECT NULLS AS `C1`," + + " ARRAY_AGG(`ENAME` ORDER BY `DEPTNO`, `ENAME` DESC) AS `C2`," + + " ARRAY_AGG(DISTINCT `ENAME`) AS `C3`," + + " ARRAY_AGG(`ENAME`) AS `C4`\n" + + "FROM `EMP`\n" + + "GROUP BY `GENDER`"); + } // DESCRIBE SCHEMA @Disabled diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/BaseExpressions.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/BaseExpressions.java index cdc108d36726c..36da6d54969c2 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/BaseExpressions.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/BaseExpressions.java @@ -53,6 +53,7 @@ import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ABS; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ACOS; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AND; +import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_AGG; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_CONCAT; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_CONTAINS; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_DISTINCT; @@ -527,6 +528,11 @@ public OutType collect() { return toApiSpecificExpression(unresolvedCall(COLLECT, toExpr())); } + /** Returns array aggregate of a given expression. */ + public OutType arrayAgg() { + return toApiSpecificExpression(unresolvedCall(ARRAY_AGG, toExpr())); + } + /** * Returns a new value being cast to {@code toType}. A cast error throws an exception and fails * the job. When performing a cast operation that may fail, like {@link DataTypes#STRING()} to diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java index 93bb9d3690c9b..9761edc0be6b2 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java @@ -745,6 +745,13 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL) .outputTypeStrategy(argument(0)) .build(); + public static final BuiltInFunctionDefinition ARRAY_AGG = + BuiltInFunctionDefinition.newBuilder() + .name("arrayAgg") + .kind(AGGREGATE) + .outputTypeStrategy(nullableIfArgs(SpecificTypeStrategies.ARRAY)) + .build(); + // -------------------------------------------------------------------------------------------- // String functions // -------------------------------------------------------------------------------------------- diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java index 69844be1c1505..24cde572c7eb7 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java @@ -84,6 +84,8 @@ public class SqlAggFunctionVisitor extends ExpressionDefaultVisitor getAuxiliaryFunctions() { public static final SqlAggFunction SINGLE_VALUE = SqlStdOperatorTable.SINGLE_VALUE; public static final SqlAggFunction APPROX_COUNT_DISTINCT = SqlStdOperatorTable.APPROX_COUNT_DISTINCT; + /** + * Use the definitions in Flink instead of {@link SqlLibraryOperators#ARRAY_AGG}, because we + * ignore nulls and returns nullable ARRAY type. Order by clause like + * ARRAY_AGG(x ORDER BY x, y) for aggregate function is not supported yet, because the + * row data cannot be obtained inside the aggregate function. + */ + public static final SqlAggFunction ARRAY_AGG = + SqlBasicAggFunction.create( + SqlKind.ARRAY_AGG, + ReturnTypes.cascade( + ReturnTypes.TO_ARRAY, SqlTypeTransforms.TO_NULLABLE), + OperandTypes.ANY) + .withFunctionType(SqlFunctionCategory.SYSTEM) + .withSyntax(SqlSyntax.FUNCTION); // ARRAY OPERATORS public static final SqlOperator ARRAY_VALUE_CONSTRUCTOR = new SqlArrayConstructor(); diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala index 861f537f6c24c..c6c02673da227 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala @@ -146,6 +146,9 @@ class AggFunctionFactory( case a: SqlAggFunction if a.getKind == SqlKind.COLLECT => createCollectAggFunction(argTypes) + case a: SqlAggFunction if a.getKind == SqlKind.ARRAY_AGG => + createArrayAggFunction(argTypes) + case fn: SqlAggFunction if fn.getKind == SqlKind.JSON_OBJECTAGG => val onNull = fn.asInstanceOf[SqlJsonObjectAggAggFunction].getNullClause new JsonObjectAggFunction(argTypes, onNull == SqlJsonConstructorNullClause.ABSENT_ON_NULL) @@ -620,4 +623,8 @@ class AggFunctionFactory( private def createCollectAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = { new CollectAggFunction(argTypes(0)) } + + private def createArrayAggFunction(types: Array[LogicalType]): UserDefinedFunction = { + new ArrayAggFunction(types(0)) + } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java new file mode 100644 index 0000000000000..9bb4f17ed3376 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions; + +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.types.Row; + +import java.util.Arrays; +import java.util.stream.Stream; + +import static org.apache.flink.table.api.DataTypes.ARRAY; +import static org.apache.flink.table.api.DataTypes.INT; +import static org.apache.flink.table.api.DataTypes.ROW; +import static org.apache.flink.table.api.DataTypes.STRING; +import static org.apache.flink.table.api.Expressions.$; +import static org.apache.flink.types.RowKind.DELETE; +import static org.apache.flink.types.RowKind.INSERT; +import static org.apache.flink.types.RowKind.UPDATE_AFTER; +import static org.apache.flink.types.RowKind.UPDATE_BEFORE; + +/** Tests for built-in ARRAY_AGG aggregation functions. */ +class ArrayAggFunctionITCase extends BuiltInAggregateFunctionTestBase { + + @Override + Stream getTestCaseSpecs() { + return Stream.of( + TestSpec.forFunction(BuiltInFunctionDefinitions.ARRAY_AGG) + .withDescription("ARRAY changelog stream aggregation") + .withSource( + ROW(STRING(), INT()), + Arrays.asList( + Row.ofKind(INSERT, "A", 1), + Row.ofKind(INSERT, "A", 2), + Row.ofKind(INSERT, "B", 2), + Row.ofKind(INSERT, "B", 2), + Row.ofKind(INSERT, "B", 3), + Row.ofKind(INSERT, "C", 3), + Row.ofKind(INSERT, "C", null), + Row.ofKind(INSERT, "D", null), + Row.ofKind(INSERT, "E", 4), + Row.ofKind(INSERT, "E", 5), + Row.ofKind(DELETE, "E", 5), + Row.ofKind(UPDATE_BEFORE, "E", 4), + Row.ofKind(UPDATE_AFTER, "E", 6))) + .testResult( + source -> + "SELECT f0, array_agg(f1) FROM " + source + " GROUP BY f0", + source -> + source.groupBy($("f0")).select($("f0"), $("f1").arrayAgg()), + ROW(STRING(), ARRAY(INT())), + ROW(STRING(), ARRAY(INT())), + Arrays.asList( + Row.of("A", new Integer[] {1, 2}), + Row.of("B", new Integer[] {2, 2, 3}), + Row.of("C", new Integer[] {3}), + Row.of("D", null), + Row.of("E", new Integer[] {6}))) + .testResult( + source -> + "SELECT f0, array_agg(DISTINCT f1) FROM " + + source + + " GROUP BY f0", + source -> + source.groupBy($("f0")) + .select($("f0"), $("f1").arrayAgg().distinct()), + ROW(STRING(), ARRAY(INT())), + ROW(STRING(), ARRAY(INT())), + Arrays.asList( + Row.of("A", new Integer[] {1, 2}), + Row.of("B", new Integer[] {2, 3}), + Row.of("C", new Integer[] {3}), + Row.of("D", null), + Row.of("E", new Integer[] {6})))); + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java new file mode 100644 index 0000000000000..339c0cc5d783b --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java @@ -0,0 +1,410 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions.aggfunctions; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.DecimalData; +import org.apache.flink.table.data.DecimalDataUtils; +import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.functions.AggregateFunction; +import org.apache.flink.table.runtime.functions.aggregate.ArrayAggFunction; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.BigIntType; +import org.apache.flink.table.types.logical.BooleanType; +import org.apache.flink.table.types.logical.DecimalType; +import org.apache.flink.table.types.logical.DoubleType; +import org.apache.flink.table.types.logical.FloatType; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.TinyIntType; +import org.apache.flink.table.types.logical.VarCharType; +import org.apache.flink.testutils.serialization.types.ShortType; +import org.apache.flink.types.RowKind; + +import org.junit.jupiter.api.Nested; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.List; + +/** Test case for built-in ARRAY_AGG with retraction aggregate function. */ +final class ArrayAggFunctionTest { + + /** Test for {@link TinyIntType}. */ + @Nested + final class ByteArrayAggFunctionTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Byte getValue(String v) { + return Byte.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.TINYINT().getLogicalType()); + } + } + + /** Test for {@link ShortType}. */ + @Nested + final class ShortArrayAggTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Short getValue(String v) { + return Short.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.SMALLINT().getLogicalType()); + } + } + + /** Test for {@link IntType}. */ + @Nested + final class IntArrayAggTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Integer getValue(String v) { + return Integer.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.INT().getLogicalType()); + } + } + + /** Test for {@link BigIntType}. */ + @Nested + final class LongArrayAggTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Long getValue(String v) { + return Long.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.BIGINT().getLogicalType()); + } + } + + /** Test for {@link FloatType}. */ + @Nested + final class FloatArrayAggTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Float getValue(String v) { + return Float.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.FLOAT().getLogicalType()); + } + } + + /** Test for {@link DoubleType}. */ + @Nested + final class DoubleArrayAggTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Double getValue(String v) { + return Double.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.DOUBLE().getLogicalType()); + } + } + + /** Test for {@link BooleanType}. */ + @Nested + final class BooleanArrayAggTest extends ArrayAggFunctionTestBase { + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList(false, false, false), + Arrays.asList(true, true, true), + Arrays.asList(true, false, null, true, false, true, null), + Arrays.asList(null, null, null), + Arrays.asList(null, true)); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData(new Object[] {false, false, false}), + new GenericArrayData(new Object[] {true, true, true}), + new GenericArrayData(new Object[] {true, false, true, false, true}), + null, + new GenericArrayData(new Object[] {true})); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.BOOLEAN().getLogicalType()); + } + } + + /** Test for {@link DecimalType}. */ + @Nested + final class DecimalArrayAggFunctionTest extends NumberArrayAggFunctionTestBase { + + private final int precision = 20; + private final int scale = 6; + + @Override + protected DecimalData getValue(String v) { + return DecimalDataUtils.castFrom(v, precision, scale); + } + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList( + getValue("1"), + getValue("1000.000001"), + getValue("-1"), + getValue("-999.998999"), + null, + getValue("0"), + getValue("-999.999"), + null, + getValue("999.999")), + Arrays.asList(null, null, null, null, null), + Arrays.asList(null, getValue("0"))); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData( + new Object[] { + getValue("1"), + getValue("1000.000001"), + getValue("-1"), + getValue("-999.998999"), + getValue("0"), + getValue("-999.999"), + getValue("999.999") + }), + null, + new GenericArrayData(new Object[] {getValue("0")})); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.DECIMAL(precision, scale).getLogicalType()); + } + } + + /** Test for {@link VarCharType}. */ + @Nested + final class StringArrayAggFunctionTest extends ArrayAggFunctionTestBase { + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList( + StringData.fromString("abc"), + StringData.fromString("def"), + StringData.fromString("ghi"), + null, + StringData.fromString("jkl"), + null, + StringData.fromString("zzz")), + Arrays.asList(null, null), + Arrays.asList(null, StringData.fromString("a")), + Arrays.asList(StringData.fromString("x"), null, StringData.fromString("e"))); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData( + new Object[] { + StringData.fromString("abc"), + StringData.fromString("def"), + StringData.fromString("ghi"), + StringData.fromString("jkl"), + StringData.fromString("zzz") + }), + null, + new GenericArrayData(new Object[] {StringData.fromString("a")}), + new GenericArrayData( + new Object[] {StringData.fromString("x"), StringData.fromString("e")})); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.STRING().getLogicalType()); + } + } + + /** Test for {@link RowType}. */ + @Nested + final class RowDArrayAggFunctionTest extends ArrayAggFunctionTestBase { + + private RowData getValue(Integer f0, String f1) { + GenericRowData rowData = new GenericRowData(RowKind.INSERT, 2); + rowData.setField(0, f0); + rowData.setField(1, f1 == null ? null : StringData.fromString(f1)); + return rowData; + } + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList( + getValue(0, "abc"), + getValue(1, "def"), + getValue(2, "ghi"), + null, + getValue(3, "jkl"), + null, + getValue(4, "zzz")), + Arrays.asList(null, null), + Arrays.asList(null, getValue(null, "a")), + Arrays.asList(getValue(5, null), null, getValue(null, "e"))); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData( + new Object[] { + getValue(0, "abc"), + getValue(1, "def"), + getValue(2, "ghi"), + getValue(3, "jkl"), + getValue(4, "zzz") + }), + null, + new GenericArrayData(new Object[] {getValue(null, "a")}), + new GenericArrayData(new Object[] {getValue(5, null), getValue(null, "e")})); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>( + DataTypes.ROW(DataTypes.INT(), DataTypes.STRING()).getLogicalType()); + } + } + + /** Test for {@link ArrayType}. */ + @Nested + final class ArrayArrayAggFunctionTest extends ArrayAggFunctionTestBase { + + private ArrayData getValue(Integer... elements) { + return new GenericArrayData(elements); + } + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList( + getValue(0, 1, 2), + getValue(1, null), + getValue(5, 3, 4, 5), + null, + getValue(6, null, 7)), + Arrays.asList(null, null)); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData( + new Object[] { + getValue(0, 1, 2), + getValue(1, null), + getValue(5, 3, 4, 5), + getValue(6, null, 7) + }), + null); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.ARRAY(DataTypes.INT()).getLogicalType()); + } + } + + /** Test base for {@link ArrayAggFunction}. */ + abstract static class ArrayAggFunctionTestBase + extends AggFunctionTestBase> { + + @Override + protected Class getAccClass() { + return ArrayAggFunction.ArrayAggAccumulator.class; + } + + @Override + protected Method getAccumulateFunc() throws NoSuchMethodException { + return getAggregator().getClass().getMethod("accumulate", getAccClass(), Object.class); + } + + @Override + protected Method getRetractFunc() throws NoSuchMethodException { + return getAggregator().getClass().getMethod("retract", getAccClass(), Object.class); + } + } + + /** Test base for {@link ArrayAggFunction} with number types. */ + abstract static class NumberArrayAggFunctionTestBase extends ArrayAggFunctionTestBase { + + protected abstract T getValue(String v); + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList(getValue("1"), null, getValue("-99"), getValue("3"), null), + Arrays.asList(null, null, null, null), + Arrays.asList(null, getValue("10"), null, getValue("3"))); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData( + new Object[] {getValue("1"), getValue("-99"), getValue("3")}), + null, + new GenericArrayData(new Object[] {getValue("10"), getValue("3")})); + } + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java new file mode 100644 index 0000000000000..ece14ae9a98dc --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions.aggregate; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.dataview.ListView; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.util.FlinkRuntimeException; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataType; + +/** Built-in ARRAY_AGG aggregate function. */ +@Internal +public final class ArrayAggFunction + extends BuiltInAggregateFunction> { + + private static final long serialVersionUID = -5860934997657147836L; + + private final transient DataType elementDataType; + + public ArrayAggFunction(LogicalType elementType) { + this.elementDataType = toInternalDataType(elementType); + } + + // -------------------------------------------------------------------------------------------- + // Planning + // -------------------------------------------------------------------------------------------- + + @Override + public List getArgumentDataTypes() { + return Collections.singletonList(elementDataType); + } + + @Override + public DataType getAccumulatorDataType() { + return DataTypes.STRUCTURED( + ArrayAggAccumulator.class, + DataTypes.FIELD("list", ListView.newListViewDataType(elementDataType.notNull())), + DataTypes.FIELD( + "retractList", ListView.newListViewDataType(elementDataType.notNull()))); + } + + @Override + public DataType getOutputDataType() { + return DataTypes.ARRAY(elementDataType).bridgedTo(ArrayData.class); + } + + // -------------------------------------------------------------------------------------------- + // Runtime + // -------------------------------------------------------------------------------------------- + + /** Accumulator for ARRAY_AGG with retraction. */ + public static class ArrayAggAccumulator { + public ListView list; + public ListView retractList; + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrayAggAccumulator that = (ArrayAggAccumulator) o; + return Objects.equals(list, that.list) && Objects.equals(retractList, that.retractList); + } + + @Override + public int hashCode() { + return Objects.hash(list, retractList); + } + } + + @Override + public ArrayAggAccumulator createAccumulator() { + final ArrayAggAccumulator acc = new ArrayAggAccumulator<>(); + acc.list = new ListView<>(); + acc.retractList = new ListView<>(); + return acc; + } + + public void accumulate(ArrayAggAccumulator acc, T value) throws Exception { + if (value != null) { + acc.list.add(value); + } + } + + public void retract(ArrayAggAccumulator acc, T value) throws Exception { + if (value != null) { + if (!acc.list.remove(value)) { + acc.retractList.add(value); + } + } + } + + public void merge(ArrayAggAccumulator acc, Iterable> its) + throws Exception { + for (ArrayAggAccumulator otherAcc : its) { + // merge list of acc and other + List buffer = new ArrayList<>(); + for (T element : acc.list.get()) { + buffer.add(element); + } + for (T element : otherAcc.list.get()) { + buffer.add(element); + } + // merge retract list of acc and other + List retractBuffer = new ArrayList<>(); + for (T element : acc.retractList.get()) { + retractBuffer.add(element); + } + for (T element : otherAcc.retractList.get()) { + retractBuffer.add(element); + } + + // merge list & retract list + List newRetractBuffer = new ArrayList<>(); + for (T element : retractBuffer) { + if (!buffer.remove(element)) { + newRetractBuffer.add(element); + } + } + + // update to acc + acc.list.clear(); + acc.list.addAll(buffer); + acc.retractList.clear(); + acc.retractList.addAll(newRetractBuffer); + } + } + + @Override + public ArrayData getValue(ArrayAggAccumulator acc) { + try { + List accList = acc.list.getList(); + if (accList == null || accList.isEmpty()) { + // array_agg returns null rather than an empty array when there are no input rows. + return null; + } else { + return new GenericArrayData(accList.toArray()); + } + } catch (Exception e) { + throw new FlinkRuntimeException(e); + } + } + + public void resetAccumulator(ArrayAggAccumulator acc) { + acc.list.clear(); + acc.retractList.clear(); + } +} From d46b2f30d3c0a7c1d4c9928513ea9f7598291e28 Mon Sep 17 00:00:00 2001 From: Jiabao Sun Date: Fri, 15 Dec 2023 11:56:38 +0800 Subject: [PATCH 2/8] Rebase master and resolve compilation problems --- .../functions/BuiltInFunctionDefinitions.java | 1 + .../planner/functions/ArrayAggFunctionITCase.java | 14 +++++++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java index 9761edc0be6b2..2ab68e81cdf78 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java @@ -748,6 +748,7 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL) public static final BuiltInFunctionDefinition ARRAY_AGG = BuiltInFunctionDefinition.newBuilder() .name("arrayAgg") + .sqlName("ARRAY_AGG") .kind(AGGREGATE) .outputTypeStrategy(nullableIfArgs(SpecificTypeStrategies.ARRAY)) .build(); diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java index 9bb4f17ed3376..e463bb52f39b9 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java @@ -22,6 +22,7 @@ import org.apache.flink.types.Row; import java.util.Arrays; +import java.util.Collections; import java.util.stream.Stream; import static org.apache.flink.table.api.DataTypes.ARRAY; @@ -61,8 +62,10 @@ Stream getTestCaseSpecs() { .testResult( source -> "SELECT f0, array_agg(f1) FROM " + source + " GROUP BY f0", - source -> - source.groupBy($("f0")).select($("f0"), $("f1").arrayAgg()), + TableApiAggSpec.groupBySelect( + Collections.singletonList($("f0")), + $("f0"), + $("f1").arrayAgg()), ROW(STRING(), ARRAY(INT())), ROW(STRING(), ARRAY(INT())), Arrays.asList( @@ -76,9 +79,10 @@ Stream getTestCaseSpecs() { "SELECT f0, array_agg(DISTINCT f1) FROM " + source + " GROUP BY f0", - source -> - source.groupBy($("f0")) - .select($("f0"), $("f1").arrayAgg().distinct()), + TableApiAggSpec.groupBySelect( + Collections.singletonList($("f0")), + $("f0"), + $("f1").arrayAgg().distinct()), ROW(STRING(), ARRAY(INT())), ROW(STRING(), ARRAY(INT())), Arrays.asList( From 58bb06464f0ba7a66277661bf90a53bedbc82e90 Mon Sep 17 00:00:00 2001 From: Jiabao Sun Date: Mon, 18 Dec 2023 16:23:52 +0800 Subject: [PATCH 3/8] Use equalityEvaluator to compare values --- .../pyflink/table/tests/test_expression.py | 2 +- .../functions/BuiltInFunctionDefinitions.java | 5 +- .../catalog/FunctionCatalogOperatorTable.java | 17 +- .../planner/delegation/PlannerContext.java | 6 +- .../expressions/SqlAggFunctionVisitor.java | 9 +- .../bridging/BridgingSqlAggFunction.java | 27 +- .../functions/sql/FlinkSqlOperatorTable.java | 16 - .../plan/utils/AggFunctionFactory.scala | 7 - .../planner/plan/utils/AggregateUtil.scala | 28 +- .../aggfunctions/ArrayAggFunctionTest.java | 410 ------------------ .../functions/aggregate/ArrayAggFunction.java | 66 ++- 11 files changed, 97 insertions(+), 496 deletions(-) delete mode 100644 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java diff --git a/flink-python/pyflink/table/tests/test_expression.py b/flink-python/pyflink/table/tests/test_expression.py index f8611d577c913..d187ef8347ee3 100644 --- a/flink-python/pyflink/table/tests/test_expression.py +++ b/flink-python/pyflink/table/tests/test_expression.py @@ -114,7 +114,7 @@ def test_expression(self): self.assertEqual('varPop(a)', str(expr1.var_pop)) self.assertEqual('varSamp(a)', str(expr1.var_samp)) self.assertEqual('collect(a)', str(expr1.collect)) - self.assertEqual('arrayAgg(a)', str(expr1.array_agg)) + self.assertEqual('ARRAY_AGG(a)', str(expr1.array_agg)) self.assertEqual("as(a, 'a', 'b', 'c')", str(expr1.alias('a', 'b', 'c'))) self.assertEqual('cast(a, INT)', str(expr1.cast(DataTypes.INT()))) self.assertEqual('asc(a)', str(expr1.asc)) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java index 2ab68e81cdf78..4dc81cb1ece08 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java @@ -747,10 +747,11 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL) public static final BuiltInFunctionDefinition ARRAY_AGG = BuiltInFunctionDefinition.newBuilder() - .name("arrayAgg") - .sqlName("ARRAY_AGG") + .name("ARRAY_AGG") .kind(AGGREGATE) .outputTypeStrategy(nullableIfArgs(SpecificTypeStrategies.ARRAY)) + .runtimeClass( + "org.apache.flink.table.runtime.functions.aggregate.ArrayAggFunction") .build(); // -------------------------------------------------------------------------------------------- diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java index 7b148fdde38c1..630a047a5b450 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java @@ -32,6 +32,7 @@ import org.apache.flink.table.functions.FunctionKind; import org.apache.flink.table.functions.ScalarFunctionDefinition; import org.apache.flink.table.functions.TableFunctionDefinition; +import org.apache.flink.table.planner.calcite.FlinkContext; import org.apache.flink.table.planner.calcite.FlinkTypeFactory; import org.apache.flink.table.planner.calcite.RexFactory; import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction; @@ -60,20 +61,18 @@ @Internal public class FunctionCatalogOperatorTable implements SqlOperatorTable { + private final FlinkContext context; private final FunctionCatalog functionCatalog; private final DataTypeFactory dataTypeFactory; private final FlinkTypeFactory typeFactory; private final RexFactory rexFactory; - public FunctionCatalogOperatorTable( - FunctionCatalog functionCatalog, - DataTypeFactory dataTypeFactory, - FlinkTypeFactory typeFactory, - RexFactory rexFactory) { - this.functionCatalog = functionCatalog; - this.dataTypeFactory = dataTypeFactory; + public FunctionCatalogOperatorTable(FlinkContext context, FlinkTypeFactory typeFactory) { + this.context = context; + this.functionCatalog = context.getFunctionCatalog(); + this.dataTypeFactory = context.getCatalogManager().getDataTypeFactory(); this.typeFactory = typeFactory; - this.rexFactory = rexFactory; + this.rexFactory = context.getRexFactory(); } @Override @@ -153,7 +152,7 @@ private Optional convertToBridgingSqlFunction( || definition.getKind() == FunctionKind.TABLE_AGGREGATE) { function = BridgingSqlAggFunction.of( - dataTypeFactory, + context, typeFactory, SqlKind.OTHER_FUNCTION, resolvedFunction, diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/delegation/PlannerContext.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/delegation/PlannerContext.java index ff17bacf9017b..01e5948e29799 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/delegation/PlannerContext.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/delegation/PlannerContext.java @@ -340,11 +340,7 @@ private SqlOperatorTable getSqlOperatorTable(CalciteConfig calciteConfig) { /** Returns builtin the operator table and external the operator for this environment. */ private SqlOperatorTable getBuiltinSqlOperatorTable() { return SqlOperatorTables.chain( - new FunctionCatalogOperatorTable( - context.getFunctionCatalog(), - context.getCatalogManager().getDataTypeFactory(), - typeFactory, - context.getRexFactory()), + new FunctionCatalogOperatorTable(context, typeFactory), FlinkSqlOperatorTable.instance(context.isBatchMode())); } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java index 24cde572c7eb7..807e9acce9e05 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java @@ -32,6 +32,7 @@ import org.apache.flink.table.functions.FunctionRequirement; import org.apache.flink.table.functions.TableAggregateFunction; import org.apache.flink.table.functions.TableAggregateFunctionDefinition; +import org.apache.flink.table.planner.calcite.FlinkContext; import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction; import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable; import org.apache.flink.table.planner.functions.utils.AggSqlFunction; @@ -84,8 +85,6 @@ public class SqlAggFunctionVisitor extends ExpressionDefaultVisitor paramTypes; private BridgingSqlAggFunction( + FlinkContext context, DataTypeFactory dataTypeFactory, FlinkTypeFactory typeFactory, SqlKind kind, @@ -84,6 +86,7 @@ private BridgingSqlAggFunction( createGroupOrderRequirement()); this.dataTypeFactory = dataTypeFactory; + this.context = context; this.typeFactory = typeFactory; this.resolvedFunction = resolvedFunction; this.typeInference = typeInference; @@ -93,7 +96,7 @@ private BridgingSqlAggFunction( /** * Creates an instance of a aggregating function (either a system or user-defined function). * - * @param dataTypeFactory used for creating {@link DataType} + * @param context used for accessing to flink context {@link FlinkContext} * @param typeFactory used for bridging to {@link RelDataType} * @param kind commonly used SQL standard function; use {@link SqlKind#OTHER_FUNCTION} if this * function cannot be mapped to a common function kind. @@ -101,7 +104,7 @@ private BridgingSqlAggFunction( * @param typeInference type inference logic */ public static BridgingSqlAggFunction of( - DataTypeFactory dataTypeFactory, + FlinkContext context, FlinkTypeFactory typeFactory, SqlKind kind, ContextResolvedFunction resolvedFunction, @@ -113,7 +116,12 @@ public static BridgingSqlAggFunction of( "Aggregating function kind expected."); return new BridgingSqlAggFunction( - dataTypeFactory, typeFactory, kind, resolvedFunction, typeInference); + context, + context.getCatalogManager().getDataTypeFactory(), + typeFactory, + kind, + resolvedFunction, + typeInference); } /** Creates an instance of a aggregate function during translation. */ @@ -124,18 +132,17 @@ public static BridgingSqlAggFunction of( final DataTypeFactory dataTypeFactory = context.getCatalogManager().getDataTypeFactory(); final TypeInference typeInference = resolvedFunction.getDefinition().getTypeInference(dataTypeFactory); - return of( - dataTypeFactory, - typeFactory, - SqlKind.OTHER_FUNCTION, - resolvedFunction, - typeInference); + return of(context, typeFactory, SqlKind.OTHER_FUNCTION, resolvedFunction, typeInference); } public DataTypeFactory getDataTypeFactory() { return dataTypeFactory; } + public FlinkContext getContext() { + return context; + } + public FlinkTypeFactory getTypeFactory() { return typeFactory; } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java index 3cfaec64f6027..f43d1ff80bffe 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java @@ -37,8 +37,6 @@ import org.apache.calcite.sql.SqlPrefixOperator; import org.apache.calcite.sql.SqlSpecialOperator; import org.apache.calcite.sql.SqlSyntax; -import org.apache.calcite.sql.fun.SqlBasicAggFunction; -import org.apache.calcite.sql.fun.SqlLibraryOperators; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.InferTypes; import org.apache.calcite.sql.type.OperandTypes; @@ -1141,20 +1139,6 @@ public List getAuxiliaryFunctions() { public static final SqlAggFunction SINGLE_VALUE = SqlStdOperatorTable.SINGLE_VALUE; public static final SqlAggFunction APPROX_COUNT_DISTINCT = SqlStdOperatorTable.APPROX_COUNT_DISTINCT; - /** - * Use the definitions in Flink instead of {@link SqlLibraryOperators#ARRAY_AGG}, because we - * ignore nulls and returns nullable ARRAY type. Order by clause like - * ARRAY_AGG(x ORDER BY x, y) for aggregate function is not supported yet, because the - * row data cannot be obtained inside the aggregate function. - */ - public static final SqlAggFunction ARRAY_AGG = - SqlBasicAggFunction.create( - SqlKind.ARRAY_AGG, - ReturnTypes.cascade( - ReturnTypes.TO_ARRAY, SqlTypeTransforms.TO_NULLABLE), - OperandTypes.ANY) - .withFunctionType(SqlFunctionCategory.SYSTEM) - .withSyntax(SqlSyntax.FUNCTION); // ARRAY OPERATORS public static final SqlOperator ARRAY_VALUE_CONSTRUCTOR = new SqlArrayConstructor(); diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala index c6c02673da227..861f537f6c24c 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala @@ -146,9 +146,6 @@ class AggFunctionFactory( case a: SqlAggFunction if a.getKind == SqlKind.COLLECT => createCollectAggFunction(argTypes) - case a: SqlAggFunction if a.getKind == SqlKind.ARRAY_AGG => - createArrayAggFunction(argTypes) - case fn: SqlAggFunction if fn.getKind == SqlKind.JSON_OBJECTAGG => val onNull = fn.asInstanceOf[SqlJsonObjectAggAggFunction].getNullClause new JsonObjectAggFunction(argTypes, onNull == SqlJsonConstructorNullClause.ABSENT_ON_NULL) @@ -623,8 +620,4 @@ class AggFunctionFactory( private def createCollectAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = { new CollectAggFunction(argTypes(0)) } - - private def createArrayAggFunction(types: Array[LogicalType]): UserDefinedFunction = { - new ArrayAggFunction(types(0)) - } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala index aa16c4128d32d..f1ed32b7b6e19 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala @@ -25,7 +25,8 @@ import org.apache.flink.table.expressions._ import org.apache.flink.table.expressions.ExpressionUtils.extractValue import org.apache.flink.table.functions._ import org.apache.flink.table.planner.JLong -import org.apache.flink.table.planner.calcite.FlinkTypeFactory +import org.apache.flink.table.planner.calcite.{FlinkContext, FlinkTypeFactory} +import org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil.DefaultExpressionEvaluatorFactory import org.apache.flink.table.planner.delegation.PlannerBase import org.apache.flink.table.planner.functions.aggfunctions.{AvgAggFunction, CountAggFunction, Sum0AggFunction} import org.apache.flink.table.planner.functions.aggfunctions.AvgAggFunction._ @@ -42,7 +43,7 @@ import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalR import org.apache.flink.table.planner.typeutils.DataViewUtils import org.apache.flink.table.planner.typeutils.LegacyDataViewUtils.useNullSerializerForStateViewFieldsFromAccType import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala -import org.apache.flink.table.planner.utils.ShortcutUtils.unwrapTypeFactory +import org.apache.flink.table.planner.utils.ShortcutUtils.{unwrapContext, unwrapTypeFactory} import org.apache.flink.table.runtime.dataview.DataViewSpec import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction import org.apache.flink.table.runtime.groupwindow._ @@ -484,7 +485,7 @@ object AggregateUtil extends Enumeration { call, index, argIndexes, - factory.createAggFunction(call, index), + factory, isStateBackedDataViews, aggCallNeedRetractions(index)) } @@ -497,7 +498,7 @@ object AggregateUtil extends Enumeration { call: AggregateCall, index: Int, argIndexes: Array[Int], - udf: UserDefinedFunction, + factory: AggFunctionFactory, hasStateBackedDataViews: Boolean, needsRetraction: Boolean): AggregateInfo = call.getAggregation match { @@ -506,7 +507,7 @@ object AggregateUtil extends Enumeration { if (bridging.getDefinition.isInstanceOf[DeclarativeAggregateFunction]) { createAggregateInfoFromInternalFunction( call, - udf, + factory.createAggFunction(call, index), index, argIndexes, needsRetraction, @@ -526,14 +527,15 @@ object AggregateUtil extends Enumeration { call, index, argIndexes, - udf.asInstanceOf[ImperativeAggregateFunction[_, _]], + factory.createAggFunction(call, index).asInstanceOf[ImperativeAggregateFunction[_, _]], hasStateBackedDataViews, - needsRetraction) + needsRetraction + ) case _: SqlAggFunction => createAggregateInfoFromInternalFunction( call, - udf, + factory.createAggFunction(call, index), index, argIndexes, needsRetraction, @@ -551,6 +553,7 @@ object AggregateUtil extends Enumeration { val function = call.getAggregation.asInstanceOf[BridgingSqlAggFunction] val definition = function.getDefinition val dataTypeFactory = function.getDataTypeFactory + val context = function.getContext // not all information is available in the call context of aggregate functions at this location // e.g. literal information is lost because the aggregation is split into multiple operators @@ -568,14 +571,15 @@ object AggregateUtil extends Enumeration { call.getType) // create the final UDF for runtime + val classLoader = classOf[PlannerBase].getClassLoader + val tableConfig = context.getTableConfig val udf = UserDefinedFunctionHelper.createSpecializedFunction( function.getName, definition, callContext, - classOf[PlannerBase].getClassLoader, - // currently, aggregate functions have no access to FlinkContext - null, - null + classLoader, + tableConfig, + new DefaultExpressionEvaluatorFactory(tableConfig, classLoader, context.getRexFactory) ) val inference = udf.getTypeInference(dataTypeFactory) diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java deleted file mode 100644 index 339c0cc5d783b..0000000000000 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java +++ /dev/null @@ -1,410 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.planner.functions.aggfunctions; - -import org.apache.flink.table.api.DataTypes; -import org.apache.flink.table.data.ArrayData; -import org.apache.flink.table.data.DecimalData; -import org.apache.flink.table.data.DecimalDataUtils; -import org.apache.flink.table.data.GenericArrayData; -import org.apache.flink.table.data.GenericRowData; -import org.apache.flink.table.data.RowData; -import org.apache.flink.table.data.StringData; -import org.apache.flink.table.functions.AggregateFunction; -import org.apache.flink.table.runtime.functions.aggregate.ArrayAggFunction; -import org.apache.flink.table.types.logical.ArrayType; -import org.apache.flink.table.types.logical.BigIntType; -import org.apache.flink.table.types.logical.BooleanType; -import org.apache.flink.table.types.logical.DecimalType; -import org.apache.flink.table.types.logical.DoubleType; -import org.apache.flink.table.types.logical.FloatType; -import org.apache.flink.table.types.logical.IntType; -import org.apache.flink.table.types.logical.RowType; -import org.apache.flink.table.types.logical.TinyIntType; -import org.apache.flink.table.types.logical.VarCharType; -import org.apache.flink.testutils.serialization.types.ShortType; -import org.apache.flink.types.RowKind; - -import org.junit.jupiter.api.Nested; - -import java.lang.reflect.Method; -import java.util.Arrays; -import java.util.List; - -/** Test case for built-in ARRAY_AGG with retraction aggregate function. */ -final class ArrayAggFunctionTest { - - /** Test for {@link TinyIntType}. */ - @Nested - final class ByteArrayAggFunctionTest extends NumberArrayAggFunctionTestBase { - - @Override - protected Byte getValue(String v) { - return Byte.valueOf(v); - } - - @Override - protected AggregateFunction> - getAggregator() { - return new ArrayAggFunction<>(DataTypes.TINYINT().getLogicalType()); - } - } - - /** Test for {@link ShortType}. */ - @Nested - final class ShortArrayAggTest extends NumberArrayAggFunctionTestBase { - - @Override - protected Short getValue(String v) { - return Short.valueOf(v); - } - - @Override - protected AggregateFunction> - getAggregator() { - return new ArrayAggFunction<>(DataTypes.SMALLINT().getLogicalType()); - } - } - - /** Test for {@link IntType}. */ - @Nested - final class IntArrayAggTest extends NumberArrayAggFunctionTestBase { - - @Override - protected Integer getValue(String v) { - return Integer.valueOf(v); - } - - @Override - protected AggregateFunction> - getAggregator() { - return new ArrayAggFunction<>(DataTypes.INT().getLogicalType()); - } - } - - /** Test for {@link BigIntType}. */ - @Nested - final class LongArrayAggTest extends NumberArrayAggFunctionTestBase { - - @Override - protected Long getValue(String v) { - return Long.valueOf(v); - } - - @Override - protected AggregateFunction> - getAggregator() { - return new ArrayAggFunction<>(DataTypes.BIGINT().getLogicalType()); - } - } - - /** Test for {@link FloatType}. */ - @Nested - final class FloatArrayAggTest extends NumberArrayAggFunctionTestBase { - - @Override - protected Float getValue(String v) { - return Float.valueOf(v); - } - - @Override - protected AggregateFunction> - getAggregator() { - return new ArrayAggFunction<>(DataTypes.FLOAT().getLogicalType()); - } - } - - /** Test for {@link DoubleType}. */ - @Nested - final class DoubleArrayAggTest extends NumberArrayAggFunctionTestBase { - - @Override - protected Double getValue(String v) { - return Double.valueOf(v); - } - - @Override - protected AggregateFunction> - getAggregator() { - return new ArrayAggFunction<>(DataTypes.DOUBLE().getLogicalType()); - } - } - - /** Test for {@link BooleanType}. */ - @Nested - final class BooleanArrayAggTest extends ArrayAggFunctionTestBase { - - @Override - protected List> getInputValueSets() { - return Arrays.asList( - Arrays.asList(false, false, false), - Arrays.asList(true, true, true), - Arrays.asList(true, false, null, true, false, true, null), - Arrays.asList(null, null, null), - Arrays.asList(null, true)); - } - - @Override - protected List getExpectedResults() { - return Arrays.asList( - new GenericArrayData(new Object[] {false, false, false}), - new GenericArrayData(new Object[] {true, true, true}), - new GenericArrayData(new Object[] {true, false, true, false, true}), - null, - new GenericArrayData(new Object[] {true})); - } - - @Override - protected AggregateFunction> - getAggregator() { - return new ArrayAggFunction<>(DataTypes.BOOLEAN().getLogicalType()); - } - } - - /** Test for {@link DecimalType}. */ - @Nested - final class DecimalArrayAggFunctionTest extends NumberArrayAggFunctionTestBase { - - private final int precision = 20; - private final int scale = 6; - - @Override - protected DecimalData getValue(String v) { - return DecimalDataUtils.castFrom(v, precision, scale); - } - - @Override - protected List> getInputValueSets() { - return Arrays.asList( - Arrays.asList( - getValue("1"), - getValue("1000.000001"), - getValue("-1"), - getValue("-999.998999"), - null, - getValue("0"), - getValue("-999.999"), - null, - getValue("999.999")), - Arrays.asList(null, null, null, null, null), - Arrays.asList(null, getValue("0"))); - } - - @Override - protected List getExpectedResults() { - return Arrays.asList( - new GenericArrayData( - new Object[] { - getValue("1"), - getValue("1000.000001"), - getValue("-1"), - getValue("-999.998999"), - getValue("0"), - getValue("-999.999"), - getValue("999.999") - }), - null, - new GenericArrayData(new Object[] {getValue("0")})); - } - - @Override - protected AggregateFunction> - getAggregator() { - return new ArrayAggFunction<>(DataTypes.DECIMAL(precision, scale).getLogicalType()); - } - } - - /** Test for {@link VarCharType}. */ - @Nested - final class StringArrayAggFunctionTest extends ArrayAggFunctionTestBase { - - @Override - protected List> getInputValueSets() { - return Arrays.asList( - Arrays.asList( - StringData.fromString("abc"), - StringData.fromString("def"), - StringData.fromString("ghi"), - null, - StringData.fromString("jkl"), - null, - StringData.fromString("zzz")), - Arrays.asList(null, null), - Arrays.asList(null, StringData.fromString("a")), - Arrays.asList(StringData.fromString("x"), null, StringData.fromString("e"))); - } - - @Override - protected List getExpectedResults() { - return Arrays.asList( - new GenericArrayData( - new Object[] { - StringData.fromString("abc"), - StringData.fromString("def"), - StringData.fromString("ghi"), - StringData.fromString("jkl"), - StringData.fromString("zzz") - }), - null, - new GenericArrayData(new Object[] {StringData.fromString("a")}), - new GenericArrayData( - new Object[] {StringData.fromString("x"), StringData.fromString("e")})); - } - - @Override - protected AggregateFunction> - getAggregator() { - return new ArrayAggFunction<>(DataTypes.STRING().getLogicalType()); - } - } - - /** Test for {@link RowType}. */ - @Nested - final class RowDArrayAggFunctionTest extends ArrayAggFunctionTestBase { - - private RowData getValue(Integer f0, String f1) { - GenericRowData rowData = new GenericRowData(RowKind.INSERT, 2); - rowData.setField(0, f0); - rowData.setField(1, f1 == null ? null : StringData.fromString(f1)); - return rowData; - } - - @Override - protected List> getInputValueSets() { - return Arrays.asList( - Arrays.asList( - getValue(0, "abc"), - getValue(1, "def"), - getValue(2, "ghi"), - null, - getValue(3, "jkl"), - null, - getValue(4, "zzz")), - Arrays.asList(null, null), - Arrays.asList(null, getValue(null, "a")), - Arrays.asList(getValue(5, null), null, getValue(null, "e"))); - } - - @Override - protected List getExpectedResults() { - return Arrays.asList( - new GenericArrayData( - new Object[] { - getValue(0, "abc"), - getValue(1, "def"), - getValue(2, "ghi"), - getValue(3, "jkl"), - getValue(4, "zzz") - }), - null, - new GenericArrayData(new Object[] {getValue(null, "a")}), - new GenericArrayData(new Object[] {getValue(5, null), getValue(null, "e")})); - } - - @Override - protected AggregateFunction> - getAggregator() { - return new ArrayAggFunction<>( - DataTypes.ROW(DataTypes.INT(), DataTypes.STRING()).getLogicalType()); - } - } - - /** Test for {@link ArrayType}. */ - @Nested - final class ArrayArrayAggFunctionTest extends ArrayAggFunctionTestBase { - - private ArrayData getValue(Integer... elements) { - return new GenericArrayData(elements); - } - - @Override - protected List> getInputValueSets() { - return Arrays.asList( - Arrays.asList( - getValue(0, 1, 2), - getValue(1, null), - getValue(5, 3, 4, 5), - null, - getValue(6, null, 7)), - Arrays.asList(null, null)); - } - - @Override - protected List getExpectedResults() { - return Arrays.asList( - new GenericArrayData( - new Object[] { - getValue(0, 1, 2), - getValue(1, null), - getValue(5, 3, 4, 5), - getValue(6, null, 7) - }), - null); - } - - @Override - protected AggregateFunction> - getAggregator() { - return new ArrayAggFunction<>(DataTypes.ARRAY(DataTypes.INT()).getLogicalType()); - } - } - - /** Test base for {@link ArrayAggFunction}. */ - abstract static class ArrayAggFunctionTestBase - extends AggFunctionTestBase> { - - @Override - protected Class getAccClass() { - return ArrayAggFunction.ArrayAggAccumulator.class; - } - - @Override - protected Method getAccumulateFunc() throws NoSuchMethodException { - return getAggregator().getClass().getMethod("accumulate", getAccClass(), Object.class); - } - - @Override - protected Method getRetractFunc() throws NoSuchMethodException { - return getAggregator().getClass().getMethod("retract", getAccClass(), Object.class); - } - } - - /** Test base for {@link ArrayAggFunction} with number types. */ - abstract static class NumberArrayAggFunctionTestBase extends ArrayAggFunctionTestBase { - - protected abstract T getValue(String v); - - @Override - protected List> getInputValueSets() { - return Arrays.asList( - Arrays.asList(getValue("1"), null, getValue("-99"), getValue("3"), null), - Arrays.asList(null, null, null, null), - Arrays.asList(null, getValue("10"), null, getValue("3"))); - } - - @Override - protected List getExpectedResults() { - return Arrays.asList( - new GenericArrayData( - new Object[] {getValue("1"), getValue("-99"), getValue("3")}), - null, - new GenericArrayData(new Object[] {getValue("10"), getValue("3")})); - } - } -} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java index ece14ae9a98dc..2ded9ce94a33b 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java @@ -23,16 +23,21 @@ import org.apache.flink.table.api.dataview.ListView; import org.apache.flink.table.data.ArrayData; import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.table.functions.FunctionContext; +import org.apache.flink.table.functions.SpecializedFunction; import org.apache.flink.table.types.DataType; -import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.utils.DataTypeUtils; import org.apache.flink.util.FlinkRuntimeException; +import java.lang.invoke.MethodHandle; import java.util.ArrayList; -import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedList; import java.util.List; import java.util.Objects; -import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataType; +import static org.apache.flink.table.api.Expressions.$; /** Built-in ARRAY_AGG aggregate function. */ @Internal @@ -41,21 +46,34 @@ public final class ArrayAggFunction private static final long serialVersionUID = -5860934997657147836L; - private final transient DataType elementDataType; + private final SpecializedFunction.ExpressionEvaluator equalityEvaluator; - public ArrayAggFunction(LogicalType elementType) { - this.elementDataType = toInternalDataType(elementType); + private transient MethodHandle equalityHandle; + + private transient DataType elementDataType; + + public ArrayAggFunction(SpecializedFunction.SpecializedContext context) { + super(BuiltInFunctionDefinitions.ARRAY_AGG, context); + this.elementDataType = + DataTypeUtils.toInternalDataType( + context.getCallContext().getArgumentDataTypes().get(0)); + this.equalityEvaluator = + context.createEvaluator( + $("element1").isEqual($("element2")), + DataTypes.BOOLEAN(), + DataTypes.FIELD("element1", elementDataType.notNull().toInternal()), + DataTypes.FIELD("element2", elementDataType.notNull().toInternal())); + } + + @Override + public void open(FunctionContext context) throws Exception { + equalityHandle = equalityEvaluator.open(context); } // -------------------------------------------------------------------------------------------- // Planning // -------------------------------------------------------------------------------------------- - @Override - public List getArgumentDataTypes() { - return Collections.singletonList(elementDataType); - } - @Override public DataType getAccumulatorDataType() { return DataTypes.STRUCTURED( @@ -65,11 +83,6 @@ public DataType getAccumulatorDataType() { "retractList", ListView.newListViewDataType(elementDataType.notNull()))); } - @Override - public DataType getOutputDataType() { - return DataTypes.ARRAY(elementDataType).bridgedTo(ArrayData.class); - } - // -------------------------------------------------------------------------------------------- // Runtime // -------------------------------------------------------------------------------------------- @@ -113,7 +126,7 @@ public void accumulate(ArrayAggAccumulator acc, T value) throws Exception { public void retract(ArrayAggAccumulator acc, T value) throws Exception { if (value != null) { - if (!acc.list.remove(value)) { + if (!remove(acc.list.get(), value)) { acc.retractList.add(value); } } @@ -123,7 +136,7 @@ public void merge(ArrayAggAccumulator acc, Iterable> i throws Exception { for (ArrayAggAccumulator otherAcc : its) { // merge list of acc and other - List buffer = new ArrayList<>(); + List buffer = new LinkedList<>(); for (T element : acc.list.get()) { buffer.add(element); } @@ -142,7 +155,7 @@ public void merge(ArrayAggAccumulator acc, Iterable> i // merge list & retract list List newRetractBuffer = new ArrayList<>(); for (T element : retractBuffer) { - if (!buffer.remove(element)) { + if (!remove(buffer, element)) { newRetractBuffer.add(element); } } @@ -155,6 +168,21 @@ public void merge(ArrayAggAccumulator acc, Iterable> i } } + private boolean remove(Iterable iterable, T value) { + try { + Iterator iterator = iterable.iterator(); + while (iterator.hasNext()) { + if ((boolean) equalityHandle.invoke(iterator.next(), value)) { + iterator.remove(); + return true; + } + } + return false; + } catch (Throwable t) { + throw new FlinkRuntimeException(t); + } + } + @Override public ArrayData getValue(ArrayAggAccumulator acc) { try { From ccabc46575b03f91136613f2cb683a4d815bcb99 Mon Sep 17 00:00:00 2001 From: Jiabao Sun Date: Tue, 16 Jan 2024 11:54:33 +0800 Subject: [PATCH 4/8] Revert changes about equalityEvaluator --- .../functions/BuiltInFunctionDefinitions.java | 2 - .../catalog/FunctionCatalogOperatorTable.java | 17 +- .../planner/delegation/PlannerContext.java | 6 +- .../expressions/SqlAggFunctionVisitor.java | 9 +- .../bridging/BridgingSqlAggFunction.java | 27 +- .../functions/sql/FlinkSqlOperatorTable.java | 17 + .../plan/utils/AggFunctionFactory.scala | 7 + .../planner/plan/utils/AggregateUtil.scala | 28 +- .../aggfunctions/ArrayAggFunctionTest.java | 410 ++++++++++++++++++ .../functions/aggregate/ArrayAggFunction.java | 66 +-- 10 files changed, 494 insertions(+), 95 deletions(-) create mode 100644 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java index 4dc81cb1ece08..704e8efac3559 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java @@ -750,8 +750,6 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL) .name("ARRAY_AGG") .kind(AGGREGATE) .outputTypeStrategy(nullableIfArgs(SpecificTypeStrategies.ARRAY)) - .runtimeClass( - "org.apache.flink.table.runtime.functions.aggregate.ArrayAggFunction") .build(); // -------------------------------------------------------------------------------------------- diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java index 630a047a5b450..7b148fdde38c1 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java @@ -32,7 +32,6 @@ import org.apache.flink.table.functions.FunctionKind; import org.apache.flink.table.functions.ScalarFunctionDefinition; import org.apache.flink.table.functions.TableFunctionDefinition; -import org.apache.flink.table.planner.calcite.FlinkContext; import org.apache.flink.table.planner.calcite.FlinkTypeFactory; import org.apache.flink.table.planner.calcite.RexFactory; import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction; @@ -61,18 +60,20 @@ @Internal public class FunctionCatalogOperatorTable implements SqlOperatorTable { - private final FlinkContext context; private final FunctionCatalog functionCatalog; private final DataTypeFactory dataTypeFactory; private final FlinkTypeFactory typeFactory; private final RexFactory rexFactory; - public FunctionCatalogOperatorTable(FlinkContext context, FlinkTypeFactory typeFactory) { - this.context = context; - this.functionCatalog = context.getFunctionCatalog(); - this.dataTypeFactory = context.getCatalogManager().getDataTypeFactory(); + public FunctionCatalogOperatorTable( + FunctionCatalog functionCatalog, + DataTypeFactory dataTypeFactory, + FlinkTypeFactory typeFactory, + RexFactory rexFactory) { + this.functionCatalog = functionCatalog; + this.dataTypeFactory = dataTypeFactory; this.typeFactory = typeFactory; - this.rexFactory = context.getRexFactory(); + this.rexFactory = rexFactory; } @Override @@ -152,7 +153,7 @@ private Optional convertToBridgingSqlFunction( || definition.getKind() == FunctionKind.TABLE_AGGREGATE) { function = BridgingSqlAggFunction.of( - context, + dataTypeFactory, typeFactory, SqlKind.OTHER_FUNCTION, resolvedFunction, diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/delegation/PlannerContext.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/delegation/PlannerContext.java index 01e5948e29799..ff17bacf9017b 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/delegation/PlannerContext.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/delegation/PlannerContext.java @@ -340,7 +340,11 @@ private SqlOperatorTable getSqlOperatorTable(CalciteConfig calciteConfig) { /** Returns builtin the operator table and external the operator for this environment. */ private SqlOperatorTable getBuiltinSqlOperatorTable() { return SqlOperatorTables.chain( - new FunctionCatalogOperatorTable(context, typeFactory), + new FunctionCatalogOperatorTable( + context.getFunctionCatalog(), + context.getCatalogManager().getDataTypeFactory(), + typeFactory, + context.getRexFactory()), FlinkSqlOperatorTable.instance(context.isBatchMode())); } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java index 807e9acce9e05..24cde572c7eb7 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java @@ -32,7 +32,6 @@ import org.apache.flink.table.functions.FunctionRequirement; import org.apache.flink.table.functions.TableAggregateFunction; import org.apache.flink.table.functions.TableAggregateFunctionDefinition; -import org.apache.flink.table.planner.calcite.FlinkContext; import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction; import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable; import org.apache.flink.table.planner.functions.utils.AggSqlFunction; @@ -85,6 +84,8 @@ public class SqlAggFunctionVisitor extends ExpressionDefaultVisitor paramTypes; private BridgingSqlAggFunction( - FlinkContext context, DataTypeFactory dataTypeFactory, FlinkTypeFactory typeFactory, SqlKind kind, @@ -86,7 +84,6 @@ private BridgingSqlAggFunction( createGroupOrderRequirement()); this.dataTypeFactory = dataTypeFactory; - this.context = context; this.typeFactory = typeFactory; this.resolvedFunction = resolvedFunction; this.typeInference = typeInference; @@ -96,7 +93,7 @@ private BridgingSqlAggFunction( /** * Creates an instance of a aggregating function (either a system or user-defined function). * - * @param context used for accessing to flink context {@link FlinkContext} + * @param dataTypeFactory used for creating {@link DataType} * @param typeFactory used for bridging to {@link RelDataType} * @param kind commonly used SQL standard function; use {@link SqlKind#OTHER_FUNCTION} if this * function cannot be mapped to a common function kind. @@ -104,7 +101,7 @@ private BridgingSqlAggFunction( * @param typeInference type inference logic */ public static BridgingSqlAggFunction of( - FlinkContext context, + DataTypeFactory dataTypeFactory, FlinkTypeFactory typeFactory, SqlKind kind, ContextResolvedFunction resolvedFunction, @@ -116,12 +113,7 @@ public static BridgingSqlAggFunction of( "Aggregating function kind expected."); return new BridgingSqlAggFunction( - context, - context.getCatalogManager().getDataTypeFactory(), - typeFactory, - kind, - resolvedFunction, - typeInference); + dataTypeFactory, typeFactory, kind, resolvedFunction, typeInference); } /** Creates an instance of a aggregate function during translation. */ @@ -132,17 +124,18 @@ public static BridgingSqlAggFunction of( final DataTypeFactory dataTypeFactory = context.getCatalogManager().getDataTypeFactory(); final TypeInference typeInference = resolvedFunction.getDefinition().getTypeInference(dataTypeFactory); - return of(context, typeFactory, SqlKind.OTHER_FUNCTION, resolvedFunction, typeInference); + return of( + dataTypeFactory, + typeFactory, + SqlKind.OTHER_FUNCTION, + resolvedFunction, + typeInference); } public DataTypeFactory getDataTypeFactory() { return dataTypeFactory; } - public FlinkContext getContext() { - return context; - } - public FlinkTypeFactory getTypeFactory() { return typeFactory; } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java index f43d1ff80bffe..d3091bfac188a 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java @@ -37,6 +37,8 @@ import org.apache.calcite.sql.SqlPrefixOperator; import org.apache.calcite.sql.SqlSpecialOperator; import org.apache.calcite.sql.SqlSyntax; +import org.apache.calcite.sql.fun.SqlBasicAggFunction; +import org.apache.calcite.sql.fun.SqlLibraryOperators; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.InferTypes; import org.apache.calcite.sql.type.OperandTypes; @@ -1140,6 +1142,21 @@ public List getAuxiliaryFunctions() { public static final SqlAggFunction APPROX_COUNT_DISTINCT = SqlStdOperatorTable.APPROX_COUNT_DISTINCT; + /** + * Use the definitions in Flink instead of {@link SqlLibraryOperators#ARRAY_AGG}, because we + * ignore nulls and returns nullable ARRAY type. Order by clause like + * ARRAY_AGG(x ORDER BY x, y) for aggregate function is not supported yet, because the + * row data cannot be obtained inside the aggregate function. + */ + public static final SqlAggFunction ARRAY_AGG = + SqlBasicAggFunction.create( + SqlKind.ARRAY_AGG, + ReturnTypes.cascade( + ReturnTypes.TO_ARRAY, SqlTypeTransforms.TO_NULLABLE), + OperandTypes.ANY) + .withFunctionType(SqlFunctionCategory.SYSTEM) + .withSyntax(SqlSyntax.FUNCTION); + // ARRAY OPERATORS public static final SqlOperator ARRAY_VALUE_CONSTRUCTOR = new SqlArrayConstructor(); public static final SqlOperator ELEMENT = SqlStdOperatorTable.ELEMENT; diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala index 861f537f6c24c..c6c02673da227 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala @@ -146,6 +146,9 @@ class AggFunctionFactory( case a: SqlAggFunction if a.getKind == SqlKind.COLLECT => createCollectAggFunction(argTypes) + case a: SqlAggFunction if a.getKind == SqlKind.ARRAY_AGG => + createArrayAggFunction(argTypes) + case fn: SqlAggFunction if fn.getKind == SqlKind.JSON_OBJECTAGG => val onNull = fn.asInstanceOf[SqlJsonObjectAggAggFunction].getNullClause new JsonObjectAggFunction(argTypes, onNull == SqlJsonConstructorNullClause.ABSENT_ON_NULL) @@ -620,4 +623,8 @@ class AggFunctionFactory( private def createCollectAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = { new CollectAggFunction(argTypes(0)) } + + private def createArrayAggFunction(types: Array[LogicalType]): UserDefinedFunction = { + new ArrayAggFunction(types(0)) + } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala index f1ed32b7b6e19..aa16c4128d32d 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala @@ -25,8 +25,7 @@ import org.apache.flink.table.expressions._ import org.apache.flink.table.expressions.ExpressionUtils.extractValue import org.apache.flink.table.functions._ import org.apache.flink.table.planner.JLong -import org.apache.flink.table.planner.calcite.{FlinkContext, FlinkTypeFactory} -import org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil.DefaultExpressionEvaluatorFactory +import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.delegation.PlannerBase import org.apache.flink.table.planner.functions.aggfunctions.{AvgAggFunction, CountAggFunction, Sum0AggFunction} import org.apache.flink.table.planner.functions.aggfunctions.AvgAggFunction._ @@ -43,7 +42,7 @@ import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalR import org.apache.flink.table.planner.typeutils.DataViewUtils import org.apache.flink.table.planner.typeutils.LegacyDataViewUtils.useNullSerializerForStateViewFieldsFromAccType import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala -import org.apache.flink.table.planner.utils.ShortcutUtils.{unwrapContext, unwrapTypeFactory} +import org.apache.flink.table.planner.utils.ShortcutUtils.unwrapTypeFactory import org.apache.flink.table.runtime.dataview.DataViewSpec import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction import org.apache.flink.table.runtime.groupwindow._ @@ -485,7 +484,7 @@ object AggregateUtil extends Enumeration { call, index, argIndexes, - factory, + factory.createAggFunction(call, index), isStateBackedDataViews, aggCallNeedRetractions(index)) } @@ -498,7 +497,7 @@ object AggregateUtil extends Enumeration { call: AggregateCall, index: Int, argIndexes: Array[Int], - factory: AggFunctionFactory, + udf: UserDefinedFunction, hasStateBackedDataViews: Boolean, needsRetraction: Boolean): AggregateInfo = call.getAggregation match { @@ -507,7 +506,7 @@ object AggregateUtil extends Enumeration { if (bridging.getDefinition.isInstanceOf[DeclarativeAggregateFunction]) { createAggregateInfoFromInternalFunction( call, - factory.createAggFunction(call, index), + udf, index, argIndexes, needsRetraction, @@ -527,15 +526,14 @@ object AggregateUtil extends Enumeration { call, index, argIndexes, - factory.createAggFunction(call, index).asInstanceOf[ImperativeAggregateFunction[_, _]], + udf.asInstanceOf[ImperativeAggregateFunction[_, _]], hasStateBackedDataViews, - needsRetraction - ) + needsRetraction) case _: SqlAggFunction => createAggregateInfoFromInternalFunction( call, - factory.createAggFunction(call, index), + udf, index, argIndexes, needsRetraction, @@ -553,7 +551,6 @@ object AggregateUtil extends Enumeration { val function = call.getAggregation.asInstanceOf[BridgingSqlAggFunction] val definition = function.getDefinition val dataTypeFactory = function.getDataTypeFactory - val context = function.getContext // not all information is available in the call context of aggregate functions at this location // e.g. literal information is lost because the aggregation is split into multiple operators @@ -571,15 +568,14 @@ object AggregateUtil extends Enumeration { call.getType) // create the final UDF for runtime - val classLoader = classOf[PlannerBase].getClassLoader - val tableConfig = context.getTableConfig val udf = UserDefinedFunctionHelper.createSpecializedFunction( function.getName, definition, callContext, - classLoader, - tableConfig, - new DefaultExpressionEvaluatorFactory(tableConfig, classLoader, context.getRexFactory) + classOf[PlannerBase].getClassLoader, + // currently, aggregate functions have no access to FlinkContext + null, + null ) val inference = udf.getTypeInference(dataTypeFactory) diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java new file mode 100644 index 0000000000000..339c0cc5d783b --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java @@ -0,0 +1,410 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions.aggfunctions; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.DecimalData; +import org.apache.flink.table.data.DecimalDataUtils; +import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.functions.AggregateFunction; +import org.apache.flink.table.runtime.functions.aggregate.ArrayAggFunction; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.BigIntType; +import org.apache.flink.table.types.logical.BooleanType; +import org.apache.flink.table.types.logical.DecimalType; +import org.apache.flink.table.types.logical.DoubleType; +import org.apache.flink.table.types.logical.FloatType; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.TinyIntType; +import org.apache.flink.table.types.logical.VarCharType; +import org.apache.flink.testutils.serialization.types.ShortType; +import org.apache.flink.types.RowKind; + +import org.junit.jupiter.api.Nested; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.List; + +/** Test case for built-in ARRAY_AGG with retraction aggregate function. */ +final class ArrayAggFunctionTest { + + /** Test for {@link TinyIntType}. */ + @Nested + final class ByteArrayAggFunctionTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Byte getValue(String v) { + return Byte.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.TINYINT().getLogicalType()); + } + } + + /** Test for {@link ShortType}. */ + @Nested + final class ShortArrayAggTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Short getValue(String v) { + return Short.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.SMALLINT().getLogicalType()); + } + } + + /** Test for {@link IntType}. */ + @Nested + final class IntArrayAggTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Integer getValue(String v) { + return Integer.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.INT().getLogicalType()); + } + } + + /** Test for {@link BigIntType}. */ + @Nested + final class LongArrayAggTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Long getValue(String v) { + return Long.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.BIGINT().getLogicalType()); + } + } + + /** Test for {@link FloatType}. */ + @Nested + final class FloatArrayAggTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Float getValue(String v) { + return Float.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.FLOAT().getLogicalType()); + } + } + + /** Test for {@link DoubleType}. */ + @Nested + final class DoubleArrayAggTest extends NumberArrayAggFunctionTestBase { + + @Override + protected Double getValue(String v) { + return Double.valueOf(v); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.DOUBLE().getLogicalType()); + } + } + + /** Test for {@link BooleanType}. */ + @Nested + final class BooleanArrayAggTest extends ArrayAggFunctionTestBase { + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList(false, false, false), + Arrays.asList(true, true, true), + Arrays.asList(true, false, null, true, false, true, null), + Arrays.asList(null, null, null), + Arrays.asList(null, true)); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData(new Object[] {false, false, false}), + new GenericArrayData(new Object[] {true, true, true}), + new GenericArrayData(new Object[] {true, false, true, false, true}), + null, + new GenericArrayData(new Object[] {true})); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.BOOLEAN().getLogicalType()); + } + } + + /** Test for {@link DecimalType}. */ + @Nested + final class DecimalArrayAggFunctionTest extends NumberArrayAggFunctionTestBase { + + private final int precision = 20; + private final int scale = 6; + + @Override + protected DecimalData getValue(String v) { + return DecimalDataUtils.castFrom(v, precision, scale); + } + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList( + getValue("1"), + getValue("1000.000001"), + getValue("-1"), + getValue("-999.998999"), + null, + getValue("0"), + getValue("-999.999"), + null, + getValue("999.999")), + Arrays.asList(null, null, null, null, null), + Arrays.asList(null, getValue("0"))); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData( + new Object[] { + getValue("1"), + getValue("1000.000001"), + getValue("-1"), + getValue("-999.998999"), + getValue("0"), + getValue("-999.999"), + getValue("999.999") + }), + null, + new GenericArrayData(new Object[] {getValue("0")})); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.DECIMAL(precision, scale).getLogicalType()); + } + } + + /** Test for {@link VarCharType}. */ + @Nested + final class StringArrayAggFunctionTest extends ArrayAggFunctionTestBase { + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList( + StringData.fromString("abc"), + StringData.fromString("def"), + StringData.fromString("ghi"), + null, + StringData.fromString("jkl"), + null, + StringData.fromString("zzz")), + Arrays.asList(null, null), + Arrays.asList(null, StringData.fromString("a")), + Arrays.asList(StringData.fromString("x"), null, StringData.fromString("e"))); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData( + new Object[] { + StringData.fromString("abc"), + StringData.fromString("def"), + StringData.fromString("ghi"), + StringData.fromString("jkl"), + StringData.fromString("zzz") + }), + null, + new GenericArrayData(new Object[] {StringData.fromString("a")}), + new GenericArrayData( + new Object[] {StringData.fromString("x"), StringData.fromString("e")})); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.STRING().getLogicalType()); + } + } + + /** Test for {@link RowType}. */ + @Nested + final class RowDArrayAggFunctionTest extends ArrayAggFunctionTestBase { + + private RowData getValue(Integer f0, String f1) { + GenericRowData rowData = new GenericRowData(RowKind.INSERT, 2); + rowData.setField(0, f0); + rowData.setField(1, f1 == null ? null : StringData.fromString(f1)); + return rowData; + } + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList( + getValue(0, "abc"), + getValue(1, "def"), + getValue(2, "ghi"), + null, + getValue(3, "jkl"), + null, + getValue(4, "zzz")), + Arrays.asList(null, null), + Arrays.asList(null, getValue(null, "a")), + Arrays.asList(getValue(5, null), null, getValue(null, "e"))); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData( + new Object[] { + getValue(0, "abc"), + getValue(1, "def"), + getValue(2, "ghi"), + getValue(3, "jkl"), + getValue(4, "zzz") + }), + null, + new GenericArrayData(new Object[] {getValue(null, "a")}), + new GenericArrayData(new Object[] {getValue(5, null), getValue(null, "e")})); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>( + DataTypes.ROW(DataTypes.INT(), DataTypes.STRING()).getLogicalType()); + } + } + + /** Test for {@link ArrayType}. */ + @Nested + final class ArrayArrayAggFunctionTest extends ArrayAggFunctionTestBase { + + private ArrayData getValue(Integer... elements) { + return new GenericArrayData(elements); + } + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList( + getValue(0, 1, 2), + getValue(1, null), + getValue(5, 3, 4, 5), + null, + getValue(6, null, 7)), + Arrays.asList(null, null)); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData( + new Object[] { + getValue(0, 1, 2), + getValue(1, null), + getValue(5, 3, 4, 5), + getValue(6, null, 7) + }), + null); + } + + @Override + protected AggregateFunction> + getAggregator() { + return new ArrayAggFunction<>(DataTypes.ARRAY(DataTypes.INT()).getLogicalType()); + } + } + + /** Test base for {@link ArrayAggFunction}. */ + abstract static class ArrayAggFunctionTestBase + extends AggFunctionTestBase> { + + @Override + protected Class getAccClass() { + return ArrayAggFunction.ArrayAggAccumulator.class; + } + + @Override + protected Method getAccumulateFunc() throws NoSuchMethodException { + return getAggregator().getClass().getMethod("accumulate", getAccClass(), Object.class); + } + + @Override + protected Method getRetractFunc() throws NoSuchMethodException { + return getAggregator().getClass().getMethod("retract", getAccClass(), Object.class); + } + } + + /** Test base for {@link ArrayAggFunction} with number types. */ + abstract static class NumberArrayAggFunctionTestBase extends ArrayAggFunctionTestBase { + + protected abstract T getValue(String v); + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Arrays.asList(getValue("1"), null, getValue("-99"), getValue("3"), null), + Arrays.asList(null, null, null, null), + Arrays.asList(null, getValue("10"), null, getValue("3"))); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList( + new GenericArrayData( + new Object[] {getValue("1"), getValue("-99"), getValue("3")}), + null, + new GenericArrayData(new Object[] {getValue("10"), getValue("3")})); + } + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java index 2ded9ce94a33b..ece14ae9a98dc 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java @@ -23,21 +23,16 @@ import org.apache.flink.table.api.dataview.ListView; import org.apache.flink.table.data.ArrayData; import org.apache.flink.table.data.GenericArrayData; -import org.apache.flink.table.functions.BuiltInFunctionDefinitions; -import org.apache.flink.table.functions.FunctionContext; -import org.apache.flink.table.functions.SpecializedFunction; import org.apache.flink.table.types.DataType; -import org.apache.flink.table.types.utils.DataTypeUtils; +import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.util.FlinkRuntimeException; -import java.lang.invoke.MethodHandle; import java.util.ArrayList; -import java.util.Iterator; -import java.util.LinkedList; +import java.util.Collections; import java.util.List; import java.util.Objects; -import static org.apache.flink.table.api.Expressions.$; +import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataType; /** Built-in ARRAY_AGG aggregate function. */ @Internal @@ -46,34 +41,21 @@ public final class ArrayAggFunction private static final long serialVersionUID = -5860934997657147836L; - private final SpecializedFunction.ExpressionEvaluator equalityEvaluator; + private final transient DataType elementDataType; - private transient MethodHandle equalityHandle; - - private transient DataType elementDataType; - - public ArrayAggFunction(SpecializedFunction.SpecializedContext context) { - super(BuiltInFunctionDefinitions.ARRAY_AGG, context); - this.elementDataType = - DataTypeUtils.toInternalDataType( - context.getCallContext().getArgumentDataTypes().get(0)); - this.equalityEvaluator = - context.createEvaluator( - $("element1").isEqual($("element2")), - DataTypes.BOOLEAN(), - DataTypes.FIELD("element1", elementDataType.notNull().toInternal()), - DataTypes.FIELD("element2", elementDataType.notNull().toInternal())); - } - - @Override - public void open(FunctionContext context) throws Exception { - equalityHandle = equalityEvaluator.open(context); + public ArrayAggFunction(LogicalType elementType) { + this.elementDataType = toInternalDataType(elementType); } // -------------------------------------------------------------------------------------------- // Planning // -------------------------------------------------------------------------------------------- + @Override + public List getArgumentDataTypes() { + return Collections.singletonList(elementDataType); + } + @Override public DataType getAccumulatorDataType() { return DataTypes.STRUCTURED( @@ -83,6 +65,11 @@ public DataType getAccumulatorDataType() { "retractList", ListView.newListViewDataType(elementDataType.notNull()))); } + @Override + public DataType getOutputDataType() { + return DataTypes.ARRAY(elementDataType).bridgedTo(ArrayData.class); + } + // -------------------------------------------------------------------------------------------- // Runtime // -------------------------------------------------------------------------------------------- @@ -126,7 +113,7 @@ public void accumulate(ArrayAggAccumulator acc, T value) throws Exception { public void retract(ArrayAggAccumulator acc, T value) throws Exception { if (value != null) { - if (!remove(acc.list.get(), value)) { + if (!acc.list.remove(value)) { acc.retractList.add(value); } } @@ -136,7 +123,7 @@ public void merge(ArrayAggAccumulator acc, Iterable> i throws Exception { for (ArrayAggAccumulator otherAcc : its) { // merge list of acc and other - List buffer = new LinkedList<>(); + List buffer = new ArrayList<>(); for (T element : acc.list.get()) { buffer.add(element); } @@ -155,7 +142,7 @@ public void merge(ArrayAggAccumulator acc, Iterable> i // merge list & retract list List newRetractBuffer = new ArrayList<>(); for (T element : retractBuffer) { - if (!remove(buffer, element)) { + if (!buffer.remove(element)) { newRetractBuffer.add(element); } } @@ -168,21 +155,6 @@ public void merge(ArrayAggAccumulator acc, Iterable> i } } - private boolean remove(Iterable iterable, T value) { - try { - Iterator iterator = iterable.iterator(); - while (iterator.hasNext()) { - if ((boolean) equalityHandle.invoke(iterator.next(), value)) { - iterator.remove(); - return true; - } - } - return false; - } catch (Throwable t) { - throw new FlinkRuntimeException(t); - } - } - @Override public ArrayData getValue(ArrayAggAccumulator acc) { try { From 96c5c2c2b7b7c9ed908f5da3efdf29a6f1b63347 Mon Sep 17 00:00:00 2001 From: Jiabao Sun Date: Wed, 17 Jan 2024 10:17:52 +0800 Subject: [PATCH 5/8] Support IGNORE NULLS and RESPECT NULLS --- docs/data/sql_functions.yml | 8 ++-- docs/data/sql_functions_zh.yml | 6 +-- .../functions/sql/FlinkSqlOperatorTable.java | 5 +- .../plan/utils/AggFunctionFactory.scala | 8 ++-- .../planner/plan/utils/AggregateUtil.scala | 2 +- .../functions/ArrayAggFunctionITCase.java | 13 ++--- .../aggfunctions/ArrayAggFunctionTest.java | 23 ++++----- .../functions/aggregate/ArrayAggFunction.java | 47 +++++++++++++------ 8 files changed, 65 insertions(+), 47 deletions(-) diff --git a/docs/data/sql_functions.yml b/docs/data/sql_functions.yml index 46b97aa2ef77c..5933cd1877823 100644 --- a/docs/data/sql_functions.yml +++ b/docs/data/sql_functions.yml @@ -1059,12 +1059,12 @@ aggregate: Divides the rows for each window partition into `n` buckets ranging from 1 to at most `n`. If the number of rows in the window partition doesn't divide evenly into the number of buckets, then the remainder values are distributed one per bucket, starting with the first bucket. For example, with 6 rows and 4 buckets, the bucket values would be as follows: 1 1 2 2 3 4 - - sql: ARRAY_AGG([ ALL | DISTINCT ] expression) + - sql: ARRAY_AGG([ ALL | DISTINCT ] expression [ RESPECT NULLS | IGNORE NULLS ]) table: FIELD.arrayAgg description: | - By default or with keyword ALL, return an array that concatenates the input rows - and returns NULL if there are no input rows. - NULL values will be ignored. Use DISTINCT for one unique instance of each value. + By default or with keyword `ALL` and, return an array that concatenates the input rows + and returns `NULL` if there are no input rows. Use `DISTINCT` for one unique instance of each value. + By default null values are respected, use `IGNORE NULLS` to skip null values. - sql: JSON_OBJECTAGG([KEY] key VALUE value [ { NULL | ABSENT } ON NULL ]) table: jsonObjectAgg(JsonOnNull, keyExpression, valueExpression) description: | diff --git a/docs/data/sql_functions_zh.yml b/docs/data/sql_functions_zh.yml index 987c2287cf398..f3d8eb6399b70 100644 --- a/docs/data/sql_functions_zh.yml +++ b/docs/data/sql_functions_zh.yml @@ -1181,11 +1181,11 @@ aggregate: 将窗口分区中的所有数据按照顺序划分为 n 个分组,返回分配给各行数据的分组编号(从 1 开始,最大为 n)。 如果不能均匀划分为 n 个分组,则剩余值从第 1 个分组开始,为每一分组分配一个。 比如某个窗口分区有 6 行数据,划分为 4 个分组,则各行的分组编号为:1,1,2,2,3,4。 - - sql: ARRAY_AGG([ ALL | DISTINCT ] expression) + - sql: ARRAY_AGG([ ALL | DISTINCT ] expression [ RESPECT NULLS | IGNORE NULLS ]) table: FIELD.arrayAgg description: | - 默认情况下或使用关键字ALL,返回输入行中表达式所组成的数组,并且如果没有输入行,则返回 `NULL`。 - `NULL` 值将被忽略。使用 `DISTINCT` 则对所有值去重后计算。 + 默认情况下或使用关键字ALL,返回输入行中表达式所组成的数组,并且如果没有输入行,则返回 `NULL`。使用 `DISTINCT` 则对所有值去重后计算。 + 默认情况下`NULL` 值不会被忽略,使用 `IGNORE NULLS` 忽略 `NULL` 值。 - sql: JSON_OBJECTAGG([KEY] key VALUE value [ { NULL | ABSENT } ON NULL ]) table: jsonObjectAgg(JsonOnNull, keyExpression, valueExpression) description: | diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java index d3091bfac188a..5d44029299880 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java @@ -1155,7 +1155,8 @@ public List getAuxiliaryFunctions() { ReturnTypes.TO_ARRAY, SqlTypeTransforms.TO_NULLABLE), OperandTypes.ANY) .withFunctionType(SqlFunctionCategory.SYSTEM) - .withSyntax(SqlSyntax.FUNCTION); + .withSyntax(SqlSyntax.FUNCTION) + .withAllowsNullTreatment(true); // ARRAY OPERATORS public static final SqlOperator ARRAY_VALUE_CONSTRUCTOR = new SqlArrayConstructor(); @@ -1171,6 +1172,8 @@ public List getAuxiliaryFunctions() { // SPECIAL OPERATORS public static final SqlOperator MULTISET_VALUE = SqlStdOperatorTable.MULTISET_VALUE; public static final SqlOperator ROW = SqlStdOperatorTable.ROW; + public static final SqlOperator IGNORE_NULLS = SqlStdOperatorTable.IGNORE_NULLS; + public static final SqlOperator RESPECT_NULLS = SqlStdOperatorTable.RESPECT_NULLS; public static final SqlOperator OVERLAPS = SqlStdOperatorTable.OVERLAPS; public static final SqlOperator LITERAL_CHAIN = SqlStdOperatorTable.LITERAL_CHAIN; public static final SqlOperator BETWEEN = SqlStdOperatorTable.BETWEEN; diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala index c6c02673da227..4ecd43638638f 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala @@ -147,7 +147,7 @@ class AggFunctionFactory( createCollectAggFunction(argTypes) case a: SqlAggFunction if a.getKind == SqlKind.ARRAY_AGG => - createArrayAggFunction(argTypes) + createArrayAggFunction(argTypes, call.ignoreNulls) case fn: SqlAggFunction if fn.getKind == SqlKind.JSON_OBJECTAGG => val onNull = fn.asInstanceOf[SqlJsonObjectAggAggFunction].getNullClause @@ -624,7 +624,9 @@ class AggFunctionFactory( new CollectAggFunction(argTypes(0)) } - private def createArrayAggFunction(types: Array[LogicalType]): UserDefinedFunction = { - new ArrayAggFunction(types(0)) + private def createArrayAggFunction( + types: Array[LogicalType], + ignoreNulls: Boolean): UserDefinedFunction = { + new ArrayAggFunction(types(0), ignoreNulls) } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala index aa16c4128d32d..66850f7f9d04d 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala @@ -849,7 +849,7 @@ object AggregateUtil extends Enumeration { call.getAggregation, false, false, - false, + call.ignoreNulls, call.getArgList, -1, // remove filterArg null, diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java index e463bb52f39b9..f97b11f565982 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java @@ -71,19 +71,14 @@ Stream getTestCaseSpecs() { Arrays.asList( Row.of("A", new Integer[] {1, 2}), Row.of("B", new Integer[] {2, 2, 3}), - Row.of("C", new Integer[] {3}), - Row.of("D", null), + Row.of("C", new Integer[] {3, null}), + Row.of("D", new Integer[] {null}), Row.of("E", new Integer[] {6}))) - .testResult( + .testSqlResult( source -> - "SELECT f0, array_agg(DISTINCT f1) FROM " + "SELECT f0, array_agg(DISTINCT f1 IGNORE NULLS) FROM " + source + " GROUP BY f0", - TableApiAggSpec.groupBySelect( - Collections.singletonList($("f0")), - $("f0"), - $("f1").arrayAgg().distinct()), - ROW(STRING(), ARRAY(INT())), ROW(STRING(), ARRAY(INT())), Arrays.asList( Row.of("A", new Integer[] {1, 2}), diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java index 339c0cc5d783b..0fc10ee48d7c1 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/ArrayAggFunctionTest.java @@ -62,7 +62,7 @@ protected Byte getValue(String v) { @Override protected AggregateFunction> getAggregator() { - return new ArrayAggFunction<>(DataTypes.TINYINT().getLogicalType()); + return new ArrayAggFunction<>(DataTypes.TINYINT().getLogicalType(), true); } } @@ -78,7 +78,7 @@ protected Short getValue(String v) { @Override protected AggregateFunction> getAggregator() { - return new ArrayAggFunction<>(DataTypes.SMALLINT().getLogicalType()); + return new ArrayAggFunction<>(DataTypes.SMALLINT().getLogicalType(), true); } } @@ -94,7 +94,7 @@ protected Integer getValue(String v) { @Override protected AggregateFunction> getAggregator() { - return new ArrayAggFunction<>(DataTypes.INT().getLogicalType()); + return new ArrayAggFunction<>(DataTypes.INT().getLogicalType(), true); } } @@ -110,7 +110,7 @@ protected Long getValue(String v) { @Override protected AggregateFunction> getAggregator() { - return new ArrayAggFunction<>(DataTypes.BIGINT().getLogicalType()); + return new ArrayAggFunction<>(DataTypes.BIGINT().getLogicalType(), true); } } @@ -126,7 +126,7 @@ protected Float getValue(String v) { @Override protected AggregateFunction> getAggregator() { - return new ArrayAggFunction<>(DataTypes.FLOAT().getLogicalType()); + return new ArrayAggFunction<>(DataTypes.FLOAT().getLogicalType(), true); } } @@ -142,7 +142,7 @@ protected Double getValue(String v) { @Override protected AggregateFunction> getAggregator() { - return new ArrayAggFunction<>(DataTypes.DOUBLE().getLogicalType()); + return new ArrayAggFunction<>(DataTypes.DOUBLE().getLogicalType(), true); } } @@ -173,7 +173,7 @@ protected List getExpectedResults() { @Override protected AggregateFunction> getAggregator() { - return new ArrayAggFunction<>(DataTypes.BOOLEAN().getLogicalType()); + return new ArrayAggFunction<>(DataTypes.BOOLEAN().getLogicalType(), true); } } @@ -226,7 +226,8 @@ protected List getExpectedResults() { @Override protected AggregateFunction> getAggregator() { - return new ArrayAggFunction<>(DataTypes.DECIMAL(precision, scale).getLogicalType()); + return new ArrayAggFunction<>( + DataTypes.DECIMAL(precision, scale).getLogicalType(), true); } } @@ -270,7 +271,7 @@ protected List getExpectedResults() { @Override protected AggregateFunction> getAggregator() { - return new ArrayAggFunction<>(DataTypes.STRING().getLogicalType()); + return new ArrayAggFunction<>(DataTypes.STRING().getLogicalType(), true); } } @@ -321,7 +322,7 @@ protected List getExpectedResults() { protected AggregateFunction> getAggregator() { return new ArrayAggFunction<>( - DataTypes.ROW(DataTypes.INT(), DataTypes.STRING()).getLogicalType()); + DataTypes.ROW(DataTypes.INT(), DataTypes.STRING()).getLogicalType(), true); } } @@ -361,7 +362,7 @@ protected List getExpectedResults() { @Override protected AggregateFunction> getAggregator() { - return new ArrayAggFunction<>(DataTypes.ARRAY(DataTypes.INT()).getLogicalType()); + return new ArrayAggFunction<>(DataTypes.ARRAY(DataTypes.INT()).getLogicalType(), true); } } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java index ece14ae9a98dc..71364c48f8ee5 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java @@ -19,16 +19,19 @@ package org.apache.flink.table.runtime.functions.aggregate; import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.table.api.DataTypes; -import org.apache.flink.table.api.dataview.ListView; import org.apache.flink.table.data.ArrayData; import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.runtime.typeutils.InternalSerializers; +import org.apache.flink.table.runtime.typeutils.LinkedListSerializer; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.util.FlinkRuntimeException; import java.util.ArrayList; import java.util.Collections; +import java.util.LinkedList; import java.util.List; import java.util.Objects; @@ -43,8 +46,11 @@ public final class ArrayAggFunction private final transient DataType elementDataType; - public ArrayAggFunction(LogicalType elementType) { + private final boolean ignoreNulls; + + public ArrayAggFunction(LogicalType elementType, boolean ignoreNulls) { this.elementDataType = toInternalDataType(elementType); + this.ignoreNulls = ignoreNulls; } // -------------------------------------------------------------------------------------------- @@ -58,11 +64,11 @@ public List getArgumentDataTypes() { @Override public DataType getAccumulatorDataType() { + DataType linkedListType = getLinkedListType(); return DataTypes.STRUCTURED( ArrayAggAccumulator.class, - DataTypes.FIELD("list", ListView.newListViewDataType(elementDataType.notNull())), - DataTypes.FIELD( - "retractList", ListView.newListViewDataType(elementDataType.notNull()))); + DataTypes.FIELD("list", linkedListType), + DataTypes.FIELD("retractList", linkedListType)); } @Override @@ -70,14 +76,21 @@ public DataType getOutputDataType() { return DataTypes.ARRAY(elementDataType).bridgedTo(ArrayData.class); } + @SuppressWarnings({"unchecked", "rawtypes"}) + private DataType getLinkedListType() { + TypeSerializer serializer = InternalSerializers.create(elementDataType.getLogicalType()); + return DataTypes.RAW( + LinkedList.class, (TypeSerializer) new LinkedListSerializer<>(serializer)); + } + // -------------------------------------------------------------------------------------------- // Runtime // -------------------------------------------------------------------------------------------- /** Accumulator for ARRAY_AGG with retraction. */ public static class ArrayAggAccumulator { - public ListView list; - public ListView retractList; + public LinkedList list; + public LinkedList retractList; @Override public boolean equals(Object o) { @@ -100,13 +113,17 @@ public int hashCode() { @Override public ArrayAggAccumulator createAccumulator() { final ArrayAggAccumulator acc = new ArrayAggAccumulator<>(); - acc.list = new ListView<>(); - acc.retractList = new ListView<>(); + acc.list = new LinkedList<>(); + acc.retractList = new LinkedList<>(); return acc; } public void accumulate(ArrayAggAccumulator acc, T value) throws Exception { - if (value != null) { + if (value == null) { + if (!ignoreNulls) { + acc.list.add(null); + } + } else { acc.list.add(value); } } @@ -124,18 +141,18 @@ public void merge(ArrayAggAccumulator acc, Iterable> i for (ArrayAggAccumulator otherAcc : its) { // merge list of acc and other List buffer = new ArrayList<>(); - for (T element : acc.list.get()) { + for (T element : acc.list) { buffer.add(element); } - for (T element : otherAcc.list.get()) { + for (T element : otherAcc.list) { buffer.add(element); } // merge retract list of acc and other List retractBuffer = new ArrayList<>(); - for (T element : acc.retractList.get()) { + for (T element : acc.retractList) { retractBuffer.add(element); } - for (T element : otherAcc.retractList.get()) { + for (T element : otherAcc.retractList) { retractBuffer.add(element); } @@ -158,7 +175,7 @@ public void merge(ArrayAggAccumulator acc, Iterable> i @Override public ArrayData getValue(ArrayAggAccumulator acc) { try { - List accList = acc.list.getList(); + List accList = acc.list; if (accList == null || accList.isEmpty()) { // array_agg returns null rather than an empty array when there are no input rows. return null; From a7e114f594da23c094768c9b507ed70f4d83f176 Mon Sep 17 00:00:00 2001 From: Jiabao Sun Date: Thu, 18 Jan 2024 10:16:31 +0800 Subject: [PATCH 6/8] Address Sergey's comments --- docs/data/sql_functions.yml | 1 + docs/data/sql_functions_zh.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/data/sql_functions.yml b/docs/data/sql_functions.yml index 5933cd1877823..a44e839ca8411 100644 --- a/docs/data/sql_functions.yml +++ b/docs/data/sql_functions.yml @@ -1065,6 +1065,7 @@ aggregate: By default or with keyword `ALL` and, return an array that concatenates the input rows and returns `NULL` if there are no input rows. Use `DISTINCT` for one unique instance of each value. By default null values are respected, use `IGNORE NULLS` to skip null values. + The `ORDER BY` clause is currently not supported. - sql: JSON_OBJECTAGG([KEY] key VALUE value [ { NULL | ABSENT } ON NULL ]) table: jsonObjectAgg(JsonOnNull, keyExpression, valueExpression) description: | diff --git a/docs/data/sql_functions_zh.yml b/docs/data/sql_functions_zh.yml index f3d8eb6399b70..b0ba79f12b3e1 100644 --- a/docs/data/sql_functions_zh.yml +++ b/docs/data/sql_functions_zh.yml @@ -1186,6 +1186,7 @@ aggregate: description: | 默认情况下或使用关键字ALL,返回输入行中表达式所组成的数组,并且如果没有输入行,则返回 `NULL`。使用 `DISTINCT` 则对所有值去重后计算。 默认情况下`NULL` 值不会被忽略,使用 `IGNORE NULLS` 忽略 `NULL` 值。 + 目前尚不支持 `ORDER BY` 子句。 - sql: JSON_OBJECTAGG([KEY] key VALUE value [ { NULL | ABSENT } ON NULL ]) table: jsonObjectAgg(JsonOnNull, keyExpression, valueExpression) description: | From 01051a2b9e2e55c504cf7d24794180706831cfb0 Mon Sep 17 00:00:00 2001 From: Jiabao Sun Date: Thu, 18 Jan 2024 19:43:22 +0800 Subject: [PATCH 7/8] Address comments --- docs/data/sql_functions.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/data/sql_functions.yml b/docs/data/sql_functions.yml index a44e839ca8411..9692db6994b97 100644 --- a/docs/data/sql_functions.yml +++ b/docs/data/sql_functions.yml @@ -1064,7 +1064,7 @@ aggregate: description: | By default or with keyword `ALL` and, return an array that concatenates the input rows and returns `NULL` if there are no input rows. Use `DISTINCT` for one unique instance of each value. - By default null values are respected, use `IGNORE NULLS` to skip null values. + By default `NULL` values are respected, use `IGNORE NULLS` to skip `NULL` values. The `ORDER BY` clause is currently not supported. - sql: JSON_OBJECTAGG([KEY] key VALUE value [ { NULL | ABSENT } ON NULL ]) table: jsonObjectAgg(JsonOnNull, keyExpression, valueExpression) From d898519a24a365ddde33c5e6811f4df99bbe9aef Mon Sep 17 00:00:00 2001 From: Jiabao Sun Date: Tue, 6 Feb 2024 16:06:40 +0800 Subject: [PATCH 8/8] Address Dawid's comments --- .../functions/sql/FlinkSqlOperatorTable.java | 6 +-- .../functions/ArrayAggFunctionITCase.java | 3 +- .../functions/aggregate/ArrayAggFunction.java | 49 ++++++------------- 3 files changed, 20 insertions(+), 38 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java index 5d44029299880..628b6bb804f6d 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java @@ -1144,9 +1144,9 @@ public List getAuxiliaryFunctions() { /** * Use the definitions in Flink instead of {@link SqlLibraryOperators#ARRAY_AGG}, because we - * ignore nulls and returns nullable ARRAY type. Order by clause like - * ARRAY_AGG(x ORDER BY x, y) for aggregate function is not supported yet, because the - * row data cannot be obtained inside the aggregate function. + * return nullable ARRAY type. Order by clause like ARRAY_AGG(x ORDER BY x, y) for + * aggregate function is not supported yet, because the row data cannot be obtained inside the + * aggregate function. */ public static final SqlAggFunction ARRAY_AGG = SqlBasicAggFunction.create( diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java index f97b11f565982..bc849f4438913 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java @@ -53,6 +53,7 @@ Stream getTestCaseSpecs() { Row.ofKind(INSERT, "B", 3), Row.ofKind(INSERT, "C", 3), Row.ofKind(INSERT, "C", null), + Row.ofKind(DELETE, "C", null), Row.ofKind(INSERT, "D", null), Row.ofKind(INSERT, "E", 4), Row.ofKind(INSERT, "E", 5), @@ -71,7 +72,7 @@ Stream getTestCaseSpecs() { Arrays.asList( Row.of("A", new Integer[] {1, 2}), Row.of("B", new Integer[] {2, 2, 3}), - Row.of("C", new Integer[] {3, null}), + Row.of("C", new Integer[] {3}), Row.of("D", new Integer[] {null}), Row.of("E", new Integer[] {6}))) .testSqlResult( diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java index 71364c48f8ee5..126b483ea11ee 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java @@ -29,7 +29,6 @@ import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.util.FlinkRuntimeException; -import java.util.ArrayList; import java.util.Collections; import java.util.LinkedList; import java.util.List; @@ -119,17 +118,13 @@ public ArrayAggAccumulator createAccumulator() { } public void accumulate(ArrayAggAccumulator acc, T value) throws Exception { - if (value == null) { - if (!ignoreNulls) { - acc.list.add(null); - } - } else { + if (value != null || !ignoreNulls) { acc.list.add(value); } } public void retract(ArrayAggAccumulator acc, T value) throws Exception { - if (value != null) { + if (value != null || !ignoreNulls) { if (!acc.list.remove(value)) { acc.retractList.add(value); } @@ -138,38 +133,24 @@ public void retract(ArrayAggAccumulator acc, T value) throws Exception { public void merge(ArrayAggAccumulator acc, Iterable> its) throws Exception { + List newRetractBuffer = new LinkedList<>(); for (ArrayAggAccumulator otherAcc : its) { - // merge list of acc and other - List buffer = new ArrayList<>(); - for (T element : acc.list) { - buffer.add(element); - } - for (T element : otherAcc.list) { - buffer.add(element); - } - // merge retract list of acc and other - List retractBuffer = new ArrayList<>(); - for (T element : acc.retractList) { - retractBuffer.add(element); - } - for (T element : otherAcc.retractList) { - retractBuffer.add(element); + if (!otherAcc.list.iterator().hasNext() && !otherAcc.retractList.iterator().hasNext()) { + // otherAcc is empty, skip it + continue; } + acc.list.addAll(otherAcc.list); + acc.retractList.addAll(otherAcc.retractList); + } - // merge list & retract list - List newRetractBuffer = new ArrayList<>(); - for (T element : retractBuffer) { - if (!buffer.remove(element)) { - newRetractBuffer.add(element); - } + for (T element : acc.retractList) { + if (!acc.list.remove(element)) { + newRetractBuffer.add(element); } - - // update to acc - acc.list.clear(); - acc.list.addAll(buffer); - acc.retractList.clear(); - acc.retractList.addAll(newRetractBuffer); } + + acc.retractList.clear(); + acc.retractList.addAll(newRetractBuffer); } @Override