Skip to content

Commit

Permalink
DRILL-6323: Lateral Join - Refactor BatchMemorySize to put outputBatc…
Browse files Browse the repository at this point in the history
…hSize in abstract class. Created a new JoinBatchMemoryManager to be shared across join record batches. Changed merge join to use AbstractBinaryRecordBatch instead of AbstractRecordBatch, and use JoinBatchMemoryManager
  • Loading branch information
Sorabh Hamirwasia authored and parthchandra committed Apr 18, 2018
1 parent 4b6f10e commit 7f19b50
Show file tree
Hide file tree
Showing 8 changed files with 260 additions and 184 deletions.
Expand Up @@ -23,13 +23,13 @@
import org.apache.drill.exec.physical.impl.broadcastsender.BroadcastSenderRootExec; import org.apache.drill.exec.physical.impl.broadcastsender.BroadcastSenderRootExec;
import org.apache.drill.exec.physical.impl.flatten.FlattenRecordBatch; import org.apache.drill.exec.physical.impl.flatten.FlattenRecordBatch;
import org.apache.drill.exec.physical.impl.join.HashJoinBatch; import org.apache.drill.exec.physical.impl.join.HashJoinBatch;
import org.apache.drill.exec.physical.impl.join.MergeJoinBatch;
import org.apache.drill.exec.physical.impl.mergereceiver.MergingRecordBatch; import org.apache.drill.exec.physical.impl.mergereceiver.MergingRecordBatch;
import org.apache.drill.exec.physical.impl.partitionsender.PartitionSenderRootExec; import org.apache.drill.exec.physical.impl.partitionsender.PartitionSenderRootExec;
import org.apache.drill.exec.physical.impl.unorderedreceiver.UnorderedReceiverBatch; import org.apache.drill.exec.physical.impl.unorderedreceiver.UnorderedReceiverBatch;
import org.apache.drill.exec.physical.impl.xsort.ExternalSortBatch; import org.apache.drill.exec.physical.impl.xsort.ExternalSortBatch;
import org.apache.drill.exec.store.parquet.columnreaders.ParquetRecordReader;
import org.apache.drill.exec.proto.UserBitShared.CoreOperatorType; import org.apache.drill.exec.proto.UserBitShared.CoreOperatorType;
import org.apache.drill.exec.record.JoinBatchMemoryManager;
import org.apache.drill.exec.store.parquet.columnreaders.ParquetRecordReader;


