Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ protected void addToSelectList(
* A special {@link SqlSelect} to capture the origin of a {@link SqlKind#EXPLICIT_TABLE} within
* TVF operands.
*/
private static class ExplicitTableSqlSelect extends SqlSelect {
static class ExplicitTableSqlSelect extends SqlSelect {

private final List<SqlIdentifier> descriptors;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package org.apache.flink.table.planner.calcite

import org.apache.flink.sql.parser.`type`.SqlMapTypeNameSpec
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.planner.calcite.PreValidateReWriter.{newValidationError, notSupported}
import org.apache.flink.table.planner.calcite.SqlRewriterUtils.{rewriteSqlCall, rewriteSqlSelect, rewriteSqlValues}
import org.apache.flink.table.planner.calcite.FlinkCalciteSqlValidator.ExplicitTableSqlSelect
import org.apache.flink.table.planner.calcite.SqlRewriterUtils.{rewriteSqlCall, rewriteSqlSelect, rewriteSqlValues, rewriteSqlWith}
import org.apache.flink.util.Preconditions.checkArgument

import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory}
import org.apache.calcite.runtime.{CalciteContextException, Resources}
import org.apache.calcite.sql.`type`.SqlTypeUtil
import org.apache.calcite.sql.{SqlCall, SqlDataTypeSpec, SqlKind, SqlNode, SqlNodeList, SqlOrderBy, SqlSelect, SqlUtil}
import org.apache.calcite.sql.{SqlBasicCall, SqlCall, SqlDataTypeSpec, SqlIdentifier, SqlKind, SqlNode, SqlNodeList, SqlOrderBy, SqlSelect, SqlUtil, SqlWith}
import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.sql.parser.SqlParserPos
import org.apache.calcite.sql.validate.SqlValidatorException
Expand Down Expand Up @@ -54,6 +54,14 @@ class SqlRewriterUtils(validator: FlinkCalciteSqlValidator) {
rewriteSqlValues(svalues, targetRowType, assignedFields, targetPosition)
}

def rewriteWith(
sqlWith: SqlWith,
targetRowType: RelDataType,
assignedFields: util.LinkedHashMap[Integer, SqlNode],
targetPosition: util.List[Int]): SqlCall = {
rewriteSqlWith(validator, sqlWith, targetRowType, assignedFields, targetPosition)
}

def rewriteCall(
rewriterUtils: SqlRewriterUtils,
validator: FlinkCalciteSqlValidator,
Expand Down Expand Up @@ -130,8 +138,11 @@ object SqlRewriterUtils {
call.getKind match {
case SqlKind.SELECT =>
val sqlSelect = call.asInstanceOf[SqlSelect]

if (targetPosition.nonEmpty && sqlSelect.getSelectList.size() != targetPosition.size()) {
if (
targetPosition.nonEmpty && sqlSelect.getSelectList.size() != targetPosition.size()
&& sqlSelect.getSelectList.count(
s => s.isInstanceOf[SqlIdentifier] && s.asInstanceOf[SqlIdentifier].isStar) == 0
) {
throw newValidationError(call, RESOURCE.columnCountMismatch())
}
rewriterUtils.rewriteSelect(sqlSelect, targetRowType, assignedFields, targetPosition)
Expand All @@ -157,13 +168,58 @@ object SqlRewriterUtils {
operands.get(1).asInstanceOf[SqlNodeList],
operands.get(2),
operands.get(3))
// Not support:
// case SqlKind.WITH =>
// case SqlKind.EXPLICIT_TABLE =>
case SqlKind.EXPLICIT_TABLE =>
val operands = call.getOperandList
val expTable = new ExplicitTableSqlSelect(
operands.get(0).asInstanceOf[SqlIdentifier],
Collections.emptyList())
rewriterUtils.rewriteSelect(expTable, targetRowType, assignedFields, targetPosition)
case SqlKind.WITH =>
rewriterUtils.rewriteWith(
call.asInstanceOf[SqlWith],
targetRowType,
assignedFields,
targetPosition)
case _ => throw new ValidationException(unsupportedErrorMessage())
}
}

def rewriteSqlWith(
validator: FlinkCalciteSqlValidator,
cte: SqlWith,
targetRowType: RelDataType,
assignedFields: util.LinkedHashMap[Integer, SqlNode],
targetPosition: util.List[Int]): SqlCall = {
// Expands the select list first in case there is a star(*).
// Validates the select first to register the where scope.
validator.validate(cte)
val selects = new util.ArrayList[SqlSelect]()
extractSelectsFromCte(cte.body.asInstanceOf[SqlCall], selects)

for (select <- selects) {
reorderAndValidateForSelect(validator, select, targetRowType, assignedFields, targetPosition)
}
cte
}

def extractSelectsFromCte(cte: SqlCall, selects: util.List[SqlSelect]): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

I kinda don't like that selects is passed in instead of returned. I tried refactoring it, but it doesn't look much better.

cte match {
case select: SqlSelect =>
selects.add(select)
return
case _ =>
}
for (s <- cte.getOperandList) {
s match {
case select: SqlSelect =>
selects.add(select)
case call: SqlCall =>
extractSelectsFromCte(call, selects)
case _ =>
}
}
}

def rewriteSqlSelect(
validator: FlinkCalciteSqlValidator,
select: SqlSelect,
Expand All @@ -173,69 +229,73 @@ object SqlRewriterUtils {
// Expands the select list first in case there is a star(*).
// Validates the select first to register the where scope.
validator.validate(select)
val sourceList = validator.expandStar(select.getSelectList, select, false).getList
reorderAndValidateForSelect(validator, select, targetRowType, assignedFields, targetPosition)
select
}

def rewriteSqlValues(
values: SqlCall,
targetRowType: RelDataType,
assignedFields: util.LinkedHashMap[Integer, SqlNode],
targetPosition: util.List[Int]): SqlCall = {
val fixedNodes = new util.ArrayList[SqlNode]
(0 until values.getOperandList.size()).foreach {
valueIdx =>
val value = values.getOperandList.get(valueIdx)
val valueAsList = if (value.getKind == SqlKind.ROW) {
value.asInstanceOf[SqlCall].getOperandList
} else {
Collections.singletonList(value)
}
val nodes = getReorderedNodes(targetRowType, assignedFields, targetPosition, valueAsList)
fixedNodes.add(SqlStdOperatorTable.ROW.createCall(value.getParserPosition, nodes))
}
SqlStdOperatorTable.VALUES.createCall(values.getParserPosition, fixedNodes)
}

private def getReorderedNodes(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here mostly extracting common logic in one place to reduce code duplication

targetRowType: RelDataType,
assignedFields: util.LinkedHashMap[Integer, SqlNode],
targetPosition: util.List[Int],
valueAsList: util.List[SqlNode]): util.List[SqlNode] = {
val currentNodes =
if (targetPosition.isEmpty) {
new util.ArrayList[SqlNode](sourceList)
new util.ArrayList[SqlNode](valueAsList)
} else {
reorder(new util.ArrayList[SqlNode](sourceList), targetPosition)
reorder(new util.ArrayList[SqlNode](valueAsList), targetPosition)
}

val fieldNodes = new util.ArrayList[SqlNode]
(0 until targetRowType.getFieldList.length).foreach {
idx =>
if (assignedFields.containsKey(idx)) {
fixedNodes.add(assignedFields.get(idx))
fieldNodes.add(assignedFields.get(idx))
} else if (currentNodes.size() > 0) {
fixedNodes.add(currentNodes.remove(0))
fieldNodes.add(currentNodes.remove(0))
}
}
// Although it is error case, we still append the old remaining
// projection nodes to new projection.
// value items to new item list.
if (currentNodes.size > 0) {
fixedNodes.addAll(currentNodes)
fieldNodes.addAll(currentNodes)
}
select.setSelectList(new SqlNodeList(fixedNodes, select.getSelectList.getParserPosition))
select
fieldNodes
}

def rewriteSqlValues(
values: SqlCall,
private def reorderAndValidateForSelect(
validator: FlinkCalciteSqlValidator,
select: SqlSelect,
targetRowType: RelDataType,
assignedFields: util.LinkedHashMap[Integer, SqlNode],
targetPosition: util.List[Int]): SqlCall = {
val fixedNodes = new util.ArrayList[SqlNode]
(0 until values.getOperandList.size()).foreach {
valueIdx =>
val value = values.getOperandList.get(valueIdx)
val valueAsList = if (value.getKind == SqlKind.ROW) {
value.asInstanceOf[SqlCall].getOperandList
} else {
Collections.singletonList(value)
}
val currentNodes =
if (targetPosition.isEmpty) {
new util.ArrayList[SqlNode](valueAsList)
} else {
reorder(new util.ArrayList[SqlNode](valueAsList), targetPosition)
}
val fieldNodes = new util.ArrayList[SqlNode]
(0 until targetRowType.getFieldList.length).foreach {
fieldIdx =>
if (assignedFields.containsKey(fieldIdx)) {
fieldNodes.add(assignedFields.get(fieldIdx))
} else if (currentNodes.size() > 0) {
fieldNodes.add(currentNodes.remove(0))
}
}
// Although it is error case, we still append the old remaining
// value items to new item list.
if (currentNodes.size > 0) {
fieldNodes.addAll(currentNodes)
}
fixedNodes.add(SqlStdOperatorTable.ROW.createCall(value.getParserPosition, fieldNodes))
targetPosition: util.List[Int]): Unit = {
val sourceList = validator.expandStar(select.getSelectList, select, false).getList

if (targetPosition.nonEmpty && sourceList.size() != targetPosition.size()) {
throw newValidationError(select, RESOURCE.columnCountMismatch())
}
SqlStdOperatorTable.VALUES.createCall(values.getParserPosition, fixedNodes)

val nodes = getReorderedNodes(targetRowType, assignedFields, targetPosition, sourceList)
select.setSelectList(new SqlNodeList(nodes, select.getSelectList.getParserPosition))
}

def newValidationError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.apache.flink.table.planner.utils.PlannerMocks;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
Expand All @@ -40,6 +42,21 @@ class FlinkCalciteSqlValidatorTest {
Schema.newBuilder()
.column("a", DataTypes.INT())
.column("b", DataTypes.INT())
.build())
.registerTemporaryTable(
"t2_copy",
Schema.newBuilder()
.column("a", DataTypes.INT())
.column("b", DataTypes.INT())
.build())
.registerTemporaryTable(
"t_nested",
Schema.newBuilder()
.column(
"f",
DataTypes.ROW(
DataTypes.FIELD("a", DataTypes.INT()),
DataTypes.FIELD("b", DataTypes.INT())))
.build());

@Test
Expand All @@ -50,74 +67,67 @@ void testUpsertInto() {
"UPSERT INTO statement is not supported. Please use INSERT INTO instead.");
}

@Test
void testInsertIntoShouldColumnMismatchWithValues() {
assertThatThrownBy(() -> plannerMocks.getParser().parse("INSERT INTO t2 (a,b) VALUES(1)"))
.isInstanceOf(ValidationException.class)
.hasMessageContaining(" Number of columns must match number of query columns");
}

@Test
void testInsertIntoShouldColumnMismatchWithSelect() {
assertThatThrownBy(() -> plannerMocks.getParser().parse("INSERT INTO t2 (a,b) SELECT 1"))
.isInstanceOf(ValidationException.class)
.hasMessageContaining(" Number of columns must match number of query columns");
}

@Test
void testInsertIntoShouldColumnMismatchWithLastValue() {
assertThatThrownBy(
() ->
plannerMocks
.getParser()
.parse("INSERT INTO t2 (a,b) VALUES (1,2), (3)"))
.isInstanceOf(ValidationException.class)
.hasMessageContaining(" Number of columns must match number of query columns");
}

@Test
void testInsertIntoShouldColumnMismatchWithFirstValue() {
assertThatThrownBy(
() ->
plannerMocks
.getParser()
.parse("INSERT INTO t2 (a,b) VALUES (1), (2,3)"))
.isInstanceOf(ValidationException.class)
.hasMessageContaining(" Number of columns must match number of query columns");
}

@Test
void testInsertIntoShouldColumnMismatchWithMultiFieldValues() {
assertThatThrownBy(
() ->
plannerMocks
.getParser()
.parse("INSERT INTO t2 (a,b) VALUES (1,2), (3,4,5)"))
@ParameterizedTest
@ValueSource(
strings = {
"INSERT INTO t2 (a, b) VALUES (1)",
"INSERT INTO t2 (a, b) VALUES (1, 2), (3)",
"INSERT INTO t2 (a, b) VALUES (1), (2, 3)",
"INSERT INTO t2 (a, b) VALUES (1, 2), (3, 4, 5)",
"INSERT INTO t2 (a, b) SELECT 1",
"INSERT INTO t2 (a, b) SELECT COALESCE(123, 456), LEAST(1, 2), GREATEST(3, 4, 5)",
"INSERT INTO t2 (a, b) SELECT * FROM t1",
"INSERT INTO t2 (a, b) SELECT *, *, * FROM t1",
"INSERT INTO t2 (a, b) SELECT *, 42 FROM t2_copy",
"INSERT INTO t2 (a, b) SELECT 42, * FROM t2_copy",
"INSERT INTO t2 (a, b) SELECT * FROM t_nested",
"INSERT INTO t2 (a, b) TABLE t_nested",
"INSERT INTO t2 (a, b) SELECT * FROM (TABLE t_nested)",
"INSERT INTO t2 (a, b) WITH cte AS (SELECT 1, 2, 3) SELECT * FROM cte",
"INSERT INTO t2 (a, b) WITH cte AS (SELECT * FROM t1, t2_copy) SELECT * FROM cte",
"INSERT INTO t2 (a, b) WITH cte1 AS (SELECT 1, 2), cte2 AS (SELECT 2, 1) SELECT * FROM cte1, cte2"
})
void testInvalidNumberOfColumnsWhileInsertInto(String sql) {
assertThatThrownBy(() -> plannerMocks.getParser().parse(sql))
.isInstanceOf(ValidationException.class)
.hasMessageContaining(" Number of columns must match number of query columns");
}

@Test
void testInsertIntoShouldNotColumnMismatchWithValues() {
assertDoesNotThrow(
() -> {
plannerMocks.getParser().parse("INSERT INTO t2 (a,b) VALUES (1,2), (3,4)");
});
}

@Test
void testInsertIntoShouldNotColumnMismatchWithSelect() {
assertDoesNotThrow(
() -> {
plannerMocks.getParser().parse("INSERT INTO t2 (a,b) Select 1, 2");
});
}

@Test
void testInsertIntoShouldNotColumnMismatchWithSingleColValues() {
@ParameterizedTest
@ValueSource(
strings = {
"INSERT INTO t2 (a, b) VALUES (1, 2), (3, 4)",
"INSERT INTO t2 (a) VALUES (1), (3)",
"INSERT INTO t2 (a, b) SELECT 1, 2",
"INSERT INTO t2 (a, b) SELECT LEAST(1, 2, 3), 2 * 2",
"INSERT INTO t2 (a, b) SELECT * FROM t2_copy",
"INSERT INTO t2 (a, b) SELECT *, * FROM t1",
"INSERT INTO t2 (a, b) SELECT *, 42 FROM t1",
"INSERT INTO t2 (a, b) SELECT 42, * FROM t1",
"INSERT INTO t2 (a, b) SELECT f.* FROM t_nested",
"INSERT INTO t2 (a, b) TABLE t2_copy",
"INSERT INTO t2 (a, b) WITH cte AS (SELECT 1, 2) SELECT * FROM cte",
"INSERT INTO t2 (a, b) WITH cte AS (SELECT * FROM t2_copy) SELECT * FROM cte",
"INSERT INTO t2 (a, b) WITH cte AS (SELECT t1.a, t2_copy.b FROM t1, t2_copy) SELECT * FROM cte",
"INSERT INTO t2 (a, b) WITH cte1 AS (SELECT 1), cte2 AS (SELECT 2) SELECT * FROM cte1, cte2",
"INSERT INTO t2 (a, b) "
+ "WITH cte1 AS (SELECT 1, 2), "
+ "cte2 AS (SELECT 2, 3) "
+ "SELECT * FROM cte1 UNION SELECT * FROM cte2",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are failing unit tests with CTEs above. Do we want some UNIONs which do not work above? Definitely a non-blocking suggestion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With latest commit all tests passed
so shouldn't be an issue I guess

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

"INSERT INTO t2 (a, b) "
+ "WITH cte1 AS (SELECT 1, 2), "
+ "cte2 AS (SELECT 2, 3), "
+ "cte3 AS (SELECT 3, 4), "
+ "cte4 AS (SELECT 4, 5) "
+ "SELECT * FROM cte1 "
+ "UNION SELECT * FROM cte2 "
+ "INTERSECT SELECT * FROM cte3 "
+ "UNION ALL SELECT * FROM cte4"
})
void validInsertIntoTest(final String sql) {
assertDoesNotThrow(
() -> {
plannerMocks.getParser().parse("INSERT INTO t2 (a) VALUES (1), (3)");
plannerMocks.getParser().parse(sql);
});
}

Expand Down
Loading