From d2a23a4cf984218faa2efe6139c03fb3c8714477 Mon Sep 17 00:00:00 2001 From: morrySnow <101034200+morrySnow@users.noreply.github.com> Date: Fri, 9 Sep 2022 01:00:07 +0800 Subject: [PATCH] [enhancement](Nereids) change aggregate and join stats calc algorithm (#12447) The original statistic derive calculate algorithm rely on NDV and other column statistics. But we cannot get these stats in product environment. This PR change these operator's stats calc algorithm to use a DEFAULT RATIO variable instead of column statistics. We should change these algorithm when we could get column stats in product environment --- .../doris/nereids/stats/JoinEstimation.java | 54 +++--- .../doris/nereids/stats/StatsCalculator.java | 40 +++-- .../doris/nereids/trees/plans/JoinType.java | 8 + .../jobs/cascades/DeriveStatsJobTest.java | 2 +- .../nereids/stats/StatsCalculatorTest.java | 157 +++++++++--------- 5 files changed, 142 insertions(+), 119 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java index 161194456e61c9..9650526e39087e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java @@ -19,21 +19,19 @@ import org.apache.doris.common.CheckedMath; import org.apache.doris.nereids.trees.expressions.Cast; -import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.algebra.Join; -import org.apache.doris.nereids.util.JoinUtils; import org.apache.doris.statistics.ColumnStats; import org.apache.doris.statistics.StatsDeriveResult; import com.google.common.base.Preconditions; +import com.google.common.collect.Maps; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; /** * Estimate hash join stats. @@ -41,29 +39,46 @@ */ public class JoinEstimation { + private static final double DEFAULT_JOIN_RATIO = 10.0; + /** * Do estimate. + * // TODO: since we have no column stats here. just use a fix ratio to compute the row count. */ public static StatsDeriveResult estimate(StatsDeriveResult leftStats, StatsDeriveResult rightStats, Join join) { JoinType joinType = join.getJoinType(); - StatsDeriveResult statsDeriveResult = new StatsDeriveResult(leftStats); - statsDeriveResult.merge(rightStats); // TODO: normalize join hashConjuncts. - List hashJoinConjuncts = join.getHashJoinConjuncts(); - List normalizedConjuncts = hashJoinConjuncts.stream().map(EqualTo.class::cast) - .map(e -> JoinUtils.swapEqualToForChildrenOrder(e, leftStats.getSlotToColumnStats().keySet())) - .collect(Collectors.toList()); - long rowCount = -1; - if (joinType.isSemiOrAntiJoin()) { - rowCount = getSemiJoinRowCount(leftStats, rightStats, normalizedConjuncts, joinType); - } else if (joinType.isInnerJoin() || joinType.isOuterJoin()) { - rowCount = getJoinRowCount(leftStats, rightStats, normalizedConjuncts, joinType); - } else if (joinType.isCrossJoin()) { + // List hashJoinConjuncts = join.getHashJoinConjuncts(); + // List normalizedConjuncts = hashJoinConjuncts.stream().map(EqualTo.class::cast) + // .map(e -> JoinUtils.swapEqualToForChildrenOrder(e, leftStats.getSlotToColumnStats().keySet())) + // .collect(Collectors.toList()); + + long rowCount; + if (joinType == JoinType.LEFT_SEMI_JOIN || joinType == JoinType.LEFT_ANTI_JOIN) { + rowCount = Math.round(leftStats.getRowCount() / DEFAULT_JOIN_RATIO) + 1; + } else if (joinType == JoinType.RIGHT_SEMI_JOIN || joinType == JoinType.RIGHT_ANTI_JOIN) { + rowCount = Math.round(rightStats.getRowCount() / DEFAULT_JOIN_RATIO) + 1; + } else if (joinType == JoinType.INNER_JOIN) { + long childRowCount = Math.max(leftStats.getRowCount(), rightStats.getRowCount()); + rowCount = Math.round(childRowCount / DEFAULT_JOIN_RATIO) + 1; + } else if (joinType == JoinType.LEFT_OUTER_JOIN) { + rowCount = leftStats.getRowCount(); + } else if (joinType == JoinType.RIGHT_OUTER_JOIN) { + rowCount = rightStats.getRowCount(); + } else if (joinType == JoinType.CROSS_JOIN) { rowCount = CheckedMath.checkedMultiply(leftStats.getRowCount(), rightStats.getRowCount()); } else { throw new RuntimeException("joinType is not supported"); } + + StatsDeriveResult statsDeriveResult = new StatsDeriveResult(rowCount, Maps.newHashMap()); + if (joinType.isRemainLeftJoin()) { + statsDeriveResult.merge(leftStats); + } + if (joinType.isRemainRightJoin()) { + statsDeriveResult.merge(rightStats); + } statsDeriveResult.setRowCount(rowCount); return statsDeriveResult; } @@ -78,7 +93,7 @@ private static Expression removeCast(Expression parent) { // TODO: If the condition of Join Plan could any expression in addition to EqualTo type, // we should handle that properly. private static long getSemiJoinRowCount(StatsDeriveResult leftStats, StatsDeriveResult rightStats, - List eqConjunctList, JoinType joinType) { + List hashConjuncts, JoinType joinType) { long rowCount; if (JoinType.RIGHT_SEMI_JOIN.equals(joinType) || JoinType.RIGHT_ANTI_JOIN.equals(joinType)) { if (rightStats.getRowCount() == -1) { @@ -94,10 +109,11 @@ private static long getSemiJoinRowCount(StatsDeriveResult leftStats, StatsDerive Map leftSlotToColStats = leftStats.getSlotToColumnStats(); Map rightSlotToColStats = rightStats.getSlotToColumnStats(); double minSelectivity = 1.0; - for (Expression eqJoinPredicate : eqConjunctList) { - long lhsNdv = leftSlotToColStats.get(removeCast(eqJoinPredicate.child(0))).getNdv(); + for (Expression hashConjunct : hashConjuncts) { + // TODO: since we have no column stats here. just use a fix ratio to compute the row count. + long lhsNdv = leftSlotToColStats.get(removeCast(hashConjunct.child(0))).getNdv(); lhsNdv = Math.min(lhsNdv, leftStats.getRowCount()); - long rhsNdv = rightSlotToColStats.get(removeCast(eqJoinPredicate.child(1))).getNdv(); + long rhsNdv = rightSlotToColStats.get(removeCast(hashConjunct.child(1))).getNdv(); rhsNdv = Math.min(rhsNdv, rightStats.getRowCount()); // Skip conjuncts with unknown NDV on either side. if (lhsNdv == -1 || rhsNdv == -1) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java index 6f16d06f04e9d9..ba25d62ce9408f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java @@ -23,7 +23,6 @@ import org.apache.doris.common.AnalysisException; import org.apache.doris.common.Pair; import org.apache.doris.nereids.memo.GroupExpression; -import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; @@ -83,6 +82,8 @@ */ public class StatsCalculator extends DefaultPlanVisitor { + private static final int DEFAULT_AGGREGATE_RATIO = 1000; + private final GroupExpression groupExpression; private StatsCalculator(GroupExpression groupExpression) { @@ -163,7 +164,7 @@ public StatsDeriveResult visitLogicalJoin(LogicalJoin assertNumRows, Void context) { - return groupExpression.getCopyOfChildStats(0); + return computeAssertNumRows(assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows()); } @Override @@ -235,7 +236,13 @@ public StatsDeriveResult visitPhysicalDistribute(PhysicalDistribute assertNumRows, Void context) { - return groupExpression.getCopyOfChildStats(0); + return computeAssertNumRows(assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows()); + } + + private StatsDeriveResult computeAssertNumRows(long desiredNumOfRows) { + StatsDeriveResult statsDeriveResult = groupExpression.getCopyOfChildStats(0); + statsDeriveResult.updateRowCountByLimit(1); + return statsDeriveResult; } private StatsDeriveResult computeFilter(Filter filter) { @@ -301,22 +308,21 @@ private StatsDeriveResult computeLimit(Limit limit) { } private StatsDeriveResult computeAggregate(Aggregate aggregate) { - List groupByExpressions = aggregate.getGroupByExpressions(); + // TODO: since we have no column stats here. just use a fix ratio to compute the row count. + // List groupByExpressions = aggregate.getGroupByExpressions(); StatsDeriveResult childStats = groupExpression.getCopyOfChildStats(0); - Map childSlotToColumnStats = childStats.getSlotToColumnStats(); - long resultSetCount = 1; - for (Expression groupByExpression : groupByExpressions) { - Set slots = groupByExpression.getInputSlots(); - // TODO: Support more complex group expr. - // For example: - // select max(col1+col3) from t1 group by col1+col3; - if (slots.size() != 1) { - continue; - } - Slot slotReference = slots.iterator().next(); - ColumnStats columnStats = childSlotToColumnStats.get(slotReference); - resultSetCount *= columnStats.getNdv(); + // Map childSlotToColumnStats = childStats.getSlotToColumnStats(); + // long resultSetCount = groupByExpressions.stream() + // .flatMap(expr -> expr.getInputSlots().stream()) + // .filter(childSlotToColumnStats::containsKey) + // .map(childSlotToColumnStats::get) + // .map(ColumnStats::getNdv) + // .reduce(1L, (a, b) -> a * b); + long resultSetCount = childStats.getRowCount() / DEFAULT_AGGREGATE_RATIO; + if (resultSetCount <= 0) { + resultSetCount = 1L; } + Map slotToColumnStats = Maps.newHashMap(); List outputExpressions = aggregate.getOutputExpressions(); // TODO: 1. Estimate the output unit size by the type of corresponding AggregateFunction diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java index 9badf47586126d..9b071b8986d9fc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java @@ -112,6 +112,14 @@ public final boolean isOuterJoin() { return this == LEFT_OUTER_JOIN || this == RIGHT_OUTER_JOIN || this == FULL_OUTER_JOIN; } + public final boolean isRemainLeftJoin() { + return this != RIGHT_SEMI_JOIN && this != RIGHT_ANTI_JOIN; + } + + public final boolean isRemainRightJoin() { + return this != LEFT_SEMI_JOIN && this != LEFT_ANTI_JOIN; + } + public final boolean isSwapJoinType() { return joinSwapMap.containsKey(this); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJobTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJobTest.java index 3b5172ed2c12fe..3a6f9ab4af0276 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJobTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJobTest.java @@ -74,7 +74,7 @@ public void testExecute() throws Exception { } StatsDeriveResult statistics = cascadesContext.getMemo().getRoot().getStatistics(); Assertions.assertNotNull(statistics); - Assertions.assertEquals(10, statistics.getRowCount()); + Assertions.assertEquals(1, statistics.getRowCount()); } private LogicalOlapScan constructOlapSCan() { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java index 0f47477f4cc93e..15501dced2ad49 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java @@ -22,21 +22,14 @@ import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.properties.LogicalProperties; -import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.EqualTo; -import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction; -import org.apache.doris.nereids.trees.expressions.functions.Sum; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.plans.GroupPlan; -import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; @@ -51,14 +44,12 @@ import com.google.common.base.Supplier; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; import mockit.Expectations; import mockit.Mocked; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -74,43 +65,44 @@ public class StatsCalculatorTest { @Mocked StatisticsManager statisticsManager; - @Test - public void testAgg() { - List qualifier = new ArrayList<>(); - qualifier.add("test"); - qualifier.add("t"); - SlotReference slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier); - SlotReference slot2 = new SlotReference("c2", IntegerType.INSTANCE, true, qualifier); - ColumnStats columnStats1 = new ColumnStats(); - columnStats1.setNdv(10); - columnStats1.setNumNulls(5); - ColumnStats columnStats2 = new ColumnStats(); - columnStats2.setNdv(20); - columnStats1.setNumNulls(10); - Map slotColumnStatsMap = new HashMap<>(); - slotColumnStatsMap.put(slot1, columnStats1); - slotColumnStatsMap.put(slot2, columnStats2); - List groupByExprList = new ArrayList<>(); - groupByExprList.add(slot1); - AggregateFunction sum = new Sum(slot2); - StatsDeriveResult childStats = new StatsDeriveResult(20, slotColumnStatsMap); - Alias alias = new Alias(sum, "a"); - Group childGroup = new Group(); - childGroup.setLogicalProperties(new LogicalProperties(new Supplier>() { - @Override - public List get() { - return Collections.emptyList(); - } - })); - GroupPlan groupPlan = new GroupPlan(childGroup); - childGroup.setStatistics(childStats); - LogicalAggregate logicalAggregate = new LogicalAggregate(groupByExprList, Arrays.asList(alias), groupPlan); - GroupExpression groupExpression = new GroupExpression(logicalAggregate, Arrays.asList(childGroup)); - Group ownerGroup = new Group(); - groupExpression.setOwnerGroup(ownerGroup); - StatsCalculator.estimate(groupExpression); - Assertions.assertEquals(groupExpression.getOwnerGroup().getStatistics().getRowCount(), 10); - } + // TODO: temporary disable this test, until we could get column stats + // @Test + // public void testAgg() { + // List qualifier = new ArrayList<>(); + // qualifier.add("test"); + // qualifier.add("t"); + // SlotReference slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier); + // SlotReference slot2 = new SlotReference("c2", IntegerType.INSTANCE, true, qualifier); + // ColumnStats columnStats1 = new ColumnStats(); + // columnStats1.setNdv(10); + // columnStats1.setNumNulls(5); + // ColumnStats columnStats2 = new ColumnStats(); + // columnStats2.setNdv(20); + // columnStats1.setNumNulls(10); + // Map slotColumnStatsMap = new HashMap<>(); + // slotColumnStatsMap.put(slot1, columnStats1); + // slotColumnStatsMap.put(slot2, columnStats2); + // List groupByExprList = new ArrayList<>(); + // groupByExprList.add(slot1); + // AggregateFunction sum = new Sum(slot2); + // StatsDeriveResult childStats = new StatsDeriveResult(20, slotColumnStatsMap); + // Alias alias = new Alias(sum, "a"); + // Group childGroup = new Group(); + // childGroup.setLogicalProperties(new LogicalProperties(new Supplier>() { + // @Override + // public List get() { + // return Collections.emptyList(); + // } + // })); + // GroupPlan groupPlan = new GroupPlan(childGroup); + // childGroup.setStatistics(childStats); + // LogicalAggregate logicalAggregate = new LogicalAggregate(groupByExprList, Arrays.asList(alias), groupPlan); + // GroupExpression groupExpression = new GroupExpression(logicalAggregate, Arrays.asList(childGroup)); + // Group ownerGroup = new Group(); + // groupExpression.setOwnerGroup(ownerGroup); + // StatsCalculator.estimate(groupExpression); + // Assertions.assertEquals(groupExpression.getOwnerGroup().getStatistics().getRowCount(), 10); + // } @Test public void testFilter() { @@ -164,42 +156,43 @@ public List get() { ownerGroupOr.getStatistics().getRowCount(), 0.001); } - @Test - public void testHashJoin() { - List qualifier = ImmutableList.of("test", "t"); - SlotReference slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier); - SlotReference slot2 = new SlotReference("c2", IntegerType.INSTANCE, true, qualifier); - ColumnStats columnStats1 = new ColumnStats(); - columnStats1.setNdv(10); - columnStats1.setNumNulls(5); - ColumnStats columnStats2 = new ColumnStats(); - columnStats2.setNdv(20); - columnStats1.setNumNulls(10); - Map slotColumnStatsMap1 = new HashMap<>(); - slotColumnStatsMap1.put(slot1, columnStats1); - - Map slotColumnStatsMap2 = new HashMap<>(); - slotColumnStatsMap2.put(slot2, columnStats2); - - final long leftRowCount = 5000; - StatsDeriveResult leftStats = new StatsDeriveResult(leftRowCount, slotColumnStatsMap1); - - final long rightRowCount = 10000; - StatsDeriveResult rightStats = new StatsDeriveResult(rightRowCount, slotColumnStatsMap2); - - EqualTo equalTo = new EqualTo(slot1, slot2); - - LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t", 0); - LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(0, "t", 0); - LogicalJoin fakeSemiJoin = new LogicalJoin<>( - JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(equalTo), Optional.empty(), scan1, scan2); - LogicalJoin fakeInnerJoin = new LogicalJoin<>( - JoinType.INNER_JOIN, Lists.newArrayList(equalTo), Optional.empty(), scan1, scan2); - StatsDeriveResult semiJoinStats = JoinEstimation.estimate(leftStats, rightStats, fakeSemiJoin); - Assertions.assertEquals(leftRowCount, semiJoinStats.getRowCount()); - StatsDeriveResult innerJoinStats = JoinEstimation.estimate(leftStats, rightStats, fakeInnerJoin); - Assertions.assertEquals(2500000, innerJoinStats.getRowCount()); - } + // TODO: temporary disable this test, until we could get column stats + // @Test + // public void testHashJoin() { + // List qualifier = ImmutableList.of("test", "t"); + // SlotReference slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier); + // SlotReference slot2 = new SlotReference("c2", IntegerType.INSTANCE, true, qualifier); + // ColumnStats columnStats1 = new ColumnStats(); + // columnStats1.setNdv(10); + // columnStats1.setNumNulls(5); + // ColumnStats columnStats2 = new ColumnStats(); + // columnStats2.setNdv(20); + // columnStats1.setNumNulls(10); + // Map slotColumnStatsMap1 = new HashMap<>(); + // slotColumnStatsMap1.put(slot1, columnStats1); + // + // Map slotColumnStatsMap2 = new HashMap<>(); + // slotColumnStatsMap2.put(slot2, columnStats2); + // + // final long leftRowCount = 5000; + // StatsDeriveResult leftStats = new StatsDeriveResult(leftRowCount, slotColumnStatsMap1); + // + // final long rightRowCount = 10000; + // StatsDeriveResult rightStats = new StatsDeriveResult(rightRowCount, slotColumnStatsMap2); + // + // EqualTo equalTo = new EqualTo(slot1, slot2); + // + // LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t", 0); + // LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(0, "t", 0); + // LogicalJoin fakeSemiJoin = new LogicalJoin<>( + // JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(equalTo), Optional.empty(), scan1, scan2); + // LogicalJoin fakeInnerJoin = new LogicalJoin<>( + // JoinType.INNER_JOIN, Lists.newArrayList(equalTo), Optional.empty(), scan1, scan2); + // StatsDeriveResult semiJoinStats = JoinEstimation.estimate(leftStats, rightStats, fakeSemiJoin); + // Assertions.assertEquals(leftRowCount, semiJoinStats.getRowCount()); + // StatsDeriveResult innerJoinStats = JoinEstimation.estimate(leftStats, rightStats, fakeInnerJoin); + // Assertions.assertEquals(2500000, innerJoinStats.getRowCount()); + // } @Test public void testOlapScan() {