Skip to content

Commit

Permalink
DRILL-2858: Refactor hash expression construction in InsertLocalExcha…
Browse files Browse the repository at this point in the history
…ngeVisitor and PrelUtil into one place
  • Loading branch information
vkorukanti committed Apr 27, 2015
1 parent a8c96f6 commit 6878bfd
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 76 deletions.
Expand Up @@ -44,6 +44,7 @@
import org.apache.drill.exec.memory.BufferAllocator; import org.apache.drill.exec.memory.BufferAllocator;
import org.apache.drill.exec.ops.FragmentContext; import org.apache.drill.exec.ops.FragmentContext;
import org.apache.drill.exec.physical.impl.join.JoinUtils; import org.apache.drill.exec.physical.impl.join.JoinUtils;
import org.apache.drill.exec.planner.physical.HashPrelUtil;
import org.apache.drill.exec.planner.physical.PrelUtil; import org.apache.drill.exec.planner.physical.PrelUtil;
import org.apache.drill.exec.record.MaterializedField; import org.apache.drill.exec.record.MaterializedField;
import org.apache.drill.exec.record.RecordBatch; import org.apache.drill.exec.record.RecordBatch;
Expand Down Expand Up @@ -313,7 +314,7 @@ private void setupGetHash(ClassGenerator<HashTable> cg, MappingSet incomingMappi
* aggregate. For join we need to hash everything as double (both for distribution and for comparison) but * aggregate. For join we need to hash everything as double (both for distribution and for comparison) but
* for aggregation we can avoid the penalty of casting to double * for aggregation we can avoid the penalty of casting to double
*/ */
LogicalExpression hashExpression = PrelUtil.getHashExpression(Arrays.asList(keyExprs), LogicalExpression hashExpression = HashPrelUtil.getHashExpression(Arrays.asList(keyExprs),
incomingProbe != null ? true : false); incomingProbe != null ? true : false);
final LogicalExpression materializedExpr = ExpressionTreeMaterializer.materializeAndCheckErrors(hashExpression, batch, context.getFunctionRegistry()); final LogicalExpression materializedExpr = ExpressionTreeMaterializer.materializeAndCheckErrors(hashExpression, batch, context.getFunctionRegistry());
HoldingContainer hash = cg.addExpr(materializedExpr); HoldingContainer hash = cg.addExpr(materializedExpr);
Expand Down
@@ -0,0 +1,138 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.planner.physical;

import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexNode;
import org.apache.drill.common.expression.ExpressionPosition;
import org.apache.drill.common.expression.FieldReference;
import org.apache.drill.common.expression.FunctionCall;
import org.apache.drill.common.expression.LogicalExpression;
import org.apache.drill.exec.planner.physical.DrillDistributionTrait.DistributionField;

import java.util.ArrayList;
import java.util.List;

/**
* Contains utility methods for creating hash expression for either distribution (in PartitionSender) or for HashTable.
*/
public class HashPrelUtil {

public static final String HASH_EXPR_NAME = "E_X_P_R_H_A_S_H_F_I_E_L_D";

/**
* Interface for creating different forms of hash expression types.
* @param <T>
*/
public interface HashExpressionCreatorHelper<T> {
T createCall(String funcName, List<T> inputFiled);
}

/**
* Implementation of {@link HashExpressionCreatorHelper} for {@link LogicalExpression} type.
*/
public static HashExpressionCreatorHelper<LogicalExpression> HASH_HELPER_LOGICALEXPRESSION =
new HashExpressionCreatorHelper<LogicalExpression>() {
@Override
public LogicalExpression createCall(String funcName, List<LogicalExpression> inputFiled) {
return new FunctionCall(funcName, inputFiled, ExpressionPosition.UNKNOWN);
}
};

private static final String HASH64_FUNCTION_NAME = "hash64";
private static final String HASH64_DOUBLE_FUNCTION_NAME = "hash64AsDouble";
private static final String CAST_TO_INT_FUNCTION_NAME = "castInt";

/**
* Create hash based partition expression based on the given distribution fields.
*
* @param distFields Field list based on which the distribution partition expression is constructed.
* @param helper Implementation of {@link HashExpressionCreatorHelper}
* which is used to create function expressions.
* @param <T> Input and output expression type.
* Currently it could be either {@link RexNode} or {@link LogicalExpression}
* @return
*/
public static <T> T createHashBasedPartitionExpression(
List<T> distFields,
HashExpressionCreatorHelper<T> helper) {
return createHashExpression(distFields, helper, true /*for distribution always hash as double*/);
}

/**
* Create hash expression based on the given input fields.
*
* @param inputExprs Expression list based on which the hash expression is constructed.
* @param helper Implementation of {@link HashExpressionCreatorHelper}
* which is used to create function expressions.
* @param hashAsDouble Whether to use the hash as double function or regular hash64 function.
* @param <T> Input and output expression type.
* Currently it could be either {@link RexNode} or {@link LogicalExpression}
* @return
*/
public static <T> T createHashExpression(
List<T> inputExprs,
HashExpressionCreatorHelper<T> helper,
boolean hashAsDouble) {

assert inputExprs.size() > 0;

final String functionName = hashAsDouble ? HASH64_DOUBLE_FUNCTION_NAME : HASH64_FUNCTION_NAME;

T func = helper.createCall(functionName, ImmutableList.of(inputExprs.get(0)));
for (int i = 1; i<inputExprs.size(); i++) {
func = helper.createCall(functionName, ImmutableList.of(inputExprs.get(i), func));
}

return helper.createCall(CAST_TO_INT_FUNCTION_NAME, ImmutableList.of(func));
}

/**
* Return a hash expression : (int) hash(field1, hash(field2, hash(field3, 0)));
*/
public static LogicalExpression getHashExpression(List<LogicalExpression> fields, boolean hashAsDouble){
return createHashExpression(fields, HASH_HELPER_LOGICALEXPRESSION, hashAsDouble);
}


/**
* Create a distribution hash expression.
*
* @param fields Distribution fields
* @param rowType Row type
* @return
*/
public static LogicalExpression getHashExpression(List<DistributionField> fields, RelDataType rowType) {
assert fields.size() > 0;

final List<String> childFields = rowType.getFieldNames();

// If we already included a field with hash - no need to calculate hash further down
if ( childFields.contains(HASH_EXPR_NAME)) {
return new FieldReference(HASH_EXPR_NAME);
}

final List<LogicalExpression> expressions = new ArrayList<LogicalExpression>(childFields.size());
for(int i =0; i < fields.size(); i++){
expressions.add(new FieldReference(childFields.get(fields.get(i).getFieldId()), ExpressionPosition.UNKNOWN));
}

return createHashBasedPartitionExpression(expressions, HASH_HELPER_LOGICALEXPRESSION);
}
}
Expand Up @@ -84,7 +84,7 @@ public PhysicalOperator getPhysicalOperator(PhysicalPlanCreator creator) throws
} }