/** /**
* Registry of operator metrics. * Registry of operator metrics.
Expand All @@ -52,7 +52,7 @@ public class OperatorMetricRegistry {
register(CoreOperatorType.EXTERNAL_SORT_VALUE, ExternalSortBatch.Metric.class); register(CoreOperatorType.EXTERNAL_SORT_VALUE, ExternalSortBatch.Metric.class);
register(CoreOperatorType.PARQUET_ROW_GROUP_SCAN_VALUE, ParquetRecordReader.Metric.class); register(CoreOperatorType.PARQUET_ROW_GROUP_SCAN_VALUE, ParquetRecordReader.Metric.class);
register(CoreOperatorType.FLATTEN_VALUE, FlattenRecordBatch.Metric.class); register(CoreOperatorType.FLATTEN_VALUE, FlattenRecordBatch.Metric.class);
register(CoreOperatorType.MERGE_JOIN_VALUE, MergeJoinBatch.Metric.class); register(CoreOperatorType.MERGE_JOIN_VALUE, JoinBatchMemoryManager.Metric.class);
} }


private static void register(final int operatorType, final Class<? extends MetricDef> metricDef) { private static void register(final int operatorType, final Class<? extends MetricDef> metricDef) {
Expand Down
Expand Up @@ -72,8 +72,7 @@ public class FlattenRecordBatch extends AbstractSingleRecordBatch<FlattenPOP> {
private boolean hasRemainder = false; private boolean hasRemainder = false;
private int remainderIndex = 0; private int remainderIndex = 0;
private int recordCount; private int recordCount;
private int outputBatchSize; private final FlattenMemoryManager flattenMemoryManager;
private final FlattenMemoryManager flattenMemoryManager = new FlattenMemoryManager();


private final Flattener.Monitor monitor = new Flattener.Monitor() { private final Flattener.Monitor monitor = new Flattener.Monitor() {
@Override @Override
Expand Down Expand Up @@ -118,6 +117,10 @@ public int metricId() {


private class FlattenMemoryManager extends RecordBatchMemoryManager { private class FlattenMemoryManager extends RecordBatchMemoryManager {


FlattenMemoryManager(int outputBatchSize) {
super(outputBatchSize);
}

@Override @Override
public void update() { public void update() {
// Get sizing information for the batch. // Get sizing information for the batch.
Expand All @@ -138,11 +141,13 @@ public void update() {
// Average rowWidth of single element in the flatten list. // Average rowWidth of single element in the flatten list.
// subtract the offset vector size from column data size. // subtract the offset vector size from column data size.
final int avgRowWidthSingleFlattenEntry = final int avgRowWidthSingleFlattenEntry =
RecordBatchSizer.safeDivide(columnSize.getTotalNetSize() - (OFFSET_VECTOR_WIDTH * columnSize.getValueCount()), columnSize.getElementCount()); RecordBatchSizer.safeDivide(columnSize.getTotalNetSize() - (getOffsetVectorWidth() * columnSize.getValueCount()),
columnSize.getElementCount());


// Average rowWidth of outgoing batch. // Average rowWidth of outgoing batch.
final int avgOutgoingRowWidth = avgRowWidthWithOutFlattenColumn + avgRowWidthSingleFlattenEntry; final int avgOutgoingRowWidth = avgRowWidthWithOutFlattenColumn + avgRowWidthSingleFlattenEntry;


final int outputBatchSize = getOutputBatchSize();
// Number of rows in outgoing batch // Number of rows in outgoing batch
setOutputRowCount(outputBatchSize, avgOutgoingRowWidth); setOutputRowCount(outputBatchSize, avgOutgoingRowWidth);


Expand All @@ -165,7 +170,8 @@ public FlattenRecordBatch(FlattenPOP pop, RecordBatch incoming, FragmentContext
super(pop, context, incoming); super(pop, context, incoming);


// get the output batch size from config. // get the output batch size from config.
outputBatchSize = (int) context.getOptions().getOption(ExecConstants.OUTPUT_BATCH_SIZE_VALIDATOR); int configuredBatchSize = (int) context.getOptions().getOption(ExecConstants.OUTPUT_BATCH_SIZE_VALIDATOR);
flattenMemoryManager = new FlattenMemoryManager(configuredBatchSize);
} }


@Override @Override
Expand Down
Expand Up @@ -17,12 +17,13 @@
*/ */
package org.apache.drill.exec.physical.impl.join; package org.apache.drill.exec.physical.impl.join;


import static org.apache.drill.exec.compile.sig.GeneratorMapping.GM; import com.google.common.base.Preconditions;

import java.io.IOException;
import java.util.List;

