Skip to content

Commit

Permalink
Handling a no-arg function in query parsing and expression tree (#5375)
Browse files Browse the repository at this point in the history
* Handling a no-arg function in query parsing and expression tree

* Addressing comments and added tests that uncovered few more places where we assumed functions have arguments
  • Loading branch information
kishoreg committed May 15, 2020
1 parent 6f840db commit 2e39a2e
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ public TransformExpressionTree(AstNode root) {
_expressionType = ExpressionType.FUNCTION;
_value = ((FunctionCallAstNode) root).getName().toLowerCase();
_children = new ArrayList<>();
for (AstNode child : root.getChildren()) {
_children.add(new TransformExpressionTree(child));
if(root.hasChildren()) {
for (AstNode child : root.getChildren()) {
_children.add(new TransformExpressionTree(child));
}
}
} else if (root instanceof IdentifierAstNode) {
_expressionType = ExpressionType.IDENTIFIER;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


public class CalciteSqlParser {

private static final Logger LOGGER = LoggerFactory.getLogger(CalciteSqlParser.class);
Expand Down Expand Up @@ -115,12 +116,15 @@ private static void validateSelectionClause(Map<Identifier, Expression> aliasMap

private static void matchIdentifierInAliasMap(Expression selectExpr, Set<String> aliasKeys)
throws SqlCompilationException {
if (selectExpr.getFunctionCall() != null) {
if (selectExpr.getFunctionCall().getOperator().equalsIgnoreCase(SqlKind.AS.toString())) {
matchIdentifierInAliasMap(selectExpr.getFunctionCall().getOperands().get(0), aliasKeys);
Function functionCall = selectExpr.getFunctionCall();
if (functionCall != null) {
if (functionCall.getOperator().equalsIgnoreCase(SqlKind.AS.toString())) {
matchIdentifierInAliasMap(functionCall.getOperands().get(0), aliasKeys);
} else {
for (Expression operand : selectExpr.getFunctionCall().getOperands()) {
matchIdentifierInAliasMap(operand, aliasKeys);
if (functionCall.getOperandsSize() > 0) {
for (Expression operand : functionCall.getOperands()) {
matchIdentifierInAliasMap(operand, aliasKeys);
}
}
}
}
Expand Down Expand Up @@ -169,16 +173,19 @@ private static void validateGroupByClause(PinotQuery pinotQuery)
}

private static boolean isAggregateExpression(Expression expression) {
if (expression.getFunctionCall() != null) {
String operator = expression.getFunctionCall().getOperator();
Function functionCall = expression.getFunctionCall();
if (functionCall != null) {
String operator = functionCall.getOperator();
try {
AggregationFunctionType.getAggregationFunctionType(operator);
return true;
} catch (IllegalArgumentException e) {
}
for (Expression operand : expression.getFunctionCall().getOperands()) {
if (isAggregateExpression(operand)) {
return true;
if (functionCall.getOperandsSize() > 0) {
for (Expression operand : functionCall.getOperands()) {
if (isAggregateExpression(operand)) {
return true;
}
}
}
}
Expand Down Expand Up @@ -291,7 +298,7 @@ private static PinotQuery compileCalciteSqlToPinotQuery(String sql) {
throw new RuntimeException(
"Unable to convert SqlNode: " + sqlNode + " to PinotQuery. Unknown node type: " + sqlNode.getKind());
}
queryReWrite(pinotQuery);
queryRewrite(pinotQuery);
return pinotQuery;
}

Expand All @@ -307,7 +314,7 @@ private static SqlParser getSqlParser(String sql) {
return SqlParser.create(sql, parserBuilder.build());
}

private static void queryReWrite(PinotQuery pinotQuery) {
private static void queryRewrite(PinotQuery pinotQuery) {
// Update Predicate Comparison
if (pinotQuery.isSetFilterExpression()) {
Expression filterExpression = pinotQuery.getFilterExpression();
Expand Down Expand Up @@ -359,10 +366,11 @@ private static Expression updateComparisonPredicate(Expression expression) {
comparisonFunction.getFunctionCall().setOperands(exprList);
return comparisonFunction;
default:
List<Expression> operands = functionCall.getOperands();
List<Expression> newOperands = new ArrayList<>();
for (int i = 0; i < operands.size(); i++) {
newOperands.add(updateComparisonPredicate(operands.get(i)));
int operandsSize = functionCall.getOperandsSize();
for (int i = 0; i < operandsSize; i++) {
Expression operand = functionCall.getOperands().get(i);
newOperands.add(updateComparisonPredicate(operand));
}
functionCall.setOperands(newOperands);
}
Expand Down Expand Up @@ -433,7 +441,7 @@ private static void applyAlias(Map<Identifier, Expression> aliasMap, Expression
expression.setType(aliasExpression.getType()).setIdentifier(aliasExpression.getIdentifier())
.setFunctionCall(aliasExpression.getFunctionCall()).setLiteral(aliasExpression.getLiteral());
}
if (expression.getFunctionCall() != null) {
if (expression.getFunctionCall() != null && expression.getFunctionCall().getOperandsSize() > 0) {
for (Expression operand : expression.getFunctionCall().getOperands()) {
applyAlias(aliasMap, operand);
}
Expand Down Expand Up @@ -590,8 +598,9 @@ private static Expression toExpression(SqlNode node) {
if (funcSqlNode.getOperator().getKind() == SqlKind.OTHER_FUNCTION) {
funcName = funcSqlNode.getOperator().getName();
}
if (funcName.equalsIgnoreCase(SqlKind.COUNT.toString()) && (funcSqlNode.getFunctionQuantifier() != null) && funcSqlNode
.getFunctionQuantifier().toValue().equalsIgnoreCase(AggregationFunctionType.DISTINCT.getName())) {
if (funcName.equalsIgnoreCase(SqlKind.COUNT.toString()) && (funcSqlNode.getFunctionQuantifier() != null)
&& funcSqlNode.getFunctionQuantifier().toValue()
.equalsIgnoreCase(AggregationFunctionType.DISTINCT.getName())) {
funcName = AggregationFunctionType.DISTINCTCOUNT.getName();
}
final Expression funcExpr = RequestUtils.getFunctionExpression(funcName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ public void testUpperCase() {
Assert.assertTrue(equalsWithStandardExpressionTree(TransformExpressionTree.compileToExpressionTree(expression)));
}

@Test
public void testNoArgFunction() {
String expression = "now()";
TransformExpressionTree expressionTree = TransformExpressionTree.compileToExpressionTree(expression);
Assert.assertEquals(expressionTree.isFunction(), true);
Assert.assertEquals(expressionTree.getValue(), "now");
Assert.assertEquals(expressionTree.getChildren().size(), 0);
}

private static boolean equalsWithStandardExpressionTree(TransformExpressionTree expressionTree) {
return expressionTree.hashCode() == STANDARD_EXPRESSION_TREE.hashCode() && expressionTree
.equals(STANDARD_EXPRESSION_TREE) && expressionTree.toString().equals(STANDARD_EXPRESSION);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1467,4 +1467,21 @@ public void testDistinctCountRewrite() {
pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(1).getIdentifier().getName(),
"distinct_bar");
}
}

@Test
public void testNoArgFunction() {
String query = "SELECT now() FROM foo ";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(), "now");

query = "SELECT a FROM foo where time_col > now()";
pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
Function greaterThan = pinotQuery.getFilterExpression().getFunctionCall();
Function minus = greaterThan.getOperands().get(0).getFunctionCall();
Assert.assertEquals(minus.getOperands().get(1).getFunctionCall().getOperator(), "now");

query = "SELECT sum(a), now() FROM foo group by now()";
pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
Assert.assertEquals(pinotQuery.getGroupByList().get(0).getFunctionCall().getOperator(), "now");
}
}

0 comments on commit 2e39a2e

Please sign in to comment.