Skip to content
Permalink
Browse files
DRILL-8211: Replace deprecated RelNode.getChildExps with Project.getP…
…rojects (#2535)
  • Loading branch information
vvysotskyi committed May 12, 2022
1 parent 3dc4f25 commit c95777439c79f4cd936cefa1ee683d83878e40c9
Showing 15 changed files with 46 additions and 68 deletions.
@@ -17,7 +17,6 @@
*/
package org.apache.drill.exec.store.mapr.db;

import org.apache.drill.shaded.guava.com.google.common.collect.Lists;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelTrait;
@@ -38,6 +37,7 @@
import org.apache.drill.exec.util.Utilities;

import java.util.List;
import java.util.stream.Collectors;

/**
* Push a physical Project into Scan. Currently, this rule is only doing projection pushdown for MapRDB-JSON tables
@@ -46,7 +46,6 @@
* planning phase.
*/
public abstract class MapRDBPushProjectIntoScan extends StoragePluginOptimizerRule {
static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(MapRDBPushProjectIntoScan.class);

private MapRDBPushProjectIntoScan(RelOptRuleOperand operand, String description) {
super(operand, description);
@@ -99,10 +98,9 @@ protected void doPushProjectIntoGroupScan(RelOptRuleCall call,
groupScan.clone(columnInfo.getFields()),
columnInfo.createNewRowType(project.getInput().getCluster().getTypeFactory()), scan.getTable());

List<RexNode> newProjects = Lists.newArrayList();
for (RexNode n : project.getChildExps()) {
newProjects.add(n.accept(columnInfo.getInputReWriter()));
}
List<RexNode> newProjects = project.getProjects().stream()
.map(n -> n.accept(columnInfo.getInputReWriter()))
.collect(Collectors.toList());

final ProjectPrel newProj =
new ProjectPrel(project.getCluster(),
@@ -65,15 +65,15 @@ public RelNode convert(RelNode relNode) {

// check for literals only without input exprs
DrillRelOptUtil.InputRefVisitor collectRefs = new DrillRelOptUtil.InputRefVisitor();
project.getChildExps().forEach(exp -> exp.accept(collectRefs));
project.getProjects().forEach(exp -> exp.accept(collectRefs));

if (!collectRefs.getInputRefs().isEmpty()) {
for (RelDataTypeField relDataTypeField : rowType.getFieldList()) {
innerProjections.add(project.getCluster().getRexBuilder().makeInputRef(project.getInput(), relDataTypeField.getIndex()));
}
}

boolean allExprsInputRefs = project.getChildExps().stream().allMatch(rexNode -> rexNode instanceof RexInputRef);
boolean allExprsInputRefs = project.getProjects().stream().allMatch(rexNode -> rexNode instanceof RexInputRef);
if (collectRefs.getInputRefs().isEmpty() || allExprsInputRefs) {
return CalciteUtils.createProject(traitSet,
convert(project.getInput(), out), project.getProjects(), project.getRowType());
@@ -444,12 +444,12 @@ public RexNode go(RexNode rex) {
}
}

@SuppressWarnings("deprecation")
public static boolean isProjectFlatten(RelNode project) {
public static boolean isProjectFlatten(RelNode relNode) {

assert project instanceof Project : "Rel is NOT an instance of project!";
assert relNode instanceof Project : "Rel is NOT an instance of Project";

for (RexNode rex : project.getChildExps()) {
Project project = (Project) relNode;
for (RexNode rex : project.getProjects()) {
if (rex instanceof RexCall) {
RexCall function = (RexCall) rex;
String functionName = function.getOperator().getName();
@@ -39,6 +39,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static org.apache.drill.exec.planner.logical.FieldsReWriterUtil.DesiredField;
import static org.apache.drill.exec.planner.logical.FieldsReWriterUtil.FieldsReWriter;
@@ -99,10 +100,9 @@ public void onMatch(RelOptRuleCall call) {
// re-write projects
Map<RexNode, Integer> fieldMapper = createFieldMapper(itemStarFields.values(), scanRel.getRowType().getFieldCount());
FieldsReWriter fieldsReWriter = new FieldsReWriter(fieldMapper);
List<RexNode> newProjects = new ArrayList<>();
for (RexNode node : projectRel.getChildExps()) {
newProjects.add(node.accept(fieldsReWriter));
}
List<RexNode> newProjects = projectRel.getProjects().stream()
.map(node -> node.accept(fieldsReWriter))
.collect(Collectors.toList());

DrillProjectRel newProject = new DrillProjectRel(
projectRel.getCluster(),
@@ -46,7 +46,7 @@
*/
public class DrillMergeProjectRule extends RelOptRule {

private FunctionImplementationRegistry functionRegistry;
private final FunctionImplementationRegistry functionRegistry;
private final boolean force;

public static DrillMergeProjectRule getInstance(boolean force, ProjectFactory pFactory,
@@ -71,15 +71,11 @@ public boolean matches(RelOptRuleCall call) {
Project bottomProject = call.rel(1);

// We have a complex output type do not fire the merge project rule
if (checkComplexOutput(topProject) || checkComplexOutput(bottomProject)) {
return false;
}

return true;
return !checkComplexOutput(topProject) && !checkComplexOutput(bottomProject);
}

private boolean checkComplexOutput(Project project) {
for (RexNode expr: project.getChildExps()) {
for (RexNode expr: project.getProjects()) {
if (expr instanceof RexCall) {
if (functionRegistry.isFunctionComplexOutput(((RexCall) expr).getOperator().getName())) {
return true;
@@ -54,7 +54,7 @@ public void onMatch(RelOptRuleCall call) {
return;
}
DrillRelOptUtil.InputRefVisitor collectRefs = new DrillRelOptUtil.InputRefVisitor();
for (RexNode exp: origProj.getChildExps()) {
for (RexNode exp: origProj.getProjects()) {
exp.accept(collectRefs);
}

@@ -78,7 +78,7 @@ public void onMatch(RelOptRuleCall call) {

if (!trivial) {
Map<Integer, Integer> mapWithoutCorr = buildMapWithoutCorrColumn(corr, correlationIndex);
List<RexNode> outputExprs = DrillRelOptUtil.transformExprs(origProj.getCluster().getRexBuilder(), origProj.getChildExps(), mapWithoutCorr);
List<RexNode> outputExprs = DrillRelOptUtil.transformExprs(origProj.getCluster().getRexBuilder(), origProj.getProjects(), mapWithoutCorr);

relNode = new DrillProjectRel(origProj.getCluster(),
left.getTraitSet().plus(DrillRel.DRILL_LOGICAL),
@@ -89,14 +89,12 @@ public void onMatch(RelOptRuleCall call) {

private Map<Integer, Integer> buildMapWithoutCorrColumn(RelNode corr, int correlationIndex) {
int index = 0;
Map<Integer, Integer> result = new HashMap();
for (int i=0;i<corr.getRowType().getFieldList().size();i++) {
if (i == correlationIndex) {
continue;
} else {
Map<Integer, Integer> result = new HashMap<>();
for (int i = 0; i < corr.getRowType().getFieldList().size(); i++) {
if (i != correlationIndex) {
result.put(i, index++);
}
}
return result;
}
}
}
@@ -36,8 +36,8 @@
import org.apache.drill.exec.util.Utilities;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

/**
* When table support project push down, rule can be applied to reduce number of read columns
@@ -110,10 +110,9 @@ public void onMatch(RelOptRuleCall call) {

TableScan newScan = createScan(scan, projectPushInfo);

List<RexNode> newProjects = new ArrayList<>();
for (RexNode n : project.getChildExps()) {
newProjects.add(n.accept(projectPushInfo.getInputReWriter()));
}
List<RexNode> newProjects = project.getProjects().stream()
.map(n -> n.accept(projectPushInfo.getInputReWriter()))
.collect(Collectors.toList());

Project newProject =
createProject(project, newScan, newProjects);
@@ -463,8 +463,9 @@ private static boolean isRowKeyColumn(int index, RelNode rel) {
}
// If no exprs present in projection the column index remains the same in the child.
// Otherwise, the column index is the `RexInputRef` index.
if (curRel != null && curRel instanceof DrillProjectRel) {
List<RexNode> childExprs = curRel.getChildExps();
if (curRel instanceof DrillProjectRel) {
DrillProjectRel projectRel = (DrillProjectRel) curRel;
List<RexNode> childExprs = projectRel.getProjects();
if (childExprs != null && childExprs.size() > 0) {
if (childExprs.get(curIndex) instanceof RexInputRef) {
curIndex = ((RexInputRef) childExprs.get(curIndex)).getIndex();
@@ -58,9 +58,9 @@ public class PreProcessLogicalRel extends RelShuttleImpl {

private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(PreProcessLogicalRel.class);

private RelDataTypeFactory factory;
private DrillOperatorTable table;
private UnsupportedOperatorCollector unsupportedOperatorCollector;
private final RelDataTypeFactory factory;
private final DrillOperatorTable table;
private final UnsupportedOperatorCollector unsupportedOperatorCollector;
private final UnwrappingExpressionVisitor unwrappingExpressionVisitor;

public static PreProcessLogicalRel createVisitor(RelDataTypeFactory factory, DrillOperatorTable table, RexBuilder rexBuilder) {
@@ -78,7 +78,7 @@ private PreProcessLogicalRel(RelDataTypeFactory factory, DrillOperatorTable tabl
@Override
public RelNode visit(LogicalProject project) {
final List<RexNode> projExpr = Lists.newArrayList();
for(RexNode rexNode : project.getChildExps()) {
for(RexNode rexNode : project.getProjects()) {
projExpr.add(rexNode.accept(unwrappingExpressionVisitor));
}

@@ -90,7 +90,7 @@ public RelNode visit(LogicalProject project) {
List<RexNode> exprList = new ArrayList<>();
boolean rewrite = false;

for (RexNode rex : project.getChildExps()) {
for (RexNode rex : project.getProjects()) {
RexNode newExpr = rex;
if (rex instanceof RexCall) {
RexCall function = (RexCall) rex;
@@ -28,7 +28,6 @@
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.runtime.FlatLists;
import org.apache.calcite.sql.SqlExplainLevel;
import org.apache.calcite.util.Pair;
@@ -209,7 +208,6 @@ public RelWriter itemIf(String term, Object value, boolean condition) {
return this;
}

@SuppressWarnings("deprecation")
public RelWriter done(RelNode node) {
int i = 0;
if (values.size() > 0 && values.get(0).left.equals("subset")) {
@@ -219,11 +217,7 @@ public RelWriter done(RelNode node) {
assert values.get(i).right == input;
++i;
}
for (RexNode expr : node.getChildExps()) {
assert values.get(i).right == expr;
++i;
}
final List<Pair<String, Object>> valuesCopy =
List<Pair<String, Object>> valuesCopy =
ImmutableList.copyOf(values);
values.clear();
explain_(node, valuesCopy);
@@ -70,7 +70,7 @@ public Prel visitProject(ProjectPrel project, Object unused) throws RelConversio
List<RelDataTypeField> relDataTypes = new ArrayList<>();
int i = 0;
RexNode flatttenExpr = null;
for (RexNode rex : project.getChildExps()) {
for (RexNode rex : project.getProjects()) {
RexNode newExpr = rex;
if (rex instanceof RexCall) {
RexCall function = (RexCall) rex;
@@ -99,7 +99,7 @@ public Prel visitProject(ProjectPrel project, Object unused) throws RelConversio
}

Prel child = ((Prel)project.getInput()).accept(this, null);
if (child == project.getInput() && exprList.equals(project.getChildExps())) {
if (child == project.getInput() && exprList.equals(project.getProjects())) {
return project;
}
return (Prel) project.copy(project.getTraitSet(), child, exprList, new RelRecordType(relDataTypes));
@@ -87,15 +87,15 @@ public Prel visitProject(final ProjectPrel project, Object unused) throws RelCon
final int lastRexInput = lastColumnReferenced + 1;
RexVisitorComplexExprSplitter exprSplitter = new RexVisitorComplexExprSplitter(funcReg, rexBuilder, lastRexInput);
int i = 0;
for (RexNode rex : newProject.getChildExps()) {
for (RexNode rex : newProject.getProjects()) {
RelDataTypeField originField = projectFields.get(i++);
RexNode splitRex = rex.accept(exprSplitter);
origRelDataTypes.add(originField);
exprList.add(splitRex);
}

final List<RexNode> complexExprs = exprSplitter.getComplexExprs();
if (complexExprs.size() == 1 && findTopComplexFunc(newProject.getChildExps()).size() == 1) {
if (complexExprs.size() == 1 && findTopComplexFunc(newProject.getProjects()).size() == 1) {
return newProject;
}

@@ -132,8 +132,7 @@ public Prel visitProject(final ProjectPrel project, Object unused) throws RelCon
relDataTypes.add(new RelDataTypeFieldImpl(getExprName(exprIndex), allExprs.size(), factory.createSqlType(SqlTypeName.ANY)));

RelRecordType childProjectType = new RelRecordType(relDataTypes);
ProjectPrel childProject = new ProjectPrel(newProject.getCluster(), newProject.getTraitSet(), newInput, ImmutableList.copyOf(allExprs), childProjectType);
newInput = childProject;
newInput = new ProjectPrel(newProject.getCluster(), newProject.getTraitSet(), newInput, ImmutableList.copyOf(allExprs), childProjectType);
}

allExprs.set(allExprs.size() - 1,
@@ -118,7 +118,7 @@ public RelNode visit(Uncollect uncollect) {

Project project = (Project) uncollect.getInput();
// If project below uncollect contains only field references, no need to rewrite it
List<RexNode> projectChildExps = project.getChildExps();
List<RexNode> projectChildExps = project.getProjects();
assert projectChildExps.size() == 1 : "Uncollect does not support multiple expressions";

RexNode projectExpr = projectChildExps.iterator().next();
@@ -78,7 +78,7 @@ public boolean matches(RelOptRuleCall call) {
try {

final LogicalProject project = call.rel(0);
for (RexNode node : project.getChildExps()) {
for (RexNode node : project.getProjects()) {
if (!checkedExpressions.get(node)) {
return false;
}
@@ -108,15 +108,8 @@ public RelNode convert(RelNode rel) {
@Override
public boolean matches(RelOptRuleCall call) {
try {

final LogicalFilter filter = call.rel(0);
for (RexNode node : filter.getChildExps()) {
if (!checkedExpressions.get(node)) {
return false;
}
}
return true;

LogicalFilter filter = call.rel(0);
return checkedExpressions.get(filter.getCondition());
} catch (ExecutionException e) {
throw new IllegalStateException("Failure while trying to evaluate push down.", e);
}
@@ -71,7 +71,7 @@ public RelNode convert(RelNode rel) {
return null;
}

List<RexNode> newProjects = project.getChildExps().stream()
List<RexNode> newProjects = project.getProjects().stream()
.map(n -> n.accept(projectPushInfo.getInputReWriter()))
.collect(Collectors.toList());

0 comments on commit c957774

Please sign in to comment.