Skip to content

Commit

Permalink
[CALCITE-4057] Implement trait propagation for EnumerableBatchNestedL…
Browse files Browse the repository at this point in the history
…oopJoin (Rui Wang).
  • Loading branch information
amaliujia committed Jun 12, 2020
1 parent bb22d47 commit 87fcf30
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.ParameterExpression;
import org.apache.calcite.linq4j.tree.Primitive;
import org.apache.calcite.plan.DeriveMode;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
Expand All @@ -37,6 +38,7 @@
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.util.BuiltInMethod;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;

import com.google.common.collect.ImmutableList;

Expand Down Expand Up @@ -88,6 +90,27 @@ public static EnumerableBatchNestedLoopJoin create(
joinType);
}

@Override public Pair<RelTraitSet, List<RelTraitSet>> passThroughTraits(
final RelTraitSet required) {
return EnumerableTraitsUtils.passThroughTraitsForJoin(
required, joinType, getLeft().getRowType().getFieldCount(), traitSet);
}

@Override public Pair<RelTraitSet, List<RelTraitSet>> deriveTraits(
final RelTraitSet childTraits, final int childId) {
return EnumerableTraitsUtils.deriveTraitsForJoin(
childTraits, childId, joinType, traitSet, right.getTraitSet()
);
}

@Override public DeriveMode getDeriveMode() {
if (joinType == JoinRelType.FULL || joinType == JoinRelType.RIGHT) {
return DeriveMode.PROHIBITED;
}

return DeriveMode.LEFT_FIRST;
}

