Skip to content

Commit

Permalink
[CALCITE-4491] Fix aggregations of nested window functions
Browse files Browse the repository at this point in the history
Dialects which do not support nested aggregations now also don't support
nested window functions.

The added test covers both cases:
- a dialect which supports nested aggregations -> OracleSqlDialect
- a dialect which does not support nested aggregations -> PostgreSqlDialect
  • Loading branch information
Dominik Labuda committed Feb 9, 2021
1 parent cfa37c3 commit d45691e
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 4 deletions.
Expand Up @@ -62,6 +62,7 @@
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOverOperator;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlSelectKeyword;
import org.apache.calcite.sql.SqlSetOperator;
Expand Down Expand Up @@ -1694,9 +1695,10 @@ private boolean needNewSubQuery(RelNode rel, List<Clause> clauses,

if (rel instanceof Aggregate) {
final Aggregate agg = (Aggregate) rel;
final boolean hasNestedAgg = hasNestedAggregations(agg);
final boolean hasNestedAgg = hasNested(agg, SqlAggFunction.class);
final boolean hasNestedWindows = hasNested(agg, SqlOverOperator.class);
if (!dialect.supportsNestedAggregations()
&& hasNestedAgg) {
&& (hasNestedAgg || hasNestedWindows)) {
return true;
}

Expand All @@ -1709,7 +1711,7 @@ private boolean needNewSubQuery(RelNode rel, List<Clause> clauses,
return false;
}

private boolean hasNestedAggregations(Aggregate rel) {
private <T> boolean hasNested(Aggregate rel, Class<T> type) {
if (node instanceof SqlSelect) {
final SqlNodeList selectList = ((SqlSelect) node).getSelectList();
if (selectList != null) {
Expand All @@ -1723,7 +1725,7 @@ private boolean hasNestedAggregations(Aggregate rel) {
(SqlBasicCall) selectList.get(aggregatesArg);
for (SqlNode operand : call.getOperands()) {
if (operand instanceof SqlCall
&& ((SqlCall) operand).getOperator() instanceof SqlAggFunction) {
&& type.isInstance(((SqlCall) operand).getOperator())) {
return true;
}
}
Expand Down
Expand Up @@ -36,6 +36,8 @@
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.rel.type.RelDataTypeSystemImpl;
import org.apache.calcite.rex.RexFieldCollation;
import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.runtime.FlatLists;
import org.apache.calcite.runtime.Hook;
import org.apache.calcite.schema.SchemaPlus;
Expand Down Expand Up @@ -76,9 +78,11 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;

import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
Expand Down Expand Up @@ -714,6 +718,72 @@ private static String toSql(RelNode root, SqlDialect dialect) {
assertThat(toSql(root), isLinux(expectedSql));
}

@Test void testAggregatedWindowFunction() {
final RelBuilder builder = relBuilder();
final RelNode root = builder
.scan("EMP")
.project(
builder.field("SAL")
)
.project(
builder.alias(
builder.getRexBuilder().makeOver(
builder.getTypeFactory().createSqlType(SqlTypeName.INTEGER),
SqlStdOperatorTable.RANK,
new ArrayList<>(),
new ArrayList<>(),
ImmutableList.of(
new RexFieldCollation(builder.field("SAL"), ImmutableSet.of())
),
RexWindowBounds.UNBOUNDED_PRECEDING,
RexWindowBounds.UNBOUNDED_FOLLOWING,
true,
true,
false,
false,
false
),
"rank"
)
)
.as("tmp")
.aggregate(
builder.groupKey(),
builder.count(
true,
"cnt",
builder.field("tmp", "rank")
)
)
.filter(
builder.call(
SqlStdOperatorTable.GREATER_THAN_OR_EQUAL,
builder.field("cnt"),
builder.literal(10)
)
)
.build();

final String expectedPostgreSql =
"SELECT COUNT(DISTINCT \"rank\") AS \"cnt\"\n"
+ "FROM (SELECT RANK() OVER (ORDER BY \"SAL\") AS \"rank\"\n"
+ "FROM \"scott\".\"EMP\") AS \"t\"\n"
+ "HAVING COUNT(DISTINCT \"rank\") >= 10";
assertThat(
toSql(root, PostgresqlSqlDialect.DEFAULT),
isLinux(expectedPostgreSql)
);

final String expectedOracleSql =
"SELECT COUNT(DISTINCT RANK() OVER (ORDER BY \"SAL\")) \"cnt\"\n"
+ "FROM \"scott\".\"EMP\"\n"
+ "HAVING COUNT(DISTINCT RANK() OVER (ORDER BY \"SAL\")) >= 10";
assertThat(
toSql(root, OracleSqlDialect.DEFAULT),
isLinux(expectedOracleSql)
);
}

@Test void testSemiJoin() {
final RelBuilder builder = relBuilder();
final RelNode root = builder
Expand Down

0 comments on commit d45691e

Please sign in to comment.