Skip to content

Commit

Permalink
Custom Calcite Rule to remove redundant references (#16402)
Browse files Browse the repository at this point in the history
Custom calcite rule mimicking AggregateProjectMergeRule to extend support to expressions.
The current calcite rule return null in such cases.
In addition, this removes the redundant references.
  • Loading branch information
sreemanamala committed May 14, 2024
1 parent 760e449 commit b8dd747
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,20 @@ public GroupingAggregatorFactory(
)
{
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");
Preconditions.checkArgument(!CollectionUtils.isNullOrEmpty(groupings), "Must have a non-empty grouping dimensions");
// (Long.SIZE - 1) is just a sanity check. In practice, it will be just few dimensions. This limit
// also makes sure that values are always positive.
Preconditions.checkArgument(
groupings.size() < Long.SIZE,
"Number of dimensions %s is more than supported %s",
groupings.size(),
Long.SIZE - 1
);
Preconditions.checkArgument(
groupings.stream().distinct().count() == groupings.size(),
"Encountered same dimension more than once in groupings"
);

this.name = name;
this.groupings = groupings;
this.keyDimensions = keyDimensions;
Expand Down Expand Up @@ -254,15 +268,6 @@ public byte[] getCacheKey()
*/
private long groupingId(List<String> groupings, @Nullable Set<String> keyDimensions)
{
Preconditions.checkArgument(!CollectionUtils.isNullOrEmpty(groupings), "Must have a non-empty grouping dimensions");
// (Long.SIZE - 1) is just a sanity check. In practice, it will be just few dimensions. This limit
// also makes sure that values are always positive.
Preconditions.checkArgument(
groupings.size() < Long.SIZE,
"Number of dimensions %s is more than supported %s",
groupings.size(),
Long.SIZE - 1
);
long temp = 0L;
for (String groupingDimension : groupings) {
temp = temp << 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ public void testFactory_highNumberOfGroupingDimensions()
));
makeFactory(new String[Long.SIZE], null);
}

@Test
public void testWithDuplicateGroupings()
{
exception.expect(IllegalArgumentException.class);
exception.expectMessage("Encountered same dimension more than once in groupings");
makeFactory(new String[]{"a", "a"}, null);
}
}

@RunWith(Parameterized.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,19 @@ public Aggregation toDruidAggregation(
}
}
}
AggregatorFactory factory = new GroupingAggregatorFactory(name, arguments);
AggregatorFactory factory;
try {
factory = new GroupingAggregatorFactory(name, arguments);
}
catch (Exception e) {
plannerContext.setPlanningError(
"Initialisation of Grouping Aggregator Factory in case of [%s] threw [%s]",
aggregateCall,
e.getMessage()
);
return null;
}

return Aggregation.create(factory);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import org.apache.druid.sql.calcite.rule.ReverseLookupRule;
import org.apache.druid.sql.calcite.rule.RewriteFirstValueLastValueRule;
import org.apache.druid.sql.calcite.rule.SortCollapseRule;
import org.apache.druid.sql.calcite.rule.logical.DruidAggregateRemoveRedundancyRule;
import org.apache.druid.sql.calcite.rule.logical.DruidLogicalRules;
import org.apache.druid.sql.calcite.run.EngineFeature;

Expand Down Expand Up @@ -496,6 +497,7 @@ public List<RelOptRule> baseRuleSet(final PlannerContext plannerContext)
rules.add(FilterJoinExcludePushToChildRule.FILTER_ON_JOIN_EXCLUDE_PUSH_TO_CHILD);
rules.add(SortCollapseRule.instance());
rules.add(ProjectAggregatePruneUnusedCallRule.instance());
rules.add(DruidAggregateRemoveRedundancyRule.instance());

return rules.build();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.druid.sql.calcite.rule.logical;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Aggregate.Group;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Util;
import org.apache.calcite.util.mapping.Mappings;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