import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.sun.codemodel.JClass;
import com.sun.codemodel.JConditional;
import com.sun.codemodel.JExpr;
import com.sun.codemodel.JMod;
import com.sun.codemodel.JVar;
import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.JoinRelType;
import org.apache.drill.common.expression.ErrorCollector; import org.apache.drill.common.expression.ErrorCollector;
import org.apache.drill.common.expression.ErrorCollectorImpl; import org.apache.drill.common.expression.ErrorCollectorImpl;
Expand All @@ -44,34 +45,31 @@
import org.apache.drill.exec.expr.ExpressionTreeMaterializer; import org.apache.drill.exec.expr.ExpressionTreeMaterializer;
import org.apache.drill.exec.expr.fn.FunctionGenerationHelper; import org.apache.drill.exec.expr.fn.FunctionGenerationHelper;
import org.apache.drill.exec.ops.FragmentContext; import org.apache.drill.exec.ops.FragmentContext;
import org.apache.drill.exec.ops.MetricDef;
import org.apache.drill.exec.physical.config.MergeJoinPOP; import org.apache.drill.exec.physical.config.MergeJoinPOP;
import org.apache.drill.exec.physical.impl.common.Comparator; import org.apache.drill.exec.physical.impl.common.Comparator;
import org.apache.drill.exec.record.RecordBatchSizer; import org.apache.drill.exec.record.AbstractBinaryRecordBatch;
import org.apache.drill.exec.record.BatchSchema; import org.apache.drill.exec.record.BatchSchema;
import org.apache.drill.exec.record.JoinBatchMemoryManager;
import org.apache.drill.exec.record.MaterializedField; import org.apache.drill.exec.record.MaterializedField;
import org.apache.drill.exec.record.RecordBatch; import org.apache.drill.exec.record.RecordBatch;
import org.apache.drill.exec.record.RecordBatchSizer;
import org.apache.drill.exec.record.RecordIterator; import org.apache.drill.exec.record.RecordIterator;
import org.apache.drill.exec.record.TypedFieldId; import org.apache.drill.exec.record.TypedFieldId;
import org.apache.drill.exec.record.VectorAccessible; import org.apache.drill.exec.record.VectorAccessible;
import org.apache.drill.exec.record.VectorContainer; import org.apache.drill.exec.record.VectorContainer;
import org.apache.drill.exec.record.VectorWrapper; import org.apache.drill.exec.record.VectorWrapper;
import org.apache.drill.exec.record.AbstractRecordBatch;
import org.apache.drill.exec.record.RecordBatchMemoryManager;
import org.apache.drill.exec.vector.ValueVector; import org.apache.drill.exec.vector.ValueVector;
import org.apache.drill.exec.vector.complex.AbstractContainerVector; import org.apache.drill.exec.vector.complex.AbstractContainerVector;


import com.google.common.base.Preconditions; import java.io.IOException;
import com.sun.codemodel.JClass; import java.util.List;
import com.sun.codemodel.JConditional;
import com.sun.codemodel.JExpr; import static org.apache.drill.exec.compile.sig.GeneratorMapping.GM;
import com.sun.codemodel.JMod;
import com.sun.codemodel.JVar;