HashToMergeExchange g = new HashToMergeExchange(childPOP, HashToMergeExchange g = new HashToMergeExchange(childPOP,
PrelUtil.getHashExpression(this.distFields, getInput().getRowType()), HashPrelUtil.getHashExpression(this.distFields, getInput().getRowType()),
PrelUtil.getOrdering(this.collation, getInput().getRowType())); PrelUtil.getOrdering(this.collation, getInput().getRowType()));
return creator.addMetadata(this, g); return creator.addMetadata(this, g);


Expand Down
Expand Up @@ -95,7 +95,7 @@ public PhysicalOperator getPhysicalOperator(PhysicalPlanCreator creator) throws
} }


// TODO - refactor to different exchange name // TODO - refactor to different exchange name
HashToRandomExchange g = new HashToRandomExchange(childPOP, PrelUtil.getHashExpression(this.fields, getInput().getRowType())); HashToRandomExchange g = new HashToRandomExchange(childPOP, HashPrelUtil.getHashExpression(this.fields, getInput().getRowType()));
return creator.addMetadata(this, g); return creator.addMetadata(this, g);
} }


Expand Down
Expand Up @@ -23,8 +23,6 @@
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;



import org.apache.drill.common.expression.CastExpression;
import org.apache.drill.common.expression.ExpressionPosition; import org.apache.drill.common.expression.ExpressionPosition;
import org.apache.drill.common.expression.FieldReference; import org.apache.drill.common.expression.FieldReference;
import org.apache.drill.common.expression.FunctionCall; import org.apache.drill.common.expression.FunctionCall;
Expand All @@ -34,8 +32,6 @@
import org.apache.drill.common.expression.PathSegment.NameSegment; import org.apache.drill.common.expression.PathSegment.NameSegment;
import org.apache.drill.common.expression.SchemaPath; import org.apache.drill.common.expression.SchemaPath;
import org.apache.drill.common.logical.data.Order.Ordering; import org.apache.drill.common.logical.data.Order.Ordering;
import org.apache.drill.common.types.TypeProtos.MinorType;
import org.apache.drill.common.types.Types;
import org.apache.drill.exec.planner.physical.DrillDistributionTrait.DistributionField; import org.apache.drill.exec.planner.physical.DrillDistributionTrait.DistributionField;
import org.apache.drill.exec.record.BatchSchema.SelectionVectorMode; import org.apache.drill.exec.record.BatchSchema.SelectionVectorMode;


