Skip to content

Commit

Permalink
DRILL-6310: limit batch size for hash aggregate
Browse files Browse the repository at this point in the history
closes #1324
  • Loading branch information
ppadma authored and vvysotskyi committed Jul 2, 2018
1 parent 482a635 commit f1a3bd1
Show file tree
Hide file tree
Showing 10 changed files with 557 additions and 163 deletions.
Expand Up @@ -19,15 +19,19 @@


import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Map;


import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import org.apache.drill.common.exceptions.UserException; import org.apache.drill.common.exceptions.UserException;
import org.apache.drill.common.expression.ErrorCollector; import org.apache.drill.common.expression.ErrorCollector;
import org.apache.drill.common.expression.ErrorCollectorImpl; import org.apache.drill.common.expression.ErrorCollectorImpl;
import org.apache.drill.common.expression.FunctionCall;
import org.apache.drill.common.expression.FunctionHolderExpression; import org.apache.drill.common.expression.FunctionHolderExpression;
import org.apache.drill.common.expression.IfExpression; import org.apache.drill.common.expression.IfExpression;
import org.apache.drill.common.expression.LogicalExpression; import org.apache.drill.common.expression.LogicalExpression;
import org.apache.drill.common.expression.SchemaPath;
import org.apache.drill.common.logical.data.NamedExpression; import org.apache.drill.common.logical.data.NamedExpression;
import org.apache.drill.common.map.CaseInsensitiveMap;
import org.apache.drill.exec.ExecConstants; import org.apache.drill.exec.ExecConstants;
import org.apache.drill.exec.compile.sig.GeneratorMapping; import org.apache.drill.exec.compile.sig.GeneratorMapping;
import org.apache.drill.exec.compile.sig.MappingSet; import org.apache.drill.exec.compile.sig.MappingSet;
Expand All @@ -49,11 +53,14 @@
import org.apache.drill.exec.record.BatchSchema.SelectionVectorMode; import org.apache.drill.exec.record.BatchSchema.SelectionVectorMode;
import org.apache.drill.exec.record.MaterializedField; import org.apache.drill.exec.record.MaterializedField;
import org.apache.drill.exec.record.RecordBatch; import org.apache.drill.exec.record.RecordBatch;
import org.apache.drill.exec.record.RecordBatchMemoryManager;
import org.apache.drill.exec.record.RecordBatchSizer;
import org.apache.drill.exec.record.TypedFieldId; import org.apache.drill.exec.record.TypedFieldId;
import org.apache.drill.exec.record.VectorWrapper; import org.apache.drill.exec.record.VectorWrapper;
import org.apache.drill.exec.record.selection.SelectionVector2; import org.apache.drill.exec.record.selection.SelectionVector2;
import org.apache.drill.exec.record.selection.SelectionVector4; import org.apache.drill.exec.record.selection.SelectionVector4;
import org.apache.drill.exec.vector.AllocationHelper; import org.apache.drill.exec.vector.AllocationHelper;
import org.apache.drill.exec.vector.FixedWidthVector;
import org.apache.drill.exec.vector.ValueVector; import org.apache.drill.exec.vector.ValueVector;


import com.sun.codemodel.JExpr; import com.sun.codemodel.JExpr;
Expand All @@ -71,6 +78,12 @@ public class HashAggBatch extends AbstractRecordBatch<HashAggregate> {
private BatchSchema incomingSchema; private BatchSchema incomingSchema;
private boolean wasKilled; private boolean wasKilled;


private int numGroupByExprs, numAggrExprs;

// This map saves the mapping between outgoing column and incoming column.
private Map<String, String> columnMapping;
private final HashAggMemoryManager hashAggMemoryManager;

private final GeneratorMapping UPDATE_AGGR_INSIDE = private final GeneratorMapping UPDATE_AGGR_INSIDE =
GeneratorMapping.create("setupInterior" /* setup method */, "updateAggrValuesInternal" /* eval method */, GeneratorMapping.create("setupInterior" /* setup method */, "updateAggrValuesInternal" /* eval method */,
"resetValues" /* reset */, "cleanup" /* cleanup */); "resetValues" /* reset */, "cleanup" /* cleanup */);
Expand All @@ -84,6 +97,67 @@ public class HashAggBatch extends AbstractRecordBatch<HashAggregate> {
"htRowIdx" /* workspace index */, "incoming" /* read container */, "outgoing" /* write container */, "htRowIdx" /* workspace index */, "incoming" /* read container */, "outgoing" /* write container */,
"aggrValuesContainer" /* workspace container */, UPDATE_AGGR_INSIDE, UPDATE_AGGR_OUTSIDE, UPDATE_AGGR_INSIDE); "aggrValuesContainer" /* workspace container */, UPDATE_AGGR_INSIDE, UPDATE_AGGR_OUTSIDE, UPDATE_AGGR_INSIDE);