/** /**
* A join operator merges two sorted streams using record iterator. * A join operator merges two sorted streams using record iterator.
*/ */
public class MergeJoinBatch extends AbstractRecordBatch<MergeJoinPOP> { public class MergeJoinBatch extends AbstractBinaryRecordBatch<MergeJoinPOP> {


private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(MergeJoinBatch.class); private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(MergeJoinBatch.class);


Expand All @@ -96,50 +94,21 @@ public class MergeJoinBatch extends AbstractRecordBatch<MergeJoinPOP> {
GM("doSetup", "doSetup", null, null), GM("doSetup", "doSetup", null, null),
GM("doSetup", "doCompare", null, null)); GM("doSetup", "doCompare", null, null));


private final RecordBatch left;
private final RecordBatch right;
private final RecordIterator leftIterator; private final RecordIterator leftIterator;
private final RecordIterator rightIterator; private final RecordIterator rightIterator;
private final JoinStatus status; private final JoinStatus status;
private final List<JoinCondition> conditions; private final List<JoinCondition> conditions;
private final List<Comparator> comparators; private final List<Comparator> comparators;
private final JoinRelType joinType; private final JoinRelType joinType;
private JoinWorker worker; private JoinWorker worker;
private final int outputBatchSize;


private static final String LEFT_INPUT = "LEFT INPUT"; private static final String LEFT_INPUT = "LEFT INPUT";
private static final String RIGHT_INPUT = "RIGHT INPUT"; private static final String RIGHT_INPUT = "RIGHT INPUT";


private static final int numInputs = 2; private class MergeJoinMemoryManager extends JoinBatchMemoryManager {
private static final int LEFT_INDEX = 0;
private static final int RIGHT_INDEX = 1;

public enum Metric implements MetricDef {
LEFT_INPUT_BATCH_COUNT,
LEFT_AVG_INPUT_BATCH_BYTES,
LEFT_AVG_INPUT_ROW_BYTES,
LEFT_INPUT_RECORD_COUNT,
RIGHT_INPUT_BATCH_COUNT,
RIGHT_AVG_INPUT_BATCH_BYTES,
RIGHT_AVG_INPUT_ROW_BYTES,
RIGHT_INPUT_RECORD_COUNT,
OUTPUT_BATCH_COUNT,
AVG_OUTPUT_BATCH_BYTES,
AVG_OUTPUT_ROW_BYTES,
OUTPUT_RECORD_COUNT;


@Override MergeJoinMemoryManager(int outputBatchSize, RecordBatch leftBatch, RecordBatch rightBatch) {
public int metricId() { super(outputBatchSize, leftBatch, rightBatch);
return ordinal();
}
}

private class MergeJoinMemoryManager extends RecordBatchMemoryManager {
private int leftRowWidth;
private int rightRowWidth;

public MergeJoinMemoryManager() {
super(numInputs);
} }


/** /**
Expand All @@ -152,73 +121,22 @@ public MergeJoinMemoryManager() {
*/ */
@Override @Override
public void update(int inputIndex) { public void update(int inputIndex) {
switch(inputIndex) { status.setTargetOutputRowCount(super.update(inputIndex, status.getOutPosition()));
case LEFT_INDEX:
setRecordBatchSizer(inputIndex, new RecordBatchSizer(left));
leftRowWidth = getRecordBatchSizer(inputIndex).netRowWidth();
logger.debug("left incoming batch size : {}", getRecordBatchSizer(inputIndex));
break;
case RIGHT_INDEX:
setRecordBatchSizer(inputIndex, new RecordBatchSizer(right));
rightRowWidth = getRecordBatchSizer(inputIndex).netRowWidth();
logger.debug("right incoming batch size : {}", getRecordBatchSizer(inputIndex));
default:
break;
}

updateIncomingStats(inputIndex);
final int newOutgoingRowWidth = leftRowWidth + rightRowWidth;

// If outgoing row width is 0, just return. This is possible for empty batches or
// when first set of batches come with OK_NEW_SCHEMA and no data.
if (newOutgoingRowWidth == 0) {
return;
}

// update the value to be used for next batch(es)
setOutputRowCount(outputBatchSize, newOutgoingRowWidth);

// Adjust for the current batch.
// calculate memory used so far based on previous outgoing row width and how many rows we already processed.
final long memoryUsed = status.getOutPosition() * getOutgoingRowWidth();
// This is the remaining memory.
final long remainingMemory = Math.max(outputBatchSize - memoryUsed, 0);
// These are number of rows we can fit in remaining memory based on new outgoing row width.
final int numOutputRowsRemaining = RecordBatchSizer.safeDivide(remainingMemory, newOutgoingRowWidth);

status.setTargetOutputRowCount(adjustOutputRowCount(status.getOutPosition() + numOutputRowsRemaining));
setOutgoingRowWidth(newOutgoingRowWidth);

logger.debug("output batch size : {}, avg outgoing rowWidth : {}, output rowCount : {}",
outputBatchSize, getOutgoingRowWidth(), getOutputRowCount());
}

@Override
public RecordBatchSizer.ColumnSize getColumnSize(String name) {
RecordBatchSizer leftSizer = getRecordBatchSizer(LEFT_INDEX);
RecordBatchSizer rightSizer = getRecordBatchSizer(RIGHT_INDEX);

if (leftSizer != null && leftSizer.getColumn(name) != null) {
return leftSizer.getColumn(name);
}
return rightSizer == null ? null : rightSizer.getColumn(name);
} }
} }


private final MergeJoinMemoryManager mergeJoinMemoryManager = new MergeJoinMemoryManager();

protected MergeJoinBatch(MergeJoinPOP popConfig, FragmentContext context, RecordBatch left, RecordBatch right) throws OutOfMemoryException { protected MergeJoinBatch(MergeJoinPOP popConfig, FragmentContext context, RecordBatch left, RecordBatch right) throws OutOfMemoryException {
super(popConfig, context, true); super(popConfig, context, true, left, right);


outputBatchSize = (int) context.getOptions().getOption(ExecConstants.OUTPUT_BATCH_SIZE_VALIDATOR); // Instantiate the batch memory manager
final int outputBatchSize = (int) context.getOptions().getOption(ExecConstants.OUTPUT_BATCH_SIZE_VALIDATOR);
batchMemoryManager = new MergeJoinMemoryManager(outputBatchSize, left, right);


if (popConfig.getConditions().size() == 0) { if (popConfig.getConditions().size() == 0) {
throw new UnsupportedOperationException("Merge Join currently does not support cartesian join. This join operator was configured with 0 conditions"); throw new UnsupportedOperationException("Merge Join currently does not support cartesian join. This join operator was configured with 0 conditions");
} }
this.left = left; this.leftIterator = new RecordIterator(left, this, oContext, 0, false, batchMemoryManager);
this.leftIterator = new RecordIterator(left, this, oContext, 0, false, mergeJoinMemoryManager); this.rightIterator = new RecordIterator(right, this, oContext, 1, batchMemoryManager);
this.right = right;
this.rightIterator = new RecordIterator(right, this, oContext, 1, mergeJoinMemoryManager);
this.joinType = popConfig.getJoinType(); this.joinType = popConfig.getJoinType();
this.status = new JoinStatus(leftIterator, rightIterator, this); this.status = new JoinStatus(leftIterator, rightIterator, this);
this.conditions = popConfig.getConditions(); this.conditions = popConfig.getConditions();
Expand All @@ -242,21 +160,10 @@ public int getRecordCount() {
public void buildSchema() { public void buildSchema() {
// initialize iterators // initialize iterators
status.initialize(); status.initialize();

final IterOutcome leftOutcome = status.getLeftStatus(); final IterOutcome leftOutcome = status.getLeftStatus();
final IterOutcome rightOutcome = status.getRightStatus(); final IterOutcome rightOutcome = status.getRightStatus();
if (leftOutcome == IterOutcome.STOP || rightOutcome == IterOutcome.STOP) {
state = BatchState.STOP;
return;
}

if (leftOutcome == IterOutcome.OUT_OF_MEMORY || rightOutcome == IterOutcome.OUT_OF_MEMORY) {
state = BatchState.OUT_OF_MEMORY;
return;
}


if (leftOutcome == IterOutcome.NONE && rightOutcome == IterOutcome.NONE) { if (!verifyOutcomeToSetBatchState(leftOutcome, rightOutcome)) {
state = BatchState.DONE;
return; return;
} }


Expand All @@ -274,12 +181,12 @@ public IterOutcome innerNext() {
case BATCH_RETURNED: case BATCH_RETURNED:
allocateBatch(false); allocateBatch(false);
status.resetOutputPos(); status.resetOutputPos();
status.setTargetOutputRowCount(mergeJoinMemoryManager.getOutputRowCount()); status.setTargetOutputRowCount(batchMemoryManager.getOutputRowCount());
break; break;
case SCHEMA_CHANGED: case SCHEMA_CHANGED:
allocateBatch(true); allocateBatch(true);
status.resetOutputPos(); status.resetOutputPos();
status.setTargetOutputRowCount(mergeJoinMemoryManager.getOutputRowCount()); status.setTargetOutputRowCount(batchMemoryManager.getOutputRowCount());
break; break;
case NO_MORE_DATA: case NO_MORE_DATA:
status.resetOutputPos(); status.resetOutputPos();
Expand Down Expand Up @@ -359,12 +266,12 @@ private void setRecordCountInContainer() {
Preconditions.checkArgument(!vw.isHyper()); Preconditions.checkArgument(!vw.isHyper());
vw.getValueVector().getMutator().setValueCount(getRecordCount()); vw.getValueVector().getMutator().setValueCount(getRecordCount());
} }
mergeJoinMemoryManager.updateOutgoingStats(getRecordCount()); batchMemoryManager.updateOutgoingStats(getRecordCount());
} }


@Override @Override
public void close() { public void close() {
updateStats(); updateBatchMemoryManagerStats();
super.close(); super.close();
leftIterator.close(); leftIterator.close();
rightIterator.close(); rightIterator.close();
Expand Down Expand Up @@ -542,9 +449,9 @@ private void allocateBatch(boolean newSchema) {


// Allocate memory for the vectors. // Allocate memory for the vectors.
// This will iteratively allocate memory for all nested columns underneath. // This will iteratively allocate memory for all nested columns underneath.
int outputRowCount = mergeJoinMemoryManager.getOutputRowCount(); int outputRowCount = batchMemoryManager.getOutputRowCount();
for (VectorWrapper w : container) { for (VectorWrapper w : container) {
RecordBatchSizer.ColumnSize colSize = mergeJoinMemoryManager.getColumnSize(w.getField().getName()); RecordBatchSizer.ColumnSize colSize = batchMemoryManager.getColumnSize(w.getField().getName());
colSize.allocateVector(w.getValueVector(), outputRowCount); colSize.allocateVector(w.getValueVector(), outputRowCount);
} }


Expand Down Expand Up @@ -610,33 +517,4 @@ private LogicalExpression materializeExpression(LogicalExpression expression, It
} }
return materializedExpr; return materializedExpr;
} }

private void updateStats() {
stats.setLongStat(MergeJoinBatch.Metric.LEFT_INPUT_BATCH_COUNT, mergeJoinMemoryManager.getNumIncomingBatches(LEFT_INDEX));
stats.setLongStat(MergeJoinBatch.Metric.LEFT_AVG_INPUT_BATCH_BYTES, mergeJoinMemoryManager.getAvgInputBatchSize(LEFT_INDEX));
stats.setLongStat(MergeJoinBatch.Metric.LEFT_AVG_INPUT_ROW_BYTES, mergeJoinMemoryManager.getAvgInputRowWidth(LEFT_INDEX));
stats.setLongStat(Metric.LEFT_INPUT_RECORD_COUNT, mergeJoinMemoryManager.getTotalInputRecords(LEFT_INDEX));

stats.setLongStat(MergeJoinBatch.Metric.RIGHT_INPUT_BATCH_COUNT, mergeJoinMemoryManager.getNumIncomingBatches(RIGHT_INDEX));
stats.setLongStat(MergeJoinBatch.Metric.RIGHT_AVG_INPUT_BATCH_BYTES, mergeJoinMemoryManager.getAvgInputBatchSize(RIGHT_INDEX));
stats.setLongStat(MergeJoinBatch.Metric.RIGHT_AVG_INPUT_ROW_BYTES, mergeJoinMemoryManager.getAvgInputRowWidth(RIGHT_INDEX));
stats.setLongStat(Metric.RIGHT_INPUT_RECORD_COUNT, mergeJoinMemoryManager.getTotalInputRecords(RIGHT_INDEX));

stats.setLongStat(MergeJoinBatch.Metric.OUTPUT_BATCH_COUNT, mergeJoinMemoryManager.getNumOutgoingBatches());
stats.setLongStat(MergeJoinBatch.Metric.AVG_OUTPUT_BATCH_BYTES, mergeJoinMemoryManager.getAvgOutputBatchSize());
stats.setLongStat(MergeJoinBatch.Metric.AVG_OUTPUT_ROW_BYTES, mergeJoinMemoryManager.getAvgOutputRowWidth());
stats.setLongStat(MergeJoinBatch.Metric.OUTPUT_RECORD_COUNT, mergeJoinMemoryManager.getTotalOutputRecords());

logger.debug("left input: batch count : {}, avg batch bytes : {}, avg row bytes : {}, record count : {}",
mergeJoinMemoryManager.getNumIncomingBatches(LEFT_INDEX), mergeJoinMemoryManager.getAvgInputBatchSize(LEFT_INDEX),
mergeJoinMemoryManager.getAvgInputRowWidth(LEFT_INDEX), mergeJoinMemoryManager.getTotalInputRecords(LEFT_INDEX));

logger.debug("right input: batch count : {}, avg batch bytes : {}, avg row bytes : {}, record count : {}",
mergeJoinMemoryManager.getNumIncomingBatches(RIGHT_INDEX), mergeJoinMemoryManager.getAvgInputBatchSize(RIGHT_INDEX),
mergeJoinMemoryManager.getAvgInputRowWidth(RIGHT_INDEX), mergeJoinMemoryManager.getTotalInputRecords(RIGHT_INDEX));

logger.debug("output: batch count : {}, avg batch bytes : {}, avg row bytes : {}, record count : {}",
mergeJoinMemoryManager.getNumOutgoingBatches(), mergeJoinMemoryManager.getAvgOutputBatchSize(),
mergeJoinMemoryManager.getAvgOutputRowWidth(), mergeJoinMemoryManager.getTotalOutputRecords());
}
} }

0 comments on commit 7f19b50

Please sign in to comment.