Skip to content

Commit

Permalink
[FLINK-31118][table] Add ARRAY_UNION function.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuyongvs committed Apr 26, 2023
1 parent 0104427 commit 63b0da3
Show file tree
Hide file tree
Showing 10 changed files with 476 additions and 77 deletions.
155 changes: 79 additions & 76 deletions docs/data/sql_functions.yml

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions flink-python/docs/reference/pyflink.table/expressions.rst
Expand Up @@ -230,6 +230,7 @@ advanced type helper functions
Expression.array_position
Expression.array_remove
Expression.array_reverse
Expression.array_union
Expression.map_keys
Expression.map_values

Expand Down
7 changes: 7 additions & 0 deletions flink-python/pyflink/table/expression.py
Expand Up @@ -1512,6 +1512,13 @@ def array_reverse(self) -> 'Expression':
"""
return _binary_op("arrayReverse")(self)

def array_union(self, array) -> 'Expression':
"""
Returns an array of the elements in the union of array1 and array2, without duplicates.
If any of the array is null, the function will return null.
"""
return _binary_op("arrayUnion")(self, array)

@property
def map_keys(self) -> 'Expression':
"""
Expand Down
Expand Up @@ -59,6 +59,7 @@
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_POSITION;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_REMOVE;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_REVERSE;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ARRAY_UNION;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ASCII;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.ASIN;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AT;
Expand Down Expand Up @@ -1396,6 +1397,16 @@ public OutType arrayReverse() {
return toApiSpecificExpression(unresolvedCall(ARRAY_REVERSE, toExpr()));
}

/**
* Returns an array of the elements in the union of array1 and array2, without duplicates.
*
* <p>If any of the array is null, the function will return null.
*/
public OutType arrayUnion(InType array) {
return toApiSpecificExpression(
unresolvedCall(ARRAY_UNION, toExpr(), objectToExpression(array)));
}

/** Returns the keys of the map as an array. */
public OutType mapKeys() {
return toApiSpecificExpression(unresolvedCall(MAP_KEYS, toExpr()));
Expand Down
Expand Up @@ -69,6 +69,7 @@
import static org.apache.flink.table.types.inference.InputTypeStrategies.OUTPUT_IF_NULL;
import static org.apache.flink.table.types.inference.InputTypeStrategies.TYPE_LITERAL;
import static org.apache.flink.table.types.inference.InputTypeStrategies.and;
import static org.apache.flink.table.types.inference.InputTypeStrategies.commonArrayType;
import static org.apache.flink.table.types.inference.InputTypeStrategies.commonType;
import static org.apache.flink.table.types.inference.InputTypeStrategies.comparable;
import static org.apache.flink.table.types.inference.InputTypeStrategies.compositeSequence;
Expand Down Expand Up @@ -261,6 +262,16 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL)
"org.apache.flink.table.runtime.functions.scalar.ArrayReverseFunction")
.build();

public static final BuiltInFunctionDefinition ARRAY_UNION =
BuiltInFunctionDefinition.newBuilder()
.name("ARRAY_UNION")
.kind(SCALAR)
.inputTypeStrategy(commonArrayType(2))
.outputTypeStrategy(nullableIfArgs(COMMON))
.runtimeClass(
"org.apache.flink.table.runtime.functions.scalar.ArrayUnionFunction")
.build();

public static final BuiltInFunctionDefinition INTERNAL_REPLICATE_ROWS =
BuiltInFunctionDefinition.newBuilder()
.name("$REPLICATE_ROWS$1")
Expand Down
Expand Up @@ -24,6 +24,7 @@
import org.apache.flink.table.types.inference.strategies.AndArgumentTypeStrategy;
import org.apache.flink.table.types.inference.strategies.AnyArgumentTypeStrategy;
import org.apache.flink.table.types.inference.strategies.CommonArgumentTypeStrategy;
import org.apache.flink.table.types.inference.strategies.CommonArrayInputTypeStrategy;
import org.apache.flink.table.types.inference.strategies.CommonInputTypeStrategy;
import org.apache.flink.table.types.inference.strategies.ComparableTypeStrategy;
import org.apache.flink.table.types.inference.strategies.CompositeArgumentTypeStrategy;
Expand Down Expand Up @@ -347,6 +348,14 @@ public static InputTypeStrategy commonType(int count) {
return new CommonInputTypeStrategy(ConstantArgumentCount.of(count));
}

