Skip to content

Commit

Permalink
DRILL-6546: Allow unnest function with nested columns and complex exp…
Browse files Browse the repository at this point in the history
…ressions

Fix loss of projected names in right side of correlate when single field is projected
  • Loading branch information
vvysotskyi authored and arina-ielchiieva committed Jul 4, 2018
1 parent 62aadda commit cacca92
Show file tree
Hide file tree
Showing 10 changed files with 362 additions and 110 deletions.
Expand Up @@ -37,6 +37,7 @@
import org.apache.drill.exec.planner.logical.DrillJoinRule; import org.apache.drill.exec.planner.logical.DrillJoinRule;
import org.apache.drill.exec.planner.logical.DrillLimitRule; import org.apache.drill.exec.planner.logical.DrillLimitRule;
import org.apache.drill.exec.planner.logical.DrillMergeProjectRule; import org.apache.drill.exec.planner.logical.DrillMergeProjectRule;
import org.apache.drill.exec.planner.logical.ProjectComplexRexNodeCorrelateTransposeRule;
import org.apache.drill.exec.planner.logical.DrillProjectLateralJoinTransposeRule; import org.apache.drill.exec.planner.logical.DrillProjectLateralJoinTransposeRule;
import org.apache.drill.exec.planner.logical.DrillProjectPushIntoLateralJoinRule; import org.apache.drill.exec.planner.logical.DrillProjectPushIntoLateralJoinRule;
import org.apache.drill.exec.planner.logical.DrillProjectRule; import org.apache.drill.exec.planner.logical.DrillProjectRule;
Expand Down Expand Up @@ -311,6 +312,8 @@ static RuleSet getDrillUserConfigurableLogicalRules(OptimizerRulesContext optimi
RuleInstance.PROJECT_WINDOW_TRANSPOSE_RULE, RuleInstance.PROJECT_WINDOW_TRANSPOSE_RULE,
DrillPushProjectIntoScanRule.INSTANCE, DrillPushProjectIntoScanRule.INSTANCE,


ProjectComplexRexNodeCorrelateTransposeRule.INSTANCE,

/* /*
Convert from Calcite Logical to Drill Logical Rules. Convert from Calcite Logical to Drill Logical Rules.
*/ */
Expand Down
Expand Up @@ -73,7 +73,7 @@ protected RelDataType deriveRowType() {
return constructRowType(SqlValidatorUtil.deriveJoinRowType(left.getRowType(), return constructRowType(SqlValidatorUtil.deriveJoinRowType(left.getRowType(),
right.getRowType(), joinType.toJoinType(), right.getRowType(), joinType.toJoinType(),
getCluster().getTypeFactory(), null, getCluster().getTypeFactory(), null,
ImmutableList.<RelDataTypeField>of())); ImmutableList.of()));
case ANTI: case ANTI:
case SEMI: case SEMI:
return constructRowType(left.getRowType()); return constructRowType(left.getRowType());
Expand All @@ -82,12 +82,19 @@ protected RelDataType deriveRowType() {
} }
} }


public int getInputSize(int offset, RelNode input) { /**
if (this.excludeCorrelateColumn && * Returns number of fields in {@link RelDataType} for
offset == 0) { * input rel node with specified ordinal considering value of
return input.getRowType().getFieldList().size() - 1; * {@code excludeCorrelateColumn}.
*
* @param ordinal ordinal of input rel node
* @return number of fields in input's {@link RelDataType}
*/
public int getInputSize(int ordinal) {
if (this.excludeCorrelateColumn && ordinal == 0) {
return getInput(ordinal).getRowType().getFieldList().size() - 1;
} }
return input.getRowType().getFieldList().size(); return getInput(ordinal).getRowType().getFieldList().size();
} }


