Skip to content

Commit

Permalink
DRILL-6323: Lateral Join - Refactor BatchMemorySize to put outputBatc…
Browse files Browse the repository at this point in the history
…hSize in abstract class. Created a new JoinBatchMemoryManager to be shared across join record batches. Changed merge join to use AbstractBinaryRecordBatch instead of AbstractRecordBatch, and use JoinBatchMemoryManager
  • Loading branch information
Sorabh Hamirwasia authored and parthchandra committed Apr 18, 2018
1 parent 4b6f10e commit 7f19b50
Show file tree
Hide file tree
Showing 8 changed files with 260 additions and 184 deletions.
Expand Up @@ -23,13 +23,13 @@
import org.apache.drill.exec.physical.impl.broadcastsender.BroadcastSenderRootExec;
import org.apache.drill.exec.physical.impl.flatten.FlattenRecordBatch;
import org.apache.drill.exec.physical.impl.join.HashJoinBatch;
import org.apache.drill.exec.physical.impl.join.MergeJoinBatch;
import org.apache.drill.exec.physical.impl.mergereceiver.MergingRecordBatch;
import org.apache.drill.exec.physical.impl.partitionsender.PartitionSenderRootExec;
import org.apache.drill.exec.physical.impl.unorderedreceiver.UnorderedReceiverBatch;
import org.apache.drill.exec.physical.impl.xsort.ExternalSortBatch;
import org.apache.drill.exec.store.parquet.columnreaders.ParquetRecordReader;
import org.apache.drill.exec.proto.UserBitShared.CoreOperatorType;
import org.apache.drill.exec.record.JoinBatchMemoryManager;
import org.apache.drill.exec.store.parquet.columnreaders.ParquetRecordReader;

/**
* Registry of operator metrics.
Expand All @@ -52,7 +52,7 @@ public class OperatorMetricRegistry {
register(CoreOperatorType.EXTERNAL_SORT_VALUE, ExternalSortBatch.Metric.class);
register(CoreOperatorType.PARQUET_ROW_GROUP_SCAN_VALUE, ParquetRecordReader.Metric.class);
register(CoreOperatorType.FLATTEN_VALUE, FlattenRecordBatch.Metric.class);
register(CoreOperatorType.MERGE_JOIN_VALUE, MergeJoinBatch.Metric.class);
register(CoreOperatorType.MERGE_JOIN_VALUE, JoinBatchMemoryManager.Metric.class);
}

private static void register(final int operatorType, final Class<? extends MetricDef> metricDef) {
Expand Down
Expand Up @@ -72,8 +72,7 @@ public class FlattenRecordBatch extends AbstractSingleRecordBatch<FlattenPOP> {
private boolean hasRemainder = false;
private int remainderIndex = 0;
private int recordCount;
private int outputBatchSize;
private final FlattenMemoryManager flattenMemoryManager = new FlattenMemoryManager();
private final FlattenMemoryManager flattenMemoryManager;

private final Flattener.Monitor monitor = new Flattener.Monitor() {
@Override
Expand Down Expand Up @@ -118,6 +117,10 @@ public int metricId() {

private class FlattenMemoryManager extends RecordBatchMemoryManager {

FlattenMemoryManager(int outputBatchSize) {
super(outputBatchSize);
}

@Override
public void update() {
// Get sizing information for the batch.
Expand All @@ -138,11 +141,13 @@ public void update() {
// Average rowWidth of single element in the flatten list.
// subtract the offset vector size from column data size.
final int avgRowWidthSingleFlattenEntry =
RecordBatchSizer.safeDivide(columnSize.getTotalNetSize() - (OFFSET_VECTOR_WIDTH * columnSize.getValueCount()), columnSize.getElementCount());
RecordBatchSizer.safeDivide(columnSize.getTotalNetSize() - (getOffsetVectorWidth() * columnSize.getValueCount()),
columnSize.getElementCount());

// Average rowWidth of outgoing batch.
final int avgOutgoingRowWidth = avgRowWidthWithOutFlattenColumn + avgRowWidthSingleFlattenEntry;

final int outputBatchSize = getOutputBatchSize();
// Number of rows in outgoing batch
setOutputRowCount(outputBatchSize, avgOutgoingRowWidth);

Expand All @@ -165,7 +170,8 @@ public FlattenRecordBatch(FlattenPOP pop, RecordBatch incoming, FragmentContext
super(pop, context, incoming);

// get the output batch size from config.
outputBatchSize = (int) context.getOptions().getOption(ExecConstants.OUTPUT_BATCH_SIZE_VALIDATOR);
int configuredBatchSize = (int) context.getOptions().getOption(ExecConstants.OUTPUT_BATCH_SIZE_VALIDATOR);
flattenMemoryManager = new FlattenMemoryManager(configuredBatchSize);
}

@Override
Expand Down
Expand Up @@ -17,12 +17,13 @@
*/
package org.apache.drill.exec.physical.impl.join;

