Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRILL-6236: batch sizing for hash join #1227

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -57,16 +57,19 @@
import org.apache.drill.exec.physical.impl.spill.SpillSet;
import org.apache.drill.exec.record.AbstractBinaryRecordBatch;
import org.apache.drill.exec.record.BatchSchema;
import org.apache.drill.exec.record.JoinBatchMemoryManager;
import org.apache.drill.exec.record.MaterializedField;
import org.apache.drill.exec.record.RecordBatch;
import org.apache.drill.exec.record.RecordBatchSizer;
import org.apache.drill.exec.record.VectorWrapper;
import org.apache.drill.exec.vector.FixedWidthVector;
import org.apache.drill.exec.vector.IntVector;
import org.apache.drill.exec.vector.ValueVector;
import org.apache.drill.exec.vector.VariableWidthVector;
import org.apache.drill.exec.vector.complex.AbstractContainerVector;
import org.apache.calcite.rel.core.JoinRelType;

import static org.apache.drill.exec.record.JoinBatchMemoryManager.LEFT_INDEX;
import static org.apache.drill.exec.record.JoinBatchMemoryManager.RIGHT_INDEX;

/**
* This class implements the runtime execution for the Hash-Join operator
* supporting INNER, LEFT OUTER, RIGHT OUTER, and FULL OUTER joins
Expand Down Expand Up @@ -95,11 +98,6 @@ public class HashJoinBatch extends AbstractBinaryRecordBatch<HashJoinPOP> {
*/
private int RECORDS_PER_BATCH; // internal batches

/**
* The maximum number of records in each outgoing batch.
*/
private static final int TARGET_RECORDS_PER_BATCH = 4000;

// Join type, INNER, LEFT, RIGHT or OUTER
private final JoinRelType joinType;

Expand Down Expand Up @@ -172,7 +170,8 @@ public static class HJSpilledPartition {
public String outerSpillFile;
int cycleNum;
int origPartn;
int prevOrigPartn; }
int prevOrigPartn;
}

