From b1d9628d75159b90525acfe4d39219c5ec9a2c88 Mon Sep 17 00:00:00 2001 From: caicancai <2356672992@qq.com> Date: Wed, 13 Mar 2024 22:25:53 +0800 Subject: [PATCH] [CALCITE-6300] Function MAP_VALUES/MAP_KEYS gives exception when mapVauleType and mapKeyType not equals map Biggest mapKeytype or mapValueType --- .../calcite/sql/fun/SqlLibraryOperators.java | 45 ++++++++++- .../apache/calcite/sql/type/ReturnTypes.java | 24 ------ .../sql/validate/SqlValidatorUtil.java | 80 +++++++++++++++++++ .../apache/calcite/test/SqlOperatorTest.java | 24 ++++++ 4 files changed, 147 insertions(+), 26 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java index 0c700d9de32..f5d6a9201c1 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java @@ -1163,6 +1163,47 @@ private static RelDataType mapReturnType(SqlOperatorBinding opBinding) { false); } + private static RelDataType mapKeyReturnType(SqlOperatorBinding opBinding) { + Pair type = getAndAdjustComponentTypes(opBinding); + return SqlTypeUtil.createArrayType( + opBinding.getTypeFactory(), + requireNonNull(type.left, "inferred key type"), + false); + } + + @SuppressWarnings("argument.type.incompatible") + private static RelDataType mapValueReturnType(SqlOperatorBinding opBinding) { + Pair type = getAndAdjustComponentTypes(opBinding); + return SqlTypeUtil.createArrayType( + opBinding.getTypeFactory(), + requireNonNull(type.right, "inferred value type"), + false); + } + + private static @Nullable Pair getAndAdjustComponentTypes( + SqlOperatorBinding opBinding) { + List operandType = new ArrayList<>(); + + RelDataType keyType = opBinding.collectOperandTypes().get(0).getKeyType(); + RelDataType valueType = opBinding.collectOperandTypes().get(0).getValueType(); + + requireNonNull(keyType, () -> "keyType of " + keyType); + requireNonNull(valueType, () -> "valuetype left of " + valueType); + operandType.add(keyType); + operandType.add(valueType); + Pair<@Nullable RelDataType, @Nullable RelDataType> type = + getComponentTypes( + opBinding.getTypeFactory(), operandType); + + requireNonNull(type.left, () -> "type left of " + type.left); + requireNonNull(type.right, () -> "type right of " + type.right); + if (type.left.getSqlTypeName() != SqlTypeName.UNKNOWN + && type.right.getSqlTypeName() != SqlTypeName.UNKNOWN) { + SqlValidatorUtil.adjustTypeForMapFunctionConstructor(type, opBinding); + } + return type; + } + private static Pair<@Nullable RelDataType, @Nullable RelDataType> getComponentTypes( RelDataTypeFactory typeFactory, List argTypes) { @@ -1514,14 +1555,14 @@ private static RelDataType deriveTypeMapConcat(SqlOperatorBinding opBinding) { @LibraryOperator(libraries = {SPARK}) public static final SqlFunction MAP_KEYS = SqlBasicFunction.create(SqlKind.MAP_KEYS, - ReturnTypes.TO_MAP_KEYS_NULLABLE, + SqlLibraryOperators::mapKeyReturnType, OperandTypes.MAP); /** The "MAP_VALUES(map)" function. */ @LibraryOperator(libraries = {SPARK}) public static final SqlFunction MAP_VALUES = SqlBasicFunction.create(SqlKind.MAP_VALUES, - ReturnTypes.TO_MAP_VALUES_NULLABLE, + SqlLibraryOperators::mapValueReturnType, OperandTypes.MAP); /** The "MAP_CONTAINS_KEY(map, key)" function. */ diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java index 386ed96fb91..f1e6fcbe978 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java +++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java @@ -706,30 +706,6 @@ public static SqlCall stripSeparator(SqlCall call) { public static final SqlReturnTypeInference TO_MAP_ENTRIES_NULLABLE = TO_MAP_ENTRIES.andThen(SqlTypeTransforms.TO_NULLABLE); - /** - * Returns a ARRAY type. - * - *

For example, given {@code (INTEGER, DATE) MAP}, returns - * {@code INTEGER ARRAY}. - */ - public static final SqlReturnTypeInference TO_MAP_KEYS = - ARG0.andThen(SqlTypeTransforms.TO_MAP_KEYS); - - public static final SqlReturnTypeInference TO_MAP_KEYS_NULLABLE = - TO_MAP_KEYS.andThen(SqlTypeTransforms.TO_NULLABLE); - - /** - * Returns a ARRAY type. - * - *

For example, given {@code (INTEGER, DATE) MAP}, returns - * {@code DATE ARRAY}. - */ - public static final SqlReturnTypeInference TO_MAP_VALUES = - ARG0.andThen(SqlTypeTransforms.TO_MAP_VALUES); - - public static final SqlReturnTypeInference TO_MAP_VALUES_NULLABLE = - TO_MAP_VALUES.andThen(SqlTypeTransforms.TO_NULLABLE); - /** * Type-inference strategy that always returns GEOMETRY. */ diff --git a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java index 9aeec2da20f..787aa20e017 100644 --- a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java +++ b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java @@ -1389,6 +1389,23 @@ public static void adjustTypeForMapConstructor( } } + /** + * When the map key or value does not equal the map component key type or value type, + * make explicit casting. + * + * @param componentType derived map pair component type + * @param opBinding description of call + */ + public static void adjustTypeForMapFunctionConstructor( + Pair componentType, SqlOperatorBinding opBinding) { + if (opBinding instanceof SqlCallBinding) { + requireNonNull(componentType.getKey(), "map key type"); + requireNonNull(componentType.getValue(), "map value type"); + adjustTypeForMapFunctionConstructor( + componentType.getKey(), componentType.getValue(), (SqlCallBinding) opBinding); + } + } + /** * Adjusts the types for operands in a SqlCallBinding during the construction of a sql collection * type such as Array or Map. This method iterates from the operands of a {@link SqlCall} @@ -1427,6 +1444,54 @@ private static void adjustTypeForMultisetConstructor( } } + /** + * Adjusts the types of specified operands in a map operation to match a given target type. + * This is particularly useful in the context of SQL operations involving array functions, + * where it's necessary to ensure that all operands have consistent types for the operation + * to be valid. + * + *

This method operates on the assumption that the operands to be adjusted are part of a + * {@link SqlCall}, which is bound within a {@link SqlOperatorBinding}. The operands to be + * cast are identified by their indexes within the {@code operands} list of the {@link SqlCall}. + * The method performs a dynamic check to determine if an operand is a basic call to a map. + * If so, it casts each element within the map to the target type. + * Otherwise, it casts the operand itself to the target type. + * + *

Example usage: For an operation like + * {@code map_values(map('foo', 1, 'bar', cast(1 as double)))}, + * if map's value targetType is double, this method would ensure that the value of the + * first map are cast to double. + * + * @param evenType the {@link RelDataType} to which the operands at even positions should be cast + * @param oddType the {@link RelDataType} to which the operands at odd positions should be cast + * @param sqlCallBinding the {@link SqlCallBinding} containing the operands to be adjusted + */ + private static void adjustTypeForMapFunctionConstructor( + RelDataType evenType, RelDataType oddType, SqlCallBinding sqlCallBinding) { + SqlCall call = sqlCallBinding.getCall(); + List operands = ((SqlBasicCall) call.getOperandList().get(0)).getOperandList(); + RelDataType operandTypes; + List operandsmap = new ArrayList<>(); + RelDataType elementType; + for (int i = 0; i < operands.size(); i++) { + if (i % 2 == 0) { + elementType = evenType; + operandTypes = sqlCallBinding.collectOperandTypes().get(0).getKeyType(); + } else { + elementType = oddType; + operandTypes = sqlCallBinding.collectOperandTypes().get(0).getValueType(); + } + requireNonNull(operandTypes, "operandType of" + operandTypes); + + if (!operandTypes.equalsSansFieldNames(elementType)) { + operandsmap.add(i, castTo(operands.get(i), elementType)); + } else { + operandsmap.add(i, operands.get(i)); + } + } + call.setOperand(0, castMapTo(operandsmap)); + } + /** * Creates a CAST operation to cast a given {@link SqlNode} to a specified {@link RelDataType}. * This method uses the {@link SqlStdOperatorTable#CAST} operator to create a new {@link SqlCall} @@ -1468,6 +1533,21 @@ private static SqlNode castArrayElementTo(SqlNode node, RelDataType type) { return node; } + /** + * Wraps the given list of {@link SqlNode} elements into a MAP value constructor. + * This method creates a new {@link SqlCall} node using the {@link SqlStdOperatorTable#MAP_VALUE_CONSTRUCTOR} + * operator, representing a MAP value constructor with the provided list of nodes as its operands. + * + * @param node the {@link SqlNode} which is to be cast + * @return a new {@link SqlNode} representing the CAST operation + */ + private static SqlNode castMapTo(List node) { + SqlNodeList operandList = new SqlNodeList(SqlParserPos.ZERO); + operandList.addAll(node); + SqlCall mapNode = SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR.createCall(operandList); + return mapNode; + } + //~ Inner Classes ---------------------------------------------------------- /** diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java index 01f308adfe0..7107dcf8665 100644 --- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java +++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java @@ -7443,6 +7443,17 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) { "CHAR(3) NOT NULL ARRAY NOT NULL"); f1.checkScalar("map_keys(map('foo', 1, null, 2))", "[foo, null]", "CHAR(3) ARRAY NOT NULL"); + + f1.checkScalar("map_keys(map(cast(1 as tinyint), 1, 2, 2))", "[1, 2]", + "INTEGER NOT NULL ARRAY NOT NULL"); + f1.checkScalar("map_keys(map(cast(1 as tinyint), 1, cast(2 as double), 2))", "[1.0, 2.0]", + "DOUBLE NOT NULL ARRAY NOT NULL"); + f1.checkScalar("map_keys(map(cast(1 as tinyint), 1, cast(2 as float), 2))", "[1.0, 2.0]", + "FLOAT NOT NULL ARRAY NOT NULL"); + f1.checkScalar("map_keys(map(cast(1 as tinyint), 1, cast(null as float), 2))", "[1.0, null]", + "FLOAT ARRAY NOT NULL"); + f1.checkScalar("map_keys(map(cast(1 as tinyint), 1, cast(null as double), 2))", "[1.0, null]", + "DOUBLE ARRAY NOT NULL"); } /** Tests {@code MAP_VALUES} function from Spark. */ @@ -7469,6 +7480,19 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) { "INTEGER NOT NULL ARRAY NOT NULL"); f1.checkScalar("map_values(map('foo', 1, 'bar', cast(null as integer)))", "[1, null]", "INTEGER ARRAY NOT NULL"); + + f1.checkScalar("map_values(map('foo', null))", "[null]", + "NULL ARRAY NOT NULL"); + f1.checkScalar("map_values(map('foo', 1, 'bar', cast(1 as tinyint)))", "[1, 1]", + "INTEGER NOT NULL ARRAY NOT NULL"); + f1.checkScalar("map_values(map('foo', 1, 'bar', cast(1 as double)))", "[1.0, 1.0]", + "DOUBLE NOT NULL ARRAY NOT NULL"); + f1.checkScalar("map_values(map('foo', 1, 'bar', cast(1 as float)))", "[1.0, 1.0]", + "FLOAT NOT NULL ARRAY NOT NULL"); + f1.checkScalar("map_values(map('foo', 1, 'bar', cast(null as float)))", "[1.0, null]", + "FLOAT ARRAY NOT NULL"); + f1.checkScalar("map_values(map('foo', 1, 'bar', cast(null as double)))", "[1.0, null]", + "DOUBLE ARRAY NOT NULL"); } /** Test case for