/**
* Planner rule that recognizes a {@link Aggregate}
* on top of a {@link Project} and if possible
* aggregate through the project or removes the project.
* <p>
* This is updated version of {@link org.apache.calcite.rel.rules.AggregateProjectMergeRule}
* to be able to handle expressions.
*/
@Value.Enclosing
public class DruidAggregateRemoveRedundancyRule
extends RelOptRule
implements TransformationRule
{

/**
* Creates a DruidAggregateRemoveRedundancyRule.
*/
private static final DruidAggregateRemoveRedundancyRule INSTANCE = new DruidAggregateRemoveRedundancyRule();

private DruidAggregateRemoveRedundancyRule()
{
super(operand(Aggregate.class, operand(Project.class, any())));
}

public static DruidAggregateRemoveRedundancyRule instance()
{
return INSTANCE;
}

@Override
public void onMatch(RelOptRuleCall call)
{
final Aggregate aggregate = call.rel(0);
final Project project = call.rel(1);
RelNode x = apply(call, aggregate, project);
if (x != null) {
call.transformTo(x);
call.getPlanner().prune(aggregate);
}
}

public static @Nullable RelNode apply(RelOptRuleCall call, Aggregate aggregate, Project project)
{
final Set<Integer> interestingFields = RelOptUtil.getAllFields(aggregate);
if (interestingFields.isEmpty()) {
return null;
}
final Map<Integer, Integer> map = new HashMap<>();
final Map<RexNode, Integer> assignedRefForExpr = new HashMap<>();
List<RexNode> newRexNodes = new ArrayList<>();
for (int source : interestingFields) {
final RexNode rex = project.getProjects().get(source);
if (!assignedRefForExpr.containsKey(rex)) {
RexNode newNode = new RexInputRef(source, rex.getType());
assignedRefForExpr.put(rex, newRexNodes.size());
newRexNodes.add(newNode);
}
map.put(source, assignedRefForExpr.get(rex));
}

if (newRexNodes.size() == project.getProjects().size()) {
return null;
}

final ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map);
ImmutableList<ImmutableBitSet> newGroupingSets = null;
if (aggregate.getGroupType() != Group.SIMPLE) {
newGroupingSets =
ImmutableBitSet.ORDERING.immutableSortedCopy(
Sets.newTreeSet(ImmutableBitSet.permute(aggregate.getGroupSets(), map)));
}

final ImmutableList.Builder<AggregateCall> aggCalls = ImmutableList.builder();
final int sourceCount = aggregate.getInput().getRowType().getFieldCount();
final int targetCount = newRexNodes.size();
final Mappings.TargetMapping targetMapping = Mappings.target(map, sourceCount, targetCount);
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
aggCalls.add(aggregateCall.transform(targetMapping));
}

final RelBuilder relBuilder = call.builder();
relBuilder.push(project);
relBuilder.project(newRexNodes);

final Aggregate newAggregate =
aggregate.copy(aggregate.getTraitSet(), relBuilder.build(),
newGroupSet, newGroupingSets, aggCalls.build()
);
relBuilder.push(newAggregate);

final List<Integer> newKeys =
Util.transform(
aggregate.getGroupSet().asList(),
key -> Objects.requireNonNull(
map.get(key),
() -> "no value found for key " + key + " in " + map
)
);

// Add a project if the group set is not in the same order or
// contains duplicates.
if (!newKeys.equals(newGroupSet.asList())) {
final List<Integer> posList = new ArrayList<>();
for (int newKey : newKeys) {
posList.add(newGroupSet.indexOf(newKey));
}
for (int i = newAggregate.getGroupCount();
i < newAggregate.getRowType().getFieldCount(); i++) {
posList.add(i);
}
relBuilder.project(relBuilder.fields(posList));
}