/**
* An {@link InputTypeStrategy} that expects {@code count} arguments that have a common array
* type.
*/
public static InputTypeStrategy commonArrayType(int count) {
return new CommonArrayInputTypeStrategy(ConstantArgumentCount.of(count));
}

// --------------------------------------------------------------------------------------------

private InputTypeStrategies() {
Expand Down
@@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.types.inference.strategies;

import org.apache.flink.annotation.Internal;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.ArgumentCount;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.inference.InputTypeStrategy;
import org.apache.flink.table.types.inference.Signature;
import org.apache.flink.table.types.inference.Signature.Argument;
import org.apache.flink.table.types.logical.LegacyTypeInformationType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
import org.apache.flink.table.types.utils.TypeConversions;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/** An {@link InputTypeStrategy} that expects that all arguments have a common array type. */
@Internal
public final class CommonArrayInputTypeStrategy implements InputTypeStrategy {

private static final Argument COMMON_ARGUMENT = Argument.ofGroup("COMMON");

private final ArgumentCount argumentCount;

public CommonArrayInputTypeStrategy(ArgumentCount argumentCount) {
this.argumentCount = argumentCount;
}

@Override
public ArgumentCount getArgumentCount() {
return argumentCount;
}

@Override
public Optional<List<DataType>> inferInputTypes(
CallContext callContext, boolean throwOnFailure) {
List<DataType> argumentDataTypes = callContext.getArgumentDataTypes();
List<LogicalType> argumentTypes =
argumentDataTypes.stream()
.map(DataType::getLogicalType)
.collect(Collectors.toList());

if (!argumentTypes.stream()
.allMatch(logicalType -> logicalType.is(LogicalTypeRoot.ARRAY))) {
return callContext.fail(throwOnFailure, "All arguments requires to be a ARRAY type");
}

if (argumentTypes.stream().anyMatch(CommonArrayInputTypeStrategy::isLegacyType)) {
return Optional.of(argumentDataTypes);
}

Optional<LogicalType> commonType = LogicalTypeMerging.findCommonType(argumentTypes);

if (!commonType.isPresent()) {
return callContext.fail(
throwOnFailure,
"Could not find a common type for arguments: %s",
argumentDataTypes);
}

return commonType.map(
type ->
Collections.nCopies(
argumentTypes.size(), TypeConversions.fromLogicalToDataType(type)));
}

@Override
public List<Signature> getExpectedSignatures(FunctionDefinition definition) {
Optional<Integer> minCount = argumentCount.getMinCount();
Optional<Integer> maxCount = argumentCount.getMaxCount();

int numberOfMandatoryArguments = minCount.orElse(0);

if (maxCount.isPresent()) {
return IntStream.range(numberOfMandatoryArguments, maxCount.get() + 1)
.mapToObj(count -> Signature.of(Collections.nCopies(count, COMMON_ARGUMENT)))
.collect(Collectors.toList());
}

List<Argument> arguments =
new ArrayList<>(Collections.nCopies(numberOfMandatoryArguments, COMMON_ARGUMENT));
arguments.add(Argument.ofGroupVarying("COMMON"));
return Collections.singletonList(Signature.of(arguments));
}

private static boolean isLegacyType(LogicalType type) {
return type instanceof LegacyTypeInformationType;
}
}
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.types.inference.strategies;

import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.types.inference.InputTypeStrategies;
import org.apache.flink.table.types.inference.InputTypeStrategiesTestBase;

import java.util.stream.Stream;

/** Tests for {@link CommonArrayInputTypeStrategy}. */
class CommonArrayInputTypeStrategyTest extends InputTypeStrategiesTestBase {

@Override
protected Stream<TestSpec> testData() {
return Stream.of(
TestSpec.forStrategy(InputTypeStrategies.commonArrayType(2))
.expectSignature("f(<COMMON>, <COMMON>)")
.calledWithArgumentTypes(
DataTypes.ARRAY(DataTypes.INT()),
DataTypes.ARRAY(DataTypes.DOUBLE().notNull()).notNull())
.expectArgumentTypes(
DataTypes.ARRAY(DataTypes.DOUBLE()),
DataTypes.ARRAY(DataTypes.DOUBLE())),
TestSpec.forStrategy(
"Strategy fails if not all of the argument types are ARRAY",
InputTypeStrategies.commonArrayType(2))
.calledWithArgumentTypes(DataTypes.INT(), DataTypes.ARRAY(DataTypes.INT()))
.expectErrorMessage("All arguments requires to be a ARRAY type"),
TestSpec.forStrategy(
"Strategy fails if can not find a common type",
InputTypeStrategies.commonArrayType(2))
.calledWithArgumentTypes(
DataTypes.ARRAY(DataTypes.INT()),
DataTypes.ARRAY(DataTypes.STRING()))
.expectErrorMessage(
"Could not find a common type for arguments: [ARRAY<INT>, ARRAY<STRING>]"));
}
}
Expand Up @@ -42,7 +42,8 @@ Stream<TestSetSpec> getTestSetSpecs() {
arrayDistinctTestCases(),
arrayPositionTestCases(),
arrayRemoveTestCases(),
arrayReverseTestCases())
arrayReverseTestCases(),
arrayUnionTestCases())
.flatMap(s -> s);
}

