Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,28 @@
import java.util.Set;

/**
* Count(*)
* Count(col)
* TODO: distinct | just push one level
* Support Pushdown Count(*)/Count(col).
* Count(col) -> Sum( cnt * cntStar )
* Count(*) -> Sum( leftCntStar * rightCntStar )
* <p>
* Related paper "Eager aggregation and lazy aggregation".
* <pre>
* aggregate: count(x)
* |
* join
* | \
* | *
* (x)
* ->
* aggregate: Sum( cnt * cntStar )
* |
* join
* | \
* | aggregate: count(*) as cntStar
* aggregate: count(x) as cnt
* </pre>
* Notice: when Count(*) exists, group by mustn't be empty.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update this comments as we discussed

*/
public class PushdownCountThroughJoin implements RewriteRuleFactory {
@Override
Expand All @@ -57,7 +77,8 @@ public List<Rule> buildRules() {
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> f instanceof Count && f.child(0) instanceof Slot);
.allMatch(f -> f instanceof Count && !f.isDistinct()
&& (((Count) f).isCountStar() || f.child(0) instanceof Slot));
})
.then(agg -> pushCount(agg, agg.child(), ImmutableList.of()))
.toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN),
Expand All @@ -69,7 +90,8 @@ public List<Rule> buildRules() {
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> f instanceof Count && f.child(0) instanceof Slot);
.allMatch(f -> f instanceof Count && !f.isDistinct()
&& (((Count) f).isCountStar() || f.child(0) instanceof Slot));
})
.then(agg -> pushCount(agg, agg.child().child(), agg.child().getProjects()))
.toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN)
Expand All @@ -83,23 +105,23 @@ private LogicalAggregate<Plan> pushCount(LogicalAggregate<? extends Plan> agg,

List<Count> leftCounts = new ArrayList<>();
List<Count> rightCounts = new ArrayList<>();
List<Count> countStars = new ArrayList<>();
for (AggregateFunction f : agg.getAggregateFunctions()) {
Count count = (Count) f;
if (count.isCountStar()) {
// TODO: handle Count(*)
return null;
}
Slot slot = (Slot) count.child(0);
if (leftOutput.contains(slot)) {
leftCounts.add(count);
} else if (rightOutput.contains(slot)) {
rightCounts.add(count);
countStars.add(count);
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
Slot slot = (Slot) count.child(0);
if (leftOutput.contains(slot)) {
leftCounts.add(count);
} else if (rightOutput.contains(slot)) {
rightCounts.add(count);
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
}
}
}

// TODO: empty GroupBy
Set<Slot> leftGroupBy = new HashSet<>();
Set<Slot> rightGroupBy = new HashSet<>();
for (Expression e : agg.getGroupByExpressions()) {
Expand All @@ -112,6 +134,11 @@ private LogicalAggregate<Plan> pushCount(LogicalAggregate<? extends Plan> agg,
return null;
}
}

if (!countStars.isEmpty() && leftGroupBy.isEmpty() && rightGroupBy.isEmpty()) {
return null;
}

join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> {
if (leftOutput.contains(slot)) {
leftGroupBy.add(slot);
Expand All @@ -133,7 +160,7 @@ private LogicalAggregate<Plan> pushCount(LogicalAggregate<? extends Plan> agg,
leftCntSlotToOutput.put((Slot) func.child(0), alias);
leftCntAggOutputBuilder.add(alias);
});
if (!rightCounts.isEmpty()) {
if (!rightCounts.isEmpty() || !countStars.isEmpty()) {
leftCnt = new Count().alias("leftCntStar");
leftCntAggOutputBuilder.add(leftCnt);
}
Expand All @@ -150,7 +177,7 @@ private LogicalAggregate<Plan> pushCount(LogicalAggregate<? extends Plan> agg,
rightCntAggOutputBuilder.add(alias);
});

if (!leftCounts.isEmpty()) {
if (!leftCounts.isEmpty() || !countStars.isEmpty()) {
rightCnt = new Count().alias("rightCntStar");
rightCntAggOutputBuilder.add(rightCnt);
}
Expand All @@ -160,22 +187,31 @@ private LogicalAggregate<Plan> pushCount(LogicalAggregate<? extends Plan> agg,
Plan newJoin = join.withChildren(leftCntAgg, rightCntAgg);

// top Sum agg
// count(slot) -> sum( count(slot) * cnt )
// count(slot) -> sum( count(slot) * cntStar )
// count(*) -> sum( leftCntStar * leftCntStar )
List<NamedExpression> newOutputExprs = new ArrayList<>();
for (NamedExpression ne : agg.getOutputExpressions()) {
if (ne instanceof Alias && ((Alias) ne).child() instanceof Count) {
Count oldTopCnt = (Count) ((Alias) ne).child();
Slot slot = (Slot) oldTopCnt.child(0);
if (leftCntSlotToOutput.containsKey(slot)) {
Preconditions.checkState(rightCnt != null);
Expression expr = new Sum(new Multiply(leftCntSlotToOutput.get(slot).toSlot(), rightCnt.toSlot()));
newOutputExprs.add((NamedExpression) ne.withChildren(expr));
} else if (rightCntSlotToOutput.containsKey(slot)) {
Preconditions.checkState(leftCnt != null);
Expression expr = new Sum(new Multiply(rightCntSlotToOutput.get(slot).toSlot(), leftCnt.toSlot()));
if (oldTopCnt.isCountStar()) {
Preconditions.checkState(rightCnt != null && leftCnt != null);
Expression expr = new Sum(new Multiply(leftCnt.toSlot(), rightCnt.toSlot()));
newOutputExprs.add((NamedExpression) ne.withChildren(expr));
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
Slot slot = (Slot) oldTopCnt.child(0);
if (leftCntSlotToOutput.containsKey(slot)) {
Preconditions.checkState(rightCnt != null);
Expression expr = new Sum(
new Multiply(leftCntSlotToOutput.get(slot).toSlot(), rightCnt.toSlot()));
newOutputExprs.add((NamedExpression) ne.withChildren(expr));
} else if (rightCntSlotToOutput.containsKey(slot)) {
Preconditions.checkState(leftCnt != null);
Expression expr = new Sum(
new Multiply(rightCntSlotToOutput.get(slot).toSlot(), leftCnt.toSlot()));
newOutputExprs.add((NamedExpression) ne.withChildren(expr));
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
}
}
} else {
newOutputExprs.add(ne);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import java.util.Set;

/**
* TODO: distinct
* Related paper "Eager aggregation and lazy aggregation".
* <pre>
* aggregate: Min/Max(x)
Expand Down Expand Up @@ -69,7 +70,8 @@ public List<Rule> buildRules() {
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> (f instanceof Min || f instanceof Max) && f.child(0) instanceof Slot);
.allMatch(f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child(
0) instanceof Slot);
})
.then(agg -> pushMinMax(agg, agg.child(), ImmutableList.of()))
.toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN),
Expand All @@ -80,7 +82,9 @@ public List<Rule> buildRules() {
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> (f instanceof Min || f instanceof Max) && f.child(0) instanceof Slot);
.allMatch(
f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child(
0) instanceof Slot);
})
.then(agg -> pushMinMax(agg, agg.child().child(), agg.child().getProjects()))
.toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import java.util.Set;

/**
* TODO: distinct
* Related paper "Eager aggregation and lazy aggregation".
* <pre>
* aggregate: Sum(x)
Expand Down Expand Up @@ -69,7 +70,7 @@ public List<Rule> buildRules() {
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> f instanceof Sum && f.child(0) instanceof Slot);
.allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot);
})
.then(agg -> pushSum(agg, agg.child(), ImmutableList.of()))
.toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN),
Expand All @@ -80,7 +81,7 @@ public List<Rule> buildRules() {
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> f instanceof Sum && f.child(0) instanceof Slot);
.allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot);
})
.then(agg -> pushSum(agg, agg.child().child(), agg.child().getProjects()))
.toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,46 @@ void testMultiCount() {
.printlnTree();
}

@Test
void testSingleCountStar() {
Alias count = new Count().alias("countStar");
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
.aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), count))
.build();

PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushdownCountThroughJoin())
.printlnTree();
}

@Test
void testSingleCountStarEmptyGroupBy() {
Alias count = new Count().alias("countStar");
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
.aggGroupUsingIndex(ImmutableList.of(), ImmutableList.of(count))
.build();

// shouldn't rewrite.
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushdownCountThroughJoin())
.printlnTree();
}

@Test
void testBothSideCountAndCountStar() {
Alias leftCnt = new Count(scan1.getOutput().get(0)).alias("leftCnt");
Alias rightCnt = new Count(scan2.getOutput().get(0)).alias("rightCnt");
Alias countStar = new Count().alias("countStar");
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
.aggGroupUsingIndex(ImmutableList.of(0),
ImmutableList.of(scan1.getOutput().get(0), leftCnt, rightCnt, countStar))
.build();

PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushdownCountThroughJoin())
.printlnTree();
}
}