public int getOutputRowCount() {
return hashAggMemoryManager.getOutputRowCount();
}

public RecordBatchMemoryManager getRecordBatchMemoryManager() {
return hashAggMemoryManager;
}

private class HashAggMemoryManager extends RecordBatchMemoryManager {
private int valuesRowWidth = 0;

HashAggMemoryManager(int outputBatchSize) {
super(outputBatchSize);
}

@Override
public void update() {
// Get sizing information for the batch.
setRecordBatchSizer(new RecordBatchSizer(incoming));

int fieldId = 0;
int newOutgoingRowWidth = 0;
for (VectorWrapper<?> w : container) {
if (w.getValueVector() instanceof FixedWidthVector) {
newOutgoingRowWidth += ((FixedWidthVector) w.getValueVector()).getValueWidth();
if (fieldId >= numGroupByExprs) {
valuesRowWidth += ((FixedWidthVector) w.getValueVector()).getValueWidth();
}
} else {
int columnWidth;
if (columnMapping.get(w.getValueVector().getField().getName()) == null) {
columnWidth = TypeHelper.getSize(w.getField().getType());
} else {
RecordBatchSizer.ColumnSize columnSize = getRecordBatchSizer().getColumn(columnMapping.get(w.getValueVector().getField().getName()));
if (columnSize == null) {
columnWidth = TypeHelper.getSize(w.getField().getType());
} else {
columnWidth = columnSize.getAllocSizePerEntry();
}
}
newOutgoingRowWidth += columnWidth;
if (fieldId >= numGroupByExprs) {
valuesRowWidth += columnWidth;
}
}
fieldId++;
}

if (updateIfNeeded(newOutgoingRowWidth)) {
// There is an update to outgoing row width.
// un comment this if we want to adjust the batch row count of in flight batches.
// To keep things simple, we are not doing this adjustment for now.
// aggregator.adjustOutputCount(getOutputBatchSize(), getOutgoingRowWidth(), newOutgoingRowWidth);
}

updateIncomingStats();
if (logger.isDebugEnabled()) {
logger.debug("BATCH_STATS, incoming: {}", getRecordBatchSizer());
}
}
}


public HashAggBatch(HashAggregate popConfig, RecordBatch incoming, FragmentContext context) { public HashAggBatch(HashAggregate popConfig, RecordBatch incoming, FragmentContext context) {
super(popConfig, context); super(popConfig, context);
Expand All @@ -103,6 +177,13 @@ public HashAggBatch(HashAggregate popConfig, RecordBatch incoming, FragmentConte


boolean allowed = oContext.getAllocator().setLenient(); boolean allowed = oContext.getAllocator().setLenient();
logger.debug("Config: Is allocator lenient? {}", allowed); logger.debug("Config: Is allocator lenient? {}", allowed);

// get the output batch size from config.
int configuredBatchSize = (int) context.getOptions().getOption(ExecConstants.OUTPUT_BATCH_SIZE_VALIDATOR);
hashAggMemoryManager = new HashAggMemoryManager(configuredBatchSize);
logger.debug("BATCH_STATS, configured output batch size: {}", configuredBatchSize);

columnMapping = CaseInsensitiveMap.newHashMap();
} }


@Override @Override
Expand Down Expand Up @@ -136,6 +217,9 @@ public void buildSchema() throws SchemaChangeException {
for (VectorWrapper<?> w : container) { for (VectorWrapper<?> w : container) {
AllocationHelper.allocatePrecomputedChildCount(w.getValueVector(), 0, 0, 0); AllocationHelper.allocatePrecomputedChildCount(w.getValueVector(), 0, 0, 0);
} }
if (incoming.getRecordCount() > 0) {
hashAggMemoryManager.update();
}
} }