import static org.apache.drill.exec.compile.sig.GeneratorMapping.GM;

import java.io.IOException;
import java.util.List;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.sun.codemodel.JClass;
import com.sun.codemodel.JConditional;
import com.sun.codemodel.JExpr;
import com.sun.codemodel.JMod;
import com.sun.codemodel.JVar;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.drill.common.expression.ErrorCollector;
import org.apache.drill.common.expression.ErrorCollectorImpl;
Expand All @@ -44,34 +45,31 @@
import org.apache.drill.exec.expr.ExpressionTreeMaterializer;
import org.apache.drill.exec.expr.fn.FunctionGenerationHelper;
import org.apache.drill.exec.ops.FragmentContext;
import org.apache.drill.exec.ops.MetricDef;
import org.apache.drill.exec.physical.config.MergeJoinPOP;
import org.apache.drill.exec.physical.impl.common.Comparator;
import org.apache.drill.exec.record.RecordBatchSizer;
import org.apache.drill.exec.record.AbstractBinaryRecordBatch;
import org.apache.drill.exec.record.BatchSchema;
import org.apache.drill.exec.record.JoinBatchMemoryManager;
import org.apache.drill.exec.record.MaterializedField;
import org.apache.drill.exec.record.RecordBatch;
import org.apache.drill.exec.record.RecordBatchSizer;
import org.apache.drill.exec.record.RecordIterator;
import org.apache.drill.exec.record.TypedFieldId;
import org.apache.drill.exec.record.VectorAccessible;
import org.apache.drill.exec.record.VectorContainer;
import org.apache.drill.exec.record.VectorWrapper;
import org.apache.drill.exec.record.AbstractRecordBatch;
import org.apache.drill.exec.record.RecordBatchMemoryManager;
import org.apache.drill.exec.vector.ValueVector;
import org.apache.drill.exec.vector.complex.AbstractContainerVector;

import com.google.common.base.Preconditions;
import com.sun.codemodel.JClass;
import com.sun.codemodel.JConditional;
import com.sun.codemodel.JExpr;
import com.sun.codemodel.JMod;
import com.sun.codemodel.JVar;
import java.io.IOException;
import java.util.List;

import static org.apache.drill.exec.compile.sig.GeneratorMapping.GM;

