Skip to content

Commit

Permalink
[CALCITE-4302] Improve cost propagation in volcano to avoid re-propag…
Browse files Browse the repository at this point in the history
…ation (Botong Huang)

CALCITE-3330 changed the cost propagation in volcano from DFS to BFS.
However, there is still room for improvement. A subset can be updated
more than once in a cost propagation process. For instance, A -> D, A ->
B -> C -> D. When subset A has an update, using BFS subset D (and thus
all subsets above/after D) can be updated twice, first via A -> D and
then C -> D. We can further improve the BFS by always popping the
relNode with the smallest cost from the queue, similar to the Dijkstra
algorithm. So that whenever a relNode is popped from the queue, its
current best cannot be further deceased any more. As a result, all
subsets will only be propagated at most once.

close apache#2187
  • Loading branch information
hbtoo authored and XuQianJin-Stars committed Jul 14, 2021
1 parent 71d69da commit d56abce
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 126 deletions.
27 changes: 6 additions & 21 deletions core/src/main/java/org/apache/calcite/plan/volcano/RelSet.java
Expand Up @@ -27,7 +27,6 @@
import org.apache.calcite.rel.convert.Converter;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.Spool;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.trace.CalciteTrace;

Expand All @@ -37,9 +36,7 @@

import java.util.ArrayList;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -353,13 +350,12 @@ void mergeWith(
LOGGER.trace("Merge set#{} into set#{}", otherSet.id, id);
otherSet.equivalentSet = this;
RelOptCluster cluster = rel.getCluster();
RelMetadataQuery mq = cluster.getMetadataQuery();

// remove from table
boolean existed = planner.allSets.remove(otherSet);
assert existed : "merging with a dead otherSet";

Map<RelSubset, RelNode> changedSubsets = new IdentityHashMap<>();
Set<RelNode> changedRels = new HashSet<>();

// merge subsets
for (RelSubset otherSubset : otherSet.subsets) {
Expand All @@ -386,7 +382,7 @@ void mergeWith(

// collect RelSubset instances, whose best should be changed
if (otherSubset.bestCost.isLt(subset.bestCost)) {
changedSubsets.put(subset, otherSubset.best);
changedRels.add(otherSubset.best);
}
}

Expand All @@ -410,17 +406,10 @@ void mergeWith(
// Has another set merged with this?
assert equivalentSet == null;

// calls propagateCostImprovements() for RelSubset instances,
// whose best should be changed to check whether that
// subset's parents get cheaper.
Set<RelSubset> activeSet = new HashSet<>();
for (Map.Entry<RelSubset, RelNode> subsetBestPair : changedSubsets.entrySet()) {
RelSubset relSubset = subsetBestPair.getKey();
relSubset.propagateCostImprovements(
planner, mq, subsetBestPair.getValue(),
activeSet);
// propagate the new best information from changed relNodes.
for (RelNode rel : changedRels) {
planner.propagateCostImprovements(rel);
}
assert activeSet.isEmpty();

// Update all rels which have a child in the other set, to reflect the
// fact that the child has been renamed.
Expand All @@ -441,12 +430,8 @@ void mergeWith(

// Make sure the cost changes as a result of merging are propagated.
for (RelNode parentRel : getParentRels()) {
final RelSubset parentSubset = planner.getSubset(parentRel);
parentSubset.propagateCostImprovements(
planner, mq, parentRel,
activeSet);
planner.propagateCostImprovements(parentRel);
}
assert activeSet.isEmpty();
assert equivalentSet == null;

// Each of the relations in the old set now has new parents, so
Expand Down
Expand Up @@ -45,15 +45,13 @@

import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -159,7 +157,7 @@ public class RelSubset extends AbstractRelNode {
* <ol>
* <li>If the are no subsuming subsets, the subset is initially empty.</li>
* <li>After creation, {@code best} and {@code bestCost} are maintained
* incrementally by {@link #propagateCostImprovements0} and
* incrementally by {@link VolcanoPlanner#propagateCostImprovements} and
* {@link RelSet#mergeWith(VolcanoPlanner, RelSet)}.</li>
* </ol>
*/
Expand Down Expand Up @@ -375,76 +373,6 @@ RelNode buildCheapestPlan(VolcanoPlanner planner) {
return cheapest;
}

/**
* Checks whether a relexp has made its subset cheaper, and if it so,
* propagate new cost to parent rel nodes using breadth first manner.
*
* @param planner Planner
* @param mq Metadata query
* @param rel Relational expression whose cost has improved
* @param activeSet Set of active subsets, for cycle detection
*/
void propagateCostImprovements(VolcanoPlanner planner, RelMetadataQuery mq,
RelNode rel, Set<RelSubset> activeSet) {
Queue<Pair<RelSubset, RelNode>> propagationQueue = new ArrayDeque<>();
for (RelSubset subset : set.subsets) {
if (rel.getTraitSet().satisfies(subset.traitSet)) {
propagationQueue.offer(Pair.of(subset, rel));
}
}

while (!propagationQueue.isEmpty()) {
Pair<RelSubset, RelNode> p = propagationQueue.poll();
p.left.propagateCostImprovements0(planner, mq, p.right, activeSet, propagationQueue);
}
}

void propagateCostImprovements0(VolcanoPlanner planner, RelMetadataQuery mq,
RelNode rel, Set<RelSubset> activeSet,
Queue<Pair<RelSubset, RelNode>> propagationQueue) {
++timestamp;

if (!activeSet.add(this)) {
// This subset is already in the chain being propagated to. This
// means that the graph is cyclic, and therefore the cost of this
// relational expression - not this subset - must be infinite.
LOGGER.trace("cyclic: {}", this);
return;
}
try {
RelOptCost cost = planner.getCost(rel, mq);

// Update subset best cost when we find a cheaper rel or the current
// best's cost is changed
if (cost.isLt(bestCost)) {
LOGGER.trace("Subset cost changed: subset [{}] cost was {} now {}",
this, bestCost, cost);

bestCost = cost;
best = rel;
upperBound = bestCost;
// since best was changed, cached metadata for this subset should be removed
mq.clearCache(this);

// Propagate cost change to parents
for (RelNode parent : getParents()) {
// removes parent cached metadata since its input was changed
mq.clearCache(parent);
final RelSubset parentSubset = planner.getSubset(parent);

// parent subset will clear its cache in propagateCostImprovements0 method itself
for (RelSubset subset : parentSubset.set.subsets) {
if (parent.getTraitSet().satisfies(subset.traitSet)) {
propagationQueue.offer(Pair.of(subset, parent));
}
}
}
}
} finally {
activeSet.remove(this);
}
}

@Override public void collectVariablesUsed(Set<CorrelationId> variableSet) {
variableSet.addAll(set.variablesUsed);
}
Expand Down
Expand Up @@ -73,6 +73,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -884,13 +885,9 @@ void rename(RelNode rel) {
final RelSubset equivSubset = getSubset(equivRel);
for (RelSubset s : subset.set.subsets) {
if (s.best == rel) {
Set<RelSubset> activeSet = new HashSet<>();
s.best = equivRel;

// Propagate cost improvement since this potentially would change the subset's best cost
s.propagateCostImprovements(
this, equivRel.getCluster().getMetadataQuery(),
equivRel, activeSet);
propagateCostImprovements(equivRel);
}
}

Expand All @@ -906,6 +903,67 @@ void rename(RelNode rel) {
}
}

/**
* Checks whether a relexp has made any subset cheaper, and if it so,
* propagate new cost to parent rel nodes.
*
* @param rel Relational expression whose cost has improved
*/
void propagateCostImprovements(RelNode rel) {
RelMetadataQuery mq = rel.getCluster().getMetadataQuery();
Map<RelNode, RelOptCost> propagateRels = new HashMap<>();
PriorityQueue<RelNode> propagateHeap = new PriorityQueue<>((o1, o2) -> {
RelOptCost c1 = propagateRels.get(o1);
RelOptCost c2 = propagateRels.get(o2);
if (c1.equals(c2)) {
return 0;
} else if (c1.isLt(c2)) {
return -1;
}
return 1;
});
propagateRels.put(rel, getCost(rel, mq));
propagateHeap.offer(rel);

while (!propagateHeap.isEmpty()) {
RelNode relNode = propagateHeap.poll();
RelOptCost cost = propagateRels.get(relNode);

for (RelSubset subset : getSet(relNode).subsets) {
if (!relNode.getTraitSet().satisfies(subset.getTraitSet())) {
continue;
}
if (!cost.isLt(subset.bestCost)) {
continue;
}
// Update subset best cost when we find a cheaper rel or the current
// best's cost is changed
subset.timestamp++;
LOGGER.trace("Subset cost changed: subset [{}] cost was {} now {}",
subset, subset.bestCost, cost);

subset.bestCost = cost;
subset.best = relNode;
// since best was changed, cached metadata for this subset should be removed
mq.clearCache(subset);

for (RelNode parent : subset.getParents()) {
mq.clearCache(parent);
RelOptCost newCost = getCost(parent, mq);
RelOptCost existingCost = propagateRels.get(parent);
if (existingCost == null || newCost.isLt(existingCost)) {
propagateRels.put(parent, newCost);
if (existingCost != null) {
// Cost reduced, force the heap to adjust its ordering
propagateHeap.remove(parent);
}
propagateHeap.offer(parent);
}
}
}
}
}

/**
* Registers a {@link RelNode}, which has already been registered, in a new
* {@link RelSet}.
Expand Down Expand Up @@ -1263,9 +1321,8 @@ private RelSubset addRelToSet(RelNode rel, RelSet set) {
// 100. We think this happens because the back-links to parents are
// not established. So, give the subset another chance to figure out
// its cost.
final RelMetadataQuery mq = rel.getCluster().getMetadataQuery();
try {
subset.propagateCostImprovements(this, mq, rel, new HashSet<>());
propagateCostImprovements(rel);
} catch (CyclicMetadataException e) {
// ignore
}
Expand Down
29 changes: 18 additions & 11 deletions druid/src/test/java/org/apache/calcite/test/DruidAdapter2IT.java
Expand Up @@ -264,8 +264,8 @@ private CalciteAssert.AssertQuery sql(String sql) {
}

@Test void testSortLimit() {
final String explain = "PLAN=EnumerableInterpreter\n"
+ " BindableSort(sort0=[$1], sort1=[$0], dir0=[ASC], dir1=[DESC], offset=[2], fetch=[3])\n"
final String explain = "PLAN=EnumerableLimit(offset=[2], fetch=[3])\n"
+ " EnumerableInterpreter\n"
+ " DruidQuery(table=[[foodmart, foodmart]], "
+ "intervals=[[1900-01-09T00:00:00.000Z/2992-01-10T00:00:00.000Z]], projects=[[$39, $30]], "
+ "groups=[{0, 1}], aggs=[[]], sort0=[1], sort1=[0], dir0=[ASC], dir1=[DESC])";
Expand Down Expand Up @@ -914,8 +914,8 @@ private void checkGroupBySingleSortLimit(boolean approx) {
+ " \"timestamp\" < '1997-09-01 00:00:00'\n"
+ "group by \"state_province\", floor(\"timestamp\" to DAY)\n"
+ "order by s desc limit 6";
final String explain = "PLAN=EnumerableInterpreter\n"
+ " BindableProject(S=[$2], M=[$3], P=[$0])\n"
final String explain = "PLAN=EnumerableCalc(expr#0..3=[{inputs}], S=[$t2], M=[$t3], P=[$t0])\n"
+ " EnumerableInterpreter\n"
+ " DruidQuery(table=[[foodmart, foodmart]], "
+ "intervals=[[1997-01-01T00:00:00.000Z/1997-09-01T00:00:00.000Z]], projects=[[$30, FLOOR"
+ "($0, FLAG(DAY)), $89]], groups=[{0, 1}], aggs=[[SUM($2), MAX($2)]], sort0=[2], "
Expand Down Expand Up @@ -955,7 +955,9 @@ private void checkGroupBySingleSortLimit(boolean approx) {
+ "from \"foodmart\"\n"
+ "group by \"state_province\", \"city\"\n"
+ "order by c desc limit 2";
final String explain = "BindableProject(C=[$2], state_province=[$0], city=[$1])\n"
final String explain = "PLAN=EnumerableCalc(expr#0..2=[{inputs}], C=[$t2], "
+ "state_province=[$t0], city=[$t1])\n"
+ " EnumerableInterpreter\n"
+ " DruidQuery(table=[[foodmart, foodmart]], intervals=[[1900-01-09T00:00:00.000Z/2992-01-10T00:00:00.000Z]], projects=[[$30, $29]], groups=[{0, 1}], aggs=[[COUNT()]], sort0=[2], dir0=[DESC], fetch=[2])";
sql(sql)
.returnsOrdered("C=7394; state_province=WA; city=Spokane",
Expand Down Expand Up @@ -3390,7 +3392,8 @@ private void testCountWithApproxDistinct(boolean approx, String sql, String expe
+ "Group by \"timestamp\" order by s LIMIT 2";
sql(sql)
.returnsOrdered("S=-15918.020000000002\nS=-14115.959999999988")
.explainContains("BindableProject(S=[$1])\n"
.explainContains("PLAN=EnumerableCalc(expr#0..1=[{inputs}], S=[$t1])\n"
+ " EnumerableInterpreter\n"
+ " DruidQuery(table=[[foodmart, foodmart]], intervals=[[1900-01-09T00:00:00.000Z/"
+ "2992-01-10T00:00:00.000Z]], projects=[[$0, *(-($90), 2)]], groups=[{0}], "
+ "aggs=[[SUM($1)]], sort0=[1], dir0=[ASC], fetch=[2])")
Expand Down Expand Up @@ -3533,8 +3536,9 @@ private void testCountWithApproxDistinct(boolean approx, String sql, String expe
CalciteAssert.AssertQuery q = sql(sql)
.queryContains(
new DruidChecker("\"queryType\":\"groupBy\"", extract_year, extract_expression))
.explainContains("PLAN=EnumerableInterpreter\n"
+ " BindableProject(QR_TIMESTAMP_OK=[$0], SUM_STORE_SALES=[$2], YR_TIMESTAMP_OK=[$1])\n"
.explainContains("PLAN=EnumerableCalc(expr#0..2=[{inputs}], QR_TIMESTAMP_OK=[$t0], "
+ "SUM_STORE_SALES=[$t2], YR_TIMESTAMP_OK=[$t1])\n"
+ " EnumerableInterpreter\n"
+ " DruidQuery(table=[[foodmart, foodmart]], intervals=[[1900-01-09T00:00:00.000Z/"
+ "2992-01-10T00:00:00.000Z]], projects=[[+(/(-(EXTRACT(FLAG(MONTH), $0), 1), 3), 1), "
+ "EXTRACT(FLAG(YEAR), $0), $90]], groups=[{0, 1}], aggs=[[SUM($2)]], fetch=[1])");
Expand Down Expand Up @@ -3569,8 +3573,9 @@ private void testCountWithApproxDistinct(boolean approx, String sql, String expe
+ " CAST(SUBSTRING(CAST(\"foodmart\".\"timestamp\" AS VARCHAR) from 12 for 2 ) AS INT),"
+ " MINUTE(\"foodmart\".\"timestamp\"), EXTRACT(HOUR FROM \"timestamp\")) LIMIT 1";
CalciteAssert.AssertQuery q = sql(sql)
.explainContains("BindableProject(HR_T_TIMESTAMP_OK=[$0], MI_T_TIMESTAMP_OK=[$1], "
+ "SUM_T_OTHER_OK=[$3], HR_T_TIMESTAMP_OK2=[$2])\n"
.explainContains("PLAN=EnumerableCalc(expr#0..3=[{inputs}], proj#0..1=[{exprs}], "
+ "SUM_T_OTHER_OK=[$t3], HR_T_TIMESTAMP_OK2=[$t2])\n"
+ " EnumerableInterpreter\n"
+ " DruidQuery(table=[[foodmart, foodmart]], intervals=[[1900-01-09T00:00:00.000Z/"
+ "2992-01-10T00:00:00.000Z]], projects=[[CAST(SUBSTRING(CAST($0):VARCHAR"
+ " "
Expand Down Expand Up @@ -3686,7 +3691,9 @@ private void testCountWithApproxDistinct(boolean approx, String sql, String expe
+ "SUM(\"store_sales\") as S1, SUM(\"store_sales\") as S2 FROM " + FOODMART_TABLE
+ " GROUP BY \"product_id\" ORDER BY prod_id2 LIMIT 1";
CalciteAssert.AssertQuery q = sql(sql)
.explainContains("BindableProject(PROD_ID1=[$0], PROD_ID2=[$0], S1=[$1], S2=[$1])\n"
.explainContains("PLAN=EnumerableCalc(expr#0..1=[{inputs}], PROD_ID1=[$t0], "
+ "PROD_ID2=[$t0], S1=[$t1], S2=[$t1])\n"
+ " EnumerableInterpreter\n"
+ " DruidQuery(table=[[foodmart, foodmart]], intervals=[[1900-01-09T00:00:00.000Z/"
+ "2992-01-10T00:00:00.000Z]], projects=[[$1, $90]], groups=[{0}], aggs=[[SUM($1)]], "
+ "sort0=[0], dir0=[ASC], fetch=[1])")
Expand Down

0 comments on commit d56abce

Please sign in to comment.