Skip to content

Commit

Permalink
DRILL-6236:Batch sizing for hash join
Browse files Browse the repository at this point in the history
  • Loading branch information
ppadma committed May 30, 2018
1 parent 75f8d70 commit 45d2897
Show file tree
Hide file tree
Showing 12 changed files with 480 additions and 102 deletions.
Expand Up @@ -57,16 +57,19 @@
import org.apache.drill.exec.physical.impl.spill.SpillSet;
import org.apache.drill.exec.record.AbstractBinaryRecordBatch;
import org.apache.drill.exec.record.BatchSchema;
import org.apache.drill.exec.record.JoinBatchMemoryManager;
import org.apache.drill.exec.record.MaterializedField;
import org.apache.drill.exec.record.RecordBatch;
import org.apache.drill.exec.record.RecordBatchSizer;
import org.apache.drill.exec.record.VectorWrapper;
import org.apache.drill.exec.vector.FixedWidthVector;
import org.apache.drill.exec.vector.IntVector;
import org.apache.drill.exec.vector.ValueVector;
import org.apache.drill.exec.vector.VariableWidthVector;
import org.apache.drill.exec.vector.complex.AbstractContainerVector;
import org.apache.calcite.rel.core.JoinRelType;

import static org.apache.drill.exec.record.JoinBatchMemoryManager.LEFT_INDEX;
import static org.apache.drill.exec.record.JoinBatchMemoryManager.RIGHT_INDEX;

/**
* This class implements the runtime execution for the Hash-Join operator
* supporting INNER, LEFT OUTER, RIGHT OUTER, and FULL OUTER joins
Expand Down Expand Up @@ -95,11 +98,6 @@ public class HashJoinBatch extends AbstractBinaryRecordBatch<HashJoinPOP> {
*/
private int RECORDS_PER_BATCH; // internal batches

/**
* The maximum number of records in each outgoing batch.
*/
private static final int TARGET_RECORDS_PER_BATCH = 4000;

// Join type, INNER, LEFT, RIGHT or OUTER
private final JoinRelType joinType;

Expand Down Expand Up @@ -172,7 +170,8 @@ public static class HJSpilledPartition {
public String outerSpillFile;
int cycleNum;
int origPartn;
int prevOrigPartn; }
int prevOrigPartn;
}

