Skip to content

Commit

Permalink
[FLINK-31102][table] Add ARRAY_REMOVE function.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuyongvs committed Mar 10, 2023
1 parent 84d000c commit 2e33d70
Show file tree
Hide file tree
Showing 7 changed files with 267 additions and 1 deletion.
3 changes: 3 additions & 0 deletions docs/data/sql_functions.yml
Expand Up @@ -617,6 +617,9 @@ collection:
- sql: ARRAY_DISTINCT(haystack)
table: haystack.arrayDistinct()
description: Returns an array with unique elements. If the array itself is null, the function will return null. Keeps ordering of elements.
- sql: ARRAY_REMOVE(haystack, needle)
table: haystack.arrayRemove(needle)
description: Remove all elements that equal to element from array. If the array itself is null, the function will return null. Keeps ordering of elements.

json:
- sql: IS JSON [ { VALUE | SCALAR | ARRAY | OBJECT } ]
Expand Down
1 change: 1 addition & 0 deletions flink-python/docs/reference/pyflink.table/expressions.rst
Expand Up @@ -226,6 +226,7 @@ advanced type helper functions
Expression.element
Expression.array_contains
Expression.array_distinct
Expression.array_remove


time definition functions
Expand Down
7 changes: 7 additions & 0 deletions flink-python/pyflink/table/expression.py
Expand Up @@ -1487,6 +1487,13 @@ def array_distinct(self) -> 'Expression':
"""
return _binary_op("arrayDistinct")(self)

def array_remove(self, needle) -> 'Expression':
"""
Remove all elements that equal to element from array.
If the array itself is null, the function will return null. Keeps ordering of elements.
"""
return _binary_op("arrayRemove")(self, needle)

# ---------------------------- time definition functions -----------------------------

@property
Expand Down
Expand Up @@ -56,6 +56,7 @@
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_CONTAINS;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_DISTINCT;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_ELEMENT;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_REMOVE;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ASCII;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ASIN;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AT;
Expand Down Expand Up @@ -1359,6 +1360,16 @@ public OutType arrayDistinct() {
return toApiSpecificExpression(unresolvedCall(ARRAY_DISTINCT, toExpr()));
}

/**
* Remove all elements that equal to element from array.
*
* <p>If the array itself is null, the function will return null. Keeps ordering of elements.
*/
public OutType arrayRemove(InType needle) {
return toApiSpecificExpression(
unresolvedCall(ARRAY_REMOVE, toExpr(), objectToExpression(needle)));
}

// Time definition

/**
Expand Down
Expand Up @@ -178,6 +178,21 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL)
.runtimeClass(
"org.apache.flink.table.runtime.functions.scalar.ArrayDistinctFunction")
.build();

public static final BuiltInFunctionDefinition ARRAY_REMOVE =
BuiltInFunctionDefinition.newBuilder()
.name("ARRAY_REMOVE")
.kind(SCALAR)
.inputTypeStrategy(
sequence(
Arrays.asList("haystack", "needle"),
Arrays.asList(
logical(LogicalTypeRoot.ARRAY), ARRAY_ELEMENT_ARG)))
.outputTypeStrategy(nullableIfArgs(ConstantArgumentCount.of(0), argument(0)))
.runtimeClass(
"org.apache.flink.table.runtime.functions.scalar.ArrayRemoveFunction")
.build();

public static final BuiltInFunctionDefinition INTERNAL_REPLICATE_ROWS =
BuiltInFunctionDefinition.newBuilder()
.name("$REPLICATE_ROWS$1")
Expand Down
Expand Up @@ -21,13 +21,16 @@
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.types.Row;
import org.apache.flink.util.CollectionUtil;

import java.time.LocalDate;
import java.util.Map;
import java.util.stream.Stream;

import static org.apache.flink.table.api.Expressions.$;
import static org.apache.flink.table.api.Expressions.lit;
import static org.apache.flink.table.api.Expressions.row;
import static org.apache.flink.util.CollectionUtil.entry;

/** Tests for {@link BuiltInFunctionDefinitions} around arrays. */
class CollectionFunctionsITCase extends BuiltInFunctionTestBase {
Expand Down Expand Up @@ -178,6 +181,137 @@ Stream<TestSetSpec> getTestSetSpecs() {
null
},
DataTypes.ARRAY(
DataTypes.ROW(DataTypes.BOOLEAN(), DataTypes.DATE()))));
DataTypes.ROW(DataTypes.BOOLEAN(), DataTypes.DATE()))),
TestSetSpec.forFunction(BuiltInFunctionDefinitions.ARRAY_REMOVE)
.onFieldsWithData(
new Integer[] {1, 2, 2},
null,
new String[] {"Hello", "World"},
new Row[] {
Row.of(true, LocalDate.of(2022, 4, 20)),
Row.of(true, LocalDate.of(1990, 10, 14)),
null
},
new Integer[] {null, null, 1},
new Integer[][] {
new Integer[] {1, null, 3}, new Integer[] {0}, new Integer[] {1}
},
new Map[] {
CollectionUtil.map(entry(1, "a"), entry(2, "b")),
CollectionUtil.map(entry(3, "c"), entry(4, "d")),
null
})
.andDataTypes(
DataTypes.ARRAY(DataTypes.INT()),
DataTypes.ARRAY(DataTypes.INT()),
DataTypes.ARRAY(DataTypes.STRING()).notNull(),
DataTypes.ARRAY(
DataTypes.ROW(DataTypes.BOOLEAN(), DataTypes.DATE())),
DataTypes.ARRAY(DataTypes.INT()),
DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.INT())),
DataTypes.ARRAY(DataTypes.MAP(DataTypes.INT(), DataTypes.STRING())))
// ARRAY<INT>
.testResult(
$("f0").arrayRemove(2),
"ARRAY_REMOVE(f0, 2)",
new Integer[] {1},
DataTypes.ARRAY(DataTypes.INT()).nullable())
.testResult(
$("f0").arrayRemove(42),
"ARRAY_REMOVE(f0, 42)",
new Integer[] {1, 2, 2},
DataTypes.ARRAY(DataTypes.INT()).nullable())
.testResult(
$("f0").arrayRemove(
lit(null, DataTypes.SMALLINT())
.cast(DataTypes.INT())),
"ARRAY_REMOVE(f0, cast(NULL AS INT))",
new Integer[] {1, 2, 2},
DataTypes.ARRAY(DataTypes.INT()).nullable())
// ARRAY<INT> of null value
.testResult(
$("f1").arrayRemove(12),
"ARRAY_REMOVE(f1, 12)",
null,
DataTypes.ARRAY(DataTypes.INT()).nullable())
.testResult(
$("f1").arrayRemove(null),
"ARRAY_REMOVE(f1, NULL)",
null,
DataTypes.ARRAY(DataTypes.INT()).nullable())
// ARRAY<STRING> NOT NULL
.testResult(
$("f2").arrayRemove("Hello"),
"ARRAY_REMOVE(f2, 'Hello')",
new String[] {"World"},
DataTypes.ARRAY(DataTypes.STRING()).notNull())
.testResult(
$("f2").arrayRemove(
lit(null, DataTypes.STRING())
.cast(DataTypes.STRING())),
"ARRAY_REMOVE(f2, cast(NULL AS VARCHAR))",
new String[] {"Hello", "World"},
DataTypes.ARRAY(DataTypes.STRING()).notNull())
// ARRAY<ROW<BOOLEAN, DATE>>
.testResult(
$("f3").arrayRemove(row(true, LocalDate.of(1990, 10, 14))),
"ARRAY_REMOVE(f3, (TRUE, DATE '1990-10-14'))",
new Row[] {Row.of(true, LocalDate.of(2022, 4, 20)), null},
DataTypes.ARRAY(
DataTypes.ROW(
DataTypes.BOOLEAN(), DataTypes.DATE()))
.nullable())
.testResult(
$("f3").arrayRemove(row(false, LocalDate.of(1990, 10, 14))),
"ARRAY_REMOVE(f3, (FALSE, DATE '1990-10-14'))",
new Row[] {
Row.of(true, LocalDate.of(2022, 4, 20)),
Row.of(true, LocalDate.of(1990, 10, 14)),
null
},
DataTypes.ARRAY(
DataTypes.ROW(
DataTypes.BOOLEAN(), DataTypes.DATE()))
.nullable())
.testResult(
$("f3").arrayRemove(null),
"ARRAY_REMOVE(f3, null)",
new Row[] {
Row.of(true, LocalDate.of(2022, 4, 20)),
Row.of(true, LocalDate.of(1990, 10, 14)),
},
DataTypes.ARRAY(
DataTypes.ROW(
DataTypes.BOOLEAN(), DataTypes.DATE()))
.nullable())
// ARRAY<INT> with null elements
.testResult(
$("f4").arrayRemove(null),
"ARRAY_REMOVE(f4, NULL)",
new Integer[] {1},
DataTypes.ARRAY(DataTypes.INT()).nullable())
// ARRAY<ARRAY<INT>>
.testResult(
$("f5").arrayRemove(new Integer[] {0}),
"ARRAY_REMOVE(f5, array[0])",
new Integer[][] {new Integer[] {1, null, 3}, new Integer[] {1}},
DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.INT()).nullable()))
// ARRAY<Map<INT, STRING>> with null elements
.testResult(
$("f6").arrayRemove(
CollectionUtil.map(entry(3, "c"), entry(4, "d"))),
"ARRAY_REMOVE(f6, MAP[3, 'c', 4, 'd'])",
new Map[] {CollectionUtil.map(entry(1, "a"), entry(2, "b")), null},
DataTypes.ARRAY(DataTypes.MAP(DataTypes.INT(), DataTypes.STRING()))
.nullable())
// invalid signatures
.testSqlValidationError(
"ARRAY_REMOVE(f0, TRUE)",
"Invalid input arguments. Expected signatures are:\n"
+ "ARRAY_REMOVE(haystack <ARRAY>, needle <ARRAY ELEMENT>)")
.testTableApiValidationError(
$("f0").arrayRemove(true),
"Invalid input arguments. Expected signatures are:\n"
+ "ARRAY_REMOVE(haystack <ARRAY>, needle <ARRAY ELEMENT>)"));
}
}
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.runtime.functions.scalar;

import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.data.ArrayData;
import org.apache.flink.table.data.GenericArrayData;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.functions.FunctionContext;
import org.apache.flink.table.functions.SpecializedFunction;
import org.apache.flink.table.types.CollectionDataType;
import org.apache.flink.table.types.DataType;
import org.apache.flink.util.FlinkRuntimeException;

import javax.annotation.Nullable;

import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.List;

import static org.apache.flink.table.api.Expressions.$;

/** Implementation of {@link BuiltInFunctionDefinitions#ARRAY_REMOVE}. */
@Internal
public class ArrayRemoveFunction extends BuiltInScalarFunction {
private final ArrayData.ElementGetter elementGetter;
private final SpecializedFunction.ExpressionEvaluator equalityEvaluator;
private transient MethodHandle equalityHandle;

public ArrayRemoveFunction(SpecializedFunction.SpecializedContext context) {
super(BuiltInFunctionDefinitions.ARRAY_REMOVE, context);
final DataType elementDataType =
((CollectionDataType) context.getCallContext().getArgumentDataTypes().get(0))
.getElementDataType();
final DataType needleDataType = context.getCallContext().getArgumentDataTypes().get(1);
elementGetter = ArrayData.createElementGetter(elementDataType.getLogicalType());
equalityEvaluator =
context.createEvaluator(
$("element").isEqual($("needle")),
DataTypes.BOOLEAN(),
DataTypes.FIELD("element", elementDataType.notNull().toInternal()),
DataTypes.FIELD("needle", needleDataType.notNull().toInternal()));
}

@Override
public void open(FunctionContext context) throws Exception {
equalityHandle = equalityEvaluator.open(context);
}

public @Nullable ArrayData eval(ArrayData haystack, Object needle) {
try {
if (haystack == null) {
return null;
}

List list = new ArrayList();
final int size = haystack.size();
for (int pos = 0; pos < size; pos++) {
final Object element = elementGetter.getElementOrNull(haystack, pos);
if ((element == null && needle != null)
|| (element != null && needle == null)
|| (element != null
&& needle != null
&& !(boolean) equalityHandle.invoke(element, needle))) {
list.add(element);
}
}
return new GenericArrayData(list.toArray());
} catch (Throwable t) {
throw new FlinkRuntimeException(t);
}
}

@Override
public void close() throws Exception {
equalityEvaluator.close();
}
}

0 comments on commit 2e33d70

Please sign in to comment.