Skip to content

Commit

Permalink
DRILL-6798: Planner changes to support semi-join.
Browse files Browse the repository at this point in the history
  • Loading branch information
HanumathRao authored and Boaz Ben-Zvi committed Nov 1, 2018
1 parent 5fff1d8 commit 71809ca
Show file tree
Hide file tree
Showing 20 changed files with 553 additions and 61 deletions.
Expand Up @@ -358,15 +358,18 @@ static RuleSet getDrillBasicRules(OptimizerRulesContext optimizerRulesContext) {
* We have to create another copy of the ruleset with the context dependent elements;
* this cannot be reused across queries.
*/
final ImmutableSet<RelOptRule> basicRules = ImmutableSet.<RelOptRule>builder()
ImmutableSet.Builder<RelOptRule> basicRules = ImmutableSet.<RelOptRule>builder()
.addAll(staticRuleSet)
.add(
DrillMergeProjectRule.getInstance(true, RelFactories.DEFAULT_PROJECT_FACTORY,
optimizerRulesContext.getFunctionRegistry())
)
.build();
);
if (optimizerRulesContext.getPlannerSettings().isHashJoinEnabled() &&
optimizerRulesContext.getPlannerSettings().isSemiJoinEnabled()) {
basicRules.add(RuleInstance.SEMI_JOIN_PROJECT_RULE);
}

return RuleSets.ofList(basicRules);
return RuleSets.ofList(basicRules.build());
}