@Override @Override
Expand Down Expand Up @@ -239,8 +323,8 @@ private HashAggregator createAggregatorInternal() throws SchemaChangeException,
// top.saveCodeForDebugging(true); // top.saveCodeForDebugging(true);
container.clear(); container.clear();


int numGroupByExprs = (popConfig.getGroupByExprs() != null) ? popConfig.getGroupByExprs().size() : 0; numGroupByExprs = (popConfig.getGroupByExprs() != null) ? popConfig.getGroupByExprs().size() : 0;
int numAggrExprs = (popConfig.getAggrExprs() != null) ? popConfig.getAggrExprs().size() : 0; numAggrExprs = (popConfig.getAggrExprs() != null) ? popConfig.getAggrExprs().size() : 0;
aggrExprs = new LogicalExpression[numAggrExprs]; aggrExprs = new LogicalExpression[numAggrExprs];
groupByOutFieldIds = new TypedFieldId[numGroupByExprs]; groupByOutFieldIds = new TypedFieldId[numGroupByExprs];
aggrOutFieldIds = new TypedFieldId[numAggrExprs]; aggrOutFieldIds = new TypedFieldId[numAggrExprs];
Expand All @@ -263,13 +347,13 @@ private HashAggregator createAggregatorInternal() throws SchemaChangeException,


// add this group-by vector to the output container // add this group-by vector to the output container
groupByOutFieldIds[i] = container.add(vv); groupByOutFieldIds[i] = container.add(vv);
columnMapping.put(outputField.getName(), ne.getExpr().toString().replace('`',' ').trim());
} }


int extraNonNullColumns = 0; // each of SUM, MAX and MIN gets an extra bigint column int extraNonNullColumns = 0; // each of SUM, MAX and MIN gets an extra bigint column
for (i = 0; i < numAggrExprs; i++) { for (i = 0; i < numAggrExprs; i++) {
NamedExpression ne = popConfig.getAggrExprs().get(i); NamedExpression ne = popConfig.getAggrExprs().get(i);
final LogicalExpression expr = final LogicalExpression expr = ExpressionTreeMaterializer.materialize(ne.getExpr(), incoming, collector, context.getFunctionRegistry());
ExpressionTreeMaterializer.materialize(ne.getExpr(), incoming, collector, context.getFunctionRegistry());


if (expr instanceof IfExpression) { if (expr instanceof IfExpression) {
throw UserException.unsupportedError(new UnsupportedOperationException("Union type not supported in aggregate functions")).build(logger); throw UserException.unsupportedError(new UnsupportedOperationException("Union type not supported in aggregate functions")).build(logger);
Expand All @@ -283,16 +367,28 @@ private HashAggregator createAggregatorInternal() throws SchemaChangeException,
continue; continue;
} }


if ( expr instanceof FunctionHolderExpression ) {
String funcName = ((FunctionHolderExpression) expr).getName();
if ( funcName.equals("sum") || funcName.equals("max") || funcName.equals("min") ) {extraNonNullColumns++;}
}
final MaterializedField outputField = MaterializedField.create(ne.getRef().getAsNamePart().getName(), expr.getMajorType()); final MaterializedField outputField = MaterializedField.create(ne.getRef().getAsNamePart().getName(), expr.getMajorType());
@SuppressWarnings("resource") @SuppressWarnings("resource") ValueVector vv = TypeHelper.getNewVector(outputField, oContext.getAllocator());
ValueVector vv = TypeHelper.getNewVector(outputField, oContext.getAllocator());
aggrOutFieldIds[i] = container.add(vv); aggrOutFieldIds[i] = container.add(vv);


aggrExprs[i] = new ValueVectorWriteExpression(aggrOutFieldIds[i], expr, true); aggrExprs[i] = new ValueVectorWriteExpression(aggrOutFieldIds[i], expr, true);

if (expr instanceof FunctionHolderExpression) {
String funcName = ((FunctionHolderExpression) expr).getName();
if (funcName.equals("sum") || funcName.equals("max") || funcName.equals("min")) {
extraNonNullColumns++;
}
if (((FunctionCall) ne.getExpr()).args.get(0) instanceof SchemaPath) {
columnMapping.put(outputField.getName(), ((SchemaPath) ((FunctionCall) ne.getExpr()).args.get(0)).getAsNamePart().getName());
} else if (((FunctionCall) ne.getExpr()).args.get(0) instanceof FunctionCall) {
FunctionCall functionCall = (FunctionCall) ((FunctionCall) ne.getExpr()).args.get(0);
if (functionCall.args.get(0) instanceof SchemaPath) {
columnMapping.put(outputField.getName(), ((SchemaPath) functionCall.args.get(0)).getAsNamePart().getName());
}
}
} else {
columnMapping.put(outputField.getName(), ne.getRef().getAsNamePart().getName());
}
} }


