Skip to content

Commit

Permalink
[CALCITE-6300] Function MAP_VALUES/MAP_KEYS gives exception when mapV…
Browse files Browse the repository at this point in the history
…auleType and mapKeyType not equals map Biggest mapKeytype or mapValueType
  • Loading branch information
caicancai committed May 5, 2024
1 parent f0dc2b0 commit b1d9628
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,47 @@ private static RelDataType mapReturnType(SqlOperatorBinding opBinding) {
false);
}

private static RelDataType mapKeyReturnType(SqlOperatorBinding opBinding) {
Pair<RelDataType, RelDataType> type = getAndAdjustComponentTypes(opBinding);
return SqlTypeUtil.createArrayType(
opBinding.getTypeFactory(),
requireNonNull(type.left, "inferred key type"),

Check failure on line 1170 in core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java

View workflow job for this annotation

GitHub Actions / CheckerFramework (JDK 11)

[Task :core:compileJava] [dereference.of.nullable] dereference of possibly-null reference type requireNonNull(type.left, "inferred key type"), ^

Check failure on line 1170 in core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java

View workflow job for this annotation

GitHub Actions / CheckerFramework (JDK 11), oldest Guava

[Task :core:compileJava] [dereference.of.nullable] dereference of possibly-null reference type requireNonNull(type.left, "inferred key type"), ^
false);
}

@SuppressWarnings("argument.type.incompatible")
private static RelDataType mapValueReturnType(SqlOperatorBinding opBinding) {
Pair<RelDataType, RelDataType> type = getAndAdjustComponentTypes(opBinding);
return SqlTypeUtil.createArrayType(
opBinding.getTypeFactory(),
requireNonNull(type.right, "inferred value type"),

Check failure on line 1179 in core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java

View workflow job for this annotation

GitHub Actions / CheckerFramework (JDK 11)

[Task :core:compileJava] [dereference.of.nullable] dereference of possibly-null reference type requireNonNull(type.right, "inferred value type"), ^

Check failure on line 1179 in core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java

View workflow job for this annotation

GitHub Actions / CheckerFramework (JDK 11), oldest Guava

[Task :core:compileJava] [dereference.of.nullable] dereference of possibly-null reference type requireNonNull(type.right, "inferred value type"), ^
false);
}

private static @Nullable Pair<RelDataType, RelDataType> getAndAdjustComponentTypes(
SqlOperatorBinding opBinding) {
List<RelDataType> operandType = new ArrayList<>();

RelDataType keyType = opBinding.collectOperandTypes().get(0).getKeyType();
RelDataType valueType = opBinding.collectOperandTypes().get(0).getValueType();

requireNonNull(keyType, () -> "keyType of " + keyType);
requireNonNull(valueType, () -> "valuetype left of " + valueType);
operandType.add(keyType);
operandType.add(valueType);
Pair<@Nullable RelDataType, @Nullable RelDataType> type =
getComponentTypes(
opBinding.getTypeFactory(), operandType);

requireNonNull(type.left, () -> "type left of " + type.left);
requireNonNull(type.right, () -> "type right of " + type.right);
if (type.left.getSqlTypeName() != SqlTypeName.UNKNOWN
&& type.right.getSqlTypeName() != SqlTypeName.UNKNOWN) {
SqlValidatorUtil.adjustTypeForMapFunctionConstructor(type, opBinding);

Check failure on line 1202 in core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java

View workflow job for this annotation

GitHub Actions / CheckerFramework (JDK 11)

[Task :core:compileJava] [argument.type.incompatible] incompatible argument for parameter componentType of adjustTypeForMapFunctionConstructor. SqlValidatorUtil.adjustTypeForMapFunctionConstructor(type, opBinding); ^ found : @initialized @nonnull Pair<@initialized @nullable RelDataType, @initialized @nullable RelDataType>

Check failure on line 1202 in core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java

View workflow job for this annotation

GitHub Actions / CheckerFramework (JDK 11), oldest Guava

[Task :core:compileJava] [argument.type.incompatible] incompatible argument for parameter componentType of adjustTypeForMapFunctionConstructor. SqlValidatorUtil.adjustTypeForMapFunctionConstructor(type, opBinding); ^ found : @initialized @nonnull Pair<@initialized @nullable RelDataType, @initialized @nullable RelDataType>
}
return type;

Check failure on line 1204 in core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java

View workflow job for this annotation

GitHub Actions / CheckerFramework (JDK 11)

[Task :core:compileJava] [return.type.incompatible] incompatible types in return. return type; ^ type of expression: @initialized @nonnull Pair<@initialized @nullable RelDataType, @initialized @nullable RelDataType>

Check failure on line 1204 in core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java

View workflow job for this annotation

GitHub Actions / CheckerFramework (JDK 11), oldest Guava

[Task :core:compileJava] [return.type.incompatible] incompatible types in return. return type; ^ type of expression: @initialized @nonnull Pair<@initialized @nullable RelDataType, @initialized @nullable RelDataType>
}