@Override public EnumerableBatchNestedLoopJoin copy(RelTraitSet traitSet,
RexNode condition, RelNode left, RelNode right, JoinRelType joinType,
boolean semiJoinDone) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ public Result implement(EnumerableRelImplementor implementor, Prefer pref) {
final List<RexNode> exps = Lists.transform(program.getProjectList(),
program::expandLocalRef);

return EnumTraitsUtils.passThroughTraitsForProject(required, exps,
return EnumerableTraitsUtils.passThroughTraitsForProject(required, exps,
input.getRowType(), input.getCluster().getTypeFactory(), traitSet);
}

Expand All @@ -281,7 +281,7 @@ public Result implement(EnumerableRelImplementor implementor, Prefer pref) {
final List<RexNode> exps = Lists.transform(program.getProjectList(),
program::expandLocalRef);

return EnumTraitsUtils.deriveTraitsForProject(childTraits, childId, exps,
return EnumerableTraitsUtils.deriveTraitsForProject(childTraits, childId, exps,
input.getRowType(), input.getCluster().getTypeFactory(), traitSet);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@
import org.apache.calcite.plan.DeriveMode;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollationTraitDef;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.CorrelationId;
Expand Down Expand Up @@ -89,41 +86,18 @@ public static EnumerableCorrelate create(

@Override public Pair<RelTraitSet, List<RelTraitSet>> passThroughTraits(
final RelTraitSet required) {
final RelCollation collation = required.getCollation();
if (collation == null || collation == RelCollations.EMPTY) {
return null;
}

// EnumerableCorrelate traits passdown shall only pass through collation to left input.
// This is because for EnumerableCorrelate always uses left input as the outer loop,
// thus only left input can preserve ordering.

for (RelFieldCollation relFieldCollation : collation.getFieldCollations()) {
// If field collation belongs to right input: bail out.
if (relFieldCollation.getFieldIndex() >= getLeft().getRowType().getFieldCount()) {
return null;
}
}

final RelTraitSet passThroughTraitSet = traitSet.replace(collation);
return Pair.of(passThroughTraitSet,
ImmutableList.of(
passThroughTraitSet,
passThroughTraitSet.replace(RelCollations.EMPTY)));
return EnumerableTraitsUtils.passThroughTraitsForJoin(
required, joinType, left.getRowType().getFieldCount(), getTraitSet());
}

@Override public Pair<RelTraitSet, List<RelTraitSet>> deriveTraits(
final RelTraitSet childTraits, final int childId) {
// should only derive traits (limited to collation for now) from left input.
assert childId == 0;

final RelCollation collation = childTraits.getCollation();
if (collation == null || collation == RelCollations.EMPTY) {
return null;
}

final RelTraitSet traits = traitSet.replace(collation);
return Pair.of(traits, ImmutableList.of(traits, right.getTraitSet()));
return EnumerableTraitsUtils.deriveTraitsForJoin(
childTraits, childId, joinType, traitSet, right.getTraitSet());
}

@Override public DeriveMode getDeriveMode() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollationTraitDef;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelNodes;
import org.apache.calcite.rel.core.CorrelationId;
Expand Down Expand Up @@ -108,42 +105,15 @@ public static EnumerableHashJoin create(

@Override public Pair<RelTraitSet, List<RelTraitSet>> passThroughTraits(
final RelTraitSet required) {
RelCollation collation = required.getCollation();
if (collation == null
|| collation == RelCollations.EMPTY
|| joinType == JoinRelType.FULL
|| joinType == JoinRelType.RIGHT) {
return null;
}

for (RelFieldCollation fc : collation.getFieldCollations()) {
// If field collation belongs to right input: cannot push down collation.
if (fc.getFieldIndex() >= getLeft().getRowType().getFieldCount()) {
return null;
}
}

RelTraitSet passthroughTraitSet = traitSet.replace(collation);
return Pair.of(passthroughTraitSet,
ImmutableList.of(
passthroughTraitSet,
passthroughTraitSet.replace(RelCollations.EMPTY)));
return EnumerableTraitsUtils.passThroughTraitsForJoin(
required, joinType, left.getRowType().getFieldCount(), getTraitSet());
}

@Override public Pair<RelTraitSet, List<RelTraitSet>> deriveTraits(
final RelTraitSet childTraits, final int childId) {
// should only derive traits (limited to collation for now) from left join input.
assert childId == 0;

RelCollation collation = childTraits.getCollation();
if (collation == null || collation == RelCollations.EMPTY) {
return null;
}

RelTraitSet derivedTraits = getTraitSet().replace(collation);
return Pair.of(
derivedTraits,
ImmutableList.of(derivedTraits, right.getTraitSet()));
return EnumerableTraitsUtils.deriveTraitsForJoin(
childTraits, childId, joinType, getTraitSet(), right.getTraitSet());
}

@Override public DeriveMode getDeriveMode() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollationTraitDef;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelNodes;
import org.apache.calcite.rel.core.CorrelationId;
Expand Down Expand Up @@ -126,50 +123,21 @@ public static EnumerableNestedLoopJoin create(

@Override public Pair<RelTraitSet, List<RelTraitSet>> passThroughTraits(
final RelTraitSet required) {
RelCollation collation = required.getCollation();
if (collation == null
|| collation == RelCollations.EMPTY
|| joinType == JoinRelType.FULL
|| joinType == JoinRelType.RIGHT) {
return null;
}

// EnumerableNestedLoopJoin traits passdown shall only pass through collation to left input.
// It is because for EnumerableNestedLoopJoin always uses left input as the outer loop,
// thus only left input can preserve ordering.
// EnumerableNestedLoopJoin traits passdown shall only pass through collation to
// left input. It is because for EnumerableNestedLoopJoin always
// uses left input as the outer loop, thus only left input can preserve ordering.
// Push sort both to left and right inputs does not help right outer join. It's because in
// implementation, EnumerableNestedLoopJoin produces (null, right_unmatched) all together,
// which does not preserve ordering from right side.


for (RelFieldCollation fc : collation.getFieldCollations()) {
// If field collation belongs to right input: cannot push down collation.
if (fc.getFieldIndex() >= getLeft().getRowType().getFieldCount()) {
return null;
}
}

RelTraitSet passthroughTraitSet = traitSet.replace(collation);
return Pair.of(passthroughTraitSet,
ImmutableList.of(
passthroughTraitSet,
passthroughTraitSet.replace(RelCollations.EMPTY)));
return EnumerableTraitsUtils.passThroughTraitsForJoin(
required, joinType, getLeft().getRowType().getFieldCount(), traitSet);
}

@Override public Pair<RelTraitSet, List<RelTraitSet>> deriveTraits(
final RelTraitSet childTraits, final int childId) {
// should only derive traits (limited to collation for now) from left join input.
assert childId == 0;

RelCollation collation = childTraits.getCollation();
if (collation == null || collation == RelCollations.EMPTY) {
return null;
}

RelTraitSet derivedTraits = getTraitSet().replace(collation);
return Pair.of(
derivedTraits,
ImmutableList.of(derivedTraits, right.getTraitSet()));
return EnumerableTraitsUtils.deriveTraitsForJoin(
childTraits, childId, joinType, traitSet, right.getTraitSet()
);
}

@Override public DeriveMode getDeriveMode() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ public Result implement(EnumerableRelImplementor implementor, Prefer pref) {

@Override public Pair<RelTraitSet, List<RelTraitSet>> passThroughTraits(
RelTraitSet required) {
return EnumTraitsUtils.passThroughTraitsForProject(required, exps,
return EnumerableTraitsUtils.passThroughTraitsForProject(required, exps,
input.getRowType(), input.getCluster().getTypeFactory(), traitSet);
}

@Override public Pair<RelTraitSet, List<RelTraitSet>> deriveTraits(
final RelTraitSet childTraits, final int childId) {
return EnumTraitsUtils.deriveTraitsForProject(childTraits, childId, exps,
return EnumerableTraitsUtils.deriveTraitsForProject(childTraits, childId, exps,
input.getRowType(), input.getCluster().getTypeFactory(), traitSet);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexCall;
Expand All @@ -47,9 +48,9 @@
* Utilities for traits propagation.
*/
@API(since = "1.24", status = API.Status.INTERNAL)
class EnumTraitsUtils {
class EnumerableTraitsUtils {

private EnumTraitsUtils() {}
private EnumerableTraitsUtils() {}

/**
* Determine whether there is mapping between project input and output fields.
Expand Down Expand Up @@ -148,4 +149,53 @@ static Pair<RelTraitSet, List<RelTraitSet>> deriveTraitsForProject(
return null;
}
}

// This function can be reused when a Join's traits pass-down shall only
// pass through collation to left input.
static Pair<RelTraitSet, List<RelTraitSet>> passThroughTraitsForJoin(
RelTraitSet required, JoinRelType joinType,
int leftInputFieldCount, RelTraitSet joinTraitSet) {
RelCollation collation = required.getCollation();
if (collation == null
|| collation == RelCollations.EMPTY
|| joinType == JoinRelType.FULL
|| joinType == JoinRelType.RIGHT) {
return null;
}

for (RelFieldCollation fc : collation.getFieldCollations()) {
// If field collation belongs to right input: cannot push down collation.
if (fc.getFieldIndex() >= leftInputFieldCount) {
return null;
}
}

RelTraitSet passthroughTraitSet = joinTraitSet.replace(collation);
return Pair.of(passthroughTraitSet,
ImmutableList.of(
passthroughTraitSet,
passthroughTraitSet.replace(RelCollations.EMPTY)));
}

// This function can be reused when a Join's traits derivation shall only
// derive collation from left input.
static Pair<RelTraitSet, List<RelTraitSet>> deriveTraitsForJoin(
RelTraitSet childTraits, int childId, JoinRelType joinType,
RelTraitSet joinTraitSet, RelTraitSet rightTraitSet) {
// should only derive traits (limited to collation for now) from left join input.
assert childId == 0;

RelCollation collation = childTraits.getCollation();
if (collation == null
|| collation == RelCollations.EMPTY
|| joinType == JoinRelType.FULL
|| joinType == JoinRelType.RIGHT) {
return null;
}

RelTraitSet derivedTraits = joinTraitSet.replace(collation);
return Pair.of(
derivedTraits,
ImmutableList.of(derivedTraits, rightTraitSet));
}
}

0 comments on commit 87fcf30

Please sign in to comment.