diff --git a/core/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java b/core/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java index e71ffe6f3c16..2b9340e08c42 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java @@ -28,6 +28,7 @@ import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rex.LogicVisitor; import org.apache.calcite.rex.RexCorrelVariable; @@ -39,6 +40,7 @@ import org.apache.calcite.rex.RexUtil; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.fun.SqlCountAggFunction; import org.apache.calcite.sql.fun.SqlQuantifyOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql2rel.RelDecorrelator; @@ -309,8 +311,22 @@ private RexNode rewriteExists(RexSubQuery e, Set variablesSet, } builder.as("dt"); + boolean generateNullsOnRight = true; + if (e.rel instanceof LogicalAggregate) { + // SELECT f0, count(*) FROM t GROUP BY () will always return 1 row. + // That means, the RHS never generates null. + final LogicalAggregate aggregate = (LogicalAggregate) e.rel; + if (aggregate.getGroupSet().isEmpty() + && aggregate.getAggCallList().stream() + .anyMatch(c -> c.getAggregation() instanceof SqlCountAggFunction)) { + generateNullsOnRight = false; + } + } - builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet); + builder.join( + generateNullsOnRight ? JoinRelType.LEFT : JoinRelType.INNER, + builder.literal(true), + variablesSet); return builder.isNotNull(Util.last(builder.fields())); } diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java index 944a61d200be..74adb34c7905 100644 --- a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java +++ b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java @@ -41,13 +41,13 @@ import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rel.core.Sort; -import org.apache.calcite.rel.core.Values; import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.logical.LogicalCorrelate; import org.apache.calcite.rel.logical.LogicalFilter; import org.apache.calcite.rel.logical.LogicalJoin; import org.apache.calcite.rel.logical.LogicalProject; import org.apache.calcite.rel.logical.LogicalSnapshot; +import org.apache.calcite.rel.logical.LogicalTableFunctionScan; import org.apache.calcite.rel.metadata.RelMdUtil; import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.rules.FilterCorrelateRule; @@ -432,10 +432,6 @@ public Frame decorrelateRel(Sort rel) { return register(rel, newSort, frame.oldToNewOutputs, frame.corDefOutputs); } - public Frame decorrelateRel(Values rel) { - // There are no inputs, so rel does not need to be changed. - return null; - } public Frame decorrelateRel(LogicalAggregate rel) { return decorrelateRel((Aggregate) rel); @@ -1017,6 +1013,13 @@ public Frame decorrelateRel(LogicalSnapshot rel) { return decorrelateRel((RelNode) rel); } + public Frame decorrelateRel(LogicalTableFunctionScan rel) { + if (RexUtil.containsCorrelation(rel.getCall())) { + return null; + } + return decorrelateRel((RelNode) rel); + } + public Frame decorrelateRel(LogicalFilter rel) { return decorrelateRel((Filter) rel); } @@ -1385,6 +1388,12 @@ private RelNode aggregateCorrelatorOutput( pair.left, projectPulledAboveLeftCorrelator, isCount); + // Fix the nullability. + if (projectPulledAboveLeftCorrelator) { + newProjExpr = relBuilder.getRexBuilder().makeAbstractCast( + relBuilder.getTypeFactory().createTypeWithNullability(newProjExpr.getType(), true), + newProjExpr); + } newProjects.add(Pair.of(newProjExpr, pair.right)); } diff --git a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java index 7f9fc34ae9d0..6fa4dfb39755 100644 --- a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java +++ b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java @@ -3369,10 +3369,15 @@ private void checkLiteral2(String expression, String expected) { + " lateral (select d.\"department_id\" + 1 as d_plusOne" + " from (values(true)))"; - final String expected = "SELECT \"$cor0\".\"department_id\", \"$cor0\".\"D_PLUSONE\"\n" - + "FROM \"foodmart\".\"department\" AS \"$cor0\",\n" - + "LATERAL (SELECT \"$cor0\".\"department_id\" + 1 AS \"D_PLUSONE\"\n" - + "FROM (VALUES (TRUE)) AS \"t\" (\"EXPR$0\")) AS \"t0\""; + final String expected = "SELECT \"department\".\"department_id\", \"t2\".\"D_PLUSONE\"\n" + + "FROM \"foodmart\".\"department\"\n" + + "INNER JOIN (SELECT \"t1\".\"department_id\" + 1 AS \"D_PLUSONE\"," + + " \"t1\".\"department_id\"\n" + + "FROM (VALUES (TRUE)) AS \"t\" (\"EXPR$0\"),\n" + + "(SELECT \"department_id\"\n" + + "FROM \"foodmart\".\"department\"\n" + + "GROUP BY \"department_id\") AS \"t1\") AS \"t2\" " + + "ON \"department\".\"department_id\" = \"t2\".\"department_id\""; sql(sql).ok(expected); } diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index dff27236bc42..85535d7d7aa4 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -5618,6 +5618,30 @@ private Sql checkSubQuery(String sql) { checkSubQuery(sql).withLateDecorrelation(true).check(); } + @Test public void testDecorrelateExists2() throws Exception { + final String sql = "select \n" + + " exists (select count(t2.id) \n" + + " from (values(1), (2)) t2(id) where t2.id = t1.id)\n" + + " from (values(3), (4)) t1(id)"; + checkSubQuery(sql).withLateDecorrelation(true).check(); + } + + @Test public void testDecorrelateExists3() throws Exception { + final String sql = "select \n" + + " exists (select min(t2.id) \n" + + " from (values(1), (2)) t2(id) where t2.id = t1.id)\n" + + " from (values(3), (4)) t1(id)"; + checkSubQuery(sql).withLateDecorrelation(true).check(); + } + + @Test public void testDecorrelateScalarSubQuery() throws Exception { + final String sql = "select \n" + + " (select count(t2.id) \n" + + " from (values(1), (2)) t2(id) where t2.id = t1.id)\n" + + " from (values(1), (2)) t1(id)"; + checkSubQuery(sql).withLateDecorrelation(true).check(); + } + /** Test case for * [CALCITE-1511] * AssertionError while decorrelating query with two EXISTS diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index 5323ed8e1145..5be8f2ba7eaf 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -714,6 +714,123 @@ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$ LogicalProject(i=[true]) LogicalFilter(condition=[=($cor0.DEPTNO, $7)]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml index f8cace1338a0..82e6770d5650 100644 --- a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml @@ -5681,15 +5681,13 @@ select * from t where exists ( @@ -6102,11 +6100,13 @@ LogicalSort(sort0=[$0], dir0=[ASC]) @@ -6119,11 +6119,13 @@ LogicalProject(C=[$0], D=[$1]) @@ -6137,11 +6139,13 @@ LogicalProject(C=[$0], D=[$1]) @@ -6155,11 +6159,15 @@ LogicalProject(C=[$0], D=[$1], C0=[$2]) @@ -6176,14 +6184,16 @@ as r(n) on c=n]]> @@ -6197,11 +6207,15 @@ LogicalProject(C=[$0], N=[$1]) @@ -6219,15 +6233,17 @@ cross join lateral diff --git a/core/src/test/resources/sql/misc.iq b/core/src/test/resources/sql/misc.iq index a7127ac2b818..ba6bbf7d6f1a 100644 --- a/core/src/test/resources/sql/misc.iq +++ b/core/src/test/resources/sql/misc.iq @@ -706,6 +706,7 @@ EnumerableCalc(expr#0..1=[{inputs}], expr#2=[1], expr#3=[=($t0, $t2)], proj#0..1 # [HIVE-5873] Semi-join to count sub-query # [CALCITE-365] AssertionError while translating query with WITH and correlated sub-query +!if (false) { with parts (PNum, OrderOnHand) as (select * from (values (3, 6), (10, 1), (8, 0)) as t(PNum, OrderOnHand)), supply (PNum, Qty) @@ -724,6 +725,7 @@ where orderOnHand (2 rows) !ok +!} # [HIVE-7362] # Just checking that HAVING-EXISTS works. diff --git a/core/src/test/resources/sql/sub-query.iq b/core/src/test/resources/sql/sub-query.iq index 595ca2bd0e17..0367f8026a73 100644 --- a/core/src/test/resources/sql/sub-query.iq +++ b/core/src/test/resources/sql/sub-query.iq @@ -441,23 +441,19 @@ where sal + 100 not in ( # [CALCITE-356] AssertionError while translating query with WITH and correlated sub-query with t (a, b) as (select * from (values (1, 2))) select * from t where exists (select 1 from "scott".emp where deptno = t.a); -EnumerableCalc(expr#0..2=[{inputs}], proj#0..1=[{exprs}]) - EnumerableCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) - EnumerableValues(tuples=[[{ 1, 2 }]]) - EnumerableAggregate(group=[{0}]) - EnumerableCalc(expr#0..7=[{inputs}], expr#8=[true], expr#9=[CAST($t7):INTEGER], expr#10=[$cor0], expr#11=[$t10.A], expr#12=[=($t9, $t11)], i=[$t8], $condition=[$t12]) - EnumerableTableScan(table=[[scott, EMP]]) +EnumerableHashJoin(condition=[=($0, $4)], joinType=[semi]) + EnumerableValues(tuples=[[{ 1, 2 }]]) + EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t7):INTEGER], expr#9=[null:BOOLEAN], expr#10=[IS NOT NULL($t8)], expr#11=[OR($t9, $t10)], EMPNO=[$t0], DEPTNO=[$t7], DEPTNO0=[$t8], $condition=[$t11]) + EnumerableTableScan(table=[[scott, EMP]]) !plan # Similar query, identical plan with t as (select * from (values (1, 2)) as t(a, b)) select * from t where exists (select 1 from "scott".emp where deptno = t.a); -EnumerableCalc(expr#0..2=[{inputs}], proj#0..1=[{exprs}]) - EnumerableCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) - EnumerableValues(tuples=[[{ 1, 2 }]]) - EnumerableAggregate(group=[{0}]) - EnumerableCalc(expr#0..7=[{inputs}], expr#8=[true], expr#9=[CAST($t7):INTEGER], expr#10=[$cor0], expr#11=[$t10.A], expr#12=[=($t9, $t11)], i=[$t8], $condition=[$t12]) - EnumerableTableScan(table=[[scott, EMP]]) +EnumerableHashJoin(condition=[=($0, $4)], joinType=[semi]) + EnumerableValues(tuples=[[{ 1, 2 }]]) + EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t7):INTEGER], expr#9=[null:BOOLEAN], expr#10=[IS NOT NULL($t8)], expr#11=[OR($t9, $t10)], EMPNO=[$t0], DEPTNO=[$t7], DEPTNO0=[$t8], $condition=[$t11]) + EnumerableTableScan(table=[[scott, EMP]]) !plan # Uncorrelated