diff --git a/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java b/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java index 285cfe44bdad..3b987fcfacbf 100644 --- a/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java +++ b/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java @@ -23,6 +23,7 @@ import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexCall; @@ -41,6 +42,7 @@ import org.apache.calcite.rex.RexWindow; import org.apache.calcite.rex.RexWindowBound; import org.apache.calcite.sql.JoinType; +import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlBasicCall; import org.apache.calcite.sql.SqlBinaryOperator; import org.apache.calcite.sql.SqlCall; @@ -92,6 +94,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.LinkedHashSet; @@ -1013,6 +1016,12 @@ public Builder builder(RelNode rel, Clause... clauses) { needNew = true; } } + if (rel instanceof LogicalAggregate + && !dialect.supportsNestedAggregations() + && hasNestedAggregations((LogicalAggregate) rel)) { + needNew = true; + } + SqlSelect select; Expressions.FluentList clauseList = Expressions.list(); if (needNew) { @@ -1056,6 +1065,32 @@ public SqlNode field(int ordinal) { needNew ? null : aliases); } + private boolean hasNestedAggregations(LogicalAggregate rel) { + List aggCallList = rel.getAggCallList(); + HashSet aggregatesArgs = new HashSet<>(); + for (AggregateCall aggregateCall: aggCallList) { + aggregatesArgs.addAll(aggregateCall.getArgList()); + } + for (Integer aggregatesArg : aggregatesArgs) { + if (!(node instanceof SqlSelect)) { + continue; + } + SqlNode selectNode = ((SqlSelect) node).getSelectList().get(aggregatesArg); + if (!(selectNode instanceof SqlBasicCall)) { + continue; + } + for (SqlNode operand : ((SqlBasicCall) selectNode).getOperands()) { + if (operand instanceof SqlCall) { + final SqlOperator operator = ((SqlCall) operand).getOperator(); + if (operator instanceof SqlAggFunction) { + return true; + } + } + } + } + return false; + } + // make private? public Clause maxClause() { Clause maxClause = null; diff --git a/core/src/main/java/org/apache/calcite/sql/SqlDialect.java b/core/src/main/java/org/apache/calcite/sql/SqlDialect.java index 9eab2b742a21..14e604556a9a 100644 --- a/core/src/main/java/org/apache/calcite/sql/SqlDialect.java +++ b/core/src/main/java/org/apache/calcite/sql/SqlDialect.java @@ -552,6 +552,21 @@ public boolean supportsOffsetFetch() { } } + /** + * Returns whether the dialect supports nested aggregations, for instance + * {@code SELECT SUM(SUM(1)) }. + */ + public boolean supportsNestedAggregations() { + switch (databaseProduct) { + case MYSQL: + case VERTICA: + case POSTGRESQL: + return false; + default: + return true; + } + } + /** Returns how NULL values are sorted if an ORDER BY item does not contain * NULLS ASCENDING or NULLS DESCENDING. */ public NullCollation getNullCollation() { diff --git a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java index 15725ece4cbf..02f5a1af206e 100644 --- a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java +++ b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java @@ -1898,6 +1898,40 @@ private void checkLiteral(String s) { sql(sql).ok(expected); } + @Test public void testDialectsLackingSupportForNestedAggregationsShouldUseSubSelectInstead() { + final String query = "select\n" + + " SUM(\"net_weight1\") as \"net_weight_converted\"\n" + + " from (" + + " select\n" + + " SUM(\"net_weight\") as \"net_weight1\"\n" + + " from \"foodmart\".\"product\"\n" + + " group by \"product_id\")"; + final String expectedOracle = "SELECT SUM(SUM(\"net_weight\")) \"net_weight_converted\"\n" + + "FROM \"foodmart\".\"product\"\n" + + "GROUP BY \"product_id\""; + final String expectedMySQL = "SELECT SUM(`net_weight1`) AS `net_weight_converted`\n" + + "FROM (SELECT SUM(`net_weight`) AS `net_weight1`\n" + + "FROM `foodmart`.`product`\n" + + "GROUP BY `product_id`) AS `t1`"; + final String expectedVertica = "SELECT SUM(\"net_weight1\") AS \"net_weight_converted\"\n" + + "FROM (SELECT SUM(\"net_weight\") AS \"net_weight1\"\n" + + "FROM \"foodmart\".\"product\"\n" + + "GROUP BY \"product_id\") AS \"t1\""; + final String expectedPostgresql = "SELECT SUM(\"net_weight1\") AS \"net_weight_converted\"\n" + + "FROM (SELECT SUM(\"net_weight\") AS \"net_weight1\"\n" + + "FROM \"foodmart\".\"product\"\n" + + "GROUP BY \"product_id\") AS \"t1\""; + sql(query) + .dialect(DatabaseProduct.ORACLE.getDialect()) + .ok(expectedOracle) + .dialect(DatabaseProduct.MYSQL.getDialect()) + .ok(expectedMySQL) + .dialect(DatabaseProduct.VERTICA.getDialect()) + .ok(expectedVertica) + .dialect(DatabaseProduct.POSTGRESQL.getDialect()) + .ok(expectedPostgresql); + } + /** Fluid interface to run tests. */ private static class Sql { private CalciteAssert.SchemaSpec schemaSpec;