Skip to content

Commit

Permalink
Clean up AggregationFunctionContext and use TransformExpressionTree a…
Browse files Browse the repository at this point in the history
…s the key in the blockValSetMap passed to the AggregationFunctions (#5364)

- Clean up all the usage of AggregationFunctionContext to directly use AggregationFunction
- Construct the AggregationFunctions and Group-by Expressions at planning phase and pass them to Operator and Executor to save the extra expression compilation
- Use TransformExpressionTree as the key in the blockValSetMap passed to the AggregationFunctions
  - The benefit of this is to save the redundant string conversion, and more efficient hashCode() and equals()
  - The keys of the blockValSetMap should be the same as AggregationFunction.getInputExpressions()
  - The only exception is CountAggregationFunction with Star-Tree where there is a single entry in blockValSetMap (column "*")
- Add base implementation of AggregationFunction: BaseSingleExpressionAggregationFunction for aggregation functions on single expressions
- For PERCENTILE group aggregation functions, support using the second arguments to pass in percentile (e.g. PERCENTILE(column, 99), PERCENTILETDIGEST(column, 90))
- Enhance Star-Tree Aggregation/Group-by Executor to handle the column name conversion so that AggregationFunctionColumnPair is transparent to the AggregationFunction

BACKWARD-INCOMPATIBLE CHANGE:
The following APIs are changed in AggregationFunction (use TransformExpressionTree instead of String as the key of blockValSetMap):
void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<TransformExpressionTree, BlockValSet> blockValSetMap);
void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map<TransformExpressionTree, BlockValSet> blockValSetMap);
void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, Map<TransformExpressionTree, BlockValSet> blockValSetMap);
  • Loading branch information
