Skip to content

Commit

Permalink
[CALCITE-4491] Fix aggregations of nested window functions
Browse files Browse the repository at this point in the history
Dialects which do not support nested aggregations now also don't support
nested window functions.

The added test covers both cases:
- a dialect which supports nested aggregations -> OracleSqlDialect
- a dialect which does not support nested aggregations -> PostgreSqlDialect
  • Loading branch information
Dominik Labuda committed Feb 9, 2021
1 parent 99aa01a commit 6c5156d
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 5 deletions.
Expand Up @@ -64,6 +64,7 @@
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOverOperator;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlSelectKeyword;
import org.apache.calcite.sql.SqlSetOperator;
Expand Down Expand Up @@ -1803,9 +1804,10 @@ private boolean needNewSubQuery(

if (rel instanceof Aggregate) {
final Aggregate agg = (Aggregate) rel;
final boolean hasNestedAgg = hasNestedAggregations(agg);
final boolean hasNestedAgg = hasNested(agg, SqlAggFunction.class);
final boolean hasNestedWindows = hasNested(agg, SqlOverOperator.class);
if (!dialect.supportsNestedAggregations()
&& hasNestedAgg) {
&& (hasNestedAgg || hasNestedWindows)) {
return true;
}

Expand All @@ -1818,9 +1820,10 @@ private boolean needNewSubQuery(
return false;
}

private boolean hasNestedAggregations(
private <T> boolean hasNested(
@UnknownInitialization Result this,
Aggregate rel) {
Aggregate rel,
Class<T> type) {
if (node instanceof SqlSelect) {
final SqlNodeList selectList = ((SqlSelect) node).getSelectList();
if (selectList != null) {
Expand All @@ -1834,7 +1837,7 @@ private boolean hasNestedAggregations(
(SqlBasicCall) selectList.get(aggregatesArg);
for (SqlNode operand : call.getOperands()) {
if (operand instanceof SqlCall
&& ((SqlCall) operand).getOperator() instanceof SqlAggFunction) {
&& type.isInstance(((SqlCall) operand).getOperator())) {
return true;
}
}
Expand Down
Expand Up @@ -36,6 +36,8 @@
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.rel.type.RelDataTypeSystemImpl;
import org.apache.calcite.rex.RexFieldCollation;
import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.runtime.FlatLists;
import org.apache.calcite.runtime.Hook;
import org.apache.calcite.schema.SchemaPlus;
Expand Down Expand Up @@ -84,6 +86,7 @@

import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -821,6 +824,74 @@ private static String toSql(RelNode root, SqlDialect dialect,
assertThat(toSql(root), isLinux(expectedSql));
}

@Test void testAggregatedWindowFunction() {
final RelBuilder builder = relBuilder();
final RelNode root = builder
.scan("EMP")
.project(
builder.field("SAL")
)
.project(
builder.alias(
builder.getRexBuilder().makeOver(
builder.getTypeFactory().createSqlType(SqlTypeName.INTEGER),
SqlStdOperatorTable.RANK,
new ArrayList<>(),
new ArrayList<>(),
ImmutableList.of(
new RexFieldCollation(builder.field("SAL"), ImmutableSet.of())
),
RexWindowBounds.UNBOUNDED_PRECEDING,
RexWindowBounds.UNBOUNDED_FOLLOWING,
true,
true,
false,
false,
false
),
"rank"
)
)
.as("tmp")
.aggregate(
builder.groupKey(),
builder.count(
true,
"cnt",
builder.field("tmp", "rank")
)
)
.filter(
builder.call(
SqlStdOperatorTable.GREATER_THAN_OR_EQUAL,
builder.field("cnt"),
builder.literal(10)
)
)
.build();

// Database not supporting nested aggregations
final String expectedPostgreSql =
"SELECT COUNT(DISTINCT \"rank\") AS \"cnt\"\n"
+ "FROM (SELECT RANK() OVER (ORDER BY \"SAL\") AS \"rank\"\n"
+ "FROM \"scott\".\"EMP\") AS \"t\"\n"
+ "HAVING COUNT(DISTINCT \"rank\") >= 10";
assertThat(
toSql(root, PostgresqlSqlDialect.DEFAULT),
isLinux(expectedPostgreSql)
);

// Database supporting nested aggregations
final String expectedOracleSql =
"SELECT COUNT(DISTINCT RANK() OVER (ORDER BY \"SAL\")) \"cnt\"\n"
+ "FROM \"scott\".\"EMP\"\n"
+ "HAVING COUNT(DISTINCT RANK() OVER (ORDER BY \"SAL\")) >= 10";
assertThat(
toSql(root, OracleSqlDialect.DEFAULT),
isLinux(expectedOracleSql)
);
}

@Test void testSemiJoin() {
final RelBuilder builder = relBuilder();
final RelNode root = builder
Expand Down

0 comments on commit 6c5156d

Please sign in to comment.