/**
* Queue of spilled partitions to process.
Expand All @@ -181,7 +180,6 @@ public static class HJSpilledPartition {
private HJSpilledPartition spilledInners[]; // for the outer to find the partition

public enum Metric implements MetricDef {

NUM_BUCKETS,
NUM_ENTRIES,
NUM_RESIZING,
Expand All @@ -190,8 +188,19 @@ public enum Metric implements MetricDef {
SPILLED_PARTITIONS, // number of original partitions spilled to disk
SPILL_MB, // Number of MB of data spilled to disk. This amount is first written,
// then later re-read. So, disk I/O is twice this amount.
SPILL_CYCLE // 0 - no spill, 1 - spill, 2 - SECONDARY, 3 - TERTIARY
;
SPILL_CYCLE, // 0 - no spill, 1 - spill, 2 - SECONDARY, 3 - TERTIARY
LEFT_INPUT_BATCH_COUNT,
LEFT_AVG_INPUT_BATCH_BYTES,
LEFT_AVG_INPUT_ROW_BYTES,
LEFT_INPUT_RECORD_COUNT,
RIGHT_INPUT_BATCH_COUNT,
RIGHT_AVG_INPUT_BATCH_BYTES,
RIGHT_AVG_INPUT_ROW_BYTES,
RIGHT_INPUT_RECORD_COUNT,
OUTPUT_BATCH_COUNT,
AVG_OUTPUT_BATCH_BYTES,
AVG_OUTPUT_ROW_BYTES,
OUTPUT_RECORD_COUNT;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The metrics are to be used also by customers; is this information relevant for them ? Is this too detailed (e.g., can be logged instead).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is relevant in the sense that they provide high level picture of amount of data being processed, memory usage etc. by each operator. This is also helpful when debugging trying to figure out what is going on.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting these metrics inside operator Metric class will not work. For joins these metrics were moved inside JoinBatchMemoryManager.Metric class since they are memory manager metrics. So when you call updateBatchMemoryManagerStats() it updates the operator stats but using ordinals from JoinBatchMemoryManager.Metric class. So the ordinal for LEFT_INPUT_BATCH_COUNT will be 0 not 4 (which is required).
I think we should improve our OperatorsMetricRegistry to register multiple Metric classes for an operator.


// duplicate for hash ag

Expand Down Expand Up @@ -221,19 +230,23 @@ protected void buildSchema() throws SchemaChangeException {
throw new SchemaChangeException(e);
}

// Build the container schema and set the counts
for (final VectorWrapper<?> w : container) {
w.getValueVector().allocateNew();
}
container.buildSchema(BatchSchema.SelectionVectorMode.NONE);
container.setRecordCount(outputRecords);
}

@Override
protected boolean prefetchFirstBatchFromBothSides() {
leftUpstream = sniffNonEmptyBatch(0, left);
rightUpstream = sniffNonEmptyBatch(1, right);

// For build side, use aggregate i.e. average row width across batches
batchMemoryManager.update(LEFT_INDEX, 0);
batchMemoryManager.update(RIGHT_INDEX, 0, true);

if (logger.isDebugEnabled()) {
logger.debug("BATCH_STATS, incoming left:\n {}", batchMemoryManager.getRecordBatchSizer(LEFT_INDEX));
logger.debug("BATCH_STATS, incoming right:\n {}", batchMemoryManager.getRecordBatchSizer(RIGHT_INDEX));
}

if (leftUpstream == IterOutcome.STOP || rightUpstream == IterOutcome.STOP) {
state = BatchState.STOP;
return false;
Expand Down Expand Up @@ -333,10 +346,21 @@ public IterOutcome innerNext() {
joinType != JoinRelType.INNER) { // or if this is a left/full outer join

// Allocate the memory for the vectors in the output container
allocateVectors();
batchMemoryManager.allocateVectors(container);
hashJoinProbe.setTargetOutputCount(batchMemoryManager.getOutputRowCount());

outputRecords = hashJoinProbe.probeAndProject();

for (final VectorWrapper<?> v : container) {
v.getValueVector().getMutator().setValueCount(outputRecords);
}
container.setRecordCount(outputRecords);

batchMemoryManager.updateOutgoingStats(outputRecords);
if (logger.isDebugEnabled()) {
logger.debug("BATCH_STATS, outgoing:\n {}", new RecordBatchSizer(this));
}

/* We are here because of one the following
* 1. Completed processing of all the records and we are done
* 2. We've filled up the outgoing batch to the maximum and we need to return upstream
Expand All @@ -347,10 +371,6 @@ public IterOutcome innerNext() {
state = BatchState.NOT_FIRST;
}

for (final VectorWrapper<?> v : container) {
v.getValueVector().getMutator().setValueCount(outputRecords);
}

return IterOutcome.OK;
}

Expand Down Expand Up @@ -557,7 +577,8 @@ private HashJoinMemoryCalculator.BuildSidePartitioning partitionNumTuning(
RECORDS_PER_BATCH,
maxBatchSize,
maxBatchSize,
TARGET_RECORDS_PER_BATCH,
batchMemoryManager.getOutputRowCount(),
batchMemoryManager.getOutputBatchSize(),
HashTable.DEFAULT_LOAD_FACTOR);

disableSpilling(null);
Expand Down Expand Up @@ -628,7 +649,8 @@ public void executeBuildPhase() throws SchemaChangeException {
RECORDS_PER_BATCH,
maxBatchSize,
maxBatchSize,
TARGET_RECORDS_PER_BATCH,
batchMemoryManager.getOutputRowCount(),
batchMemoryManager.getOutputBatchSize(),
HashTable.DEFAULT_LOAD_FACTOR);

if (firstCycle && doMemoryCalculation) {
Expand Down Expand Up @@ -665,6 +687,7 @@ public void executeBuildPhase() throws SchemaChangeException {
for (HashPartition partn : partitions) { partn.updateBatches(); }
// Fall through
case OK:
batchMemoryManager.update(buildBatch, RIGHT_INDEX, 0, true);
// Special treatment (when no spill, and single partition) -- use the incoming vectors as they are (no row copy)
if ( numPartitions == 1 ) {
partitions[0].appendBatch(buildBatch);
Expand Down Expand Up @@ -803,22 +826,6 @@ private void setupOutputContainerSchema() {

}

private void allocateVectors() {
for (final VectorWrapper<?> vectorWrapper : container) {
ValueVector valueVector = vectorWrapper.getValueVector();

if (valueVector instanceof FixedWidthVector) {
((FixedWidthVector) valueVector).allocateNew(TARGET_RECORDS_PER_BATCH);
} else if (valueVector instanceof VariableWidthVector) {
((VariableWidthVector) valueVector).allocateNew(8 * TARGET_RECORDS_PER_BATCH, TARGET_RECORDS_PER_BATCH);
} else {
valueVector.allocateNew();
}
}

container.setRecordCount(0); // reset container's counter back to zero records
}

// (After the inner side was read whole) - Has that inner partition spilled
public boolean isSpilledInner(int part) {
if ( spilledInners == null ) { return false; } // empty inner
Expand Down Expand Up @@ -879,6 +886,10 @@ public HashJoinBatch(HashJoinPOP popConfig, FragmentContext context,

// Create empty partitions (in the ctor - covers the case where right side is empty)
partitions = new HashPartition[0];

// get the output batch size from config.
int configuredBatchSize = (int) context.getOptions().getOption(ExecConstants.OUTPUT_BATCH_SIZE_VALIDATOR);
batchMemoryManager = new JoinBatchMemoryManager(configuredBatchSize, left, right);
}

/**
Expand Down Expand Up @@ -966,13 +977,49 @@ public void killIncoming(boolean sendUpstream) {
buildBatch.kill(sendUpstream);
}

public void updateMetrics() {
stats.setLongStat(HashJoinBatch.Metric.LEFT_INPUT_BATCH_COUNT, batchMemoryManager.getNumIncomingBatches(LEFT_INDEX));
stats.setLongStat(HashJoinBatch.Metric.LEFT_AVG_INPUT_BATCH_BYTES, batchMemoryManager.getAvgInputBatchSize(LEFT_INDEX));
stats.setLongStat(HashJoinBatch.Metric.LEFT_AVG_INPUT_ROW_BYTES, batchMemoryManager.getAvgInputRowWidth(LEFT_INDEX));
stats.setLongStat(HashJoinBatch.Metric.LEFT_INPUT_RECORD_COUNT, batchMemoryManager.getTotalInputRecords(LEFT_INDEX));

stats.setLongStat(HashJoinBatch.Metric.RIGHT_INPUT_BATCH_COUNT, batchMemoryManager.getNumIncomingBatches(RIGHT_INDEX));
stats.setLongStat(HashJoinBatch.Metric.RIGHT_AVG_INPUT_BATCH_BYTES, batchMemoryManager.getAvgInputBatchSize(RIGHT_INDEX));
stats.setLongStat(HashJoinBatch.Metric.RIGHT_AVG_INPUT_ROW_BYTES, batchMemoryManager.getAvgInputRowWidth(RIGHT_INDEX));
stats.setLongStat(HashJoinBatch.Metric.RIGHT_INPUT_RECORD_COUNT, batchMemoryManager.getTotalInputRecords(RIGHT_INDEX));

stats.setLongStat(HashJoinBatch.Metric.OUTPUT_BATCH_COUNT, batchMemoryManager.getNumOutgoingBatches());
stats.setLongStat(HashJoinBatch.Metric.AVG_OUTPUT_BATCH_BYTES, batchMemoryManager.getAvgOutputBatchSize());
stats.setLongStat(HashJoinBatch.Metric.AVG_OUTPUT_ROW_BYTES, batchMemoryManager.getAvgOutputRowWidth());
stats.setLongStat(HashJoinBatch.Metric.OUTPUT_RECORD_COUNT, batchMemoryManager.getTotalOutputRecords());
}

@Override
public void close() {
if ( cycleNum > 0 ) { // spilling happened
// In case closing due to cancellation, BaseRootExec.close() does not close the open
// SpilledRecordBatch "scanners" as it only knows about the original left/right ops.
killIncoming(false);
}

updateMetrics();

logger.debug("BATCH_STATS, incoming aggregate left: batch count : {}, avg bytes : {}, avg row bytes : {}, record count : {}",
batchMemoryManager.getNumIncomingBatches(JoinBatchMemoryManager.LEFT_INDEX),
batchMemoryManager.getAvgInputBatchSize(JoinBatchMemoryManager.LEFT_INDEX),
batchMemoryManager.getAvgInputRowWidth(JoinBatchMemoryManager.LEFT_INDEX),
batchMemoryManager.getTotalInputRecords(JoinBatchMemoryManager.LEFT_INDEX));

logger.debug("BATCH_STATS, incoming aggregate right: batch count : {}, avg bytes : {}, avg row bytes : {}, record count : {}",
batchMemoryManager.getNumIncomingBatches(JoinBatchMemoryManager.RIGHT_INDEX),
batchMemoryManager.getAvgInputBatchSize(JoinBatchMemoryManager.RIGHT_INDEX),
batchMemoryManager.getAvgInputRowWidth(JoinBatchMemoryManager.RIGHT_INDEX),
batchMemoryManager.getTotalInputRecords(JoinBatchMemoryManager.RIGHT_INDEX));

logger.debug("BATCH_STATS, outgoing aggregate: batch count : {}, avg bytes : {}, avg row bytes : {}, record count : {}",
batchMemoryManager.getNumOutgoingBatches(), batchMemoryManager.getAvgOutputBatchSize(),
batchMemoryManager.getAvgOutputRowWidth(), batchMemoryManager.getTotalOutputRecords());

this.cleanup();
super.close();
}
Expand Down
Expand Up @@ -77,6 +77,7 @@ public void initialize(boolean autoTune,
int maxBatchNumRecordsBuild,
int maxBatchNumRecordsProbe,
int outputBatchNumRecords,
int outputBatchSize,
double loadFactor) {
this.initialPartitions = initialPartitions;
}
Expand Down
Expand Up @@ -100,6 +100,7 @@ void initialize(boolean autoTune,
int maxBatchNumRecordsBuild,
int maxBatchNumRecordsProbe,
int outputBatchNumRecords,
int outputBatchSize,
double loadFactor);

void setPartitionStatSet(PartitionStatSet partitionStatSet);
Expand Down
Expand Up @@ -142,6 +142,7 @@ public void initialize(boolean autoTune,
int maxBatchNumRecordsBuild,
int maxBatchNumRecordsProbe,
int outputBatchNumRecords,
int outputBatchSize,
double loadFactor) {
this.initialPartitions = initialPartitions;
}
Expand Down Expand Up @@ -203,7 +204,7 @@ public HashJoinState getState() {
* <h1>Life Cycle</h1>
* <p>
* <ul>
* <li><b>Step 0:</b> Call {@link #initialize(boolean, boolean, RecordBatch, RecordBatch, Set, long, int, int, int, int, int, int, double)}.
* <li><b>Step 0:</b> Call {@link #initialize(boolean, boolean, RecordBatch, RecordBatch, Set, long, int, int, int, int, int, int, int, double)}.
* This will initialize the StateCalculate with the additional information it needs.</li>
* <li><b>Step 1:</b> Call {@link #getNumPartitions()} to see the number of partitions that fit in memory.</li>
* <li><b>Step 2:</b> Call {@link #shouldSpill()} To determine if spilling needs to occurr.</li>
Expand Down Expand Up @@ -233,9 +234,7 @@ public static class BuildSidePartitioningImpl implements BuildSidePartitioning {
private int partitions;
private int recordsPerPartitionBatchBuild;
private int recordsPerPartitionBatchProbe;
private int outputBatchNumRecords;
private Map<String, Long> buildValueSizes;
private Map<String, Long> probeValueSizes;
private int outputBatchSize;
private Map<String, Long> keySizes;
private boolean autoTune;
private boolean reserveHash;
Expand Down Expand Up @@ -273,6 +272,7 @@ public void initialize(boolean autoTune,
int maxBatchNumRecordsBuild,
int maxBatchNumRecordsProbe,
int outputBatchNumRecords,
int outputBatchSize,
double loadFactor) {
Preconditions.checkNotNull(buildSideBatch);
Preconditions.checkNotNull(probeSideBatch);
Expand Down Expand Up @@ -300,8 +300,6 @@ public void initialize(boolean autoTune,

initialize(autoTune,
reserveHash,
buildValueSizes,
probeValueSizes,
keySizes,
memoryAvailable,
initialPartitions,
Expand All @@ -313,7 +311,7 @@ public void initialize(boolean autoTune,
recordsPerPartitionBatchProbe,
maxBatchNumRecordsBuild,
maxBatchNumRecordsProbe,
outputBatchNumRecords,
outputBatchSize,
loadFactor);
}

Expand Down Expand Up @@ -352,8 +350,6 @@ public static long getBatchSizeEstimate(final RecordBatch recordBatch) {
@VisibleForTesting
protected void initialize(boolean autoTune,
boolean reserveHash,
CaseInsensitiveMap<Long> buildValueSizes,
CaseInsensitiveMap<Long> probeValueSizes,
CaseInsensitiveMap<Long> keySizes,
long memoryAvailable,
int initialPartitions,
Expand All @@ -365,7 +361,7 @@ protected void initialize(boolean autoTune,
int recordsPerPartitionBatchProbe,
int maxBatchNumRecordsBuild,
int maxBatchNumRecordsProbe,
int outputBatchNumRecords,
int outputBatchSize,
double loadFactor) {
Preconditions.checkState(!firstInitialized);
Preconditions.checkArgument(initialPartitions >= 1);
Expand All @@ -374,8 +370,6 @@ protected void initialize(boolean autoTune,
this.loadFactor = loadFactor;
this.autoTune = autoTune;
this.reserveHash = reserveHash;
this.buildValueSizes = Preconditions.checkNotNull(buildValueSizes);
this.probeValueSizes = Preconditions.checkNotNull(probeValueSizes);
this.keySizes = Preconditions.checkNotNull(keySizes);
this.memoryAvailable = memoryAvailable;
this.buildBatchSize = buildBatchSize;
Expand All @@ -387,7 +381,7 @@ protected void initialize(boolean autoTune,
this.recordsPerPartitionBatchProbe = recordsPerPartitionBatchProbe;
this.maxBatchNumRecordsBuild = maxBatchNumRecordsBuild;
this.maxBatchNumRecordsProbe = maxBatchNumRecordsProbe;
this.outputBatchNumRecords = outputBatchNumRecords;
this.outputBatchSize = outputBatchSize;

calculateMemoryUsage();

Expand Down Expand Up @@ -448,8 +442,7 @@ private void calculateMemoryUsage()
safetyFactor,
reserveHash);

maxOutputBatchSize = computeMaxOutputBatchSize(buildValueSizes, probeValueSizes, keySizes,
outputBatchNumRecords, safetyFactor, fragmentationFactor);
maxOutputBatchSize = (long) ((double)outputBatchSize * fragmentationFactor * safetyFactor);

long probeReservedMemory;

Expand Down Expand Up @@ -519,18 +512,6 @@ private void calculateMemoryUsage()
}
}

public static long computeMaxOutputBatchSize(Map<String, Long> buildValueSizes,
Map<String, Long> probeValueSizes,
Map<String, Long> keySizes,
int outputNumRecords,
double safetyFactor,
double fragmentationFactor) {
long outputSize = HashTableSizeCalculatorConservativeImpl.computeVectorSizes(keySizes, outputNumRecords, safetyFactor)
+ HashTableSizeCalculatorConservativeImpl.computeVectorSizes(buildValueSizes, outputNumRecords, safetyFactor)
+ HashTableSizeCalculatorConservativeImpl.computeVectorSizes(probeValueSizes, outputNumRecords, safetyFactor);
return RecordBatchSizer.multiplyByFactor(outputSize, fragmentationFactor);
}

@Override
public boolean shouldSpill() {
Preconditions.checkState(initialized);
Expand Down
Expand Up @@ -42,4 +42,6 @@ enum ProbeState {
void setupHashJoinProbe(RecordBatch probeBatch, HashJoinBatch outgoing, JoinRelType joinRelType, RecordBatch.IterOutcome leftStartState, HashPartition[] partitions, int cycleNum, VectorContainer container, HashJoinBatch.HJSpilledPartition[] spilledInners, boolean buildSideIsEmpty, int numPartitions, int rightHVColPosition);
int probeAndProject() throws SchemaChangeException;
void changeToFinalProbeState();
void setTargetOutputCount(int targetOutputCount);
int getOutputCount();
}