Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,23 @@

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;

import org.apache.calcite.util.Pair;
import org.apache.commons.collections4.ListValuedMap;
import org.apache.commons.collections4.multimap.ArrayListValuedHashMap;
import org.apache.hadoop.hive.ql.exec.ColumnInfo;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorFactory;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.RowSchema;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.optimizer.graph.OperatorGraph;
Expand Down Expand Up @@ -104,7 +111,7 @@ private String sig(Pair<Operator<?>, Operator<?>> o1) {
}
}

private void fixParallelEdges(OperatorGraph og) {
private void fixParallelEdges(OperatorGraph og) throws SemanticException {

// Identify edge operators
ListValuedMap<Pair<Cluster, Cluster>, Pair<Operator<?>, Operator<?>>> edgeOperators =
Expand Down Expand Up @@ -157,7 +164,13 @@ private void removeOneEdge(List<Pair<Operator<?>, Operator<?>>> values) {
values.remove(toKeep);
}

private boolean isParallelEdgeSupported(Pair<Operator<?>, Operator<?>> pair) {
public boolean isParallelEdgeSupported(Pair<Operator<?>, Operator<?>> pair) {

Operator<?> rs = pair.left;
if (rs instanceof ReduceSinkOperator && !colMappingInverseKeys((ReduceSinkOperator) rs).isPresent()) {
return false;
}

Operator<?> child = pair.right;
if (child instanceof MapJoinOperator) {
return true;
Expand All @@ -171,7 +184,7 @@ private boolean isParallelEdgeSupported(Pair<Operator<?>, Operator<?>> pair) {
/**
* Fixes a parallel edge going into a mapjoin by introducing a concentrator RS.
*/
private void fixParallelEdge(Operator<? extends OperatorDesc> p, Operator<?> o) {
private void fixParallelEdge(Operator<? extends OperatorDesc> p, Operator<?> o) throws SemanticException {
LOG.info("Fixing parallel by adding a concentrator RS between {} -> {}", p, o);

ReduceSinkDesc conf = (ReduceSinkDesc) p.getConf();
Expand Down Expand Up @@ -199,16 +212,16 @@ private void fixParallelEdge(Operator<? extends OperatorDesc> p, Operator<?> o)

}

private Operator<SelectDesc> buildSEL(Operator<? extends OperatorDesc> p, ReduceSinkDesc conf) {
private Operator<SelectDesc> buildSEL(Operator<? extends OperatorDesc> p, ReduceSinkDesc conf)
throws SemanticException {
List<ExprNodeDesc> colList = new ArrayList<>();
List<String> outputColumnNames = new ArrayList<>();
List<ColumnInfo> newColumns = new ArrayList<>();

for (Entry<String, ExprNodeDesc> e : conf.getColumnExprMap().entrySet()) {

String colName = e.getKey();
ExprNodeDesc expr = e.getValue();
Set<String> inverseKeys = colMappingInverseKeys((ReduceSinkOperator) p).get();
for (String colName : inverseKeys) {

ExprNodeDesc expr = conf.getColumnExprMap().get(colName);
ExprNodeDesc colRef = new ExprNodeColumnDesc(expr.getTypeInfo(), colName, colName, false);

colList.add(colRef);
Expand All @@ -227,7 +240,7 @@ private Operator<SelectDesc> buildSEL(Operator<? extends OperatorDesc> p, Reduce
return newSEL;
}

private String extractColumnName(ExprNodeDesc expr) {
private static String extractColumnName(ExprNodeDesc expr) throws SemanticException {
if (expr instanceof ExprNodeColumnDesc) {
ExprNodeColumnDesc exprNodeColumnDesc = (ExprNodeColumnDesc) expr;
return exprNodeColumnDesc.getColumn();
Expand All @@ -237,6 +250,20 @@ private String extractColumnName(ExprNodeDesc expr) {
ExprNodeConstantDesc exprNodeConstantDesc = (ExprNodeConstantDesc) expr;
return exprNodeConstantDesc.getFoldedFromCol();
}
throw new RuntimeException("unexpected mapping expression!");
throw new SemanticException("unexpected mapping expression!");
}

public static Optional<Set<String>> colMappingInverseKeys(ReduceSinkOperator rs) {
Map<String, String> ret = new HashMap<String, String>();
Map<String, ExprNodeDesc> exprMap = rs.getColumnExprMap();
try {
for (Entry<String, ExprNodeDesc> e : exprMap.entrySet()) {
ret.put(extractColumnName(e.getValue()), e.getKey());
}
return Optional.of(new TreeSet<>(ret.values()));
} catch (SemanticException e) {
return Optional.empty();
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.apache.hadoop.hive.ql.optimizer.graph.OperatorGraph.Cluster;
import org.apache.hadoop.hive.ql.optimizer.graph.OperatorGraph.EdgeType;
import org.apache.hadoop.hive.ql.optimizer.graph.OperatorGraph.OpEdge;
import org.apache.hadoop.hive.ql.optimizer.graph.OperatorGraph.OperatorEdgePredicate;
import org.apache.hadoop.hive.ql.parse.GenTezUtils;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.PrunedPartitionList;
Expand All @@ -81,7 +82,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.Function;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -1805,11 +1805,11 @@ private static boolean validPreConditions(ParseContext pctx, SharedWorkOptimizer
// If we do, we cannot merge. The reason is that Tez currently does
// not support parallel edges, i.e., multiple edges from same work x
// into same work y.
EdgePredicate edgePredicate;
RelaxedVertexEdgePredicate edgePredicate;
if (pctx.getConf().getBoolVar(ConfVars.HIVE_SHARED_WORK_PARALLEL_EDGE_SUPPORT)) {
edgePredicate = new EdgePredicate(EnumSet.<EdgeType> of(EdgeType.DPP, EdgeType.SEMIJOIN, EdgeType.BROADCAST));
edgePredicate = new RelaxedVertexEdgePredicate(EnumSet.<EdgeType> of(EdgeType.DPP, EdgeType.SEMIJOIN, EdgeType.BROADCAST));
} else {
edgePredicate = new EdgePredicate(EnumSet.<EdgeType> of(EdgeType.DPP));
edgePredicate = new RelaxedVertexEdgePredicate(EnumSet.<EdgeType> of(EdgeType.DPP));
}

OperatorGraph og = new OperatorGraph(pctx);
Expand Down Expand Up @@ -1856,17 +1856,26 @@ private static boolean validPreConditions(ParseContext pctx, SharedWorkOptimizer
return true;
}

static class EdgePredicate implements Function<OpEdge, Boolean> {
static class RelaxedVertexEdgePredicate implements OperatorEdgePredicate {

private EnumSet<EdgeType> nonTraverseableEdgeTypes;
private EnumSet<EdgeType> traverseableEdgeTypes;

public EdgePredicate(EnumSet<EdgeType> nonTraverseableEdgeTypes) {
this.nonTraverseableEdgeTypes = nonTraverseableEdgeTypes;
public RelaxedVertexEdgePredicate(EnumSet<EdgeType> nonTraverseableEdgeTypes) {
this.traverseableEdgeTypes = nonTraverseableEdgeTypes;
}

@Override
public Boolean apply(OpEdge input) {
return !nonTraverseableEdgeTypes.contains(input.getEdgeType());
public boolean accept(Operator<?> s, Operator<?> t, OpEdge opEdge) {
if (!traverseableEdgeTypes.contains(opEdge.getEdgeType())) {
return true;
}
if (s instanceof ReduceSinkOperator) {
ReduceSinkOperator rs = (ReduceSinkOperator) s;
if (!ParallelEdgeFixer.colMappingInverseKeys(rs).isPresent()) {
return true;
}
}
return false;
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
import org.apache.hadoop.hive.ql.parse.SemiJoinBranchInfo;
import org.apache.hadoop.hive.ql.plan.DynamicPruningEventDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;

import com.google.common.base.Function;
import com.google.common.collect.Sets;

/**
Expand Down Expand Up @@ -83,6 +81,11 @@ public EdgeType getEdgeType() {

}

public static interface OperatorEdgePredicate {

boolean accept(Operator<?> s, Operator<?> t, OpEdge opEdge);

}

Map<Operator<?>, Cluster> nodeCluster = new HashMap<>();

Expand All @@ -105,31 +108,31 @@ protected void add(Operator<?> curr) {
members.add(curr);
}

public Set<Cluster> parentClusters(Function<OpEdge, Boolean> traverseEdge) {
public Set<Cluster> parentClusters(OperatorEdgePredicate traverseEdge) {
Set<Cluster> ret = new HashSet<Cluster>();
for (Operator<?> operator : members) {
for (Operator<? extends OperatorDesc> p : operator.getParentOperators()) {
if (members.contains(p)) {
continue;
}
Optional<OpEdge> e = g.getEdge(p, operator);
if (traverseEdge.apply(e.get())) {
if (traverseEdge.accept(p, operator, e.get())) {
ret.add(nodeCluster.get(p));
}
}
}
return ret;
}

public Set<Cluster> childClusters(Function<OpEdge, Boolean> traverseEdge) {
public Set<Cluster> childClusters(OperatorEdgePredicate traverseEdge) {
Set<Cluster> ret = new HashSet<Cluster>();
for (Operator<?> operator : members) {
for (Operator<? extends OperatorDesc> p : operator.getChildOperators()) {
if (members.contains(p)) {
continue;
}
Optional<OpEdge> e = g.getEdge(operator, p);
if (traverseEdge.apply(e.get())) {
if (traverseEdge.accept(operator, p, e.get())) {
ret.add(nodeCluster.get(p));
}
}
Expand Down
Loading