Skip to content
Permalink
Browse files
DRILL-6310: limit batch size for hash aggregate
closes #1324
  • Loading branch information
ppadma authored and vvysotskyi committed Jul 2, 2018
1 parent 482a635 commit f1a3bd12c2fd3ad525642933ace2d7b9cedaacc9
Showing 10 changed files with 557 additions and 163 deletions.
@@ -19,15 +19,19 @@

import java.io.IOException;
import java.util.List;
import java.util.Map;

import com.google.common.collect.Lists;
import org.apache.drill.common.exceptions.UserException;
import org.apache.drill.common.expression.ErrorCollector;
import org.apache.drill.common.expression.ErrorCollectorImpl;
import org.apache.drill.common.expression.FunctionCall;
import org.apache.drill.common.expression.FunctionHolderExpression;
import org.apache.drill.common.expression.IfExpression;
import org.apache.drill.common.expression.LogicalExpression;
import org.apache.drill.common.expression.SchemaPath;
import org.apache.drill.common.logical.data.NamedExpression;
import org.apache.drill.common.map.CaseInsensitiveMap;
import org.apache.drill.exec.ExecConstants;
import org.apache.drill.exec.compile.sig.GeneratorMapping;
import org.apache.drill.exec.compile.sig.MappingSet;
@@ -49,11 +53,14 @@
import org.apache.drill.exec.record.BatchSchema.SelectionVectorMode;
import org.apache.drill.exec.record.MaterializedField;
import org.apache.drill.exec.record.RecordBatch;
import org.apache.drill.exec.record.RecordBatchMemoryManager;
import org.apache.drill.exec.record.RecordBatchSizer;
import org.apache.drill.exec.record.TypedFieldId;
import org.apache.drill.exec.record.VectorWrapper;
import org.apache.drill.exec.record.selection.SelectionVector2;
import org.apache.drill.exec.record.selection.SelectionVector4;
import org.apache.drill.exec.vector.AllocationHelper;
import org.apache.drill.exec.vector.FixedWidthVector;
import org.apache.drill.exec.vector.ValueVector;

