Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up AggregationFunctionContext and use TransformExpressionTree as the key in the blockValSetMap passed to the AggregationFunctions #5364

Merged
merged 2 commits into from
May 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -456,13 +456,9 @@ private void handleCaseSensitivity(BrokerRequest brokerRequest) {
for (AggregationInfo info : brokerRequest.getAggregationsInfo()) {
if (!info.getAggregationType().equalsIgnoreCase(AggregationFunctionType.COUNT.getName())) {
// Always read from backward compatible api in AggregationFunctionUtils.
List<String> expressions = AggregationFunctionUtils.getAggregationExpressions(info);

List<String> newExpressions = new ArrayList<>(expressions.size());
for (String expression : expressions) {
newExpressions.add(fixColumnNameCase(actualTableName, expression));
}
info.setExpressions(newExpressions);
List<String> arguments = AggregationFunctionUtils.getArguments(info);
arguments.replaceAll(e -> fixColumnNameCase(actualTableName, e));
info.setExpressions(arguments);
}
}
if (brokerRequest.isSetGroupBy()) {
Expand Down Expand Up @@ -720,11 +716,10 @@ static void validateRequest(BrokerRequest brokerRequest, int queryResponseLimit)
throw new UnsupportedOperationException("DISTINCT with GROUP BY is currently not supported");
}
if (brokerRequest.isSetOrderBy()) {
List<String> columns = AggregationFunctionUtils.getAggregationExpressions(aggregationInfo);
Set<String> set = new HashSet<>(columns);
Set<String> expressionSet = new HashSet<>(AggregationFunctionUtils.getArguments(aggregationInfo));
List<SelectionSort> orderByColumns = brokerRequest.getOrderBy();
for (SelectionSort selectionSort : orderByColumns) {
if (!set.contains(selectionSort.getColumn())) {
if (!expressionSet.contains(selectionSort.getColumn())) {
throw new UnsupportedOperationException(
"ORDER By should be only on some/all of the columns passed as arguments to DISTINCT");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,17 @@ public static AggregationFunctionType getAggregationFunctionType(String function
String upperCaseFunctionName = functionName.toUpperCase();
if (upperCaseFunctionName.startsWith("PERCENTILE")) {
String remainingFunctionName = upperCaseFunctionName.substring(10);
if (remainingFunctionName.matches("\\d+")) {
if (remainingFunctionName.isEmpty() || remainingFunctionName.matches("\\d+")) {
return PERCENTILE;
} else if (remainingFunctionName.matches("EST\\d+")) {
} else if (remainingFunctionName.equals("EST") || remainingFunctionName.matches("EST\\d+")) {
return PERCENTILEEST;
} else if (remainingFunctionName.matches("TDIGEST\\d+")) {
} else if (remainingFunctionName.equals("TDIGEST") || remainingFunctionName.matches("TDIGEST\\d+")) {
return PERCENTILETDIGEST;
} else if (remainingFunctionName.matches("\\d+MV")) {
} else if (remainingFunctionName.equals("MV") || remainingFunctionName.matches("\\d+MV")) {
return PERCENTILEMV;
} else if (remainingFunctionName.matches("EST\\d+MV")) {
} else if (remainingFunctionName.equals("ESTMV") || remainingFunctionName.matches("EST\\d+MV")) {
return PERCENTILEESTMV;
} else if (remainingFunctionName.matches("TDIGEST\\d+MV")) {
} else if (remainingFunctionName.equals("TDIGESTMV") || remainingFunctionName.matches("TDIGEST\\d+MV")) {
return PERCENTILETDIGESTMV;
} else {
throw new IllegalArgumentException("Invalid aggregation function name: " + functionName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,21 @@
import java.util.List;
import java.util.Objects;
import java.util.Set;
import javax.annotation.Nonnull;
import org.apache.pinot.spi.utils.EqualityUtils;
import javax.annotation.Nullable;
import org.apache.pinot.pql.parsers.Pql2Compiler;
import org.apache.pinot.pql.parsers.pql2.ast.AstNode;
import org.apache.pinot.pql.parsers.pql2.ast.FunctionCallAstNode;
import org.apache.pinot.pql.parsers.pql2.ast.IdentifierAstNode;
import org.apache.pinot.pql.parsers.pql2.ast.LiteralAstNode;
import org.apache.pinot.pql.parsers.pql2.ast.StringLiteralAstNode;
import org.apache.pinot.spi.utils.EqualityUtils;


/**
* Class for representing expression trees for transforms.
* <ul>
* <li>A TransformExpressionTree node has either transform function or a column name, or a literal.</li>
* <li>Leaf nodes either have column name or literal, whereas non-leaf nodes have transform function.</li>
* <li>Transform function in non-leaf nodes is applied to its children nodes.</li>
* <li>Transform function is applied to its children.</li>
* </ul>
*/
public class TransformExpressionTree {
Expand Down Expand Up @@ -66,10 +65,9 @@ public static String getStandardExpression(AstNode astNode) {
} else if (astNode instanceof FunctionCallAstNode) {
// UDF expression
return standardizeExpression(((FunctionCallAstNode) astNode).getExpression());
} else if (astNode instanceof StringLiteralAstNode) {
// Treat string as column name
// NOTE: this is for backward-compatibility
return ((StringLiteralAstNode) astNode).getText();
} else if (astNode instanceof LiteralAstNode) {
// Literal
return ((LiteralAstNode) astNode).getValueAsString();
} else {
throw new IllegalStateException("Cannot get standard expression from " + astNode.getClass().getSimpleName());
}
Expand Down Expand Up @@ -106,6 +104,13 @@ public TransformExpressionTree(AstNode root) {
}
}

public TransformExpressionTree(ExpressionType expressionType, String value,
@Nullable List<TransformExpressionTree> children) {
_expressionType = expressionType;
_value = value;
_children = children;
}

/**
* Returns the expression type of the node, which can be one of the following:
* <ul>
Expand Down Expand Up @@ -168,7 +173,7 @@ public boolean isLiteral() {
*
* @param columns Output columns
*/
public void getColumns(@Nonnull Set<String> columns) {
public void getColumns(Set<String> columns) {
if (_expressionType == ExpressionType.IDENTIFIER) {
columns.add(_value);
} else if (_children != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ private AggregationInfo buildAggregationInfo(Function function) {
private String getColumnExpression(Expression functionParam) {
switch (functionParam.getType()) {
case LITERAL:
return functionParam.getLiteral().getStringValue();
return functionParam.getLiteral().getFieldValue().toString();
case IDENTIFIER:
return functionParam.getIdentifier().getName();
case FUNCTION:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.pinot.common.request.Selection;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.common.utils.DataTable;
import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.util.QueryOptions;
Expand Down Expand Up @@ -101,9 +100,8 @@ public static DataTable buildEmptyDataTable(BrokerRequest brokerRequest)
}

// Aggregation query.
AggregationFunctionContext[] aggregationFunctionContexts =
AggregationFunctionUtils.getAggregationFunctionContexts(brokerRequest);
int numAggregations = aggregationFunctionContexts.length;
AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(brokerRequest);
int numAggregations = aggregationFunctions.length;
if (brokerRequest.isSetGroupBy()) {
// Aggregation group-by query.

Expand All @@ -121,9 +119,9 @@ public static DataTable buildEmptyDataTable(BrokerRequest brokerRequest)
columnDataTypes[index] = DataSchema.ColumnDataType.STRING;
index++;
}
for (AggregationFunctionContext aggregationFunctionContext : aggregationFunctionContexts) {
columnNames[index] = aggregationFunctionContext.getResultColumnName();
AggregationFunction aggregationFunction = aggregationFunctionContext.getAggregationFunction();
for (AggregationFunction aggregationFunction : aggregationFunctions) {
// NOTE: Use AggregationFunction.getResultColumnName() for SQL format response
columnNames[index] = aggregationFunction.getResultColumnName();
columnDataTypes[index] = aggregationFunction.getIntermediateResultColumnType();
index++;
}
Expand All @@ -137,9 +135,10 @@ public static DataTable buildEmptyDataTable(BrokerRequest brokerRequest)

// Build the data table.
DataTableBuilder dataTableBuilder = new DataTableBuilder(new DataSchema(columnNames, columnDataTypes));
for (AggregationFunctionContext aggregationFunctionContext : aggregationFunctionContexts) {
for (AggregationFunction aggregationFunction : aggregationFunctions) {
dataTableBuilder.startRow();
dataTableBuilder.setColumn(0, aggregationFunctionContext.getAggregationColumnName());
// NOTE: For backward-compatibility, use AggregationFunction.getColumnName() for PQL format response
dataTableBuilder.setColumn(0, aggregationFunction.getColumnName());
dataTableBuilder.setColumn(1, Collections.emptyMap());
dataTableBuilder.finishRow();
}
Expand All @@ -152,9 +151,9 @@ public static DataTable buildEmptyDataTable(BrokerRequest brokerRequest)
DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[numAggregations];
Object[] aggregationResults = new Object[numAggregations];
for (int i = 0; i < numAggregations; i++) {
AggregationFunctionContext aggregationFunctionContext = aggregationFunctionContexts[i];
aggregationColumnNames[i] = aggregationFunctionContext.getAggregationColumnName();
AggregationFunction aggregationFunction = aggregationFunctionContext.getAggregationFunction();
AggregationFunction aggregationFunction = aggregationFunctions[i];
// NOTE: For backward-compatibility, use AggregationFunction.getColumnName() for aggregation only query
aggregationColumnNames[i] = aggregationFunction.getColumnName();
columnDataTypes[i] = aggregationFunction.getIntermediateResultColumnType();
aggregationResults[i] =
aggregationFunction.extractAggregationResult(aggregationFunction.createAggregationResultHolder());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,20 @@
import java.util.Iterator;
import java.util.List;
import org.apache.commons.collections.CollectionUtils;
import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.SelectionSort;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;


/**
* Base abstract implementation of Table
*/
public abstract class BaseTable implements Table {

final AggregationFunction[] _aggregationFunctions;
final int _numAggregations;
// TODO: After fixing the DistinctTable logic, make it final
protected DataSchema _dataSchema;
final int _numColumns;
protected final int _numColumns;
protected final AggregationFunction[] _aggregationFunctions;
protected final int _numAggregations;

// the capacity we need to trim to
protected int _capacity;
Expand All @@ -46,34 +44,27 @@ public abstract class BaseTable implements Table {
protected boolean _isOrderBy;
protected TableResizer _tableResizer;

private final List<AggregationInfo> _aggregationInfos;

/**
* Initializes the variables and comparators needed for the table
*/
public BaseTable(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy, int capacity) {
public BaseTable(DataSchema dataSchema, AggregationFunction[] aggregationFunctions, List<SelectionSort> orderBy,
int capacity) {
_dataSchema = dataSchema;
_numColumns = dataSchema.size();

_numAggregations = aggregationInfos.size();
_aggregationFunctions = new AggregationFunction[_numAggregations];
for (int i = 0; i < _numAggregations; i++) {
_aggregationFunctions[i] =
AggregationFunctionUtils.getAggregationFunctionContext(aggregationInfos.get(i)).getAggregationFunction();
}

_aggregationInfos = aggregationInfos;
_aggregationFunctions = aggregationFunctions;
_numAggregations = aggregationFunctions.length;
addCapacityAndOrderByInfo(orderBy, capacity);
}

protected void addCapacityAndOrderByInfo(List<SelectionSort> orderBy, int capacity) {
_isOrderBy = CollectionUtils.isNotEmpty(orderBy);
if (_isOrderBy) {
_tableResizer = new TableResizer(_dataSchema, _aggregationInfos, orderBy);
_tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, orderBy);

// TODO: tune these numbers and come up with a better formula (github ISSUE-4801)
// Based on the capacity and maxCapacity, the resizer will smartly choose to evict/retain recors from the PQ
if (capacity <= 100_000) { // Capacity is small, make a very large buffer. Make PQ of records to retain, during resize
if (capacity
<= 100_000) { // Capacity is small, make a very large buffer. Make PQ of records to retain, during resize
_maxCapacity = 1_000_000;
} else { // Capacity is large, make buffer only slightly bigger. Make PQ of records to evict, during resize
_maxCapacity = (int) (capacity * 1.2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.SelectionSort;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -52,13 +52,13 @@ public class ConcurrentIndexedTable extends IndexedTable {
/**
* Initializes the data structures needed for this Table
* @param dataSchema data schema of the record's keys and values
* @param aggregationInfos aggregation infos for the aggregations in record's values
* @param aggregationFunctions aggregation functions for the record's values
* @param orderBy list of {@link SelectionSort} defining the order by
* @param capacity the capacity of the table
*/
public ConcurrentIndexedTable(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy,
int capacity) {
super(dataSchema, aggregationInfos, orderBy, capacity);
public ConcurrentIndexedTable(DataSchema dataSchema, AggregationFunction[] aggregationFunctions,
List<SelectionSort> orderBy, int capacity) {
super(dataSchema, aggregationFunctions, orderBy, capacity);

_lookupMap = new ConcurrentHashMap<>();
_readWriteLock = new ReentrantReadWriteLock();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@

import java.util.Arrays;
import java.util.List;
import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.SelectionSort;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;


/**
Expand All @@ -36,8 +36,9 @@ public abstract class IndexedTable extends BaseTable {
/**
* Initializes the variables and comparators needed for the table
*/
IndexedTable(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy, int capacity) {
super(dataSchema, aggregationInfos, orderBy, capacity);
IndexedTable(DataSchema dataSchema, AggregationFunction[] aggregationFunctions, List<SelectionSort> orderBy,
int capacity) {
super(dataSchema, aggregationFunctions, orderBy, capacity);

_numKeyColumns = dataSchema.size() - _numAggregations;
_keyExtractor = new KeyExtractor(_numKeyColumns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
import java.util.List;
import java.util.Map;
import javax.annotation.concurrent.NotThreadSafe;
import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.SelectionSort;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -48,13 +48,13 @@ public class SimpleIndexedTable extends IndexedTable {
/**
* Initializes the data structures needed for this Table
* @param dataSchema data schema of the record's keys and values
* @param aggregationInfos aggregation infos for the aggregations in record'd values
* @param aggregationFunctions aggregation functions for the record's values
* @param orderBy list of {@link SelectionSort} defining the order by
* @param capacity the capacity of the table
*/
public SimpleIndexedTable(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy,
int capacity) {
super(dataSchema, aggregationInfos, orderBy, capacity);
public SimpleIndexedTable(DataSchema dataSchema, AggregationFunction[] aggregationFunctions,
List<SelectionSort> orderBy, int capacity) {
super(dataSchema, aggregationFunctions, orderBy, capacity);

_lookupMap = new HashMap<>();
}
Expand Down