Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2242,7 +2242,8 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL)
BuiltInFunctionDefinition.newBuilder()
.name("in")
.kind(SCALAR)
.outputTypeStrategy(TypeStrategies.MISSING)
.inputTypeStrategy(SpecificInputTypeStrategies.IN)
.outputTypeStrategy(nullableIfArgs(explicit(DataTypes.BOOLEAN())))
.build();

public static final BuiltInFunctionDefinition CAST =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,9 @@
import org.apache.flink.table.types.inference.ConstantArgumentCount;
import org.apache.flink.table.types.inference.InputTypeStrategy;
import org.apache.flink.table.types.inference.Signature;
import org.apache.flink.table.types.logical.DistinctType;
import org.apache.flink.table.types.logical.LegacyTypeInformationType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeFamily;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
import org.apache.flink.table.types.logical.RawType;
import org.apache.flink.table.types.logical.StructuredType;
import org.apache.flink.table.types.logical.StructuredType.StructuredComparison;
import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;
import org.apache.flink.util.Preconditions;

import java.util.Collections;
Expand All @@ -49,7 +44,7 @@
* with itself (e.g. for aggregations).
*
* <p>For the rules which types are comparable with which types see {@link
* #areComparable(LogicalType, LogicalType)}.
* LogicalTypeChecks#areComparable(LogicalType, LogicalType, StructuredComparison)}.
*/
@Internal
public final class ComparableTypeStrategy implements InputTypeStrategy {
Expand Down Expand Up @@ -78,7 +73,7 @@ public Optional<List<DataType>> inferInputTypes(
final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes();
if (argumentDataTypes.size() == 1) {
final LogicalType argType = argumentDataTypes.get(0).getLogicalType();
if (!areComparable(argType, argType)) {
if (!LogicalTypeChecks.areComparable(argType, argType, requiredComparison)) {
return callContext.fail(
throwOnFailure,
"Type '%s' should support %s comparison with itself.",
Expand All @@ -90,7 +85,7 @@ public Optional<List<DataType>> inferInputTypes(
final LogicalType firstType = argumentDataTypes.get(i).getLogicalType();
final LogicalType secondType = argumentDataTypes.get(i + 1).getLogicalType();

if (!areComparable(firstType, secondType)) {
if (!LogicalTypeChecks.areComparable(firstType, secondType, requiredComparison)) {
return callContext.fail(
throwOnFailure,
"All types in a comparison should support %s comparison with each other. "
Expand All @@ -111,120 +106,9 @@ private String comparisonToString() {
: "both 'EQUALS' and 'ORDER'";
}

private boolean areComparable(LogicalType firstType, LogicalType secondType) {
return areComparableWithNormalizedNullability(firstType.copy(true), secondType.copy(true));
}

private boolean areComparableWithNormalizedNullability(
LogicalType firstType, LogicalType secondType) {
// A hack to support legacy types. To be removed when we drop the legacy types.
if (firstType instanceof LegacyTypeInformationType
|| secondType instanceof LegacyTypeInformationType) {
return true;
}

// everything is comparable with null, it should return null in that case
if (firstType.is(LogicalTypeRoot.NULL) || secondType.is(LogicalTypeRoot.NULL)) {
return true;
}

if (firstType.getTypeRoot() == secondType.getTypeRoot()) {
return areTypesOfSameRootComparable(firstType, secondType);
}

if (firstType.is(LogicalTypeFamily.NUMERIC) && secondType.is(LogicalTypeFamily.NUMERIC)) {
return true;
}

// DATE + ALL TIMESTAMPS
if (firstType.is(LogicalTypeFamily.DATETIME) && secondType.is(LogicalTypeFamily.DATETIME)) {
return true;
}

// VARCHAR + CHAR (we do not compare collations here)
if (firstType.is(LogicalTypeFamily.CHARACTER_STRING)
&& secondType.is(LogicalTypeFamily.CHARACTER_STRING)) {
return true;
}

// VARBINARY + BINARY
if (firstType.is(LogicalTypeFamily.BINARY_STRING)
&& secondType.is(LogicalTypeFamily.BINARY_STRING)) {
return true;
}

return false;
}

private boolean areTypesOfSameRootComparable(LogicalType firstType, LogicalType secondType) {
switch (firstType.getTypeRoot()) {
case ARRAY:
case MULTISET:
case MAP:
case ROW:
return areConstructedTypesComparable(firstType, secondType);
case DISTINCT_TYPE:
return areDistinctTypesComparable(firstType, secondType);
case STRUCTURED_TYPE:
return areStructuredTypesComparable(firstType, secondType);
case RAW:
return areRawTypesComparable(firstType, secondType);
default:
return true;
}
}

private boolean areRawTypesComparable(LogicalType firstType, LogicalType secondType) {
return firstType.equals(secondType)
&& Comparable.class.isAssignableFrom(
((RawType<?>) firstType).getOriginatingClass());
}

private boolean areDistinctTypesComparable(LogicalType firstType, LogicalType secondType) {
DistinctType firstDistinctType = (DistinctType) firstType;
DistinctType secondDistinctType = (DistinctType) secondType;
return firstType.equals(secondType)
&& areComparable(
firstDistinctType.getSourceType(), secondDistinctType.getSourceType());
}

private boolean areStructuredTypesComparable(LogicalType firstType, LogicalType secondType) {
return firstType.equals(secondType) && hasRequiredComparison((StructuredType) firstType);
}

private boolean areConstructedTypesComparable(LogicalType firstType, LogicalType secondType) {
List<LogicalType> firstChildren = firstType.getChildren();
List<LogicalType> secondChildren = secondType.getChildren();

if (firstChildren.size() != secondChildren.size()) {
return false;
}

for (int i = 0; i < firstChildren.size(); i++) {
if (!areComparable(firstChildren.get(i), secondChildren.get(i))) {
return false;
}
}

return true;
}

@Override
public List<Signature> getExpectedSignatures(FunctionDefinition definition) {
return Collections.singletonList(
Signature.of(Signature.Argument.ofGroupVarying("COMPARABLE")));
}

private Boolean hasRequiredComparison(StructuredType structuredType) {
switch (requiredComparison) {
case EQUALS:
return structuredType.getComparison().isEquality();
case FULL:
return structuredType.getComparison().isComparison();
case NONE:
default:
// this is not important, required comparison will never be NONE
return true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ public final class SpecificInputTypeStrategies {
public static final InputTypeStrategy TWO_EQUALS_COMPARABLE =
comparable(ConstantArgumentCount.of(2), StructuredType.StructuredComparison.EQUALS);

/** Type strategy specific for {@link BuiltInFunctionDefinitions#IN}. */
public static final InputTypeStrategy IN = new SubQueryInputTypeStrategy();

private SpecificInputTypeStrategies() {
// no instantiation
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.types.inference.strategies;

import org.apache.flink.annotation.Internal;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.ArgumentCount;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.inference.ConstantArgumentCount;
import org.apache.flink.table.types.inference.InputTypeStrategy;
import org.apache.flink.table.types.inference.Signature;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.StructuredType;
import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;
import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
import org.apache.flink.table.types.utils.TypeConversions;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/** {@link InputTypeStrategy} for {@link BuiltInFunctionDefinitions#IN}. */
@Internal
public class SubQueryInputTypeStrategy implements InputTypeStrategy {
@Override
public ArgumentCount getArgumentCount() {
return ConstantArgumentCount.from(2);
}

@Override
public Optional<List<DataType>> inferInputTypes(
CallContext callContext, boolean throwOnFailure) {
final LogicalType rightType;
final DataType leftType = callContext.getArgumentDataTypes().get(0);
if (callContext.getArgumentDataTypes().size() > 2) {
final Optional<LogicalType> commonType =
LogicalTypeMerging.findCommonType(
callContext.getArgumentDataTypes().stream()
.map(DataType::getLogicalType)
.collect(Collectors.toList()));
if (!commonType.isPresent()) {
return callContext.fail(
throwOnFailure, "Could not find a common type of the sublist.");
}
rightType = commonType.get();
} else {
rightType = callContext.getArgumentDataTypes().get(1).getLogicalType();
}

// check if the types are comparable, if the types are not comparable, check if it is not
// a sub-query case like SELECT a IN (SELECT b FROM table1). We check if the result of the
// rightType is of a ROW type with a single column, and if that column is comparable with
// left type
if (!LogicalTypeChecks.areComparable(
leftType.getLogicalType(),
rightType,
StructuredType.StructuredComparison.EQUALS)
&& !isComparableWithSubQuery(leftType.getLogicalType(), rightType)) {
return callContext.fail(
throwOnFailure,
"Types on the right side of IN operator (%s) are not comparable with %s.",
rightType,
leftType.getLogicalType());
}

return Optional.of(
Stream.concat(
Stream.of(leftType),
IntStream.range(1, callContext.getArgumentDataTypes().size())
.mapToObj(
i ->
TypeConversions.fromLogicalToDataType(
rightType)))
.collect(Collectors.toList()));
}

private static boolean isComparableWithSubQuery(LogicalType left, LogicalType right) {
if (right.is(LogicalTypeRoot.ROW) && right.getChildren().size() == 1) {
final RowType rowType = (RowType) right;
return LogicalTypeChecks.areComparable(
left, rowType.getTypeAt(0), StructuredType.StructuredComparison.EQUALS);
}
return false;
}

@Override
public List<Signature> getExpectedSignatures(FunctionDefinition definition) {
return Arrays.asList(
Signature.of(
Signature.Argument.ofGroup("COMPARABLE"),
Signature.Argument.ofGroupVarying("COMPARABLE")),
Signature.of(
Signature.Argument.ofGroup("COMPARABLE"),
Signature.Argument.ofGroup("SUBQUERY")));
}
}
Loading