setupUpdateAggrValues(cgInner); setupUpdateAggrValues(cgInner);
Expand Down Expand Up @@ -345,11 +441,32 @@ private void setupGetIndex(ClassGenerator<HashAggregator> cg) {
} }
} }


private void updateStats() {
stats.setLongStat(HashAggTemplate.Metric.INPUT_BATCH_COUNT, hashAggMemoryManager.getNumIncomingBatches());
stats.setLongStat(HashAggTemplate.Metric.AVG_INPUT_BATCH_BYTES, hashAggMemoryManager.getAvgInputBatchSize());
stats.setLongStat(HashAggTemplate.Metric.AVG_INPUT_ROW_BYTES, hashAggMemoryManager.getAvgInputRowWidth());
stats.setLongStat(HashAggTemplate.Metric.INPUT_RECORD_COUNT, hashAggMemoryManager.getTotalInputRecords());
stats.setLongStat(HashAggTemplate.Metric.OUTPUT_BATCH_COUNT, hashAggMemoryManager.getNumOutgoingBatches());
stats.setLongStat(HashAggTemplate.Metric.AVG_OUTPUT_BATCH_BYTES, hashAggMemoryManager.getAvgOutputBatchSize());
stats.setLongStat(HashAggTemplate.Metric.AVG_OUTPUT_ROW_BYTES, hashAggMemoryManager.getAvgOutputRowWidth());
stats.setLongStat(HashAggTemplate.Metric.OUTPUT_RECORD_COUNT, hashAggMemoryManager.getTotalOutputRecords());

if (logger.isDebugEnabled()) {
logger.debug("BATCH_STATS, incoming aggregate: count : {}, avg bytes : {}, avg row bytes : {}, record count : {}",
hashAggMemoryManager.getNumIncomingBatches(), hashAggMemoryManager.getAvgInputBatchSize(),
hashAggMemoryManager.getAvgInputRowWidth(), hashAggMemoryManager.getTotalInputRecords());

logger.debug("BATCH_STATS, outgoing aggregate: count : {}, avg bytes : {}, avg row bytes : {}, record count : {}",
hashAggMemoryManager.getNumOutgoingBatches(), hashAggMemoryManager.getAvgOutputBatchSize(),
hashAggMemoryManager.getAvgOutputRowWidth(), hashAggMemoryManager.getTotalOutputRecords());
}
}
@Override @Override
public void close() { public void close() {
if (aggregator != null) { if (aggregator != null) {
aggregator.cleanup(); aggregator.cleanup();
} }
updateStats();
super.close(); super.close();
} }


Expand Down

0 comments on commit f1a3bd1

Please sign in to comment.