public RelDataType constructRowType(RelDataType inputRowType) { public RelDataType constructRowType(RelDataType inputRowType) {
Expand Down
Expand Up @@ -50,7 +50,7 @@ public Correlate copy(RelTraitSet traitSet,
public LogicalOperator implement(DrillImplementor implementor) { public LogicalOperator implement(DrillImplementor implementor) {
final List<String> fields = getRowType().getFieldNames(); final List<String> fields = getRowType().getFieldNames();
assert DrillJoinRel.isUnique(fields); assert DrillJoinRel.isUnique(fields);
final int leftCount = getInputSize(0,left); final int leftCount = getInputSize(0);


final LogicalOperator leftOp = DrillJoinRel.implementInput(implementor, 0, 0, left, this); final LogicalOperator leftOp = DrillJoinRel.implementInput(implementor, 0, 0, left, this);
final LogicalOperator rightOp = DrillJoinRel.implementInput(implementor, 1, leftCount, right, this); final LogicalOperator rightOp = DrillJoinRel.implementInput(implementor, 1, leftCount, right, this);
Expand Down
Expand Up @@ -24,6 +24,8 @@
import org.apache.calcite.rel.core.Uncollect; import org.apache.calcite.rel.core.Uncollect;
import org.apache.calcite.rel.logical.LogicalProject; import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalValues; import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;


public class DrillUnnestRule extends RelOptRule { public class DrillUnnestRule extends RelOptRule {
public static final RelOptRule INSTANCE = new DrillUnnestRule(); public static final RelOptRule INSTANCE = new DrillUnnestRule();
Expand All @@ -38,11 +40,14 @@ private DrillUnnestRule() {
public void onMatch(RelOptRuleCall call) { public void onMatch(RelOptRuleCall call) {
final Uncollect uncollect = call.rel(0); final Uncollect uncollect = call.rel(0);
final LogicalProject project = call.rel(1); final LogicalProject project = call.rel(1);
final LogicalValues values = call.rel(2);


RexNode projectedNode = project.getProjects().iterator().next();
if (projectedNode.getKind() != SqlKind.FIELD_ACCESS) {
return;
}
final RelTraitSet traits = uncollect.getTraitSet().plus(DrillRel.DRILL_LOGICAL); final RelTraitSet traits = uncollect.getTraitSet().plus(DrillRel.DRILL_LOGICAL);
DrillUnnestRel unnest = new DrillUnnestRel(uncollect.getCluster(), traits, uncollect.getRowType(), DrillUnnestRel unnest = new DrillUnnestRel(uncollect.getCluster(),
project.getProjects().iterator().next()); traits, uncollect.getRowType(), projectedNode);
call.transformTo(unnest); call.transformTo(unnest);
} }
} }
@@ -0,0 +1,154 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.planner.logical;

import com.google.common.collect.ImmutableList;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.Uncollect;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.trace.CalciteTrace;
import org.apache.drill.common.exceptions.UserException;

import java.util.ArrayList;
import java.util.List;

/**
* Rule that moves non-{@link RexFieldAccess} rex node from project below {@link Uncollect}
* to the left side of the {@link Correlate}.
*/
public class ProjectComplexRexNodeCorrelateTransposeRule extends RelOptRule {

public static final RelOptRule INSTANCE = new ProjectComplexRexNodeCorrelateTransposeRule();

public ProjectComplexRexNodeCorrelateTransposeRule() {
super(operand(LogicalCorrelate.class,
operand(RelNode.class, any()),
operand(Uncollect.class, operand(LogicalProject.class, any()))),
DrillRelFactories.LOGICAL_BUILDER,
"ProjectComplexRexNodeCorrelateTransposeRule");
}

@Override
public void onMatch(RelOptRuleCall call) {
final Correlate correlate = call.rel(0);
final Uncollect uncollect = call.rel(2);
final LogicalProject project = call.rel(3);

// uncollect requires project with single expression
RexNode projectedNode = project.getProjects().iterator().next();

// check that the expression is complex call
if (!(projectedNode instanceof RexFieldAccess)) {
RelBuilder builder = call.builder();
RexBuilder rexBuilder = builder.getRexBuilder();

builder.push(correlate.getLeft());

// creates project with complex expr on top of the left side
List<RexNode> leftProjExprs = new ArrayList<>();

String complexFieldName = correlate.getRowType().getFieldNames()
.get(correlate.getRowType().getFieldNames().size() - 1);

List<String> fieldNames = new ArrayList<>();
for (RelDataTypeField field : correlate.getLeft().getRowType().getFieldList()) {
leftProjExprs.add(rexBuilder.makeInputRef(correlate.getLeft(), field.getIndex()));
fieldNames.add(field.getName());
}
fieldNames.add(complexFieldName);
List<RexNode> topProjectExpressions = new ArrayList<>(leftProjExprs);

// adds complex expression with replaced correlation
// to the projected list from the left
leftProjExprs.add(projectedNode.accept(new RexFieldAccessReplacer(builder)));

RelNode leftProject = builder.project(leftProjExprs, fieldNames)
.build();

CorrelationId correlationId = correlate.getCluster().createCorrel();
RexCorrelVariable rexCorrel =
(RexCorrelVariable) rexBuilder.makeCorrel(
leftProject.getRowType(),
correlationId);
builder.push(project.getInput());
RelNode rightProject = builder.project(
ImmutableList.of(rexBuilder.makeFieldAccess(rexCorrel, leftProjExprs.size() - 1)),
ImmutableList.of(complexFieldName))
.build();

int requiredColumnsCount = correlate.getRequiredColumns().cardinality();
if (requiredColumnsCount != 1) {
throw UserException.planError()
.message("Required columns count for Correlate operator " +
"differs from the expected value:\n" +
"Expected columns count is %s, but actual is %s",
1, requiredColumnsCount)
.build(CalciteTrace.getPlannerTracer());
}

RelNode newUncollect = uncollect.copy(uncollect.getTraitSet(), rightProject);
Correlate newCorrelate = correlate.copy(uncollect.getTraitSet(), leftProject, newUncollect,
correlationId, ImmutableBitSet.of(leftProjExprs.size() - 1), correlate.getJoinType());
builder.push(newCorrelate);

switch(correlate.getJoinType()) {
case LEFT:
case INNER:
// adds field from the right input of correlate to the top project
topProjectExpressions.add(
rexBuilder.makeInputRef(newCorrelate, topProjectExpressions.size() + 1));
// fall through
case ANTI:
case SEMI:
builder.project(topProjectExpressions, correlate.getRowType().getFieldNames());
}

call.transformTo(builder.build());
}
}

/**
* Visitor for RexNode which replaces {@link RexFieldAccess}
* with a reference to the field used in {@link RexFieldAccess}.
*/
private static class RexFieldAccessReplacer extends RexShuttle {
private final RelBuilder builder;

public RexFieldAccessReplacer(RelBuilder builder) {
this.builder = builder;
}

@Override
public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
return builder.field(fieldAccess.getField().getName());
}
}
}
Expand Up @@ -88,11 +88,12 @@ private SchemaPath getColumn() {
* Check to make sure that the fields of the inputs are the same as the output field names. * Check to make sure that the fields of the inputs are the same as the output field names.
* If not, insert a project renaming them. * If not, insert a project renaming them.
*/ */
public RelNode getLateralInput(int offset, RelNode input) { public RelNode getLateralInput(int ordinal, RelNode input) {
int offset = ordinal == 0 ? 0 : getInputSize(0);
Preconditions.checkArgument(DrillJoinRelBase.uniqueFieldNames(input.getRowType())); Preconditions.checkArgument(DrillJoinRelBase.uniqueFieldNames(input.getRowType()));
final List<String> fields = getRowType().getFieldNames(); final List<String> fields = getRowType().getFieldNames();
final List<String> inputFields = input.getRowType().getFieldNames(); final List<String> inputFields = input.getRowType().getFieldNames();
final List<String> outputFields = fields.subList(offset, offset + getInputSize(offset, input)); final List<String> outputFields = fields.subList(offset, offset + getInputSize(ordinal));
if (ListUtils.subtract(outputFields, inputFields).size() != 0) { if (ListUtils.subtract(outputFields, inputFields).size() != 0) {
// Ensure that input field names are the same as output field names. // Ensure that input field names are the same as output field names.
// If there are duplicate field names on left and right, fields will get // If there are duplicate field names on left and right, fields will get
Expand Down
Expand Up @@ -19,7 +19,6 @@


import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexNode;
import org.apache.drill.exec.planner.logical.DrillUnnestRel; import org.apache.drill.exec.planner.logical.DrillUnnestRel;
import org.apache.drill.exec.planner.logical.RelOptHelper; import org.apache.drill.exec.planner.logical.RelOptHelper;
Expand All @@ -34,10 +33,6 @@ private UnnestPrule() {
public void onMatch(RelOptRuleCall call) { public void onMatch(RelOptRuleCall call) {
final DrillUnnestRel unnest = call.rel(0); final DrillUnnestRel unnest = call.rel(0);
RexNode ref = unnest.getRef(); RexNode ref = unnest.getRef();
if (ref instanceof RexFieldAccess) {
final RexFieldAccess field = (RexFieldAccess)ref;
field.getField().getName();
}


UnnestPrel unnestPrel = new UnnestPrel(unnest.getCluster(), UnnestPrel unnestPrel = new UnnestPrel(unnest.getCluster(),
unnest.getTraitSet().plus(Prel.DRILL_PHYSICAL), unnest.getRowType(), ref); unnest.getTraitSet().plus(Prel.DRILL_PHYSICAL), unnest.getRowType(), ref);
Expand Down
Expand Up @@ -17,6 +17,7 @@
*/ */
package org.apache.drill.exec.planner.physical.visitor; package org.apache.drill.exec.planner.physical.visitor;


import java.util.ArrayList;
import java.util.List; import java.util.List;


import org.apache.drill.exec.planner.physical.JoinPrel; import org.apache.drill.exec.planner.physical.JoinPrel;
Expand Down Expand Up @@ -75,16 +76,11 @@ public Prel visitJoin(JoinPrel prel, Void value) throws RuntimeException {
public Prel visitLateral(LateralJoinPrel prel, Void value) throws RuntimeException { public Prel visitLateral(LateralJoinPrel prel, Void value) throws RuntimeException {


List<RelNode> children = getChildren(prel); List<RelNode> children = getChildren(prel);
List<RelNode> reNamedChildren = new ArrayList<>();


final int leftCount = prel.getInputSize(0,children.get(0)); for (int i = 0; i < children.size(); i++) {

reNamedChildren.add(prel.getLateralInput(i, children.get(i)));
List<RelNode> reNamedChildren = Lists.newArrayList(); }

RelNode left = prel.getLateralInput(0, children.get(0));
RelNode right = prel.getLateralInput(leftCount, children.get(1));

reNamedChildren.add(left);
reNamedChildren.add(right);


return preparePrel(prel, reNamedChildren); return preparePrel(prel, reNamedChildren);
} }
Expand Down

0 comments on commit cacca92

Please sign in to comment.