Skip to content

Commit

Permalink
[enhancement](Nereids) change aggregate and join stats calc algorithm (
Browse files Browse the repository at this point in the history
…#12447)

The original statistic derive calculate algorithm rely on NDV and other column statistics. But we cannot get these stats in product environment. 
This PR change these operator's stats calc algorithm to use a DEFAULT RATIO variable instead of column statistics.
We should change these algorithm when we could get column stats in product environment
  • Loading branch information
morrySnow committed Sep 8, 2022
1 parent b4f0f39 commit d2a23a4
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,51 +19,66 @@

import org.apache.doris.common.CheckedMath;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.algebra.Join;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.statistics.ColumnStats;
import org.apache.doris.statistics.StatsDeriveResult;

import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
* Estimate hash join stats.
* TODO: Update other props in the ColumnStats properly.
*/
public class JoinEstimation {

private static final double DEFAULT_JOIN_RATIO = 10.0;

/**
* Do estimate.
* // TODO: since we have no column stats here. just use a fix ratio to compute the row count.
*/
public static StatsDeriveResult estimate(StatsDeriveResult leftStats, StatsDeriveResult rightStats, Join join) {
JoinType joinType = join.getJoinType();
StatsDeriveResult statsDeriveResult = new StatsDeriveResult(leftStats);
statsDeriveResult.merge(rightStats);
// TODO: normalize join hashConjuncts.
List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts();
List<Expression> normalizedConjuncts = hashJoinConjuncts.stream().map(EqualTo.class::cast)
.map(e -> JoinUtils.swapEqualToForChildrenOrder(e, leftStats.getSlotToColumnStats().keySet()))
.collect(Collectors.toList());
long rowCount = -1;
if (joinType.isSemiOrAntiJoin()) {
rowCount = getSemiJoinRowCount(leftStats, rightStats, normalizedConjuncts, joinType);
} else if (joinType.isInnerJoin() || joinType.isOuterJoin()) {
rowCount = getJoinRowCount(leftStats, rightStats, normalizedConjuncts, joinType);
} else if (joinType.isCrossJoin()) {
// List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts();
// List<Expression> normalizedConjuncts = hashJoinConjuncts.stream().map(EqualTo.class::cast)
// .map(e -> JoinUtils.swapEqualToForChildrenOrder(e, leftStats.getSlotToColumnStats().keySet()))
// .collect(Collectors.toList());

long rowCount;
if (joinType == JoinType.LEFT_SEMI_JOIN || joinType == JoinType.LEFT_ANTI_JOIN) {
rowCount = Math.round(leftStats.getRowCount() / DEFAULT_JOIN_RATIO) + 1;
} else if (joinType == JoinType.RIGHT_SEMI_JOIN || joinType == JoinType.RIGHT_ANTI_JOIN) {
rowCount = Math.round(rightStats.getRowCount() / DEFAULT_JOIN_RATIO) + 1;
} else if (joinType == JoinType.INNER_JOIN) {
long childRowCount = Math.max(leftStats.getRowCount(), rightStats.getRowCount());
rowCount = Math.round(childRowCount / DEFAULT_JOIN_RATIO) + 1;
} else if (joinType == JoinType.LEFT_OUTER_JOIN) {
rowCount = leftStats.getRowCount();
} else if (joinType == JoinType.RIGHT_OUTER_JOIN) {
rowCount = rightStats.getRowCount();
} else if (joinType == JoinType.CROSS_JOIN) {
rowCount = CheckedMath.checkedMultiply(leftStats.getRowCount(),
rightStats.getRowCount());
} else {
throw new RuntimeException("joinType is not supported");
}

StatsDeriveResult statsDeriveResult = new StatsDeriveResult(rowCount, Maps.newHashMap());
if (joinType.isRemainLeftJoin()) {
statsDeriveResult.merge(leftStats);
}
if (joinType.isRemainRightJoin()) {
statsDeriveResult.merge(rightStats);
}
statsDeriveResult.setRowCount(rowCount);
return statsDeriveResult;
}
Expand All @@ -78,7 +93,7 @@ private static Expression removeCast(Expression parent) {
// TODO: If the condition of Join Plan could any expression in addition to EqualTo type,
// we should handle that properly.
private static long getSemiJoinRowCount(StatsDeriveResult leftStats, StatsDeriveResult rightStats,
List<Expression> eqConjunctList, JoinType joinType) {
List<Expression> hashConjuncts, JoinType joinType) {
long rowCount;
if (JoinType.RIGHT_SEMI_JOIN.equals(joinType) || JoinType.RIGHT_ANTI_JOIN.equals(joinType)) {
if (rightStats.getRowCount() == -1) {
Expand All @@ -94,10 +109,11 @@ private static long getSemiJoinRowCount(StatsDeriveResult leftStats, StatsDerive
Map<Slot, ColumnStats> leftSlotToColStats = leftStats.getSlotToColumnStats();
Map<Slot, ColumnStats> rightSlotToColStats = rightStats.getSlotToColumnStats();
double minSelectivity = 1.0;
for (Expression eqJoinPredicate : eqConjunctList) {
long lhsNdv = leftSlotToColStats.get(removeCast(eqJoinPredicate.child(0))).getNdv();
for (Expression hashConjunct : hashConjuncts) {
// TODO: since we have no column stats here. just use a fix ratio to compute the row count.
long lhsNdv = leftSlotToColStats.get(removeCast(hashConjunct.child(0))).getNdv();
lhsNdv = Math.min(lhsNdv, leftStats.getRowCount());
long rhsNdv = rightSlotToColStats.get(removeCast(eqJoinPredicate.child(1))).getNdv();
long rhsNdv = rightSlotToColStats.get(removeCast(hashConjunct.child(1))).getNdv();
rhsNdv = Math.min(rhsNdv, rightStats.getRowCount());
// Skip conjuncts with unknown NDV on either side.
if (lhsNdv == -1 || rhsNdv == -1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
Expand Down Expand Up @@ -83,6 +82,8 @@
*/
public class StatsCalculator extends DefaultPlanVisitor<StatsDeriveResult, Void> {

private static final int DEFAULT_AGGREGATE_RATIO = 1000;

private final GroupExpression groupExpression;

private StatsCalculator(GroupExpression groupExpression) {
Expand Down Expand Up @@ -163,7 +164,7 @@ public StatsDeriveResult visitLogicalJoin(LogicalJoin<? extends Plan, ? extends
@Override
public StatsDeriveResult visitLogicalAssertNumRows(
LogicalAssertNumRows<? extends Plan> assertNumRows, Void context) {
return groupExpression.getCopyOfChildStats(0);
return computeAssertNumRows(assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows());
}

@Override
Expand Down Expand Up @@ -235,7 +236,13 @@ public StatsDeriveResult visitPhysicalDistribute(PhysicalDistribute<? extends Pl
@Override
public StatsDeriveResult visitPhysicalAssertNumRows(PhysicalAssertNumRows<? extends Plan> assertNumRows,
Void context) {
return groupExpression.getCopyOfChildStats(0);
return computeAssertNumRows(assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows());
}

private StatsDeriveResult computeAssertNumRows(long desiredNumOfRows) {
StatsDeriveResult statsDeriveResult = groupExpression.getCopyOfChildStats(0);
statsDeriveResult.updateRowCountByLimit(1);
return statsDeriveResult;
}

private StatsDeriveResult computeFilter(Filter filter) {
Expand Down Expand Up @@ -301,22 +308,21 @@ private StatsDeriveResult computeLimit(Limit limit) {
}

private StatsDeriveResult computeAggregate(Aggregate aggregate) {
List<Expression> groupByExpressions = aggregate.getGroupByExpressions();
// TODO: since we have no column stats here. just use a fix ratio to compute the row count.
// List<Expression> groupByExpressions = aggregate.getGroupByExpressions();
StatsDeriveResult childStats = groupExpression.getCopyOfChildStats(0);
Map<Slot, ColumnStats> childSlotToColumnStats = childStats.getSlotToColumnStats();
long resultSetCount = 1;
for (Expression groupByExpression : groupByExpressions) {
Set<Slot> slots = groupByExpression.getInputSlots();
// TODO: Support more complex group expr.
// For example:
// select max(col1+col3) from t1 group by col1+col3;
if (slots.size() != 1) {
continue;
}
Slot slotReference = slots.iterator().next();
ColumnStats columnStats = childSlotToColumnStats.get(slotReference);
resultSetCount *= columnStats.getNdv();
// Map<Slot, ColumnStats> childSlotToColumnStats = childStats.getSlotToColumnStats();
// long resultSetCount = groupByExpressions.stream()
// .flatMap(expr -> expr.getInputSlots().stream())
// .filter(childSlotToColumnStats::containsKey)
// .map(childSlotToColumnStats::get)
// .map(ColumnStats::getNdv)
// .reduce(1L, (a, b) -> a * b);
long resultSetCount = childStats.getRowCount() / DEFAULT_AGGREGATE_RATIO;
if (resultSetCount <= 0) {
resultSetCount = 1L;
}

Map<Slot, ColumnStats> slotToColumnStats = Maps.newHashMap();
List<NamedExpression> outputExpressions = aggregate.getOutputExpressions();
// TODO: 1. Estimate the output unit size by the type of corresponding AggregateFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ public final boolean isOuterJoin() {
return this == LEFT_OUTER_JOIN || this == RIGHT_OUTER_JOIN || this == FULL_OUTER_JOIN;
}

public final boolean isRemainLeftJoin() {
return this != RIGHT_SEMI_JOIN && this != RIGHT_ANTI_JOIN;
}

public final boolean isRemainRightJoin() {
return this != LEFT_SEMI_JOIN && this != LEFT_ANTI_JOIN;
}

public final boolean isSwapJoinType() {
return joinSwapMap.containsKey(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public void testExecute() throws Exception {
}
StatsDeriveResult statistics = cascadesContext.getMemo().getRoot().getStatistics();
Assertions.assertNotNull(statistics);
Assertions.assertEquals(10, statistics.getRowCount());
Assertions.assertEquals(1, statistics.getRowCount());
}

private LogicalOlapScan constructOlapSCan() {
Expand Down

0 comments on commit d2a23a4

Please sign in to comment.