Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexCall;
Expand All @@ -41,6 +42,7 @@
import org.apache.calcite.rex.RexWindow;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.sql.JoinType;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlBinaryOperator;
import org.apache.calcite.sql.SqlCall;
Expand Down Expand Up @@ -92,6 +94,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
Expand Down Expand Up @@ -1013,6 +1016,12 @@ public Builder builder(RelNode rel, Clause... clauses) {
needNew = true;
}
}
if (rel instanceof LogicalAggregate
&& !dialect.supportsNestedAggregations()
&& hasNestedAggregations((LogicalAggregate) rel)) {
needNew = true;
}

SqlSelect select;
Expressions.FluentList<Clause> clauseList = Expressions.list();
if (needNew) {
Expand Down Expand Up @@ -1056,6 +1065,32 @@ public SqlNode field(int ordinal) {
needNew ? null : aliases);
}

private boolean hasNestedAggregations(LogicalAggregate rel) {
List<AggregateCall> aggCallList = rel.getAggCallList();
HashSet<Integer> aggregatesArgs = new HashSet<>();
for (AggregateCall aggregateCall: aggCallList) {
aggregatesArgs.addAll(aggregateCall.getArgList());
}
for (Integer aggregatesArg : aggregatesArgs) {
if (!(node instanceof SqlSelect)) {
continue;
}
SqlNode selectNode = ((SqlSelect) node).getSelectList().get(aggregatesArg);
if (!(selectNode instanceof SqlBasicCall)) {
continue;
}
for (SqlNode operand : ((SqlBasicCall) selectNode).getOperands()) {
if (operand instanceof SqlCall) {
final SqlOperator operator = ((SqlCall) operand).getOperator();
if (operator instanceof SqlAggFunction) {
return true;
}
}
}
}
return false;
}

// make private?
public Clause maxClause() {
Clause maxClause = null;
Expand Down
15 changes: 15 additions & 0 deletions core/src/main/java/org/apache/calcite/sql/SqlDialect.java
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,21 @@ public boolean supportsOffsetFetch() {
}
}

/**
* Returns whether the dialect supports nested aggregations, for instance
* {@code SELECT SUM(SUM(1)) }.
*/
public boolean supportsNestedAggregations() {
switch (databaseProduct) {
case MYSQL:
case VERTICA:
case POSTGRESQL:
return false;
default:
return true;
}
}

/** Returns how NULL values are sorted if an ORDER BY item does not contain
* NULLS ASCENDING or NULLS DESCENDING. */
public NullCollation getNullCollation() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1898,6 +1898,40 @@ private void checkLiteral(String s) {
sql(sql).ok(expected);
}

@Test public void testDialectsLackingSupportForNestedAggregationsShouldUseSubSelectInstead() {
final String query = "select\n"
+ " SUM(\"net_weight1\") as \"net_weight_converted\"\n"
+ " from ("
+ " select\n"
+ " SUM(\"net_weight\") as \"net_weight1\"\n"
+ " from \"foodmart\".\"product\"\n"
+ " group by \"product_id\")";
final String expectedOracle = "SELECT SUM(SUM(\"net_weight\")) \"net_weight_converted\"\n"
+ "FROM \"foodmart\".\"product\"\n"
+ "GROUP BY \"product_id\"";
final String expectedMySQL = "SELECT SUM(`net_weight1`) AS `net_weight_converted`\n"
+ "FROM (SELECT SUM(`net_weight`) AS `net_weight1`\n"
+ "FROM `foodmart`.`product`\n"
+ "GROUP BY `product_id`) AS `t1`";
final String expectedVertica = "SELECT SUM(\"net_weight1\") AS \"net_weight_converted\"\n"
+ "FROM (SELECT SUM(\"net_weight\") AS \"net_weight1\"\n"
+ "FROM \"foodmart\".\"product\"\n"
+ "GROUP BY \"product_id\") AS \"t1\"";
final String expectedPostgresql = "SELECT SUM(\"net_weight1\") AS \"net_weight_converted\"\n"
+ "FROM (SELECT SUM(\"net_weight\") AS \"net_weight1\"\n"
+ "FROM \"foodmart\".\"product\"\n"
+ "GROUP BY \"product_id\") AS \"t1\"";
sql(query)
.dialect(DatabaseProduct.ORACLE.getDialect())
.ok(expectedOracle)
.dialect(DatabaseProduct.MYSQL.getDialect())
.ok(expectedMySQL)
.dialect(DatabaseProduct.VERTICA.getDialect())
.ok(expectedVertica)
.dialect(DatabaseProduct.POSTGRESQL.getDialect())
.ok(expectedPostgresql);
}

/** Fluid interface to run tests. */
private static class Sql {
private CalciteAssert.SchemaSpec schemaSpec;
Expand Down