Expand Down Expand Up @@ -64,10 +60,6 @@


public class PrelUtil { public class PrelUtil {


public static final String HASH_EXPR_NAME = "E_X_P_R_H_A_S_H_F_I_E_L_D";
private static final String HASH64_FUNCTION_NAME = "hash64";
private static final String HASH64_DOUBLE_FUNCTION_NAME = "hash64AsDouble";

public static List<Ordering> getOrdering(RelCollation collation, RelDataType rowType) { public static List<Ordering> getOrdering(RelCollation collation, RelDataType rowType) {
List<Ordering> orderExpr = Lists.newArrayList(); List<Ordering> orderExpr = Lists.newArrayList();


Expand All @@ -81,40 +73,6 @@ public static List<Ordering> getOrdering(RelCollation collation, RelDataType row
return orderExpr; return orderExpr;
} }


/*
* Return a hash expression : (int) hash(field1, hash(field2, hash(field3, 0)));
*/
public static LogicalExpression getHashExpression(List<LogicalExpression> fields, boolean hashAsDouble){
assert fields.size() > 0;

String functionName = hashAsDouble ? HASH64_DOUBLE_FUNCTION_NAME : HASH64_FUNCTION_NAME;
FunctionCall func = new FunctionCall(functionName, ImmutableList.of(fields.get(0)), ExpressionPosition.UNKNOWN);
for (int i = 1; i<fields.size(); i++) {
func = new FunctionCall(functionName, ImmutableList.of(fields.get(i), func), ExpressionPosition.UNKNOWN);
}

return new CastExpression(func, Types.required(MinorType.INT), ExpressionPosition.UNKNOWN);

}

public static LogicalExpression getHashExpression(List<DistributionField> fields, RelDataType rowType) {
assert fields.size() > 0;

final List<String> childFields = rowType.getFieldNames();

// If we already included a field with hash - no need to calculate hash further down
if ( childFields.contains(HASH_EXPR_NAME)) {
return new FieldReference(HASH_EXPR_NAME);
}

final List<LogicalExpression> expressions = new ArrayList<LogicalExpression>(childFields.size());
for(int i =0; i < fields.size(); i++){
expressions.add(new FieldReference(childFields.get(fields.get(i).getFieldId()), ExpressionPosition.UNKNOWN));
}

// for distribution always hash as double
return getHashExpression(expressions, true);
}


public static Iterator<Prel> iter(RelNode... nodes) { public static Iterator<Prel> iter(RelNode... nodes) {
return (Iterator<Prel>) (Object) Arrays.asList(nodes).iterator(); return (Iterator<Prel>) (Object) Arrays.asList(nodes).iterator();
Expand Down
Expand Up @@ -49,7 +49,7 @@ public PhysicalOperator getPhysicalOperator(PhysicalPlanCreator creator) throws


PhysicalOperator childPOP = child.getPhysicalOperator(creator); PhysicalOperator childPOP = child.getPhysicalOperator(creator);


UnorderedDeMuxExchange p = new UnorderedDeMuxExchange(childPOP, PrelUtil.getHashExpression(this.fields, getInput().getRowType())); UnorderedDeMuxExchange p = new UnorderedDeMuxExchange(childPOP, HashPrelUtil.getHashExpression(this.fields, getInput().getRowType()));
return creator.addMetadata(this, p); return creator.addMetadata(this, p);
} }


Expand Down
Expand Up @@ -21,6 +21,8 @@


import org.apache.drill.common.types.TypeProtos.MajorType; import org.apache.drill.common.types.TypeProtos.MajorType;
import org.apache.drill.exec.planner.physical.ExchangePrel; import org.apache.drill.exec.planner.physical.ExchangePrel;
import org.apache.drill.exec.planner.physical.HashPrelUtil;
import org.apache.drill.exec.planner.physical.HashPrelUtil.HashExpressionCreatorHelper;
import org.apache.drill.exec.planner.physical.HashToRandomExchangePrel; import org.apache.drill.exec.planner.physical.HashToRandomExchangePrel;
import org.apache.drill.exec.planner.physical.PlannerSettings; import org.apache.drill.exec.planner.physical.PlannerSettings;
import org.apache.drill.exec.planner.physical.Prel; import org.apache.drill.exec.planner.physical.Prel;
Expand All @@ -42,13 +44,25 @@
import java.util.List; import java.util.List;


public class InsertLocalExchangeVisitor extends BasePrelVisitor<Prel, Void, RuntimeException> { public class InsertLocalExchangeVisitor extends BasePrelVisitor<Prel, Void, RuntimeException> {
private static final DrillSqlOperator SQL_OP_HASH64_WITH_NO_SEED = new DrillSqlOperator("hash64", 1, MajorType.getDefaultInstance(), true);
private static final DrillSqlOperator SQL_OP_HASH64_WITH_SEED = new DrillSqlOperator("hash64", 2, MajorType.getDefaultInstance(), true);
private static final DrillSqlOperator SQL_OP_CAST_INT = new DrillSqlOperator("castINT", 1, MajorType.getDefaultInstance(), true);

private final boolean isMuxEnabled; private final boolean isMuxEnabled;
private final boolean isDeMuxEnabled; private final boolean isDeMuxEnabled;



public static class RexNodeBasedHashExpressionCreatorHelper implements HashExpressionCreatorHelper<RexNode> {
private final RexBuilder rexBuilder;

public RexNodeBasedHashExpressionCreatorHelper(RexBuilder rexBuilder) {
this.rexBuilder = rexBuilder;
}

@Override
public RexNode createCall(String funcName, List<RexNode> inputFields) {
final DrillSqlOperator op =
new DrillSqlOperator(funcName, inputFields.size(), MajorType.getDefaultInstance(), true);
return rexBuilder.makeCall(op, inputFields);
}
}

public static Prel insertLocalExchanges(Prel prel, OptionManager options) { public static Prel insertLocalExchanges(Prel prel, OptionManager options) {
boolean isMuxEnabled = options.getOption(PlannerSettings.MUX_EXCHANGE.getOptionName()).bool_val; boolean isMuxEnabled = options.getOption(PlannerSettings.MUX_EXCHANGE.getOptionName()).bool_val;
boolean isDeMuxEnabled = options.getOption(PlannerSettings.DEMUX_EXCHANGE.getOptionName()).bool_val; boolean isDeMuxEnabled = options.getOption(PlannerSettings.DEMUX_EXCHANGE.getOptionName()).bool_val;
Expand Down Expand Up @@ -77,39 +91,36 @@ public Prel visitExchange(ExchangePrel prel, Void value) throws RuntimeException


Prel newPrel = child; Prel newPrel = child;


HashToRandomExchangePrel hashPrel = (HashToRandomExchangePrel) prel; final HashToRandomExchangePrel hashPrel = (HashToRandomExchangePrel) prel;
final List<String> childFields = child.getRowType().getFieldNames(); final List<String> childFields = child.getRowType().getFieldNames();
List <RexNode> removeUpdatedExpr = Lists.newArrayList();


if ( isMuxEnabled ) { List <RexNode> removeUpdatedExpr = null;

if (isMuxEnabled) {
// Insert Project Operator with new column that will be a hash for HashToRandomExchange fields // Insert Project Operator with new column that will be a hash for HashToRandomExchange fields
List<DistributionField> fields = hashPrel.getFields(); final List<DistributionField> distFields = hashPrel.getFields();
List<String> outputFieldNames = Lists.newArrayList(childFields); final List<String> outputFieldNames = Lists.newArrayList(childFields);
final RexBuilder rexBuilder = prel.getCluster().getRexBuilder(); final RexBuilder rexBuilder = prel.getCluster().getRexBuilder();
final List<RelDataTypeField> childRowTypeFields = child.getRowType().getFieldList(); final List<RelDataTypeField> childRowTypeFields = child.getRowType().getFieldList();


// First field has no seed argument for hash64 function. final HashExpressionCreatorHelper<RexNode> hashHelper = new RexNodeBasedHashExpressionCreatorHelper(rexBuilder);
final int firstFieldId = fields.get(0).getFieldId(); final List<RexNode> distFieldRefs = Lists.newArrayListWithExpectedSize(distFields.size());
RexNode firstFieldInputRef = rexBuilder.makeInputRef(childRowTypeFields.get(firstFieldId).getType(), firstFieldId); for(int i=0; i<distFields.size(); i++) {
RexNode hashExpr = rexBuilder.makeCall(SQL_OP_HASH64_WITH_NO_SEED, firstFieldInputRef); final int fieldId = distFields.get(i).getFieldId();

distFieldRefs.add(rexBuilder.makeInputRef(childRowTypeFields.get(fieldId).getType(), fieldId));
for (int i=1; i<fields.size(); i++) {
final int fieldId = fields.get(i).getFieldId();
RexNode inputRef = rexBuilder.makeInputRef(childRowTypeFields.get(fieldId).getType(), fieldId);
hashExpr = rexBuilder.makeCall(SQL_OP_HASH64_WITH_SEED, inputRef, hashExpr);
} }


hashExpr = rexBuilder.makeCall(SQL_OP_CAST_INT, hashExpr); final List <RexNode> updatedExpr = Lists.newArrayListWithExpectedSize(childRowTypeFields.size());

removeUpdatedExpr = Lists.newArrayListWithExpectedSize(childRowTypeFields.size());
List <RexNode> updatedExpr = Lists.newArrayList();
for ( RelDataTypeField field : childRowTypeFields) { for ( RelDataTypeField field : childRowTypeFields) {
RexNode rex = rexBuilder.makeInputRef(field.getType(), field.getIndex()); RexNode rex = rexBuilder.makeInputRef(field.getType(), field.getIndex());
updatedExpr.add(rex); updatedExpr.add(rex);
removeUpdatedExpr.add(rex); removeUpdatedExpr.add(rex);
} }
outputFieldNames.add(PrelUtil.HASH_EXPR_NAME);


updatedExpr.add(hashExpr); outputFieldNames.add(HashPrelUtil.HASH_EXPR_NAME);
updatedExpr.add(HashPrelUtil.createHashBasedPartitionExpression(distFieldRefs, hashHelper));

RelDataType rowType = RexUtil.createStructType(prel.getCluster().getTypeFactory(), updatedExpr, outputFieldNames); RelDataType rowType = RexUtil.createStructType(prel.getCluster().getTypeFactory(), updatedExpr, outputFieldNames);


ProjectPrel addColumnprojectPrel = new ProjectPrel(child.getCluster(), child.getTraitSet(), child, updatedExpr, rowType); ProjectPrel addColumnprojectPrel = new ProjectPrel(child.getCluster(), child.getTraitSet(), child, updatedExpr, rowType);
Expand Down

0 comments on commit 6878bfd

Please sign in to comment.