Jackie-Jiang committed May 12, 2020
1 parent 25bc1b5 commit 8b0089f
Show file tree
Hide file tree
Showing 81 changed files with 1,430 additions and 1,992 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -456,13 +456,9 @@ private void handleCaseSensitivity(BrokerRequest brokerRequest) {
for (AggregationInfo info : brokerRequest.getAggregationsInfo()) {
if (!info.getAggregationType().equalsIgnoreCase(AggregationFunctionType.COUNT.getName())) {
// Always read from backward compatible api in AggregationFunctionUtils.
List<String> expressions = AggregationFunctionUtils.getAggregationExpressions(info);

List<String> newExpressions = new ArrayList<>(expressions.size());
for (String expression : expressions) {
newExpressions.add(fixColumnNameCase(actualTableName, expression));
}
info.setExpressions(newExpressions);
List<String> arguments = AggregationFunctionUtils.getArguments(info);
arguments.replaceAll(e -> fixColumnNameCase(actualTableName, e));
info.setExpressions(arguments);
}
}
if (brokerRequest.isSetGroupBy()) {
Expand Down Expand Up @@ -720,11 +716,10 @@ static void validateRequest(BrokerRequest brokerRequest, int queryResponseLimit)
throw new UnsupportedOperationException("DISTINCT with GROUP BY is currently not supported");
}
if (brokerRequest.isSetOrderBy()) {
List<String> columns = AggregationFunctionUtils.getAggregationExpressions(aggregationInfo);
Set<String> set = new HashSet<>(columns);
Set<String> expressionSet = new HashSet<>(AggregationFunctionUtils.getArguments(aggregationInfo));
List<SelectionSort> orderByColumns = brokerRequest.getOrderBy();
for (SelectionSort selectionSort : orderByColumns) {
if (!set.contains(selectionSort.getColumn())) {
if (!expressionSet.contains(selectionSort.getColumn())) {
throw new UnsupportedOperationException(
"ORDER By should be only on some/all of the columns passed as arguments to DISTINCT");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,17 @@ public static AggregationFunctionType getAggregationFunctionType(String function
String upperCaseFunctionName = functionName.toUpperCase();
if (upperCaseFunctionName.startsWith("PERCENTILE")) {
String remainingFunctionName = upperCaseFunctionName.substring(10);
if (remainingFunctionName.matches("\\d+")) {
if (remainingFunctionName.isEmpty() || remainingFunctionName.matches("\\d+")) {
return PERCENTILE;
} else if (remainingFunctionName.matches("EST\\d+")) {
} else if (remainingFunctionName.equals("EST") || remainingFunctionName.matches("EST\\d+")) {
return PERCENTILEEST;
} else if (remainingFunctionName.matches("TDIGEST\\d+")) {
} else if (remainingFunctionName.equals("TDIGEST") || remainingFunctionName.matches("TDIGEST\\d+")) {
return PERCENTILETDIGEST;
} else if (remainingFunctionName.matches("\\d+MV")) {
} else if (remainingFunctionName.equals("MV") || remainingFunctionName.matches("\\d+MV")) {
return PERCENTILEMV;
} else if (remainingFunctionName.matches("EST\\d+MV")) {
} else if (remainingFunctionName.equals("ESTMV") || remainingFunctionName.matches("EST\\d+MV")) {
return PERCENTILEESTMV;
} else if (remainingFunctionName.matches("TDIGEST\\d+MV")) {
} else if (remainingFunctionName.equals("TDIGESTMV") || remainingFunctionName.matches("TDIGEST\\d+MV")) {
return PERCENTILETDIGESTMV;
} else {
throw new IllegalArgumentException("Invalid aggregation function name: " + functionName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,21 @@
import java.util.List;
import java.util.Objects;
import java.util.Set;
import javax.annotation.Nonnull;
import org.apache.pinot.spi.utils.EqualityUtils;
import javax.annotation.Nullable;
import org.apache.pinot.pql.parsers.Pql2Compiler;
import org.apache.pinot.pql.parsers.pql2.ast.AstNode;
import org.apache.pinot.pql.parsers.pql2.ast.FunctionCallAstNode;
import org.apache.pinot.pql.parsers.pql2.ast.IdentifierAstNode;
import org.apache.pinot.pql.parsers.pql2.ast.LiteralAstNode;
import org.apache.pinot.pql.parsers.pql2.ast.StringLiteralAstNode;
import org.apache.pinot.spi.utils.EqualityUtils;


/**
* Class for representing expression trees for transforms.
* <ul>
* <li>A TransformExpressionTree node has either transform function or a column name, or a literal.</li>
* <li>Leaf nodes either have column name or literal, whereas non-leaf nodes have transform function.</li>
* <li>Transform function in non-leaf nodes is applied to its children nodes.</li>
* <li>Transform function is applied to its children.</li>
* </ul>
*/
public class TransformExpressionTree {
Expand Down Expand Up @@ -66,10 +65,9 @@ public static String getStandardExpression(AstNode astNode) {
} else if (astNode instanceof FunctionCallAstNode) {
// UDF expression
return standardizeExpression(((FunctionCallAstNode) astNode).getExpression());
} else if (astNode instanceof StringLiteralAstNode) {
// Treat string as column name
// NOTE: this is for backward-compatibility
return ((StringLiteralAstNode) astNode).getText();
} else if (astNode instanceof LiteralAstNode) {
// Literal
return ((LiteralAstNode) astNode).getValueAsString();
} else {
throw new IllegalStateException("Cannot get standard expression from " + astNode.getClass().getSimpleName());
}
Expand Down Expand Up @@ -106,6 +104,13 @@ public TransformExpressionTree(AstNode root) {
}
}

public TransformExpressionTree(ExpressionType expressionType, String value,
@Nullable List<TransformExpressionTree> children) {
_expressionType = expressionType;
_value = value;
_children = children;
}

/**
* Returns the expression type of the node, which can be one of the following:
* <ul>
Expand Down Expand Up @@ -168,7 +173,7 @@ public boolean isLiteral() {
*
* @param columns Output columns
*/
public void getColumns(@Nonnull Set<String> columns) {
public void getColumns(Set<String> columns) {
if (_expressionType == ExpressionType.IDENTIFIER) {
columns.add(_value);
} else if (_children != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ private AggregationInfo buildAggregationInfo(Function function) {
private String getColumnExpression(Expression functionParam) {
switch (functionParam.getType()) {
case LITERAL:
return functionParam.getLiteral().getStringValue();
return functionParam.getLiteral().getFieldValue().toString();
case IDENTIFIER:
return functionParam.getIdentifier().getName();
case FUNCTION:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.pinot.common.request.Selection;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.common.utils.DataTable;
import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.util.QueryOptions;
Expand Down Expand Up @@ -101,9 +100,8 @@ public static DataTable buildEmptyDataTable(BrokerRequest brokerRequest)
}

// Aggregation query.
AggregationFunctionContext[] aggregationFunctionContexts =
AggregationFunctionUtils.getAggregationFunctionContexts(brokerRequest);
int numAggregations = aggregationFunctionContexts.length;
AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(brokerRequest);
int numAggregations = aggregationFunctions.length;
if (brokerRequest.isSetGroupBy()) {
// Aggregation group-by query.

Expand All @@ -121,9 +119,9 @@ public static DataTable buildEmptyDataTable(BrokerRequest brokerRequest)
columnDataTypes[index] = DataSchema.ColumnDataType.STRING;
index++;
}
for (AggregationFunctionContext aggregationFunctionContext : aggregationFunctionContexts) {
columnNames[index] = aggregationFunctionContext.getResultColumnName();
AggregationFunction aggregationFunction = aggregationFunctionContext.getAggregationFunction();
for (AggregationFunction aggregationFunction : aggregationFunctions) {
// NOTE: Use AggregationFunction.getResultColumnName() for SQL format response
columnNames[index] = aggregationFunction.getResultColumnName();
columnDataTypes[index] = aggregationFunction.getIntermediateResultColumnType();
index++;
}
Expand All @@ -137,9 +135,10 @@ public static DataTable buildEmptyDataTable(BrokerRequest brokerRequest)

// Build the data table.
DataTableBuilder dataTableBuilder = new DataTableBuilder(new DataSchema(columnNames, columnDataTypes));
for (AggregationFunctionContext aggregationFunctionContext : aggregationFunctionContexts) {
for (AggregationFunction aggregationFunction : aggregationFunctions) {
dataTableBuilder.startRow();
dataTableBuilder.setColumn(0, aggregationFunctionContext.getAggregationColumnName());
// NOTE: For backward-compatibility, use AggregationFunction.getColumnName() for PQL format response
dataTableBuilder.setColumn(0, aggregationFunction.getColumnName());
dataTableBuilder.setColumn(1, Collections.emptyMap());
dataTableBuilder.finishRow();
}
Expand All @@ -152,9 +151,9 @@ public static DataTable buildEmptyDataTable(BrokerRequest brokerRequest)
DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[numAggregations];
Object[] aggregationResults = new Object[numAggregations];
for (int i = 0; i < numAggregations; i++) {
AggregationFunctionContext aggregationFunctionContext = aggregationFunctionContexts[i];
aggregationColumnNames[i] = aggregationFunctionContext.getAggregationColumnName();
AggregationFunction aggregationFunction = aggregationFunctionContext.getAggregationFunction();
AggregationFunction aggregationFunction = aggregationFunctions[i];
// NOTE: For backward-compatibility, use AggregationFunction.getColumnName() for aggregation only query
aggregationColumnNames[i] = aggregationFunction.getColumnName();
columnDataTypes[i] = aggregationFunction.getIntermediateResultColumnType();
aggregationResults[i] =
aggregationFunction.extractAggregationResult(aggregationFunction.createAggregationResultHolder());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,20 @@
import java.util.Iterator;
import java.util.List;
import org.apache.commons.collections.CollectionUtils;
import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.SelectionSort;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;


/**
* Base abstract implementation of Table
*/
public abstract class BaseTable implements Table {

final AggregationFunction[] _aggregationFunctions;
final int _numAggregations;
// TODO: After fixing the DistinctTable logic, make it final
protected DataSchema _dataSchema;
final int _numColumns;
protected final int _numColumns;
protected final AggregationFunction[] _aggregationFunctions;
protected final int _numAggregations;

// the capacity we need to trim to
protected int _capacity;
Expand All @@ -46,34 +44,27 @@ public abstract class BaseTable implements Table {
protected boolean _isOrderBy;
protected TableResizer _tableResizer;

private final List<AggregationInfo> _aggregationInfos;

/**
* Initializes the variables and comparators needed for the table
*/
public BaseTable(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy, int capacity) {
public BaseTable(DataSchema dataSchema, AggregationFunction[] aggregationFunctions, List<SelectionSort> orderBy,
int capacity) {
_dataSchema = dataSchema;
_numColumns = dataSchema.size();

_numAggregations = aggregationInfos.size();
_aggregationFunctions = new AggregationFunction[_numAggregations];
for (int i = 0; i < _numAggregations; i++) {
_aggregationFunctions[i] =
AggregationFunctionUtils.getAggregationFunctionContext(aggregationInfos.get(i)).getAggregationFunction();
}

_aggregationInfos = aggregationInfos;
_aggregationFunctions = aggregationFunctions;
_numAggregations = aggregationFunctions.length;
addCapacityAndOrderByInfo(orderBy, capacity);
}

protected void addCapacityAndOrderByInfo(List<SelectionSort> orderBy, int capacity) {
_isOrderBy = CollectionUtils.isNotEmpty(orderBy);
if (_isOrderBy) {
_tableResizer = new TableResizer(_dataSchema, _aggregationInfos, orderBy);
_tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, orderBy);

// TODO: tune these numbers and come up with a better formula (github ISSUE-4801)
// Based on the capacity and maxCapacity, the resizer will smartly choose to evict/retain recors from the PQ
if (capacity <= 100_000) { // Capacity is small, make a very large buffer. Make PQ of records to retain, during resize
if (capacity
<= 100_000) { // Capacity is small, make a very large buffer. Make PQ of records to retain, during resize
_maxCapacity = 1_000_000;
} else { // Capacity is large, make buffer only slightly bigger. Make PQ of records to evict, during resize
_maxCapacity = (int) (capacity * 1.2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.SelectionSort;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -52,13 +52,13 @@ public class ConcurrentIndexedTable extends IndexedTable {
/**
* Initializes the data structures needed for this Table
* @param dataSchema data schema of the record's keys and values
* @param aggregationInfos aggregation infos for the aggregations in record's values
* @param aggregationFunctions aggregation functions for the record's values
* @param orderBy list of {@link SelectionSort} defining the order by
* @param capacity the capacity of the table
*/
public ConcurrentIndexedTable(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy,
int capacity) {
super(dataSchema, aggregationInfos, orderBy, capacity);
public ConcurrentIndexedTable(DataSchema dataSchema, AggregationFunction[] aggregationFunctions,
List<SelectionSort> orderBy, int capacity) {
super(dataSchema, aggregationFunctions, orderBy, capacity);

_lookupMap = new ConcurrentHashMap<>();
_readWriteLock = new ReentrantReadWriteLock();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@

import java.util.Arrays;
import java.util.List;
import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.SelectionSort;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;


/**
Expand All @@ -36,8 +36,9 @@ public abstract class IndexedTable extends BaseTable {
/**
* Initializes the variables and comparators needed for the table
*/
IndexedTable(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy, int capacity) {
super(dataSchema, aggregationInfos, orderBy, capacity);
IndexedTable(DataSchema dataSchema, AggregationFunction[] aggregationFunctions, List<SelectionSort> orderBy,
int capacity) {
super(dataSchema, aggregationFunctions, orderBy, capacity);

_numKeyColumns = dataSchema.size() - _numAggregations;
_keyExtractor = new KeyExtractor(_numKeyColumns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
import java.util.List;
import java.util.Map;
import javax.annotation.concurrent.NotThreadSafe;
import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.SelectionSort;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -48,13 +48,13 @@ public class SimpleIndexedTable extends IndexedTable {
/**
* Initializes the data structures needed for this Table
* @param dataSchema data schema of the record's keys and values
* @param aggregationInfos aggregation infos for the aggregations in record'd values
* @param aggregationFunctions aggregation functions for the record's values
* @param orderBy list of {@link SelectionSort} defining the order by
* @param capacity the capacity of the table
*/
public SimpleIndexedTable(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy,
int capacity) {
super(dataSchema, aggregationInfos, orderBy, capacity);
public SimpleIndexedTable(DataSchema dataSchema, AggregationFunction[] aggregationFunctions,
List<SelectionSort> orderBy, int capacity) {
super(dataSchema, aggregationFunctions, orderBy, capacity);

_lookupMap = new HashMap<>();
}
Expand Down

0 comments on commit 8b0089f

Please sign in to comment.