Expand Down Expand Up @@ -415,4 +416,67 @@ private Stream<TestSetSpec> arrayReverseTestCases() {
DataTypes.ARRAY(
DataTypes.ROW(DataTypes.BOOLEAN(), DataTypes.DATE()))));
}

private Stream<TestSetSpec> arrayUnionTestCases() {
return Stream.of(
TestSetSpec.forFunction(BuiltInFunctionDefinitions.ARRAY_UNION)
.onFieldsWithData(
new Integer[] {1, 2, null},
null,
new Row[] {
Row.of(true, LocalDate.of(2022, 4, 20)),
Row.of(true, LocalDate.of(1990, 10, 14)),
null
},
1)
.andDataTypes(
DataTypes.ARRAY(DataTypes.INT()),
DataTypes.ARRAY(DataTypes.INT()),
DataTypes.ARRAY(
DataTypes.ROW(DataTypes.BOOLEAN(), DataTypes.DATE())),
DataTypes.INT())
// ARRAY<INT>
.testResult(
$("f0").arrayUnion(new Integer[] {1, null, 4}),
"ARRAY_UNION(f0, ARRAY[1, NULL, 4])",
new Integer[] {1, 2, null, 4},
DataTypes.ARRAY(DataTypes.INT()))
// insert cast bug https://issues.apache.org/jira/browse/CALCITE-5674.
// .testResult(
// $("f0").arrayUnion(array(1.0d, null,
// 4.0d)),
// "ARRAY_UNION(f0, ARRAY[1.0E0, NULL,
// 4.0E0])",
// new Double[] {1.0d, 2.0d, null, 4.0d},
// DataTypes.ARRAY(DataTypes.DOUBLE()))
// ARRAY<INT> of null value
.testResult(
$("f1").arrayUnion(new Integer[] {1, null, 4}),
"ARRAY_UNION(f1, ARRAY[1, NULL, 4])",
null,
DataTypes.ARRAY(DataTypes.INT()))
// ARRAY<ROW<BOOLEAN, DATE>>
.testResult(
$("f2").arrayUnion(
new Row[] {
null, Row.of(true, LocalDate.of(1990, 10, 14)),
}),
"ARRAY_UNION(f2, ARRAY[NULL, (TRUE, DATE '1990-10-14')])",
new Row[] {
Row.of(true, LocalDate.of(2022, 4, 20)),
Row.of(true, LocalDate.of(1990, 10, 14)),
null
},
DataTypes.ARRAY(
DataTypes.ROW(DataTypes.BOOLEAN(), DataTypes.DATE())))
// invalid signatures
.testSqlValidationError(
"ARRAY_UNION(f3, TRUE)",
"Invalid input arguments. Expected signatures are:\n"
+ "ARRAY_UNION(<COMMON>, <COMMON>)")
.testTableApiValidationError(
$("f3").arrayUnion(true),
"Invalid input arguments. Expected signatures are:\n"
+ "ARRAY_UNION(<COMMON>, <COMMON>)"));
}
}

0 comments on commit 63b0da3

Please sign in to comment.