/**
Expand Down Expand Up @@ -474,7 +477,6 @@ static RuleSet getJoinPermRules(OptimizerRulesContext optimizerRulesContext) {
static RuleSet getPhysicalRules(OptimizerRulesContext optimizerRulesContext) {
final List<RelOptRule> ruleList = new ArrayList<>();
final PlannerSettings ps = optimizerRulesContext.getPlannerSettings();

ruleList.add(ConvertCountToDirectScan.AGG_ON_PROJ_ON_SCAN);
ruleList.add(ConvertCountToDirectScan.AGG_ON_SCAN);
ruleList.add(SortConvertPrule.INSTANCE);
Expand Down Expand Up @@ -509,9 +511,14 @@ static RuleSet getPhysicalRules(OptimizerRulesContext optimizerRulesContext) {

if (ps.isHashJoinEnabled()) {
ruleList.add(HashJoinPrule.DIST_INSTANCE);

if (ps.isSemiJoinEnabled()) {
ruleList.add(HashJoinPrule.SEMI_DIST_INSTANCE);
}
if(ps.isBroadcastJoinEnabled()){
ruleList.add(HashJoinPrule.BROADCAST_INSTANCE);
if (ps.isSemiJoinEnabled()) {
ruleList.add(HashJoinPrule.SEMI_BROADCAST_INSTANCE);
}
}
}

Expand All @@ -521,7 +528,6 @@ static RuleSet getPhysicalRules(OptimizerRulesContext optimizerRulesContext) {
if(ps.isBroadcastJoinEnabled()){
ruleList.add(MergeJoinPrule.BROADCAST_INSTANCE);
}

}

// NLJ plans consist of broadcasting the right child, hence we need
Expand Down
Expand Up @@ -18,8 +18,11 @@
package org.apache.drill.exec.planner;

import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.volcano.AbstractConverter;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalCalc;
import org.apache.calcite.rel.logical.LogicalJoin;
Expand All @@ -39,12 +42,13 @@
import org.apache.calcite.rel.rules.ProjectToWindowRule;
import org.apache.calcite.rel.rules.ProjectWindowTransposeRule;
import org.apache.calcite.rel.rules.ReduceExpressionsRule;
import org.apache.calcite.rel.rules.SemiJoinRule;
import org.apache.calcite.rel.rules.SortRemoveRule;
import org.apache.calcite.rel.rules.SubQueryRemoveRule;
import org.apache.calcite.rel.rules.UnionToDistinctRule;
import org.apache.drill.exec.planner.logical.DrillConditions;
import org.apache.drill.exec.planner.logical.DrillRelFactories;

import org.apache.drill.shaded.guava.com.google.common.base.Preconditions;
/**
* Contains rule instances which use custom RelBuilder.
*/
Expand All @@ -58,6 +62,15 @@ public interface RuleInstance {
new UnionToDistinctRule(LogicalUnion.class,
DrillRelFactories.LOGICAL_BUILDER);

SemiJoinRule SEMI_JOIN_PROJECT_RULE = new SemiJoinRule.ProjectToSemiJoinRule(Project.class, Join.class, Aggregate.class,
DrillRelFactories.LOGICAL_BUILDER, "DrillSemiJoinRule:project") {
public boolean matches(RelOptRuleCall call) {
Preconditions.checkArgument(call.rel(1) instanceof Join);
Join join = call.rel(1);
return !(join.getCondition().isAlwaysTrue() || join.getCondition().isAlwaysFalse());
}
};

JoinPushExpressionsRule JOIN_PUSH_EXPRESSIONS_RULE =
new JoinPushExpressionsRule(Join.class,
DrillRelFactories.LOGICAL_BUILDER);
Expand Down
Expand Up @@ -29,6 +29,7 @@
import org.apache.drill.exec.physical.impl.join.JoinUtils;
import org.apache.drill.exec.physical.impl.join.JoinUtils.JoinCategory;
import org.apache.drill.exec.planner.cost.DrillCostBase.DrillCostFactory;
import org.apache.drill.exec.planner.logical.DrillJoin;
import org.apache.drill.exec.planner.physical.PrelUtil;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
Expand All @@ -45,7 +46,7 @@
/**
* Base class for logical and physical Joins implemented in Drill.
*/
public abstract class DrillJoinRelBase extends Join implements DrillRelNode {
public abstract class DrillJoinRelBase extends Join implements DrillJoin {
protected List<Integer> leftKeys = Lists.newArrayList();
protected List<Integer> rightKeys = Lists.newArrayList();

Expand Down
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.planner.logical;

import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rex.RexNode;
import org.apache.drill.exec.planner.common.DrillRelNode;
import java.util.List;

/**
* Interface which needs to be implemented by all the join relation expressions.
*/
public interface DrillJoin extends DrillRelNode {

/* Columns of left table that are part of join condition */
List<Integer> getLeftKeys();

/* Columns of right table that are part of join condition */
List<Integer> getRightKeys();

/* JoinType of the join operation*/
JoinRelType getJoinType();

/* Join condition of the join relation */
RexNode getCondition();

/* Left RelNode of the Join Relation */
RelNode getLeft();

/* Right RelNode of the Join Relation */
RelNode getRight();
}
Expand Up @@ -104,7 +104,7 @@ public LogicalOperator implement(DrillImplementor implementor) {
* @return
*/
private LogicalOperator implementInput(DrillImplementor implementor, int i, int offset, RelNode input) {
return implementInput(implementor, i, offset, input, this);
return implementInput(implementor, i, offset, input, this, this.getRowType().getFieldNames());
}

/**
Expand All @@ -118,12 +118,12 @@ private LogicalOperator implementInput(DrillImplementor implementor, int i, int
* @return
*/
public static LogicalOperator implementInput(DrillImplementor implementor, int i, int offset,
RelNode input, DrillRel currentNode) {
RelNode input, DrillRel currentNode,
List<String> parentFields) {
final LogicalOperator inputOp = implementor.visitChild(currentNode, i, input);
assert uniqueFieldNames(input.getRowType());
final List<String> fields = currentNode.getRowType().getFieldNames();
final List<String> inputFields = input.getRowType().getFieldNames();
final List<String> outputFields = fields.subList(offset, offset + inputFields.size());
final List<String> outputFields = parentFields.subList(offset, offset + inputFields.size());
if (!outputFields.equals(inputFields)) {
// Ensure that input field names are the same as output field names.
// If there are duplicate field names on left and right, fields will get
Expand Down
Expand Up @@ -28,6 +28,7 @@
import org.apache.drill.common.logical.data.LogicalOperator;
import org.apache.drill.exec.planner.common.DrillLateralJoinRelBase;

import java.util.ArrayList;
import java.util.List;


Expand All @@ -48,12 +49,14 @@ public Correlate copy(RelTraitSet traitSet,

@Override
public LogicalOperator implement(DrillImplementor implementor) {
final List<String> fields = getRowType().getFieldNames();
List<String> fields = new ArrayList<>();
fields.addAll(getInput(0).getRowType().getFieldNames());
fields.addAll(getInput(1).getRowType().getFieldNames());
assert DrillJoinRel.isUnique(fields);
final int leftCount = getInputSize(0);

final LogicalOperator leftOp = DrillJoinRel.implementInput(implementor, 0, 0, left, this);
final LogicalOperator rightOp = DrillJoinRel.implementInput(implementor, 1, leftCount, right, this);
final LogicalOperator leftOp = DrillJoinRel.implementInput(implementor, 0, 0, left, this, fields);
final LogicalOperator rightOp = DrillJoinRel.implementInput(implementor, 1, leftCount, right, this, fields);

return new LateralJoin(leftOp, rightOp);
}
Expand Down
Expand Up @@ -22,6 +22,7 @@
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.type.RelDataType;
Expand All @@ -39,7 +40,6 @@
import static org.apache.calcite.rel.core.RelFactories.DEFAULT_JOIN_FACTORY;
import static org.apache.calcite.rel.core.RelFactories.DEFAULT_MATCH_FACTORY;
import static org.apache.calcite.rel.core.RelFactories.DEFAULT_PROJECT_FACTORY;
import static org.apache.calcite.rel.core.RelFactories.DEFAULT_SEMI_JOIN_FACTORY;
import static org.apache.calcite.rel.core.RelFactories.DEFAULT_SET_OP_FACTORY;
import static org.apache.calcite.rel.core.RelFactories.DEFAULT_SORT_FACTORY;
import static org.apache.calcite.rel.core.RelFactories.DEFAULT_TABLE_SCAN_FACTORY;
Expand All @@ -60,6 +60,17 @@ public class DrillRelFactories {
public static final RelFactories.JoinFactory DRILL_LOGICAL_JOIN_FACTORY = new DrillJoinFactoryImpl();

public static final RelFactories.AggregateFactory DRILL_LOGICAL_AGGREGATE_FACTORY = new DrillAggregateFactoryImpl();

public static final RelFactories.SemiJoinFactory DRILL_SEMI_JOIN_FACTORY = new SemiJoinFactoryImpl();

private static class SemiJoinFactoryImpl implements RelFactories.SemiJoinFactory {
public RelNode createSemiJoin(RelNode left, RelNode right,
RexNode condition) {
final JoinInfo joinInfo = JoinInfo.of(left, right, condition);
return DrillSemiJoinRel.create(left, right,
condition, joinInfo.leftKeys, joinInfo.rightKeys);
}
}
/**
* A {@link RelBuilderFactory} that creates a {@link DrillRelBuilder} that will
* create logical relational expressions for everything.
Expand All @@ -69,7 +80,7 @@ public class DrillRelFactories {
Contexts.of(DEFAULT_PROJECT_FACTORY,
DEFAULT_FILTER_FACTORY,
DEFAULT_JOIN_FACTORY,
DEFAULT_SEMI_JOIN_FACTORY,
DRILL_SEMI_JOIN_FACTORY,
DEFAULT_SORT_FACTORY,
DEFAULT_AGGREGATE_FACTORY,
DEFAULT_MATCH_FACTORY,
Expand Down
@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.planner.logical;

import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.SemiJoin;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Pair;
import org.apache.drill.common.expression.FieldReference;
import org.apache.drill.common.logical.data.Join;
import org.apache.drill.common.logical.data.JoinCondition;
import org.apache.drill.common.logical.data.LogicalOperator;
import org.apache.drill.common.logical.data.LogicalSemiJoin;
import org.apache.drill.shaded.guava.com.google.common.collect.Lists;
import org.apache.drill.shaded.guava.com.google.common.base.Preconditions;

import java.util.ArrayList;
import java.util.List;

public class DrillSemiJoinRel extends SemiJoin implements DrillJoin, DrillRel {

public DrillSemiJoinRel(
RelOptCluster cluster,
RelTraitSet traitSet,
RelNode left,
RelNode right,
RexNode condition,
ImmutableIntList leftKeys,
ImmutableIntList rightKeys) {
super(cluster,
traitSet,
left,
right,
condition,
leftKeys,
rightKeys);
}

public static SemiJoin create(RelNode left, RelNode right, RexNode condition,
ImmutableIntList leftKeys, ImmutableIntList rightKeys) {
final RelOptCluster cluster = left.getCluster();
return new DrillSemiJoinRel(cluster, cluster.traitSetOf(DrillRel.DRILL_LOGICAL), left,
right, condition, leftKeys, rightKeys);
}

@Override
public SemiJoin copy(RelTraitSet traitSet, RexNode condition,
RelNode left, RelNode right, JoinRelType joinType, boolean semiJoinDone) {
Preconditions.checkArgument(joinType == JoinRelType.INNER);
final JoinInfo joinInfo = JoinInfo.of(left, right, condition);
Preconditions.checkArgument(joinInfo.isEqui());
return new DrillSemiJoinRel(getCluster(), traitSet, left, right, condition,
joinInfo.leftKeys, joinInfo.rightKeys);
}

@Override
public LogicalOperator implement(DrillImplementor implementor) {
List<String> fields = new ArrayList<>();
fields.addAll(getInput(0).getRowType().getFieldNames());
fields.addAll(getInput(1).getRowType().getFieldNames());
Preconditions.checkArgument(DrillJoinRel.isUnique(fields));
final int leftCount = left.getRowType().getFieldCount();
final List<String> leftFields = fields.subList(0, leftCount);
final List<String> rightFields = fields.subList(leftCount, leftCount + right.getRowType().getFieldCount());

final LogicalOperator leftOp = DrillJoinRel.implementInput(implementor, 0, 0, left, this, fields);
final LogicalOperator rightOp = DrillJoinRel.implementInput(implementor, 1, leftCount, right, this, fields);

Join.Builder builder = Join.builder();
builder.type(joinType);
builder.left(leftOp);
builder.right(rightOp);
List<JoinCondition> conditions = Lists.newArrayList();
for (Pair<Integer, Integer> pair : Pair.zip(leftKeys, rightKeys)) {
conditions.add(new JoinCondition(DrillJoinRel.EQUALITY_CONDITION,
new FieldReference(leftFields.get(pair.left)), new FieldReference(rightFields.get(pair.right))));
}

return new LogicalSemiJoin(leftOp, rightOp, conditions, joinType);
}
}

0 comments on commit 71809ca

Please sign in to comment.