/**
* A join operator merges two sorted streams using record iterator.
*/
public class MergeJoinBatch extends AbstractRecordBatch<MergeJoinPOP> {
public class MergeJoinBatch extends AbstractBinaryRecordBatch<MergeJoinPOP> {

private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(MergeJoinBatch.class);

Expand All @@ -96,50 +94,21 @@ public class MergeJoinBatch extends AbstractRecordBatch<MergeJoinPOP> {
GM("doSetup", "doSetup", null, null),
GM("doSetup", "doCompare", null, null));

private final RecordBatch left;
private final RecordBatch right;
private final RecordIterator leftIterator;
private final RecordIterator rightIterator;
private final JoinStatus status;
private final List<JoinCondition> conditions;
private final List<Comparator> comparators;
private final JoinRelType joinType;
private JoinWorker worker;
private final int outputBatchSize;

private static final String LEFT_INPUT = "LEFT INPUT";
private static final String RIGHT_INPUT = "RIGHT INPUT";

private static final int numInputs = 2;
private static final int LEFT_INDEX = 0;
private static final int RIGHT_INDEX = 1;

public enum Metric implements MetricDef {
LEFT_INPUT_BATCH_COUNT,
LEFT_AVG_INPUT_BATCH_BYTES,
LEFT_AVG_INPUT_ROW_BYTES,
LEFT_INPUT_RECORD_COUNT,
RIGHT_INPUT_BATCH_COUNT,
RIGHT_AVG_INPUT_BATCH_BYTES,
RIGHT_AVG_INPUT_ROW_BYTES,
RIGHT_INPUT_RECORD_COUNT,
OUTPUT_BATCH_COUNT,
AVG_OUTPUT_BATCH_BYTES,
AVG_OUTPUT_ROW_BYTES,
OUTPUT_RECORD_COUNT;
private class MergeJoinMemoryManager extends JoinBatchMemoryManager {

@Override
public int metricId() {
return ordinal();
}
}

private class MergeJoinMemoryManager extends RecordBatchMemoryManager {
private int leftRowWidth;
private int rightRowWidth;

public MergeJoinMemoryManager() {
super(numInputs);
MergeJoinMemoryManager(int outputBatchSize, RecordBatch leftBatch, RecordBatch rightBatch) {
super(outputBatchSize, leftBatch, rightBatch);
}

/**
Expand All @@ -152,73 +121,22 @@ public MergeJoinMemoryManager() {
*/
@Override
public void update(int inputIndex) {
switch(inputIndex) {
case LEFT_INDEX:
setRecordBatchSizer(inputIndex, new RecordBatchSizer(left));
leftRowWidth = getRecordBatchSizer(inputIndex).netRowWidth();
logger.debug("left incoming batch size : {}", getRecordBatchSizer(inputIndex));
break;
case RIGHT_INDEX:
setRecordBatchSizer(inputIndex, new RecordBatchSizer(right));
rightRowWidth = getRecordBatchSizer(inputIndex).netRowWidth();
logger.debug("right incoming batch size : {}", getRecordBatchSizer(inputIndex));
default:
break;
}

updateIncomingStats(inputIndex);
final int newOutgoingRowWidth = leftRowWidth + rightRowWidth;

// If outgoing row width is 0, just return. This is possible for empty batches or
// when first set of batches come with OK_NEW_SCHEMA and no data.
if (newOutgoingRowWidth == 0) {
return;
}

// update the value to be used for next batch(es)
setOutputRowCount(outputBatchSize, newOutgoingRowWidth);

// Adjust for the current batch.
// calculate memory used so far based on previous outgoing row width and how many rows we already processed.
final long memoryUsed = status.getOutPosition() * getOutgoingRowWidth();
// This is the remaining memory.
final long remainingMemory = Math.max(outputBatchSize - memoryUsed, 0);
// These are number of rows we can fit in remaining memory based on new outgoing row width.
final int numOutputRowsRemaining = RecordBatchSizer.safeDivide(remainingMemory, newOutgoingRowWidth);

status.setTargetOutputRowCount(adjustOutputRowCount(status.getOutPosition() + numOutputRowsRemaining));
setOutgoingRowWidth(newOutgoingRowWidth);

logger.debug("output batch size : {}, avg outgoing rowWidth : {}, output rowCount : {}",
outputBatchSize, getOutgoingRowWidth(), getOutputRowCount());
}

@Override
public RecordBatchSizer.ColumnSize getColumnSize(String name) {
RecordBatchSizer leftSizer = getRecordBatchSizer(LEFT_INDEX);
RecordBatchSizer rightSizer = getRecordBatchSizer(RIGHT_INDEX);

if (leftSizer != null && leftSizer.getColumn(name) != null) {
return leftSizer.getColumn(name);
}
return rightSizer == null ? null : rightSizer.getColumn(name);
status.setTargetOutputRowCount(super.update(inputIndex, status.getOutPosition()));
}
}

private final MergeJoinMemoryManager mergeJoinMemoryManager = new MergeJoinMemoryManager();

protected MergeJoinBatch(MergeJoinPOP popConfig, FragmentContext context, RecordBatch left, RecordBatch right) throws OutOfMemoryException {
super(popConfig, context, true);
super(popConfig, context, true, left, right);

outputBatchSize = (int) context.getOptions().getOption(ExecConstants.OUTPUT_BATCH_SIZE_VALIDATOR);
// Instantiate the batch memory manager
final int outputBatchSize = (int) context.getOptions().getOption(ExecConstants.OUTPUT_BATCH_SIZE_VALIDATOR);
batchMemoryManager = new MergeJoinMemoryManager(outputBatchSize, left, right);

if (popConfig.getConditions().size() == 0) {
throw new UnsupportedOperationException("Merge Join currently does not support cartesian join. This join operator was configured with 0 conditions");
}
this.left = left;
this.leftIterator = new RecordIterator(left, this, oContext, 0, false, mergeJoinMemoryManager);
this.right = right;
this.rightIterator = new RecordIterator(right, this, oContext, 1, mergeJoinMemoryManager);
this.leftIterator = new RecordIterator(left, this, oContext, 0, false, batchMemoryManager);
this.rightIterator = new RecordIterator(right, this, oContext, 1, batchMemoryManager);
this.joinType = popConfig.getJoinType();
this.status = new JoinStatus(leftIterator, rightIterator, this);
this.conditions = popConfig.getConditions();
Expand All @@ -242,21 +160,10 @@ public int getRecordCount() {
public void buildSchema() {
// initialize iterators
status.initialize();

final IterOutcome leftOutcome = status.getLeftStatus();
final IterOutcome rightOutcome = status.getRightStatus();
if (leftOutcome == IterOutcome.STOP || rightOutcome == IterOutcome.STOP) {
state = BatchState.STOP;
return;
}

if (leftOutcome == IterOutcome.OUT_OF_MEMORY || rightOutcome == IterOutcome.OUT_OF_MEMORY) {
state = BatchState.OUT_OF_MEMORY;
return;
}

if (leftOutcome == IterOutcome.NONE && rightOutcome == IterOutcome.NONE) {
state = BatchState.DONE;
if (!verifyOutcomeToSetBatchState(leftOutcome, rightOutcome)) {
return;
}

Expand All @@ -274,12 +181,12 @@ public IterOutcome innerNext() {
case BATCH_RETURNED:
allocateBatch(false);
status.resetOutputPos();
status.setTargetOutputRowCount(mergeJoinMemoryManager.getOutputRowCount());
status.setTargetOutputRowCount(batchMemoryManager.getOutputRowCount());
break;
case SCHEMA_CHANGED:
allocateBatch(true);
status.resetOutputPos();
status.setTargetOutputRowCount(mergeJoinMemoryManager.getOutputRowCount());
status.setTargetOutputRowCount(batchMemoryManager.getOutputRowCount());
break;
case NO_MORE_DATA:
status.resetOutputPos();
Expand Down Expand Up @@ -359,12 +266,12 @@ private void setRecordCountInContainer() {
Preconditions.checkArgument(!vw.isHyper());
vw.getValueVector().getMutator().setValueCount(getRecordCount());
}
mergeJoinMemoryManager.updateOutgoingStats(getRecordCount());
batchMemoryManager.updateOutgoingStats(getRecordCount());
}

@Override
public void close() {
updateStats();
updateBatchMemoryManagerStats();
super.close();
leftIterator.close();
rightIterator.close();
Expand Down Expand Up @@ -542,9 +449,9 @@ private void allocateBatch(boolean newSchema) {

// Allocate memory for the vectors.
// This will iteratively allocate memory for all nested columns underneath.
int outputRowCount = mergeJoinMemoryManager.getOutputRowCount();
int outputRowCount = batchMemoryManager.getOutputRowCount();
for (VectorWrapper w : container) {
RecordBatchSizer.ColumnSize colSize = mergeJoinMemoryManager.getColumnSize(w.getField().getName());
RecordBatchSizer.ColumnSize colSize = batchMemoryManager.getColumnSize(w.getField().getName());
colSize.allocateVector(w.getValueVector(), outputRowCount);
}

Expand Down Expand Up @@ -610,33 +517,4 @@ private LogicalExpression materializeExpression(LogicalExpression expression, It
}
return materializedExpr;
}

private void updateStats() {
stats.setLongStat(MergeJoinBatch.Metric.LEFT_INPUT_BATCH_COUNT, mergeJoinMemoryManager.getNumIncomingBatches(LEFT_INDEX));
stats.setLongStat(MergeJoinBatch.Metric.LEFT_AVG_INPUT_BATCH_BYTES, mergeJoinMemoryManager.getAvgInputBatchSize(LEFT_INDEX));
stats.setLongStat(MergeJoinBatch.Metric.LEFT_AVG_INPUT_ROW_BYTES, mergeJoinMemoryManager.getAvgInputRowWidth(LEFT_INDEX));
stats.setLongStat(Metric.LEFT_INPUT_RECORD_COUNT, mergeJoinMemoryManager.getTotalInputRecords(LEFT_INDEX));

stats.setLongStat(MergeJoinBatch.Metric.RIGHT_INPUT_BATCH_COUNT, mergeJoinMemoryManager.getNumIncomingBatches(RIGHT_INDEX));
stats.setLongStat(MergeJoinBatch.Metric.RIGHT_AVG_INPUT_BATCH_BYTES, mergeJoinMemoryManager.getAvgInputBatchSize(RIGHT_INDEX));
stats.setLongStat(MergeJoinBatch.Metric.RIGHT_AVG_INPUT_ROW_BYTES, mergeJoinMemoryManager.getAvgInputRowWidth(RIGHT_INDEX));
stats.setLongStat(Metric.RIGHT_INPUT_RECORD_COUNT, mergeJoinMemoryManager.getTotalInputRecords(RIGHT_INDEX));

stats.setLongStat(MergeJoinBatch.Metric.OUTPUT_BATCH_COUNT, mergeJoinMemoryManager.getNumOutgoingBatches());
stats.setLongStat(MergeJoinBatch.Metric.AVG_OUTPUT_BATCH_BYTES, mergeJoinMemoryManager.getAvgOutputBatchSize());
stats.setLongStat(MergeJoinBatch.Metric.AVG_OUTPUT_ROW_BYTES, mergeJoinMemoryManager.getAvgOutputRowWidth());
stats.setLongStat(MergeJoinBatch.Metric.OUTPUT_RECORD_COUNT, mergeJoinMemoryManager.getTotalOutputRecords());

logger.debug("left input: batch count : {}, avg batch bytes : {}, avg row bytes : {}, record count : {}",
mergeJoinMemoryManager.getNumIncomingBatches(LEFT_INDEX), mergeJoinMemoryManager.getAvgInputBatchSize(LEFT_INDEX),
mergeJoinMemoryManager.getAvgInputRowWidth(LEFT_INDEX), mergeJoinMemoryManager.getTotalInputRecords(LEFT_INDEX));

logger.debug("right input: batch count : {}, avg batch bytes : {}, avg row bytes : {}, record count : {}",
mergeJoinMemoryManager.getNumIncomingBatches(RIGHT_INDEX), mergeJoinMemoryManager.getAvgInputBatchSize(RIGHT_INDEX),
mergeJoinMemoryManager.getAvgInputRowWidth(RIGHT_INDEX), mergeJoinMemoryManager.getTotalInputRecords(RIGHT_INDEX));

logger.debug("output: batch count : {}, avg batch bytes : {}, avg row bytes : {}, record count : {}",
mergeJoinMemoryManager.getNumOutgoingBatches(), mergeJoinMemoryManager.getAvgOutputBatchSize(),
mergeJoinMemoryManager.getAvgOutputRowWidth(), mergeJoinMemoryManager.getTotalOutputRecords());
}
}

0 comments on commit 7f19b50

Please sign in to comment.