return relBuilder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8788,8 +8788,8 @@ public void testQueryWithSelectProjectAndIdentityProjectDoesNotRename()
)
.setDimensions(
dimensions(
new DefaultDimensionSpec("dim1", "d0", ColumnType.STRING),
new DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
new DefaultDimensionSpec("v0", "d0", ColumnType.LONG),
new DefaultDimensionSpec("dim1", "d1", ColumnType.STRING)
)
)
.setAggregatorSpecs(
Expand Down Expand Up @@ -8832,9 +8832,9 @@ public void testQueryWithSelectProjectAndIdentityProjectDoesNotRename()
new FilteredAggregatorFactory(
new CountAggregatorFactory("_a1"),
and(
notNull("d0"),
notNull("d1"),
equality("a1", 0L, ColumnType.LONG),
expressionFilter("\"d1\"")
expressionFilter("\"d0\"")
)
)
)
Expand Down Expand Up @@ -12938,8 +12938,7 @@ public void testRepeatedIdenticalVirtualExpressionGrouping()
.setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG))
.setDimensions(
dimensions(
new DefaultDimensionSpec("v0", "d0", ColumnType.LONG),
new DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
new DefaultDimensionSpec("v0", "d0", ColumnType.LONG)
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
Expand Down Expand Up @@ -15680,10 +15679,63 @@ public void testStringOperationsNullableInference()
.build()
)
).expectedResults(
ResultMatchMode.RELAX_NULLS,
ImmutableList.of(
new Object[]{null, null, null}
)
);
NullHandling.sqlCompatible() ? ImmutableList.of(
new Object[]{null, null, null}
) : ImmutableList.of(
new Object[]{false, false, ""}
)
).run();
}

@SqlTestFrameworkConfig.NumMergeBuffers(4)
@Test
public void testGroupingSetsWithAggrgateCase()
{
cannotVectorize();
msqIncompatible();
final Map<String, Object> queryContext = ImmutableMap.of(
PlannerConfig.CTX_KEY_USE_APPROXIMATE_COUNT_DISTINCT, false,
PlannerConfig.CTX_KEY_USE_GROUPING_SET_FOR_EXACT_DISTINCT, true
);
testBuilder()
.sql(
"SELECT\n"
+ " TIME_FLOOR(\"__time\", 'PT1H') ,\n"
+ " COUNT(DISTINCT \"page\") ,\n"
+ " COUNT(DISTINCT CASE WHEN \"channel\" = '#it.wikipedia' THEN \"user\" END), \n"
+ " COUNT(DISTINCT \"user\") FILTER (WHERE \"channel\" = '#it.wikipedia'), "
+ " COUNT(DISTINCT \"user\") \n"
+ "FROM \"wikipedia\"\n"
+ "GROUP BY 1"
)
.queryContext(queryContext)
.expectedResults(
ImmutableList.of(
new Object[]{1442016000000L, 264L, 5L, 5L, 149L},
new Object[]{1442019600000L, 1090L, 14L, 14L, 506L},
new Object[]{1442023200000L, 1045L, 10L, 10L, 459L},
new Object[]{1442026800000L, 766L, 10L, 10L, 427L},
new Object[]{1442030400000L, 781L, 6L, 6L, 427L},
new Object[]{1442034000000L, 1223L, 10L, 10L, 448L},
new Object[]{1442037600000L, 2092L, 13L, 13L, 498L},
new Object[]{1442041200000L, 2181L, 21L, 21L, 574L},
new Object[]{1442044800000L, 1552L, 36L, 36L, 707L},
new Object[]{1442048400000L, 1624L, 44L, 44L, 770L},
new Object[]{1442052000000L, 1710L, 37L, 37L, 785L},
new Object[]{1442055600000L, 1532L, 40L, 40L, 799L},
new Object[]{1442059200000L, 1633L, 45L, 45L, 855L},
new Object[]{1442062800000L, 1958L, 44L, 44L, 905L},
new Object[]{1442066400000L, 1779L, 48L, 48L, 886L},
new Object[]{1442070000000L, 1868L, 37L, 37L, 949L},
new Object[]{1442073600000L, 1846L, 50L, 50L, 969L},
new Object[]{1442077200000L, 2168L, 38L, 38L, 941L},
new Object[]{1442080800000L, 2043L, 40L, 40L, 925L},
new Object[]{1442084400000L, 1924L, 32L, 32L, 930L},
new Object[]{1442088000000L, 1736L, 31L, 31L, 882L},
new Object[]{1442091600000L, 1672L, 40L, 40L, 861L},
new Object[]{1442095200000L, 1504L, 28L, 28L, 716L},
new Object[]{1442098800000L, 1407L, 20L, 20L, 631L}
)
).run();
}
}

0 comments on commit b8dd747

Please sign in to comment.