import com.sun.codemodel.JExpr;
@@ -71,6 +78,12 @@ public class HashAggBatch extends AbstractRecordBatch<HashAggregate> {
private BatchSchema incomingSchema;
private boolean wasKilled;

private int numGroupByExprs, numAggrExprs;

// This map saves the mapping between outgoing column and incoming column.
private Map<String, String> columnMapping;
private final HashAggMemoryManager hashAggMemoryManager;

private final GeneratorMapping UPDATE_AGGR_INSIDE =
GeneratorMapping.create("setupInterior" /* setup method */, "updateAggrValuesInternal" /* eval method */,
"resetValues" /* reset */, "cleanup" /* cleanup */);
@@ -84,6 +97,67 @@ public class HashAggBatch extends AbstractRecordBatch<HashAggregate> {
"htRowIdx" /* workspace index */, "incoming" /* read container */, "outgoing" /* write container */,
"aggrValuesContainer" /* workspace container */, UPDATE_AGGR_INSIDE, UPDATE_AGGR_OUTSIDE, UPDATE_AGGR_INSIDE);

public int getOutputRowCount() {
return hashAggMemoryManager.getOutputRowCount();
}

public RecordBatchMemoryManager getRecordBatchMemoryManager() {
return hashAggMemoryManager;
}

private class HashAggMemoryManager extends RecordBatchMemoryManager {
private int valuesRowWidth = 0;

HashAggMemoryManager(int outputBatchSize) {
super(outputBatchSize);
}

@Override
public void update() {
// Get sizing information for the batch.
setRecordBatchSizer(new RecordBatchSizer(incoming));

int fieldId = 0;
int newOutgoingRowWidth = 0;
for (VectorWrapper<?> w : container) {
if (w.getValueVector() instanceof FixedWidthVector) {
newOutgoingRowWidth += ((FixedWidthVector) w.getValueVector()).getValueWidth();
if (fieldId >= numGroupByExprs) {
valuesRowWidth += ((FixedWidthVector) w.getValueVector()).getValueWidth();
}
} else {
int columnWidth;
if (columnMapping.get(w.getValueVector().getField().getName()) == null) {
columnWidth = TypeHelper.getSize(w.getField().getType());
} else {
RecordBatchSizer.ColumnSize columnSize = getRecordBatchSizer().getColumn(columnMapping.get(w.getValueVector().getField().getName()));
if (columnSize == null) {
columnWidth = TypeHelper.getSize(w.getField().getType());
} else {
columnWidth = columnSize.getAllocSizePerEntry();
}
}
newOutgoingRowWidth += columnWidth;
if (fieldId >= numGroupByExprs) {
valuesRowWidth += columnWidth;
}
}
fieldId++;
}

if (updateIfNeeded(newOutgoingRowWidth)) {
// There is an update to outgoing row width.
// un comment this if we want to adjust the batch row count of in flight batches.
// To keep things simple, we are not doing this adjustment for now.
// aggregator.adjustOutputCount(getOutputBatchSize(), getOutgoingRowWidth(), newOutgoingRowWidth);
}

updateIncomingStats();
if (logger.isDebugEnabled()) {
logger.debug("BATCH_STATS, incoming: {}", getRecordBatchSizer());
}
}
}

public HashAggBatch(HashAggregate popConfig, RecordBatch incoming, FragmentContext context) {
super(popConfig, context);
@@ -103,6 +177,13 @@ public HashAggBatch(HashAggregate popConfig, RecordBatch incoming, FragmentConte

boolean allowed = oContext.getAllocator().setLenient();
logger.debug("Config: Is allocator lenient? {}", allowed);

// get the output batch size from config.
int configuredBatchSize = (int) context.getOptions().getOption(ExecConstants.OUTPUT_BATCH_SIZE_VALIDATOR);
hashAggMemoryManager = new HashAggMemoryManager(configuredBatchSize);
logger.debug("BATCH_STATS, configured output batch size: {}", configuredBatchSize);

columnMapping = CaseInsensitiveMap.newHashMap();
}

@Override
@@ -136,6 +217,9 @@ public void buildSchema() throws SchemaChangeException {
for (VectorWrapper<?> w : container) {
AllocationHelper.allocatePrecomputedChildCount(w.getValueVector(), 0, 0, 0);
}
if (incoming.getRecordCount() > 0) {
hashAggMemoryManager.update();
}
}

@Override
@@ -239,8 +323,8 @@ private HashAggregator createAggregatorInternal() throws SchemaChangeException,
// top.saveCodeForDebugging(true);
container.clear();

int numGroupByExprs = (popConfig.getGroupByExprs() != null) ? popConfig.getGroupByExprs().size() : 0;
int numAggrExprs = (popConfig.getAggrExprs() != null) ? popConfig.getAggrExprs().size() : 0;
numGroupByExprs = (popConfig.getGroupByExprs() != null) ? popConfig.getGroupByExprs().size() : 0;
numAggrExprs = (popConfig.getAggrExprs() != null) ? popConfig.getAggrExprs().size() : 0;
aggrExprs = new LogicalExpression[numAggrExprs];
groupByOutFieldIds = new TypedFieldId[numGroupByExprs];
aggrOutFieldIds = new TypedFieldId[numAggrExprs];
@@ -263,13 +347,13 @@ private HashAggregator createAggregatorInternal() throws SchemaChangeException,

// add this group-by vector to the output container
groupByOutFieldIds[i] = container.add(vv);
columnMapping.put(outputField.getName(), ne.getExpr().toString().replace('`',' ').trim());
}

int extraNonNullColumns = 0; // each of SUM, MAX and MIN gets an extra bigint column
for (i = 0; i < numAggrExprs; i++) {
NamedExpression ne = popConfig.getAggrExprs().get(i);
final LogicalExpression expr =
ExpressionTreeMaterializer.materialize(ne.getExpr(), incoming, collector, context.getFunctionRegistry());
final LogicalExpression expr = ExpressionTreeMaterializer.materialize(ne.getExpr(), incoming, collector, context.getFunctionRegistry());

if (expr instanceof IfExpression) {
throw UserException.unsupportedError(new UnsupportedOperationException("Union type not supported in aggregate functions")).build(logger);
@@ -283,16 +367,28 @@ private HashAggregator createAggregatorInternal() throws SchemaChangeException,
continue;
}

if ( expr instanceof FunctionHolderExpression ) {
String funcName = ((FunctionHolderExpression) expr).getName();
if ( funcName.equals("sum") || funcName.equals("max") || funcName.equals("min") ) {extraNonNullColumns++;}
}
final MaterializedField outputField = MaterializedField.create(ne.getRef().getAsNamePart().getName(), expr.getMajorType());
@SuppressWarnings("resource")
ValueVector vv = TypeHelper.getNewVector(outputField, oContext.getAllocator());
@SuppressWarnings("resource") ValueVector vv = TypeHelper.getNewVector(outputField, oContext.getAllocator());
aggrOutFieldIds[i] = container.add(vv);

aggrExprs[i] = new ValueVectorWriteExpression(aggrOutFieldIds[i], expr, true);

if (expr instanceof FunctionHolderExpression) {
String funcName = ((FunctionHolderExpression) expr).getName();
if (funcName.equals("sum") || funcName.equals("max") || funcName.equals("min")) {
extraNonNullColumns++;
}
if (((FunctionCall) ne.getExpr()).args.get(0) instanceof SchemaPath) {
columnMapping.put(outputField.getName(), ((SchemaPath) ((FunctionCall) ne.getExpr()).args.get(0)).getAsNamePart().getName());
} else if (((FunctionCall) ne.getExpr()).args.get(0) instanceof FunctionCall) {
FunctionCall functionCall = (FunctionCall) ((FunctionCall) ne.getExpr()).args.get(0);
if (functionCall.args.get(0) instanceof SchemaPath) {
columnMapping.put(outputField.getName(), ((SchemaPath) functionCall.args.get(0)).getAsNamePart().getName());
}
}
} else {
columnMapping.put(outputField.getName(), ne.getRef().getAsNamePart().getName());
}
}

setupUpdateAggrValues(cgInner);
@@ -345,11 +441,32 @@ private void setupGetIndex(ClassGenerator<HashAggregator> cg) {
}
}

