Skip to content

Commit

Permalink
[CALCITE-6340] RelBuilder drops traits when aggregating over duplicat…
Browse files Browse the repository at this point in the history
…e projected fields

Calculate the post-pruning RelTraitSet on the projection using TraitSet#apply(Mapping)

Co-authored-by: Alessandro Solimando <alessandro.solimando@gmail.com>
  • Loading branch information
jduo and asolimando committed May 16, 2024
1 parent f854ef5 commit 0b2fd0b
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 2 deletions.
7 changes: 5 additions & 2 deletions core/src/main/java/org/apache/calcite/tools/RelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -2502,9 +2502,12 @@ private RelBuilder pruneAggregateInputFieldsAndDeduplicateAggCalls(
newProjects.add(project.getProjects().get(i));
builder.add(project.getRowType().getFieldList().get(i));
}

// This currently does not apply mappings correctly to the RelCollation due to
// https://issues.apache.org/jira/browse/CALCITE-6391
r =
project.copy(cluster.traitSet(), project.getInput(), newProjects,
builder.build());
project.copy(project.getTraitSet().apply(targetMapping), project.getInput(),
newProjects, builder.build());
} else {
groupSetAfterPruning = groupSet;
groupSetsAfterPruning = groupSets;
Expand Down
6 changes: 6 additions & 0 deletions core/src/main/java/org/apache/calcite/util/Bug.java
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ public abstract class Bug {
* Fix failing quidem tests for FORMAT in CAST</a> is fixed. */
public static final boolean CALCITE_6375_FIXED = false;

/** Whether
* <a href="https://issues.apache.org/jira/browse/CALCITE/issues/CALCITE-6391">
* [CALCITE-6391] Apply mapping to RelCompositeTrait does not apply it to wrapped traits</a>
* is fixed. */
public static final boolean CALCITE_6391_FIXED = false;

/**
* Use this to flag temporary code.
*/
Expand Down
219 changes: 219 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelDistributions;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Correlate;
Expand Down Expand Up @@ -78,6 +79,7 @@
import org.apache.calcite.tools.RelRunners;
import org.apache.calcite.tools.RuleSet;
import org.apache.calcite.tools.RuleSets;
import org.apache.calcite.util.Bug;
import org.apache.calcite.util.Holder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
Expand Down Expand Up @@ -1450,6 +1452,223 @@ private RexNode caseCall(RelBuilder b, RexNode ref, RexNode... nodes) {
assertThat(root, hasTree(expected));
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE/issues/CALCITE-6340">
* [CALCITE-6340] RelBuilder drops traits when aggregating over duplicate projected fields</a>.
*/
@Test void testPruneProjectInputOfAggregatePreservesConventionAndCollationsWhenEmpty() {
final RelBuilder builder = createBuilder(config -> config.withPruneInputOfAggregate(true));

RelNode node = builder
.scan("EMP")
.sort(builder.nullsLast(builder.desc(builder.field(1))),
builder.field(0))
.project(builder.alias(builder.field(0), "a"),
builder.alias(builder.field(1), "b"),
builder.alias(builder.field(0), "c"),
builder.alias(builder.field(1), "d"))
.build();

final RelTraitSet desiredTraits = builder.getCluster().traitSet()
.replace(EnumerableConvention.INSTANCE);

final RuleSet prepareRules =
RuleSets.ofList(EnumerableRules.ENUMERABLE_PROJECT_RULE,
EnumerableRules.ENUMERABLE_SORT_RULE,
EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE);
final Program program = Programs.of(prepareRules);
node =
program.run(node.getCluster().getPlanner(), node, desiredTraits, ImmutableList.of(),
ImmutableList.of());

// collations are lost as the sort is on column [1, 0], but we group on 0, convention stays
node = builder.push(node)
.aggregate(
builder.groupKey(0), builder.aggregateCall(
SqlStdOperatorTable.SUM, builder.field(0)))
.build();

final RelTraitSet expectedTraitSet = builder.getCluster().traitSet()
.replace(EnumerableConvention.INSTANCE);
assertTrue(expectedTraitSet.contains(EnumerableConvention.INSTANCE));

if (Bug.CALCITE_6391_FIXED) {
assertThat(node.getInput(0).getTraitSet(), is(expectedTraitSet));
} else {
assertThat(node.getInput(0).getTraitSet().get(0), is(expectedTraitSet.get(0)));
}
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE/issues/CALCITE-6340">
* [CALCITE-6340] RelBuilder drops traits when aggregating over duplicate projected fields</a>.
*/
@Test void testPruneProjectInputOfAggregatePreservesConventionAndSingletonCollation() {
final RelBuilder builder = createBuilder(config -> config.withPruneInputOfAggregate(true));

RelNode node = builder
.scan("EMP")
.sort(builder.nullsLast(builder.desc(builder.field(1))))
.project(builder.alias(builder.field(0), "a"),
builder.alias(builder.field(1), "b"),
builder.alias(builder.field(0), "c"),
builder.alias(builder.field(1), "d"))
.build();

final RelTraitSet desiredTraits = builder.getCluster().traitSet()
.replace(EnumerableConvention.INSTANCE);

final RuleSet prepareRules =
RuleSets.ofList(EnumerableRules.ENUMERABLE_PROJECT_RULE,
EnumerableRules.ENUMERABLE_SORT_RULE,
EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE);
final Program program = Programs.of(prepareRules);

// turn the logical plan into a physical plan so that a convention can be set
node =
program.run(node.getCluster().getPlanner(), node, desiredTraits, ImmutableList.of(),
ImmutableList.of());


node = builder.push(node)
.aggregate(
builder.groupKey(1), builder.aggregateCall(
SqlStdOperatorTable.SUM, builder.field(1)))
.build();

final RelTraitSet expectedTraitSet = builder.getCluster().traitSet()
.replace(EnumerableConvention.INSTANCE)
.replace(
RelCollations.of(
new RelFieldCollation(0,
RelFieldCollation.Direction.DESCENDING, RelFieldCollation.NullDirection.LAST)));

if (Bug.CALCITE_6391_FIXED) {
assertThat(node.getInput(0).getTraitSet(), is(expectedTraitSet));
} else {
assertThat(node.getInput(0).getTraitSet().get(0), is(expectedTraitSet.get(0)));
}
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE/issues/CALCITE-6340">
* [CALCITE-6340] RelBuilder drops traits when aggregating over duplicate projected fields</a>.
*/
@Test void testPruneProjectInputOfAggregatePreservesConventionAndCompositeCollation() {
final RelBuilder builder = createBuilder(config -> config.withPruneInputOfAggregate(true));

RelNode node = builder
.scan("EMP")
.sort(builder.nullsLast(builder.desc(builder.field(1))),
builder.field(0))
.project(builder.alias(builder.field(0), "a"),
builder.alias(builder.field(1), "b"),
builder.alias(builder.field(0), "c"),
builder.alias(builder.field(1), "d"))
.build();

final RelTraitSet desiredTraits = builder.getCluster().traitSet()
.replace(EnumerableConvention.INSTANCE);

final RuleSet prepareRules =
RuleSets.ofList(EnumerableRules.ENUMERABLE_PROJECT_RULE,
EnumerableRules.ENUMERABLE_SORT_RULE,
EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE);
final Program program = Programs.of(prepareRules);

// turn the logical plan into a physical plan so that a convention can be set
node =
program.run(node.getCluster().getPlanner(), node, desiredTraits, ImmutableList.of(),
ImmutableList.of());


node = builder.push(node)
.aggregate(
builder.groupKey(1), builder.aggregateCall(
SqlStdOperatorTable.SUM, builder.field(1)))
.build();

final RelTraitSet expectedTraitSet = builder.getCluster().traitSet()
.replace(EnumerableConvention.INSTANCE)
.replace(
RelCollations.of(
new RelFieldCollation(0,
RelFieldCollation.Direction.DESCENDING, RelFieldCollation.NullDirection.LAST)));

if (Bug.CALCITE_6391_FIXED) {
assertThat(node.getInput(0).getTraitSet(), is(expectedTraitSet));
} else {
assertThat(node.getInput(0).getTraitSet().get(0), is(expectedTraitSet.get(0)));
}
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE/issues/CALCITE-6340">
* [CALCITE-6340] RelBuilder drops traits when aggregating over duplicate projected fields</a>.
*/
@Test void testPruneProjectInputOfAggregatePreservesConventionAndDistribution() {
final RelBuilder builder = createBuilder(config -> config.withPruneInputOfAggregate(true));

RelNode node = builder
.scan("EMP")
.project(builder.alias(builder.field(0), "a"),
builder.alias(builder.field(0), "b"),
builder.alias(builder.field(1), "c"))
.build();

final RelTraitSet desiredTraits = builder.getCluster().traitSet()
.replace(EnumerableConvention.INSTANCE);
final RuleSet prepareRules =
RuleSets.ofList(EnumerableRules.ENUMERABLE_PROJECT_RULE,
EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE);
final Program program = Programs.of(prepareRules);

// turn the logical plan into a physical plan so that a distribution can be set
node =
program.run(node.getCluster().getPlanner(), node, desiredTraits, ImmutableList.of(),
ImmutableList.of());

// setting the distribution drops the collations
node = node.copy(desiredTraits.plus(RelDistributions.BROADCAST_DISTRIBUTED), node.getInputs());

node = builder.push(node)
.aggregate(
builder.groupKey(0), builder.aggregateCall(
SqlStdOperatorTable.SUM, builder.field(0)))
.build();

final RelTraitSet expectedTraitSet = builder.getCluster().traitSet()
.replace(EnumerableConvention.INSTANCE)
.plus(RelDistributions.BROADCAST_DISTRIBUTED);

assertThat(node.getInput(0).getTraitSet(), is(expectedTraitSet));
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE/issues/CALCITE-6340">
* [CALCITE-6340] RelBuilder drops traits when aggregating over duplicate projected fields</a>.
*/
@Test void testPruneProjectInputOfAggregatePreservesConvention() {
final RelBuilder builder = createBuilder(config -> config.withPruneInputOfAggregate(true));
final RelNode node = builder.scan("DEPT")
.adoptConvention(EnumerableConvention.INSTANCE)
.project(builder.alias(builder.field(0), "a"),
builder.alias(builder.field(0), "b"))
.aggregate(
builder.groupKey(0), builder.aggregateCall(
SqlStdOperatorTable.SUM, builder.field(0))).build();

final RelTraitSet expectedTraitSet = builder.getCluster().traitSet()
.replace(EnumerableConvention.INSTANCE)
.replace(RelCollations.of(new RelFieldCollation(0)));

if (Bug.CALCITE_6391_FIXED) {
assertThat(node.getInput(0).getTraitSet(), is(expectedTraitSet));
} else {
assertThat(node.getInput(0).getTraitSet().get(0), is(expectedTraitSet.get(0)));
}
}

private RelNode buildRelWithDuplicateAggregates(
UnaryOperator<RelBuilder.Config> transform,
int... groupFieldOrdinals) {
Expand Down

0 comments on commit 0b2fd0b

Please sign in to comment.