/**
* Queue of spilled partitions to process.
Expand All @@ -181,7 +180,6 @@ public static class HJSpilledPartition {
private HJSpilledPartition spilledInners[]; // for the outer to find the partition

public enum Metric implements MetricDef {

NUM_BUCKETS,
NUM_ENTRIES,
NUM_RESIZING,
Expand All @@ -190,8 +188,19 @@ public enum Metric implements MetricDef {
SPILLED_PARTITIONS, // number of original partitions spilled to disk
SPILL_MB, // Number of MB of data spilled to disk. This amount is first written,
// then later re-read. So, disk I/O is twice this amount.
SPILL_CYCLE // 0 - no spill, 1 - spill, 2 - SECONDARY, 3 - TERTIARY
;
SPILL_CYCLE, // 0 - no spill, 1 - spill, 2 - SECONDARY, 3 - TERTIARY
LEFT_INPUT_BATCH_COUNT,
LEFT_AVG_INPUT_BATCH_BYTES,
LEFT_AVG_INPUT_ROW_BYTES,
LEFT_INPUT_RECORD_COUNT,
RIGHT_INPUT_BATCH_COUNT,
RIGHT_AVG_INPUT_BATCH_BYTES,
RIGHT_AVG_INPUT_ROW_BYTES,
RIGHT_INPUT_RECORD_COUNT,
OUTPUT_BATCH_COUNT,
AVG_OUTPUT_BATCH_BYTES,
AVG_OUTPUT_ROW_BYTES,
OUTPUT_RECORD_COUNT;

// duplicate for hash ag

Expand Down Expand Up @@ -221,19 +230,23 @@ protected void buildSchema() throws SchemaChangeException {
throw new SchemaChangeException(e);
}

// Build the container schema and set the counts
for (final VectorWrapper<?> w : container) {
w.getValueVector().allocateNew();
}
container.buildSchema(BatchSchema.SelectionVectorMode.NONE);
container.setRecordCount(outputRecords);
}

@Override
protected boolean prefetchFirstBatchFromBothSides() {
leftUpstream = sniffNonEmptyBatch(0, left);
rightUpstream = sniffNonEmptyBatch(1, right);

// For build side, use aggregate i.e. average row width across batches
batchMemoryManager.update(LEFT_INDEX, 0);
batchMemoryManager.update(RIGHT_INDEX, 0, true);

if (logger.isDebugEnabled()) {
logger.debug("BATCH_STATS, incoming left:\n {}", batchMemoryManager.getRecordBatchSizer(LEFT_INDEX));
logger.debug("BATCH_STATS, incoming right:\n {}", batchMemoryManager.getRecordBatchSizer(RIGHT_INDEX));
}

if (leftUpstream == IterOutcome.STOP || rightUpstream == IterOutcome.STOP) {
state = BatchState.STOP;
return false;
Expand Down Expand Up @@ -333,10 +346,21 @@ public IterOutcome innerNext() {
joinType != JoinRelType.INNER) { // or if this is a left/full outer join

// Allocate the memory for the vectors in the output container
allocateVectors();
batchMemoryManager.allocateVectors(container);
hashJoinProbe.setTargetOutputCount(batchMemoryManager.getOutputRowCount());

outputRecords = hashJoinProbe.probeAndProject();

for (final VectorWrapper<?> v : container) {
v.getValueVector().getMutator().setValueCount(outputRecords);
}
container.setRecordCount(outputRecords);

batchMemoryManager.updateOutgoingStats(outputRecords);
if (logger.isDebugEnabled()) {
logger.debug("BATCH_STATS, outgoing:\n {}", new RecordBatchSizer(this));
}

/* We are here because of one the following
* 1. Completed processing of all the records and we are done
* 2. We've filled up the outgoing batch to the maximum and we need to return upstream
Expand All @@ -347,10 +371,6 @@ public IterOutcome innerNext() {
state = BatchState.NOT_FIRST;
}

for (final VectorWrapper<?> v : container) {
v.getValueVector().getMutator().setValueCount(outputRecords);
}

return IterOutcome.OK;
}

Expand Down Expand Up @@ -557,7 +577,8 @@ private HashJoinMemoryCalculator.BuildSidePartitioning partitionNumTuning(
RECORDS_PER_BATCH,
maxBatchSize,
maxBatchSize,
TARGET_RECORDS_PER_BATCH,
batchMemoryManager.getOutputRowCount(),
batchMemoryManager.getOutputBatchSize(),
HashTable.DEFAULT_LOAD_FACTOR);

disableSpilling(null);
Expand Down Expand Up @@ -628,7 +649,8 @@ public void executeBuildPhase() throws SchemaChangeException {
RECORDS_PER_BATCH,
maxBatchSize,
maxBatchSize,
TARGET_RECORDS_PER_BATCH,
batchMemoryManager.getOutputRowCount(),
batchMemoryManager.getOutputBatchSize(),
HashTable.DEFAULT_LOAD_FACTOR);

if (firstCycle && doMemoryCalculation) {
Expand Down Expand Up @@ -665,6 +687,7 @@ public void executeBuildPhase() throws SchemaChangeException {
for (HashPartition partn : partitions) { partn.updateBatches(); }
// Fall through
case OK:
batchMemoryManager.update(buildBatch, RIGHT_INDEX, 0, true);
// Special treatment (when no spill, and single partition) -- use the incoming vectors as they are (no row copy)
if ( numPartitions == 1 ) {
partitions[0].appendBatch(buildBatch);
Expand Down Expand Up @@ -803,22 +826,6 @@ private void setupOutputContainerSchema() {

}

private void allocateVectors() {
for (final VectorWrapper<?> vectorWrapper : container) {
ValueVector valueVector = vectorWrapper.getValueVector();

if (valueVector instanceof FixedWidthVector) {
((FixedWidthVector) valueVector).allocateNew(TARGET_RECORDS_PER_BATCH);
} else if (valueVector instanceof VariableWidthVector) {
((VariableWidthVector) valueVector).allocateNew(8 * TARGET_RECORDS_PER_BATCH, TARGET_RECORDS_PER_BATCH);
} else {
valueVector.allocateNew();
}
}

container.setRecordCount(0); // reset container's counter back to zero records
}

// (After the inner side was read whole) - Has that inner partition spilled
public boolean isSpilledInner(int part) {
if ( spilledInners == null ) { return false; } // empty inner
Expand Down Expand Up @@ -879,6 +886,10 @@ public HashJoinBatch(HashJoinPOP popConfig, FragmentContext context,

// Create empty partitions (in the ctor - covers the case where right side is empty)
partitions = new HashPartition[0];

// get the output batch size from config.
int configuredBatchSize = (int) context.getOptions().getOption(ExecConstants.OUTPUT_BATCH_SIZE_VALIDATOR);
batchMemoryManager = new JoinBatchMemoryManager(configuredBatchSize, left, right);
}

/**
Expand Down Expand Up @@ -966,13 +977,49 @@ public void killIncoming(boolean sendUpstream) {
buildBatch.kill(sendUpstream);
}

public void updateMetrics() {
stats.setLongStat(HashJoinBatch.Metric.LEFT_INPUT_BATCH_COUNT, batchMemoryManager.getNumIncomingBatches(LEFT_INDEX));
stats.setLongStat(HashJoinBatch.Metric.LEFT_AVG_INPUT_BATCH_BYTES, batchMemoryManager.getAvgInputBatchSize(LEFT_INDEX));
stats.setLongStat(HashJoinBatch.Metric.LEFT_AVG_INPUT_ROW_BYTES, batchMemoryManager.getAvgInputRowWidth(LEFT_INDEX));
stats.setLongStat(HashJoinBatch.Metric.LEFT_INPUT_RECORD_COUNT, batchMemoryManager.getTotalInputRecords(LEFT_INDEX));

stats.setLongStat(HashJoinBatch.Metric.RIGHT_INPUT_BATCH_COUNT, batchMemoryManager.getNumIncomingBatches(RIGHT_INDEX));
stats.setLongStat(HashJoinBatch.Metric.RIGHT_AVG_INPUT_BATCH_BYTES, batchMemoryManager.getAvgInputBatchSize(RIGHT_INDEX));
stats.setLongStat(HashJoinBatch.Metric.RIGHT_AVG_INPUT_ROW_BYTES, batchMemoryManager.getAvgInputRowWidth(RIGHT_INDEX));
stats.setLongStat(HashJoinBatch.Metric.RIGHT_INPUT_RECORD_COUNT, batchMemoryManager.getTotalInputRecords(RIGHT_INDEX));

stats.setLongStat(HashJoinBatch.Metric.OUTPUT_BATCH_COUNT, batchMemoryManager.getNumOutgoingBatches());
stats.setLongStat(HashJoinBatch.Metric.AVG_OUTPUT_BATCH_BYTES, batchMemoryManager.getAvgOutputBatchSize());
stats.setLongStat(HashJoinBatch.Metric.AVG_OUTPUT_ROW_BYTES, batchMemoryManager.getAvgOutputRowWidth());
stats.setLongStat(HashJoinBatch.Metric.OUTPUT_RECORD_COUNT, batchMemoryManager.getTotalOutputRecords());
}

@Override
public void close() {
if ( cycleNum > 0 ) { // spilling happened
// In case closing due to cancellation, BaseRootExec.close() does not close the open
// SpilledRecordBatch "scanners" as it only knows about the original left/right ops.
killIncoming(false);
}

updateMetrics();

logger.debug("BATCH_STATS, incoming aggregate left: batch count : {}, avg bytes : {}, avg row bytes : {}, record count : {}",
batchMemoryManager.getNumIncomingBatches(JoinBatchMemoryManager.LEFT_INDEX),
batchMemoryManager.getAvgInputBatchSize(JoinBatchMemoryManager.LEFT_INDEX),
batchMemoryManager.getAvgInputRowWidth(JoinBatchMemoryManager.LEFT_INDEX),
batchMemoryManager.getTotalInputRecords(JoinBatchMemoryManager.LEFT_INDEX));

logger.debug("BATCH_STATS, incoming aggregate right: batch count : {}, avg bytes : {}, avg row bytes : {}, record count : {}",
batchMemoryManager.getNumIncomingBatches(JoinBatchMemoryManager.RIGHT_INDEX),
batchMemoryManager.getAvgInputBatchSize(JoinBatchMemoryManager.RIGHT_INDEX),
batchMemoryManager.getAvgInputRowWidth(JoinBatchMemoryManager.RIGHT_INDEX),
batchMemoryManager.getTotalInputRecords(JoinBatchMemoryManager.RIGHT_INDEX));

logger.debug("BATCH_STATS, outgoing aggregate: batch count : {}, avg bytes : {}, avg row bytes : {}, record count : {}",
batchMemoryManager.getNumOutgoingBatches(), batchMemoryManager.getAvgOutputBatchSize(),
batchMemoryManager.getAvgOutputRowWidth(), batchMemoryManager.getTotalOutputRecords());

this.cleanup();
super.close();
}
Expand Down
Expand Up @@ -77,6 +77,7 @@ public void initialize(boolean autoTune,
int maxBatchNumRecordsBuild,
int maxBatchNumRecordsProbe,
int outputBatchNumRecords,
int outputBatchSize,
double loadFactor) {
this.initialPartitions = initialPartitions;
}
Expand Down
Expand Up @@ -100,6 +100,7 @@ void initialize(boolean autoTune,
int maxBatchNumRecordsBuild,
int maxBatchNumRecordsProbe,
int outputBatchNumRecords,
int outputBatchSize,
double loadFactor);

void setPartitionStatSet(PartitionStatSet partitionStatSet);
Expand Down
Expand Up @@ -142,6 +142,7 @@ public void initialize(boolean autoTune,
int maxBatchNumRecordsBuild,
int maxBatchNumRecordsProbe,
int outputBatchNumRecords,
int outputBatchSize,
double loadFactor) {
this.initialPartitions = initialPartitions;
}
Expand Down Expand Up @@ -203,7 +204,7 @@ public HashJoinState getState() {
* <h1>Life Cycle</h1>
* <p>
* <ul>
* <li><b>Step 0:</b> Call {@link #initialize(boolean, boolean, RecordBatch, RecordBatch, Set, long, int, int, int, int, int, int, double)}.
* <li><b>Step 0:</b> Call {@link #initialize(boolean, boolean, RecordBatch, RecordBatch, Set, long, int, int, int, int, int, int, int, double)}.
* This will initialize the StateCalculate with the additional information it needs.</li>
* <li><b>Step 1:</b> Call {@link #getNumPartitions()} to see the number of partitions that fit in memory.</li>
* <li><b>Step 2:</b> Call {@link #shouldSpill()} To determine if spilling needs to occurr.</li>
Expand Down Expand Up @@ -233,9 +234,7 @@ public static class BuildSidePartitioningImpl implements BuildSidePartitioning {
private int partitions;
private int recordsPerPartitionBatchBuild;
private int recordsPerPartitionBatchProbe;
private int outputBatchNumRecords;
private Map<String, Long> buildValueSizes;
private Map<String, Long> probeValueSizes;
private int outputBatchSize;
private Map<String, Long> keySizes;
private boolean autoTune;
private boolean reserveHash;
Expand Down Expand Up @@ -273,6 +272,7 @@ public void initialize(boolean autoTune,
int maxBatchNumRecordsBuild,
int maxBatchNumRecordsProbe,
int outputBatchNumRecords,
int outputBatchSize,
double loadFactor) {
Preconditions.checkNotNull(buildSideBatch);
Preconditions.checkNotNull(probeSideBatch);
Expand Down Expand Up @@ -300,8 +300,6 @@ public void initialize(boolean autoTune,

initialize(autoTune,
reserveHash,
buildValueSizes,
probeValueSizes,
keySizes,
memoryAvailable,
initialPartitions,
Expand All @@ -313,7 +311,7 @@ public void initialize(boolean autoTune,
recordsPerPartitionBatchProbe,
maxBatchNumRecordsBuild,
maxBatchNumRecordsProbe,
outputBatchNumRecords,
outputBatchSize,
loadFactor);
}

Expand Down Expand Up @@ -352,8 +350,6 @@ public static long getBatchSizeEstimate(final RecordBatch recordBatch) {
@VisibleForTesting
protected void initialize(boolean autoTune,
boolean reserveHash,
CaseInsensitiveMap<Long> buildValueSizes,
CaseInsensitiveMap<Long> probeValueSizes,
CaseInsensitiveMap<Long> keySizes,
long memoryAvailable,
int initialPartitions,
Expand All @@ -365,7 +361,7 @@ protected void initialize(boolean autoTune,
int recordsPerPartitionBatchProbe,
int maxBatchNumRecordsBuild,
int maxBatchNumRecordsProbe,
int outputBatchNumRecords,
int outputBatchSize,
double loadFactor) {
Preconditions.checkState(!firstInitialized);
Preconditions.checkArgument(initialPartitions >= 1);
Expand All @@ -374,8 +370,6 @@ protected void initialize(boolean autoTune,
this.loadFactor = loadFactor;
this.autoTune = autoTune;
this.reserveHash = reserveHash;
this.buildValueSizes = Preconditions.checkNotNull(buildValueSizes);
this.probeValueSizes = Preconditions.checkNotNull(probeValueSizes);
this.keySizes = Preconditions.checkNotNull(keySizes);
this.memoryAvailable = memoryAvailable;
this.buildBatchSize = buildBatchSize;
Expand All @@ -387,7 +381,7 @@ protected void initialize(boolean autoTune,
this.recordsPerPartitionBatchProbe = recordsPerPartitionBatchProbe;
this.maxBatchNumRecordsBuild = maxBatchNumRecordsBuild;
this.maxBatchNumRecordsProbe = maxBatchNumRecordsProbe;
this.outputBatchNumRecords = outputBatchNumRecords;
this.outputBatchSize = outputBatchSize;

calculateMemoryUsage();

Expand Down Expand Up @@ -448,8 +442,7 @@ private void calculateMemoryUsage()
safetyFactor,
reserveHash);

maxOutputBatchSize = computeMaxOutputBatchSize(buildValueSizes, probeValueSizes, keySizes,
outputBatchNumRecords, safetyFactor, fragmentationFactor);
maxOutputBatchSize = (long) (outputBatchSize * fragmentationFactor * safetyFactor);

long probeReservedMemory;

Expand Down Expand Up @@ -519,18 +512,6 @@ private void calculateMemoryUsage()
}
}

public static long computeMaxOutputBatchSize(Map<String, Long> buildValueSizes,
Map<String, Long> probeValueSizes,
Map<String, Long> keySizes,
int outputNumRecords,
double safetyFactor,
double fragmentationFactor) {
long outputSize = HashTableSizeCalculatorConservativeImpl.computeVectorSizes(keySizes, outputNumRecords, safetyFactor)
+ HashTableSizeCalculatorConservativeImpl.computeVectorSizes(buildValueSizes, outputNumRecords, safetyFactor)
+ HashTableSizeCalculatorConservativeImpl.computeVectorSizes(probeValueSizes, outputNumRecords, safetyFactor);
return RecordBatchSizer.multiplyByFactor(outputSize, fragmentationFactor);
}

@Override
public boolean shouldSpill() {
Preconditions.checkState(initialized);
Expand Down
Expand Up @@ -42,4 +42,6 @@ enum ProbeState {
void setupHashJoinProbe(RecordBatch probeBatch, HashJoinBatch outgoing, JoinRelType joinRelType, RecordBatch.IterOutcome leftStartState, HashPartition[] partitions, int cycleNum, VectorContainer container, HashJoinBatch.HJSpilledPartition[] spilledInners, boolean buildSideIsEmpty, int numPartitions, int rightHVColPosition);
int probeAndProject() throws SchemaChangeException;
void changeToFinalProbeState();
void setTargetOutputCount(int targetOutputCount);
int getOutputCount();
}

0 comments on commit 45d2897

Please sign in to comment.