private void updateStats() {
stats.setLongStat(HashAggTemplate.Metric.INPUT_BATCH_COUNT, hashAggMemoryManager.getNumIncomingBatches());
stats.setLongStat(HashAggTemplate.Metric.AVG_INPUT_BATCH_BYTES, hashAggMemoryManager.getAvgInputBatchSize());
stats.setLongStat(HashAggTemplate.Metric.AVG_INPUT_ROW_BYTES, hashAggMemoryManager.getAvgInputRowWidth());
stats.setLongStat(HashAggTemplate.Metric.INPUT_RECORD_COUNT, hashAggMemoryManager.getTotalInputRecords());
stats.setLongStat(HashAggTemplate.Metric.OUTPUT_BATCH_COUNT, hashAggMemoryManager.getNumOutgoingBatches());
stats.setLongStat(HashAggTemplate.Metric.AVG_OUTPUT_BATCH_BYTES, hashAggMemoryManager.getAvgOutputBatchSize());
stats.setLongStat(HashAggTemplate.Metric.AVG_OUTPUT_ROW_BYTES, hashAggMemoryManager.getAvgOutputRowWidth());
stats.setLongStat(HashAggTemplate.Metric.OUTPUT_RECORD_COUNT, hashAggMemoryManager.getTotalOutputRecords());

if (logger.isDebugEnabled()) {
logger.debug("BATCH_STATS, incoming aggregate: count : {}, avg bytes : {}, avg row bytes : {}, record count : {}",
hashAggMemoryManager.getNumIncomingBatches(), hashAggMemoryManager.getAvgInputBatchSize(),
hashAggMemoryManager.getAvgInputRowWidth(), hashAggMemoryManager.getTotalInputRecords());

logger.debug("BATCH_STATS, outgoing aggregate: count : {}, avg bytes : {}, avg row bytes : {}, record count : {}",
hashAggMemoryManager.getNumOutgoingBatches(), hashAggMemoryManager.getAvgOutputBatchSize(),
hashAggMemoryManager.getAvgOutputRowWidth(), hashAggMemoryManager.getTotalOutputRecords());
}
}
@Override
public void close() {
if (aggregator != null) {
aggregator.cleanup();
}
updateStats();
super.close();
}

0 comments on commit f1a3bd1

Please sign in to comment.