Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public Edge(LogicalJoin join, int index, BitSet leftChildEdges, BitSet rightChil
this.subTreeNodes = subTreeNodes;
}

public LogicalJoin getJoin() {
public LogicalJoin<? extends Plan, ? extends Plan> getJoin() {
return join;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter;
import org.apache.doris.nereids.stats.JoinEstimation;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
Expand All @@ -41,6 +42,7 @@
import java.util.List;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.Stack;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -73,6 +75,8 @@ public class GraphSimplifier {
private final Stack<SimplificationStep> appliedSteps = new Stack<>();
private final Stack<SimplificationStep> unAppliedSteps = new Stack<>();

private final Set<Edge> validEdges;

/**
* Create a graph simplifier
*
Expand All @@ -89,6 +93,23 @@ public GraphSimplifier(HyperGraph graph) {
cacheStats.put(node.getNodeMap(), node.getGroup().getStatistics());
cacheCost.put(node.getNodeMap(), Cost.withRowCount(node.getRowCount()));
}
validEdges = graph.getEdges().stream()
.filter(e -> {
for (Slot slot : e.getJoin().getConditionSlot()) {
boolean contains = false;
for (int nodeIdx : LongBitmap.getIterator(e.getReferenceNodes())) {
if (graph.getNode(nodeIdx).getOutput().contains(slot)) {
contains = true;
break;
}
}
if (!contains) {
return false;
}
}
return true;
})
.collect(Collectors.toSet());
circleDetector = new CircleDetector(edgeSize);

// init first simplification step
Expand Down Expand Up @@ -240,6 +261,13 @@ private void applyStepsWithNum(int num) {
}
}

public @Nullable Pair<Long, Long> getLastAppliedSteps() {
if (appliedSteps.isEmpty()) {
return null;
}
return Pair.of(appliedSteps.peek().newLeft, appliedSteps.peek().newRight);
}

/**
* Process all neighbors and try to make simplification step
* Note that when a given ordering is less advantageous and dropped out,
Expand Down Expand Up @@ -308,8 +336,10 @@ private void updatePriorityQueue(int index) {
private Optional<SimplificationStep> makeSimplificationStep(int edgeIndex1, int edgeIndex2) {
Edge edge1 = graph.getEdge(edgeIndex1);
Edge edge2 = graph.getEdge(edgeIndex2);
if (edge1.isSub(edge2) || edge2.isSub(edge1) || circleDetector.checkCircleWithEdge(edgeIndex1, edgeIndex2)
|| circleDetector.checkCircleWithEdge(edgeIndex2, edgeIndex1)) {
if (edge1.isSub(edge2) || edge2.isSub(edge1)
|| circleDetector.checkCircleWithEdge(edgeIndex1, edgeIndex2)
|| circleDetector.checkCircleWithEdge(edgeIndex2, edgeIndex1)
|| !validEdges.contains(edge1) || !validEdges.contains(edge2)) {
return Optional.empty();
}
long left1 = edge1.getLeftExtendedNodes();
Expand Down Expand Up @@ -346,6 +376,7 @@ private Optional<SimplificationStep> makeSimplificationStep(int edgeIndex1, int
} else {
return Optional.empty();
}

// edge1 is not the neighborhood of edge2
SimplificationStep simplificationStep = orderJoin(edge1Before2, edge2Before1, edgeIndex1, edgeIndex2);
return Optional.of(simplificationStep);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;

/**
* HyperGraph Node.
Expand Down Expand Up @@ -83,4 +85,8 @@ public double getRowCount() {
public Group getGroup() {
return group;
}

public Set<Slot> getOutput() {
return group.getLogicalExpression().getPlan().getOutputSet();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,55 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph;

import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.HyperGraphBuilder;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.statistics.Statistics;

import com.google.common.collect.Sets;
import org.apache.hadoop.util.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.HashMap;

class GraphSimplifierTest {
private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
private static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);

@Test
void testComplexProject() {
Alias alias1 = new Alias(scan1.getOutput().get(0), "p1");
LogicalPlan project1 = new LogicalPlanBuilder(scan1)
.projectExprs(Lists.newArrayList(alias1)).build();
Alias alias2 = new Alias(scan2.getOutput().get(0), "p2");
LogicalPlan project2 = new LogicalPlanBuilder(scan2)
.projectExprs(Lists.newArrayList(alias2)).build();
Alias alias3 = new Alias(scan3.getOutput().get(0), "p3");
LogicalPlan project3 = new LogicalPlanBuilder(scan3)
.projectExprs(Lists.newArrayList(alias3)).build();
LogicalPlan join = new LogicalPlanBuilder(project1)
.join(project2, JoinType.INNER_JOIN, Lists.newArrayList(new EqualTo(alias1.toSlot(), alias2.toSlot())), new ArrayList<>())
.join(project3, JoinType.INNER_JOIN, Lists.newArrayList(new EqualTo(alias2.toSlot(), alias3.toSlot())), new ArrayList<>())
.build();
HyperGraph hyperGraph = HyperGraphBuilder.buildHyperGraphFromPlan(join);
for (Node node : hyperGraph.getNodes()) {
node.getGroup().setStatistics(new Statistics(1, new HashMap<>()));
}
GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph);
while (graphSimplifier.applySimplificationStep()) {
}
Assertions.assertNull(graphSimplifier.getLastAppliedSteps());
}

@Test
void testStarQuery() {
// t1
Expand All @@ -48,6 +88,7 @@ void testStarQuery() {
SubgraphEnumerator subgraphEnumerator = new SubgraphEnumerator(counter, hyperGraph);
subgraphEnumerator.enumerate();
for (int count : counter.getAllCount().values()) {
System.out.println(count);
Assertions.assertTrue(count < 10);
}
Assertions.assertTrue(graphSimplifier.isTotalOrder());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,17 @@ private HyperGraph buildHyperGraph(Plan plan) {
return hyperGraph;
}

public static HyperGraph buildHyperGraphFromPlan(Plan plan) {
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(MemoTestUtils.createConnectContext(),
plan);
JoinOrderJob joinOrderJob = new JoinOrderJob(cascadesContext.getMemo().getRoot(),
cascadesContext.getCurrentJobContext());
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
HyperGraph hyperGraph = new HyperGraph();
joinOrderJob.buildGraph(cascadesContext.getMemo().getRoot(), hyperGraph);
return hyperGraph;
}

private void injectRowcount(Group group) {
if (!group.isValidJoinGroup()) {
LogicalOlapScan scanPlan = (LogicalOlapScan) group.getLogicalExpression().getPlan();
Expand Down