diff --git a/core/src/main/java/org/apache/calcite/rel/type/RelCrossType.java b/core/src/main/java/org/apache/calcite/rel/type/RelCrossType.java index 31445cc3d3a..f39f2b67737 100644 --- a/core/src/main/java/org/apache/calcite/rel/type/RelCrossType.java +++ b/core/src/main/java/org/apache/calcite/rel/type/RelCrossType.java @@ -31,7 +31,7 @@ public class RelCrossType extends RelDataTypeImpl { //~ Instance fields -------------------------------------------------------- - public final ImmutableList types; + private final ImmutableList types; //~ Constructors ----------------------------------------------------------- @@ -58,6 +58,15 @@ public RelCrossType( return false; } + /** + * Returns the contained types. + * + * @return data types. + */ + public List getTypes() { + return types; + } + @Override protected void generateTypeString(StringBuilder sb, boolean withDetail) { sb.append("CrossType("); for (Ord type : Ord.zip(types)) { diff --git a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactoryImpl.java b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactoryImpl.java index d4715f2ea07..3f3b0904552 100644 --- a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactoryImpl.java +++ b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactoryImpl.java @@ -449,11 +449,11 @@ private static List getFieldList(List types) { * Returns a list of all atomic types in a list. */ private static void getTypeList( - ImmutableList inTypes, + List inTypes, List flatTypes) { for (RelDataType inType : inTypes) { if (inType instanceof RelCrossType) { - getTypeList(((RelCrossType) inType).types, flatTypes); + getTypeList(((RelCrossType) inType).getTypes(), flatTypes); } else { flatTypes.add(inType); } @@ -470,7 +470,7 @@ private static void addFields( List fieldList) { if (type instanceof RelCrossType) { final RelCrossType crossType = (RelCrossType) type; - for (RelDataType type1 : crossType.types) { + for (RelDataType type1 : crossType.getTypes()) { addFields(type1, fieldList); } } else { diff --git a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java index b6076cc786e..c1a479d038e 100644 --- a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java +++ b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java @@ -22,6 +22,7 @@ import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.type.DynamicRecordType; +import org.apache.calcite.rel.type.RelCrossType; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; @@ -3500,9 +3501,9 @@ protected void validateJoin(SqlJoin join, SqlValidatorScope scope) { (List) getCondition(join); // Parser ensures that using clause is not empty. - Preconditions.checkArgument(list.size() > 0, "Empty USING clause"); + Preconditions.checkArgument(!list.isEmpty(), "Empty USING clause"); for (SqlIdentifier id : list) { - validateCommonJoinColumn(id, left, right, scope); + validateCommonJoinColumn(id, left, right, scope, natural); } break; default: @@ -3521,7 +3522,7 @@ protected void validateJoin(SqlJoin join, SqlValidatorScope scope) { for (String name : deriveNaturalJoinColumnList(join)) { final SqlIdentifier id = new SqlIdentifier(name, join.isNaturalNode().getParserPosition()); - validateCommonJoinColumn(id, left, right, scope); + validateCommonJoinColumn(id, left, right, scope, natural); } } @@ -3585,16 +3586,17 @@ private void validateNoAggs(AggFinder aggFinder, SqlNode node, } } - /** Validates a column in a USING clause, or an inferred join key in a - * NATURAL join. */ + /** Validates a column in a USING clause, or an inferred join key in a NATURAL join. */ private void validateCommonJoinColumn(SqlIdentifier id, SqlNode left, - SqlNode right, SqlValidatorScope scope) { + SqlNode right, SqlValidatorScope scope, boolean natural) { if (id.names.size() != 1) { throw newValidationError(id, RESOURCE.columnNotFound(id.toString())); } - final RelDataType leftColType = validateCommonInputJoinColumn(id, left, scope); - final RelDataType rightColType = validateCommonInputJoinColumn(id, right, scope); + final RelDataType leftColType = natural + ? checkAndDeriveDataType(id, left) + : validateCommonInputJoinColumn(id, left, scope, natural); + final RelDataType rightColType = validateCommonInputJoinColumn(id, right, scope, natural); if (!SqlTypeUtil.isComparable(leftColType, rightColType)) { throw newValidationError(id, RESOURCE.naturalOrUsingColumnNotCompatible(id.getSimple(), @@ -3602,10 +3604,21 @@ private void validateCommonJoinColumn(SqlIdentifier id, SqlNode left, } } + private RelDataType checkAndDeriveDataType(SqlIdentifier id, SqlNode node) { + Preconditions.checkArgument(id.names.size() == 1); + String name = id.names.get(0); + SqlNameMatcher nameMatcher = getCatalogReader().nameMatcher(); + RelDataType rowType = getNamespaceOrThrow(node).getRowType(); + RelDataType colType = requireNonNull( + nameMatcher.field(rowType, name), + () -> "unable to find left field " + name + " in " + rowType).getType(); + return colType; + } + /** Validates a column in a USING clause, or an inferred join key in a * NATURAL join, in the left or right input to the join. */ private RelDataType validateCommonInputJoinColumn(SqlIdentifier id, - SqlNode leftOrRight, SqlValidatorScope scope) { + SqlNode leftOrRight, SqlValidatorScope scope, boolean natural) { Preconditions.checkArgument(id.names.size() == 1); final String name = id.names.get(0); final SqlValidatorNamespace namespace = getNamespaceOrThrow(leftOrRight); @@ -3615,9 +3628,17 @@ private RelDataType validateCommonInputJoinColumn(SqlIdentifier id, if (field == null) { throw newValidationError(id, RESOURCE.columnNotFound(name)); } - if (nameMatcher.frequency(rowType.getFieldNames(), name) > 1) { - throw newValidationError(id, - RESOURCE.columnInUsingNotUnique(name)); + Collection rowTypes; + if (!natural && rowType instanceof RelCrossType) { + final RelCrossType crossType = (RelCrossType) rowType; + rowTypes = new ArrayList<>(crossType.getTypes()); + } else { + rowTypes = Collections.singleton(rowType); + } + for (RelDataType rowType0 : rowTypes) { + if (nameMatcher.frequency(rowType0.getFieldNames(), name) > 1) { + throw newValidationError(id, RESOURCE.columnInUsingNotUnique(name)); + } } checkRollUpInUsing(id, leftOrRight, scope); return field.getType(); diff --git a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java index b981d12dc5c..be8d40ecc6a 100644 --- a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java +++ b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java @@ -63,6 +63,7 @@ import com.google.common.collect.Ordering; import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -6163,7 +6164,7 @@ private ImmutableList cube(ImmutableBitSet... sets) { /** Test case for * [CALCITE-5171] * NATURAL join and USING should fail if join columns are not unique. */ - @Test void testNaturalJoinDuplicateColumns() { + @Test void testJoinDuplicateColumns() { // NATURAL join and USING should fail if join columns are not unique final String message = "Column name 'DEPTNO' in NATURAL join or " + "USING clause is not unique on one side of join"; @@ -6179,20 +6180,60 @@ private ImmutableList cube(ImmutableBitSet... sets) { + " using (^deptno^)") .fails(message); + // Reversed query gives reversed error message + sql("select e.ename, d.name\n" + + "from (select ename, sal as deptno, deptno from emp) as e\n" + + "join dept as d\n" + + " using (^deptno^)") + .fails(message); + // Also with "*". (Proves that FROM is validated before SELECT.) sql("select *\n" + "from emp\n" + "left join (select deptno, name as deptno from dept)\n" + " using (^deptno^)") .fails(message); + } - // Reversed query gives reversed error message - sql("select e.ename, d.name\n" - + "from (select ename, sal as deptno, deptno from emp) as e\n" - + "join dept as d\n" - + " using (^deptno^)") + @Test @DisplayName("Natural join require input column uniqueness") + void testNaturalJoinRequireInputColumnUniqueness() { + final String message = "Column name 'DEPTNO' in NATURAL join or " + + "USING clause is not unique on one side of join"; + // Invalid. NATURAL JOIN eliminates duplicate columns from its output but + // requires input columns to be unique. + sql("select *\n" + + "from (emp as e cross join dept as d)\n" + + "^natural^ join\n" + + "(emp as e2 cross join dept as d2)") .fails(message); + } + + @Test @DisplayName("Should produce two DEPTNO columns") + void testReturnsCorrectRowTypeOnCombinedJoin() { + sql("select *\n" + + "from emp as e\n" + + "natural join dept as d\n" + + "join (select deptno as x, deptno from dept) as d2" + + " on d2.deptno = e.deptno") + .type("RecordType(" + + "INTEGER NOT NULL DEPTNO, " + + "INTEGER NOT NULL EMPNO, " + + "VARCHAR(20) NOT NULL ENAME, " + + "VARCHAR(10) NOT NULL JOB, " + + "INTEGER MGR, " + + "TIMESTAMP(0) NOT NULL HIREDATE, " + + "INTEGER NOT NULL SAL, " + + "INTEGER NOT NULL COMM, " + + "BOOLEAN NOT NULL SLACKER, " + + "VARCHAR(10) NOT NULL NAME, " + + "INTEGER NOT NULL X, " + + "INTEGER NOT NULL DEPTNO1) NOT NULL"); + } + /** Test case for + * [CALCITE-5171] + * NATURAL join and USING should fail if join columns are not unique. */ + @Test void testCorrectJoinDuplicateColumns() { // The error only occurs if the duplicate column is referenced. The // following query has a duplicate hiredate column. sql("select e.ename, d.name\n" @@ -6200,6 +6241,10 @@ private ImmutableList cube(ImmutableBitSet... sets) { + "join (select ename, sal as hiredate, deptno from emp) as e\n" + " using (deptno)") .ok(); + + // Previous join chain does not affect validation. + sql("select * from EMP natural join EMPNULLABLES natural join DEPT") + .ok(); } @Test void testNaturalEmptyKey() { @@ -6347,9 +6392,7 @@ private ImmutableList cube(ImmutableBitSet... sets) { + "from emp as e\n" + "join dept as d using (deptno)\n" + "join dept as d2 using (^deptno^)"; - final String expected = "Column name 'DEPTNO' in NATURAL join or " - + "USING clause is not unique on one side of join"; - sql(sql1).fails(expected); + sql(sql1).ok(); final String sql2 = "select *\n" + "from emp as e\n"