From 67692437ea39d56314930d57f2620040c135f08b Mon Sep 17 00:00:00 2001 From: bvarghese1 Date: Mon, 15 May 2023 20:52:54 -0700 Subject: [PATCH] [FLINK-31663][table] Add ARRAY_EXCEPT function --- docs/data/sql_functions.yml | 3 + .../reference/pyflink.table/expressions.rst | 1 + flink-python/pyflink/table/expression.py | 7 + .../table/api/internal/BaseExpressions.java | 11 ++ .../functions/BuiltInFunctionDefinitions.java | 10 ++ .../functions/CollectionFunctionsITCase.java | 144 +++++++++++++++++- .../functions/scalar/ArrayExceptFunction.java | 110 +++++++++++++ 7 files changed, 285 insertions(+), 1 deletion(-) create mode 100644 flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/scalar/ArrayExceptFunction.java diff --git a/docs/data/sql_functions.yml b/docs/data/sql_functions.yml index dc83282363820..5988811aedc5a 100644 --- a/docs/data/sql_functions.yml +++ b/docs/data/sql_functions.yml @@ -646,6 +646,9 @@ collection: - sql: ARRAY_UNION(array1, array2) table: haystack.arrayUnion(array) description: Returns an array of the elements in the union of array1 and array2, without duplicates. If any of the array is null, the function will return null. + - sql: ARRAY_EXCEPT(array1, array2) + table: arrayOne.arrayExcept(arrayTwo) + description: Returns an array of the elements in array1 but not in array2, without duplicates. If array1 is null, the function will return null. - sql: MAP_KEYS(map) table: MAP.mapKeys() description: Returns the keys of the map as array. No order guaranteed. diff --git a/flink-python/docs/reference/pyflink.table/expressions.rst b/flink-python/docs/reference/pyflink.table/expressions.rst index 095bdc3f5baec..76591a011c4bb 100644 --- a/flink-python/docs/reference/pyflink.table/expressions.rst +++ b/flink-python/docs/reference/pyflink.table/expressions.rst @@ -231,6 +231,7 @@ advanced type helper functions Expression.array_remove Expression.array_reverse Expression.array_union + Expression.array_except Expression.map_keys Expression.map_values diff --git a/flink-python/pyflink/table/expression.py b/flink-python/pyflink/table/expression.py index f435a203f3888..681817b0d2951 100644 --- a/flink-python/pyflink/table/expression.py +++ b/flink-python/pyflink/table/expression.py @@ -1519,6 +1519,13 @@ def array_union(self, array) -> 'Expression': """ return _binary_op("arrayUnion")(self, array) + def array_except(self, array) -> 'Expression': + """ + Returns an array of the elements in array1 but not in array2, without duplicates. + If array1 is null, the function will return null. + """ + return _binary_op("arrayExcept")(self, array) + @property def map_keys(self) -> 'Expression': """ diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/BaseExpressions.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/BaseExpressions.java index f1f0000bf19ad..d7bf4de0b0c6a 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/BaseExpressions.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/BaseExpressions.java @@ -56,6 +56,7 @@ import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_CONTAINS; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_DISTINCT; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_ELEMENT; +import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_EXCEPT; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_POSITION; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_REMOVE; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_REVERSE; @@ -1407,6 +1408,16 @@ public OutType arrayUnion(InType array) { unresolvedCall(ARRAY_UNION, toExpr(), objectToExpression(array))); } + /** + * Return an array of the elements in array1 but not in array2, without duplicates + * + *

If array1 is null, the function will return null. + */ + public OutType arrayExcept(InType array) { + return toApiSpecificExpression( + unresolvedCall(ARRAY_EXCEPT, toExpr(), objectToExpression(array))); + } + /** Returns the keys of the map as an array. */ public OutType mapKeys() { return toApiSpecificExpression(unresolvedCall(MAP_KEYS, toExpr())); diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java index f9163ba86db9e..b5225f34f6b68 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java @@ -272,6 +272,16 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL) "org.apache.flink.table.runtime.functions.scalar.ArrayUnionFunction") .build(); + public static final BuiltInFunctionDefinition ARRAY_EXCEPT = + BuiltInFunctionDefinition.newBuilder() + .name("ARRAY_EXCEPT") + .kind(SCALAR) + .inputTypeStrategy(commonArrayType(2)) + .outputTypeStrategy(nullableIfArgs(COMMON)) + .runtimeClass( + "org.apache.flink.table.runtime.functions.scalar.ArrayExceptFunction") + .build(); + public static final BuiltInFunctionDefinition INTERNAL_REPLICATE_ROWS = BuiltInFunctionDefinition.newBuilder() .name("$REPLICATE_ROWS$1") diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CollectionFunctionsITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CollectionFunctionsITCase.java index 64eac0bb2c2e7..e6becdd8e0227 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CollectionFunctionsITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CollectionFunctionsITCase.java @@ -43,7 +43,8 @@ Stream getTestSetSpecs() { arrayPositionTestCases(), arrayRemoveTestCases(), arrayReverseTestCases(), - arrayUnionTestCases()) + arrayUnionTestCases(), + arrayExceptTestCases()) .flatMap(s -> s); } @@ -479,4 +480,145 @@ private Stream arrayUnionTestCases() { "Invalid input arguments. Expected signatures are:\n" + "ARRAY_UNION(, )")); } + + private Stream arrayExceptTestCases() { + return Stream.of( + TestSetSpec.forFunction(BuiltInFunctionDefinitions.ARRAY_EXCEPT) + .onFieldsWithData( + new Integer[] {1, 2, 2}, + null, + new Row[] { + Row.of(true, LocalDate.of(2022, 4, 20)), + Row.of(true, LocalDate.of(1990, 10, 14)), + null + }, + new Integer[] {null, null, 1}, + new Integer[][] { + new Integer[] {1, null, 3}, new Integer[] {0}, new Integer[] {1} + }, + new Map[] { + CollectionUtil.map(entry(1, "a"), entry(2, "b")), + CollectionUtil.map(entry(3, "c"), entry(4, "d")), + null + }) + .andDataTypes( + DataTypes.ARRAY(DataTypes.INT()), + DataTypes.ARRAY(DataTypes.INT()), + DataTypes.ARRAY( + DataTypes.ROW(DataTypes.BOOLEAN(), DataTypes.DATE())), + DataTypes.ARRAY(DataTypes.INT()), + DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.INT())), + DataTypes.ARRAY(DataTypes.MAP(DataTypes.INT(), DataTypes.STRING()))) + // ARRAY + .testResult( + $("f0").arrayExcept(new Integer[] {2, 3, 4}), + "ARRAY_EXCEPT(f0, ARRAY[2,3,4])", + new Integer[] {1}, + DataTypes.ARRAY(DataTypes.INT()).nullable()) + .testResult( + $("f0").arrayExcept(new Integer[] {1}), + "ARRAY_EXCEPT(f0, ARRAY[1])", + new Integer[] {2}, + DataTypes.ARRAY(DataTypes.INT()).nullable()) + .testResult( + $("f0").arrayExcept(new Integer[] {42}), + "ARRAY_EXCEPT(f0, ARRAY[42])", + new Integer[] {1, 2}, + DataTypes.ARRAY(DataTypes.INT()).nullable()) + // arrayTwo is NULL + .testResult( + $("f0").arrayExcept( + lit(null, DataTypes.ARRAY(DataTypes.INT())) + .cast(DataTypes.ARRAY(DataTypes.INT()))), + "ARRAY_EXCEPT(f0, CAST(NULL AS ARRAY))", + new Integer[] {1, 2}, + DataTypes.ARRAY(DataTypes.INT()).nullable()) + // arrayTwo contains null elements + .testResult( + $("f0").arrayExcept(new Integer[] {null, 2}), + "ARRAY_EXCEPT(f0, ARRAY[null, 2])", + new Integer[] {1}, + DataTypes.ARRAY(DataTypes.INT()).nullable()) + // arrayOne is NULL + .testResult( + $("f1").arrayExcept(new Integer[] {1, 2, 3}), + "ARRAY_EXCEPT(f1, ARRAY[1,2,3])", + null, + DataTypes.ARRAY(DataTypes.INT()).nullable()) + // arrayOne contains null elements + .testResult( + $("f3").arrayExcept(new Integer[] {null, 42}), + "ARRAY_EXCEPT(f3, ARRAY[null, 42])", + new Integer[] {1}, + DataTypes.ARRAY(DataTypes.INT()).nullable()) + // ARRAY> + .testResult( + $("f2").arrayExcept( + new Row[] { + Row.of(true, LocalDate.of(1990, 10, 14)) + }), + "ARRAY_EXCEPT(f2, ARRAY[(TRUE, DATE '1990-10-14')])", + new Row[] {Row.of(true, LocalDate.of(2022, 4, 20)), null}, + DataTypes.ARRAY( + DataTypes.ROW( + DataTypes.BOOLEAN(), DataTypes.DATE())) + .nullable()) + .testResult( + $("f2").arrayExcept( + lit( + null, + DataTypes.ARRAY( + DataTypes.ROW( + DataTypes.BOOLEAN(), + DataTypes.DATE()))) + .cast( + DataTypes.ARRAY( + DataTypes.ROW( + DataTypes.BOOLEAN(), + DataTypes + .DATE())))), + "ARRAY_EXCEPT(f2, CAST(NULL AS ARRAY>))", + new Row[] { + Row.of(true, LocalDate.of(2022, 4, 20)), + Row.of(true, LocalDate.of(1990, 10, 14)), + null, + }, + DataTypes.ARRAY( + DataTypes.ROW( + DataTypes.BOOLEAN(), DataTypes.DATE())) + .nullable()) + // ARRAY> + .testResult( + $("f4").arrayExcept(new Integer[][] {new Integer[] {0}}), + "ARRAY_EXCEPT(f4, ARRAY[ARRAY[0]])", + new Integer[][] {new Integer[] {1, null, 3}, new Integer[] {1}}, + DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.INT()).nullable())) + // ARRAY> with NULL elements + .testResult( + $("f5").arrayExcept( + new Map[] { + CollectionUtil.map(entry(3, "c"), entry(4, "d")) + }), + "ARRAY_EXCEPT(f5, ARRAY[MAP[3, 'c', 4, 'd']])", + new Map[] {CollectionUtil.map(entry(1, "a"), entry(2, "b")), null}, + DataTypes.ARRAY(DataTypes.MAP(DataTypes.INT(), DataTypes.STRING())) + .nullable()) + // Invalid signatures + .testSqlValidationError( + "ARRAY_EXCEPT(f0, TRUE)", + "Invalid input arguments. Expected signatures are:\n" + + "ARRAY_EXCEPT(, )") + .testTableApiValidationError( + $("f0").arrayExcept(true), + "Invalid input arguments. Expected signatures are:\n" + + "ARRAY_EXCEPT(, )") + .testSqlValidationError( + "ARRAY_EXCEPT(f0, ARRAY['hi', 'there'])", + "Invalid input arguments. Expected signatures are:\n" + + "ARRAY_EXCEPT(, )") + .testTableApiValidationError( + $("f0").arrayExcept(new String[] {"hi", "there"}), + "Invalid input arguments. Expected signatures are:\n" + + "ARRAY_EXCEPT(, )")); + } } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/scalar/ArrayExceptFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/scalar/ArrayExceptFunction.java new file mode 100644 index 0000000000000..fdaeec8823226 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/scalar/ArrayExceptFunction.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions.scalar; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.table.functions.FunctionContext; +import org.apache.flink.table.functions.SpecializedFunction; +import org.apache.flink.table.types.CollectionDataType; +import org.apache.flink.table.types.DataType; +import org.apache.flink.util.FlinkRuntimeException; + +import javax.annotation.Nullable; + +import java.lang.invoke.MethodHandle; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.apache.flink.table.api.Expressions.$; + +/** Implementation of {@link BuiltInFunctionDefinitions#ARRAY_EXCEPT}. */ +@Internal +public class ArrayExceptFunction extends BuiltInScalarFunction { + private final ArrayData.ElementGetter elementGetter; + private final SpecializedFunction.ExpressionEvaluator containsEvaluator; + private transient MethodHandle containsHandle; + + public ArrayExceptFunction(SpecializedFunction.SpecializedContext context) { + super(BuiltInFunctionDefinitions.ARRAY_EXCEPT, context); + final DataType arrayElementDataType = + ((CollectionDataType) context.getCallContext().getArgumentDataTypes().get(0)) + .getElementDataType(); + final DataType arrayDataType = context.getCallContext().getArgumentDataTypes().get(0); + elementGetter = ArrayData.createElementGetter(arrayElementDataType.getLogicalType()); + containsEvaluator = + context.createEvaluator( + $("array").arrayContains($("element")), + DataTypes.BOOLEAN(), + DataTypes.FIELD("array", arrayDataType.notNull().toInternal()), + DataTypes.FIELD("element", arrayElementDataType.notNull().toInternal())); + } + + @Override + public void open(FunctionContext context) throws Exception { + containsHandle = containsEvaluator.open(context); + } + + public @Nullable ArrayData eval(ArrayData arrayOne, ArrayData arrayTwo) { + try { + if (arrayOne == null) { + return null; + } + + boolean isNullPresent = false; + if (arrayTwo != null) { + for (int pos = 0; pos < arrayTwo.size(); pos++) { + final Object element = elementGetter.getElementOrNull(arrayTwo, pos); + if (element == null) { + isNullPresent = true; + break; + } + } + } + + List list = new ArrayList(); + Set seen = new HashSet<>(); + for (int pos = 0; pos < arrayOne.size(); pos++) { + final Object element = elementGetter.getElementOrNull(arrayOne, pos); + if ((arrayTwo == null && !seen.contains(element)) + || (element == null && !isNullPresent) + || (element != null + && !seen.contains(element) + && !(boolean) containsHandle.invoke(arrayTwo, element))) { + list.add(element); + } + seen.add(element); + } + + return new GenericArrayData(list.toArray()); + } catch (Throwable t) { + throw new FlinkRuntimeException(t); + } + } + + @Override + public void close() throws Exception { + containsEvaluator.close(); + } +}