-
Notifications
You must be signed in to change notification settings - Fork 13.9k
[FLINK-36266][table] Insert into as select * behaves incorrect #25316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bb673d5
76d4cf7
72d057b
5cb6268
895110c
0d9565d
292703b
06378c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,14 +19,14 @@ package org.apache.flink.table.planner.calcite | |
|
|
||
| import org.apache.flink.sql.parser.`type`.SqlMapTypeNameSpec | ||
| import org.apache.flink.table.api.ValidationException | ||
| import org.apache.flink.table.planner.calcite.PreValidateReWriter.{newValidationError, notSupported} | ||
| import org.apache.flink.table.planner.calcite.SqlRewriterUtils.{rewriteSqlCall, rewriteSqlSelect, rewriteSqlValues} | ||
| import org.apache.flink.table.planner.calcite.FlinkCalciteSqlValidator.ExplicitTableSqlSelect | ||
| import org.apache.flink.table.planner.calcite.SqlRewriterUtils.{rewriteSqlCall, rewriteSqlSelect, rewriteSqlValues, rewriteSqlWith} | ||
| import org.apache.flink.util.Preconditions.checkArgument | ||
|
|
||
| import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory} | ||
| import org.apache.calcite.runtime.{CalciteContextException, Resources} | ||
| import org.apache.calcite.sql.`type`.SqlTypeUtil | ||
| import org.apache.calcite.sql.{SqlCall, SqlDataTypeSpec, SqlKind, SqlNode, SqlNodeList, SqlOrderBy, SqlSelect, SqlUtil} | ||
| import org.apache.calcite.sql.{SqlBasicCall, SqlCall, SqlDataTypeSpec, SqlIdentifier, SqlKind, SqlNode, SqlNodeList, SqlOrderBy, SqlSelect, SqlUtil, SqlWith} | ||
| import org.apache.calcite.sql.fun.SqlStdOperatorTable | ||
| import org.apache.calcite.sql.parser.SqlParserPos | ||
| import org.apache.calcite.sql.validate.SqlValidatorException | ||
|
|
@@ -54,6 +54,14 @@ class SqlRewriterUtils(validator: FlinkCalciteSqlValidator) { | |
| rewriteSqlValues(svalues, targetRowType, assignedFields, targetPosition) | ||
| } | ||
|
|
||
| def rewriteWith( | ||
| sqlWith: SqlWith, | ||
| targetRowType: RelDataType, | ||
| assignedFields: util.LinkedHashMap[Integer, SqlNode], | ||
| targetPosition: util.List[Int]): SqlCall = { | ||
| rewriteSqlWith(validator, sqlWith, targetRowType, assignedFields, targetPosition) | ||
| } | ||
|
|
||
| def rewriteCall( | ||
| rewriterUtils: SqlRewriterUtils, | ||
| validator: FlinkCalciteSqlValidator, | ||
|
|
@@ -130,8 +138,11 @@ object SqlRewriterUtils { | |
| call.getKind match { | ||
| case SqlKind.SELECT => | ||
| val sqlSelect = call.asInstanceOf[SqlSelect] | ||
|
|
||
| if (targetPosition.nonEmpty && sqlSelect.getSelectList.size() != targetPosition.size()) { | ||
| if ( | ||
| targetPosition.nonEmpty && sqlSelect.getSelectList.size() != targetPosition.size() | ||
| && sqlSelect.getSelectList.count( | ||
| s => s.isInstanceOf[SqlIdentifier] && s.asInstanceOf[SqlIdentifier].isStar) == 0 | ||
| ) { | ||
| throw newValidationError(call, RESOURCE.columnCountMismatch()) | ||
| } | ||
| rewriterUtils.rewriteSelect(sqlSelect, targetRowType, assignedFields, targetPosition) | ||
|
|
@@ -157,13 +168,58 @@ object SqlRewriterUtils { | |
| operands.get(1).asInstanceOf[SqlNodeList], | ||
| operands.get(2), | ||
| operands.get(3)) | ||
| // Not support: | ||
| // case SqlKind.WITH => | ||
| // case SqlKind.EXPLICIT_TABLE => | ||
| case SqlKind.EXPLICIT_TABLE => | ||
| val operands = call.getOperandList | ||
| val expTable = new ExplicitTableSqlSelect( | ||
| operands.get(0).asInstanceOf[SqlIdentifier], | ||
| Collections.emptyList()) | ||
| rewriterUtils.rewriteSelect(expTable, targetRowType, assignedFields, targetPosition) | ||
| case SqlKind.WITH => | ||
| rewriterUtils.rewriteWith( | ||
| call.asInstanceOf[SqlWith], | ||
| targetRowType, | ||
| assignedFields, | ||
| targetPosition) | ||
| case _ => throw new ValidationException(unsupportedErrorMessage()) | ||
| } | ||
| } | ||
|
|
||
| def rewriteSqlWith( | ||
| validator: FlinkCalciteSqlValidator, | ||
| cte: SqlWith, | ||
| targetRowType: RelDataType, | ||
| assignedFields: util.LinkedHashMap[Integer, SqlNode], | ||
| targetPosition: util.List[Int]): SqlCall = { | ||
| // Expands the select list first in case there is a star(*). | ||
| // Validates the select first to register the where scope. | ||
| validator.validate(cte) | ||
| val selects = new util.ArrayList[SqlSelect]() | ||
| extractSelectsFromCte(cte.body.asInstanceOf[SqlCall], selects) | ||
|
|
||
| for (select <- selects) { | ||
| reorderAndValidateForSelect(validator, select, targetRowType, assignedFields, targetPosition) | ||
| } | ||
| cte | ||
| } | ||
|
|
||
| def extractSelectsFromCte(cte: SqlCall, selects: util.List[SqlSelect]): Unit = { | ||
| cte match { | ||
| case select: SqlSelect => | ||
| selects.add(select) | ||
| return | ||
| case _ => | ||
| } | ||
| for (s <- cte.getOperandList) { | ||
| s match { | ||
| case select: SqlSelect => | ||
| selects.add(select) | ||
| case call: SqlCall => | ||
| extractSelectsFromCte(call, selects) | ||
| case _ => | ||
| } | ||
| } | ||
| } | ||
|
|
||
| def rewriteSqlSelect( | ||
| validator: FlinkCalciteSqlValidator, | ||
| select: SqlSelect, | ||
|
|
@@ -173,69 +229,73 @@ object SqlRewriterUtils { | |
| // Expands the select list first in case there is a star(*). | ||
| // Validates the select first to register the where scope. | ||
| validator.validate(select) | ||
| val sourceList = validator.expandStar(select.getSelectList, select, false).getList | ||
| reorderAndValidateForSelect(validator, select, targetRowType, assignedFields, targetPosition) | ||
| select | ||
| } | ||
|
|
||
| def rewriteSqlValues( | ||
| values: SqlCall, | ||
| targetRowType: RelDataType, | ||
| assignedFields: util.LinkedHashMap[Integer, SqlNode], | ||
| targetPosition: util.List[Int]): SqlCall = { | ||
| val fixedNodes = new util.ArrayList[SqlNode] | ||
| (0 until values.getOperandList.size()).foreach { | ||
| valueIdx => | ||
| val value = values.getOperandList.get(valueIdx) | ||
| val valueAsList = if (value.getKind == SqlKind.ROW) { | ||
| value.asInstanceOf[SqlCall].getOperandList | ||
| } else { | ||
| Collections.singletonList(value) | ||
| } | ||
| val nodes = getReorderedNodes(targetRowType, assignedFields, targetPosition, valueAsList) | ||
| fixedNodes.add(SqlStdOperatorTable.ROW.createCall(value.getParserPosition, nodes)) | ||
| } | ||
| SqlStdOperatorTable.VALUES.createCall(values.getParserPosition, fixedNodes) | ||
| } | ||
|
|
||
| private def getReorderedNodes( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here mostly extracting common logic in one place to reduce code duplication |
||
| targetRowType: RelDataType, | ||
| assignedFields: util.LinkedHashMap[Integer, SqlNode], | ||
| targetPosition: util.List[Int], | ||
| valueAsList: util.List[SqlNode]): util.List[SqlNode] = { | ||
| val currentNodes = | ||
| if (targetPosition.isEmpty) { | ||
| new util.ArrayList[SqlNode](sourceList) | ||
| new util.ArrayList[SqlNode](valueAsList) | ||
| } else { | ||
| reorder(new util.ArrayList[SqlNode](sourceList), targetPosition) | ||
| reorder(new util.ArrayList[SqlNode](valueAsList), targetPosition) | ||
| } | ||
|
|
||
| val fieldNodes = new util.ArrayList[SqlNode] | ||
| (0 until targetRowType.getFieldList.length).foreach { | ||
| idx => | ||
| if (assignedFields.containsKey(idx)) { | ||
| fixedNodes.add(assignedFields.get(idx)) | ||
| fieldNodes.add(assignedFields.get(idx)) | ||
| } else if (currentNodes.size() > 0) { | ||
| fixedNodes.add(currentNodes.remove(0)) | ||
| fieldNodes.add(currentNodes.remove(0)) | ||
| } | ||
| } | ||
| // Although it is error case, we still append the old remaining | ||
| // projection nodes to new projection. | ||
| // value items to new item list. | ||
| if (currentNodes.size > 0) { | ||
| fixedNodes.addAll(currentNodes) | ||
| fieldNodes.addAll(currentNodes) | ||
| } | ||
| select.setSelectList(new SqlNodeList(fixedNodes, select.getSelectList.getParserPosition)) | ||
| select | ||
| fieldNodes | ||
| } | ||
|
|
||
| def rewriteSqlValues( | ||
| values: SqlCall, | ||
| private def reorderAndValidateForSelect( | ||
| validator: FlinkCalciteSqlValidator, | ||
| select: SqlSelect, | ||
| targetRowType: RelDataType, | ||
| assignedFields: util.LinkedHashMap[Integer, SqlNode], | ||
| targetPosition: util.List[Int]): SqlCall = { | ||
| val fixedNodes = new util.ArrayList[SqlNode] | ||
| (0 until values.getOperandList.size()).foreach { | ||
| valueIdx => | ||
| val value = values.getOperandList.get(valueIdx) | ||
| val valueAsList = if (value.getKind == SqlKind.ROW) { | ||
| value.asInstanceOf[SqlCall].getOperandList | ||
| } else { | ||
| Collections.singletonList(value) | ||
| } | ||
| val currentNodes = | ||
| if (targetPosition.isEmpty) { | ||
| new util.ArrayList[SqlNode](valueAsList) | ||
| } else { | ||
| reorder(new util.ArrayList[SqlNode](valueAsList), targetPosition) | ||
| } | ||
| val fieldNodes = new util.ArrayList[SqlNode] | ||
| (0 until targetRowType.getFieldList.length).foreach { | ||
| fieldIdx => | ||
| if (assignedFields.containsKey(fieldIdx)) { | ||
| fieldNodes.add(assignedFields.get(fieldIdx)) | ||
| } else if (currentNodes.size() > 0) { | ||
| fieldNodes.add(currentNodes.remove(0)) | ||
| } | ||
| } | ||
| // Although it is error case, we still append the old remaining | ||
| // value items to new item list. | ||
| if (currentNodes.size > 0) { | ||
| fieldNodes.addAll(currentNodes) | ||
| } | ||
| fixedNodes.add(SqlStdOperatorTable.ROW.createCall(value.getParserPosition, fieldNodes)) | ||
| targetPosition: util.List[Int]): Unit = { | ||
| val sourceList = validator.expandStar(select.getSelectList, select, false).getList | ||
|
|
||
| if (targetPosition.nonEmpty && sourceList.size() != targetPosition.size()) { | ||
| throw newValidationError(select, RESOURCE.columnCountMismatch()) | ||
| } | ||
| SqlStdOperatorTable.VALUES.createCall(values.getParserPosition, fixedNodes) | ||
|
|
||
| val nodes = getReorderedNodes(targetRowType, assignedFields, targetPosition, sourceList) | ||
| select.setSelectList(new SqlNodeList(nodes, select.getSelectList.getParserPosition)) | ||
| } | ||
|
|
||
| def newValidationError( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,8 @@ | |
| import org.apache.flink.table.planner.utils.PlannerMocks; | ||
|
|
||
| import org.junit.jupiter.api.Test; | ||
| import org.junit.jupiter.params.ParameterizedTest; | ||
| import org.junit.jupiter.params.provider.ValueSource; | ||
|
|
||
| import static org.assertj.core.api.Assertions.assertThatThrownBy; | ||
| import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; | ||
|
|
@@ -40,6 +42,21 @@ class FlinkCalciteSqlValidatorTest { | |
| Schema.newBuilder() | ||
| .column("a", DataTypes.INT()) | ||
| .column("b", DataTypes.INT()) | ||
| .build()) | ||
| .registerTemporaryTable( | ||
| "t2_copy", | ||
| Schema.newBuilder() | ||
| .column("a", DataTypes.INT()) | ||
| .column("b", DataTypes.INT()) | ||
| .build()) | ||
| .registerTemporaryTable( | ||
| "t_nested", | ||
| Schema.newBuilder() | ||
| .column( | ||
| "f", | ||
| DataTypes.ROW( | ||
| DataTypes.FIELD("a", DataTypes.INT()), | ||
| DataTypes.FIELD("b", DataTypes.INT()))) | ||
| .build()); | ||
|
|
||
| @Test | ||
|
|
@@ -50,74 +67,67 @@ void testUpsertInto() { | |
| "UPSERT INTO statement is not supported. Please use INSERT INTO instead."); | ||
| } | ||
|
|
||
| @Test | ||
| void testInsertIntoShouldColumnMismatchWithValues() { | ||
| assertThatThrownBy(() -> plannerMocks.getParser().parse("INSERT INTO t2 (a,b) VALUES(1)")) | ||
| .isInstanceOf(ValidationException.class) | ||
| .hasMessageContaining(" Number of columns must match number of query columns"); | ||
| } | ||
|
|
||
| @Test | ||
| void testInsertIntoShouldColumnMismatchWithSelect() { | ||
| assertThatThrownBy(() -> plannerMocks.getParser().parse("INSERT INTO t2 (a,b) SELECT 1")) | ||
| .isInstanceOf(ValidationException.class) | ||
| .hasMessageContaining(" Number of columns must match number of query columns"); | ||
| } | ||
|
|
||
| @Test | ||
| void testInsertIntoShouldColumnMismatchWithLastValue() { | ||
| assertThatThrownBy( | ||
| () -> | ||
| plannerMocks | ||
| .getParser() | ||
| .parse("INSERT INTO t2 (a,b) VALUES (1,2), (3)")) | ||
| .isInstanceOf(ValidationException.class) | ||
| .hasMessageContaining(" Number of columns must match number of query columns"); | ||
| } | ||
|
|
||
| @Test | ||
| void testInsertIntoShouldColumnMismatchWithFirstValue() { | ||
| assertThatThrownBy( | ||
| () -> | ||
| plannerMocks | ||
| .getParser() | ||
| .parse("INSERT INTO t2 (a,b) VALUES (1), (2,3)")) | ||
| .isInstanceOf(ValidationException.class) | ||
| .hasMessageContaining(" Number of columns must match number of query columns"); | ||
| } | ||
|
|
||
| @Test | ||
| void testInsertIntoShouldColumnMismatchWithMultiFieldValues() { | ||
| assertThatThrownBy( | ||
| () -> | ||
| plannerMocks | ||
| .getParser() | ||
| .parse("INSERT INTO t2 (a,b) VALUES (1,2), (3,4,5)")) | ||
| @ParameterizedTest | ||
| @ValueSource( | ||
| strings = { | ||
| "INSERT INTO t2 (a, b) VALUES (1)", | ||
| "INSERT INTO t2 (a, b) VALUES (1, 2), (3)", | ||
| "INSERT INTO t2 (a, b) VALUES (1), (2, 3)", | ||
| "INSERT INTO t2 (a, b) VALUES (1, 2), (3, 4, 5)", | ||
| "INSERT INTO t2 (a, b) SELECT 1", | ||
| "INSERT INTO t2 (a, b) SELECT COALESCE(123, 456), LEAST(1, 2), GREATEST(3, 4, 5)", | ||
| "INSERT INTO t2 (a, b) SELECT * FROM t1", | ||
| "INSERT INTO t2 (a, b) SELECT *, *, * FROM t1", | ||
| "INSERT INTO t2 (a, b) SELECT *, 42 FROM t2_copy", | ||
| "INSERT INTO t2 (a, b) SELECT 42, * FROM t2_copy", | ||
| "INSERT INTO t2 (a, b) SELECT * FROM t_nested", | ||
| "INSERT INTO t2 (a, b) TABLE t_nested", | ||
| "INSERT INTO t2 (a, b) SELECT * FROM (TABLE t_nested)", | ||
| "INSERT INTO t2 (a, b) WITH cte AS (SELECT 1, 2, 3) SELECT * FROM cte", | ||
| "INSERT INTO t2 (a, b) WITH cte AS (SELECT * FROM t1, t2_copy) SELECT * FROM cte", | ||
| "INSERT INTO t2 (a, b) WITH cte1 AS (SELECT 1, 2), cte2 AS (SELECT 2, 1) SELECT * FROM cte1, cte2" | ||
| }) | ||
| void testInvalidNumberOfColumnsWhileInsertInto(String sql) { | ||
| assertThatThrownBy(() -> plannerMocks.getParser().parse(sql)) | ||
| .isInstanceOf(ValidationException.class) | ||
| .hasMessageContaining(" Number of columns must match number of query columns"); | ||
| } | ||
|
|
||
| @Test | ||
| void testInsertIntoShouldNotColumnMismatchWithValues() { | ||
| assertDoesNotThrow( | ||
| () -> { | ||
| plannerMocks.getParser().parse("INSERT INTO t2 (a,b) VALUES (1,2), (3,4)"); | ||
| }); | ||
| } | ||
|
|
||
| @Test | ||
| void testInsertIntoShouldNotColumnMismatchWithSelect() { | ||
| assertDoesNotThrow( | ||
| () -> { | ||
| plannerMocks.getParser().parse("INSERT INTO t2 (a,b) Select 1, 2"); | ||
| }); | ||
| } | ||
|
|
||
| @Test | ||
| void testInsertIntoShouldNotColumnMismatchWithSingleColValues() { | ||
| @ParameterizedTest | ||
| @ValueSource( | ||
| strings = { | ||
| "INSERT INTO t2 (a, b) VALUES (1, 2), (3, 4)", | ||
| "INSERT INTO t2 (a) VALUES (1), (3)", | ||
| "INSERT INTO t2 (a, b) SELECT 1, 2", | ||
| "INSERT INTO t2 (a, b) SELECT LEAST(1, 2, 3), 2 * 2", | ||
| "INSERT INTO t2 (a, b) SELECT * FROM t2_copy", | ||
| "INSERT INTO t2 (a, b) SELECT *, * FROM t1", | ||
| "INSERT INTO t2 (a, b) SELECT *, 42 FROM t1", | ||
| "INSERT INTO t2 (a, b) SELECT 42, * FROM t1", | ||
| "INSERT INTO t2 (a, b) SELECT f.* FROM t_nested", | ||
| "INSERT INTO t2 (a, b) TABLE t2_copy", | ||
| "INSERT INTO t2 (a, b) WITH cte AS (SELECT 1, 2) SELECT * FROM cte", | ||
| "INSERT INTO t2 (a, b) WITH cte AS (SELECT * FROM t2_copy) SELECT * FROM cte", | ||
| "INSERT INTO t2 (a, b) WITH cte AS (SELECT t1.a, t2_copy.b FROM t1, t2_copy) SELECT * FROM cte", | ||
| "INSERT INTO t2 (a, b) WITH cte1 AS (SELECT 1), cte2 AS (SELECT 2) SELECT * FROM cte1, cte2", | ||
| "INSERT INTO t2 (a, b) " | ||
| + "WITH cte1 AS (SELECT 1, 2), " | ||
| + "cte2 AS (SELECT 2, 3) " | ||
| + "SELECT * FROM cte1 UNION SELECT * FROM cte2", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are failing unit tests with CTEs above. Do we want some UNIONs which do not work above? Definitely a non-blocking suggestion.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With latest commit all tests passed
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good! |
||
| "INSERT INTO t2 (a, b) " | ||
| + "WITH cte1 AS (SELECT 1, 2), " | ||
| + "cte2 AS (SELECT 2, 3), " | ||
| + "cte3 AS (SELECT 3, 4), " | ||
| + "cte4 AS (SELECT 4, 5) " | ||
| + "SELECT * FROM cte1 " | ||
| + "UNION SELECT * FROM cte2 " | ||
| + "INTERSECT SELECT * FROM cte3 " | ||
| + "UNION ALL SELECT * FROM cte4" | ||
| }) | ||
| void validInsertIntoTest(final String sql) { | ||
| assertDoesNotThrow( | ||
| () -> { | ||
| plannerMocks.getParser().parse("INSERT INTO t2 (a) VALUES (1), (3)"); | ||
| plannerMocks.getParser().parse(sql); | ||
| }); | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
I kinda don't like that
selectsis passed in instead of returned. I tried refactoring it, but it doesn't look much better.