private static Pair<@Nullable RelDataType, @Nullable RelDataType> getComponentTypes(
RelDataTypeFactory typeFactory,
List<RelDataType> argTypes) {
Expand Down Expand Up @@ -1514,14 +1555,14 @@ private static RelDataType deriveTypeMapConcat(SqlOperatorBinding opBinding) {
@LibraryOperator(libraries = {SPARK})
public static final SqlFunction MAP_KEYS =
SqlBasicFunction.create(SqlKind.MAP_KEYS,
ReturnTypes.TO_MAP_KEYS_NULLABLE,
SqlLibraryOperators::mapKeyReturnType,
OperandTypes.MAP);

/** The "MAP_VALUES(map)" function. */
@LibraryOperator(libraries = {SPARK})
public static final SqlFunction MAP_VALUES =
SqlBasicFunction.create(SqlKind.MAP_VALUES,
ReturnTypes.TO_MAP_VALUES_NULLABLE,
SqlLibraryOperators::mapValueReturnType,
OperandTypes.MAP);

/** The "MAP_CONTAINS_KEY(map, key)" function. */
Expand Down
24 changes: 0 additions & 24 deletions core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
Original file line number Diff line number Diff line change
Expand Up @@ -706,30 +706,6 @@ public static SqlCall stripSeparator(SqlCall call) {
public static final SqlReturnTypeInference TO_MAP_ENTRIES_NULLABLE =
TO_MAP_ENTRIES.andThen(SqlTypeTransforms.TO_NULLABLE);

/**
* Returns a ARRAY type.
*
* <p>For example, given {@code (INTEGER, DATE) MAP}, returns
* {@code INTEGER ARRAY}.
*/
public static final SqlReturnTypeInference TO_MAP_KEYS =
ARG0.andThen(SqlTypeTransforms.TO_MAP_KEYS);

public static final SqlReturnTypeInference TO_MAP_KEYS_NULLABLE =
TO_MAP_KEYS.andThen(SqlTypeTransforms.TO_NULLABLE);

/**
* Returns a ARRAY type.
*
* <p>For example, given {@code (INTEGER, DATE) MAP}, returns
* {@code DATE ARRAY}.
*/
public static final SqlReturnTypeInference TO_MAP_VALUES =
ARG0.andThen(SqlTypeTransforms.TO_MAP_VALUES);

public static final SqlReturnTypeInference TO_MAP_VALUES_NULLABLE =
TO_MAP_VALUES.andThen(SqlTypeTransforms.TO_NULLABLE);

/**
* Type-inference strategy that always returns GEOMETRY.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,23 @@ public static void adjustTypeForMapConstructor(
}
}

/**
* When the map key or value does not equal the map component key type or value type,
* make explicit casting.
*
* @param componentType derived map pair component type
* @param opBinding description of call
*/
public static void adjustTypeForMapFunctionConstructor(
Pair<RelDataType, RelDataType> componentType, SqlOperatorBinding opBinding) {
if (opBinding instanceof SqlCallBinding) {
requireNonNull(componentType.getKey(), "map key type");
requireNonNull(componentType.getValue(), "map value type");
adjustTypeForMapFunctionConstructor(
componentType.getKey(), componentType.getValue(), (SqlCallBinding) opBinding);
}
}

/**
* Adjusts the types for operands in a SqlCallBinding during the construction of a sql collection
* type such as Array or Map. This method iterates from the operands of a {@link SqlCall}
Expand Down Expand Up @@ -1427,6 +1444,54 @@ private static void adjustTypeForMultisetConstructor(
}
}

/**
* Adjusts the types of specified operands in a map operation to match a given target type.
* This is particularly useful in the context of SQL operations involving array functions,
* where it's necessary to ensure that all operands have consistent types for the operation
* to be valid.
*
* <p>This method operates on the assumption that the operands to be adjusted are part of a
* {@link SqlCall}, which is bound within a {@link SqlOperatorBinding}. The operands to be
* cast are identified by their indexes within the {@code operands} list of the {@link SqlCall}.
* The method performs a dynamic check to determine if an operand is a basic call to a map.
* If so, it casts each element within the map to the target type.
* Otherwise, it casts the operand itself to the target type.
*
* <p>Example usage: For an operation like
* {@code map_values(map('foo', 1, 'bar', cast(1 as double)))},
* if map's value targetType is double, this method would ensure that the value of the
* first map are cast to double.
*
* @param evenType the {@link RelDataType} to which the operands at even positions should be cast
* @param oddType the {@link RelDataType} to which the operands at odd positions should be cast
* @param sqlCallBinding the {@link SqlCallBinding} containing the operands to be adjusted
*/
private static void adjustTypeForMapFunctionConstructor(
RelDataType evenType, RelDataType oddType, SqlCallBinding sqlCallBinding) {
SqlCall call = sqlCallBinding.getCall();
List<SqlNode> operands = ((SqlBasicCall) call.getOperandList().get(0)).getOperandList();
RelDataType operandTypes;
List<SqlNode> operandsmap = new ArrayList<>();
RelDataType elementType;
for (int i = 0; i < operands.size(); i++) {
if (i % 2 == 0) {
elementType = evenType;
operandTypes = sqlCallBinding.collectOperandTypes().get(0).getKeyType();
} else {
elementType = oddType;
operandTypes = sqlCallBinding.collectOperandTypes().get(0).getValueType();
}
requireNonNull(operandTypes, "operandType of" + operandTypes);

if (!operandTypes.equalsSansFieldNames(elementType)) {
operandsmap.add(i, castTo(operands.get(i), elementType));
} else {
operandsmap.add(i, operands.get(i));
}
}
call.setOperand(0, castMapTo(operandsmap));
}

/**
* Creates a CAST operation to cast a given {@link SqlNode} to a specified {@link RelDataType}.
* This method uses the {@link SqlStdOperatorTable#CAST} operator to create a new {@link SqlCall}
Expand Down Expand Up @@ -1468,6 +1533,21 @@ private static SqlNode castArrayElementTo(SqlNode node, RelDataType type) {
return node;
}

/**
* Wraps the given list of {@link SqlNode} elements into a MAP value constructor.
* This method creates a new {@link SqlCall} node using the {@link SqlStdOperatorTable#MAP_VALUE_CONSTRUCTOR}
* operator, representing a MAP value constructor with the provided list of nodes as its operands.
*
* @param node the {@link SqlNode} which is to be cast
* @return a new {@link SqlNode} representing the CAST operation
*/
private static SqlNode castMapTo(List<SqlNode> node) {
SqlNodeList operandList = new SqlNodeList(SqlParserPos.ZERO);
operandList.addAll(node);
SqlCall mapNode = SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR.createCall(operandList);
return mapNode;
}

//~ Inner Classes ----------------------------------------------------------

/**
Expand Down
24 changes: 24 additions & 0 deletions testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7443,6 +7443,17 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) {
"CHAR(3) NOT NULL ARRAY NOT NULL");
f1.checkScalar("map_keys(map('foo', 1, null, 2))", "[foo, null]",
"CHAR(3) ARRAY NOT NULL");

f1.checkScalar("map_keys(map(cast(1 as tinyint), 1, 2, 2))", "[1, 2]",
"INTEGER NOT NULL ARRAY NOT NULL");
f1.checkScalar("map_keys(map(cast(1 as tinyint), 1, cast(2 as double), 2))", "[1.0, 2.0]",
"DOUBLE NOT NULL ARRAY NOT NULL");
f1.checkScalar("map_keys(map(cast(1 as tinyint), 1, cast(2 as float), 2))", "[1.0, 2.0]",
"FLOAT NOT NULL ARRAY NOT NULL");
f1.checkScalar("map_keys(map(cast(1 as tinyint), 1, cast(null as float), 2))", "[1.0, null]",
"FLOAT ARRAY NOT NULL");
f1.checkScalar("map_keys(map(cast(1 as tinyint), 1, cast(null as double), 2))", "[1.0, null]",
"DOUBLE ARRAY NOT NULL");
}

/** Tests {@code MAP_VALUES} function from Spark. */
Expand All @@ -7469,6 +7480,19 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) {
"INTEGER NOT NULL ARRAY NOT NULL");
f1.checkScalar("map_values(map('foo', 1, 'bar', cast(null as integer)))", "[1, null]",
"INTEGER ARRAY NOT NULL");

f1.checkScalar("map_values(map('foo', null))", "[null]",
"NULL ARRAY NOT NULL");
f1.checkScalar("map_values(map('foo', 1, 'bar', cast(1 as tinyint)))", "[1, 1]",
"INTEGER NOT NULL ARRAY NOT NULL");
f1.checkScalar("map_values(map('foo', 1, 'bar', cast(1 as double)))", "[1.0, 1.0]",
"DOUBLE NOT NULL ARRAY NOT NULL");
f1.checkScalar("map_values(map('foo', 1, 'bar', cast(1 as float)))", "[1.0, 1.0]",
"FLOAT NOT NULL ARRAY NOT NULL");
f1.checkScalar("map_values(map('foo', 1, 'bar', cast(null as float)))", "[1.0, null]",
"FLOAT ARRAY NOT NULL");
f1.checkScalar("map_values(map('foo', 1, 'bar', cast(null as double)))", "[1.0, null]",
"DOUBLE ARRAY NOT NULL");
}

/** Test case for
Expand Down

0 comments on commit b1d9628

Please sign in to comment.