Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling a no-arg function in query parsing and expression tree #5375

Merged
merged 2 commits into from
May 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ public TransformExpressionTree(AstNode root) {
_expressionType = ExpressionType.FUNCTION;
_value = ((FunctionCallAstNode) root).getName().toLowerCase();
_children = new ArrayList<>();
for (AstNode child : root.getChildren()) {
_children.add(new TransformExpressionTree(child));
if(root.hasChildren()) {
for (AstNode child : root.getChildren()) {
_children.add(new TransformExpressionTree(child));
}
}
} else if (root instanceof IdentifierAstNode) {
_expressionType = ExpressionType.IDENTIFIER;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


public class CalciteSqlParser {

private static final Logger LOGGER = LoggerFactory.getLogger(CalciteSqlParser.class);
Expand Down Expand Up @@ -115,12 +116,15 @@ private static void validateSelectionClause(Map<Identifier, Expression> aliasMap

private static void matchIdentifierInAliasMap(Expression selectExpr, Set<String> aliasKeys)
throws SqlCompilationException {
if (selectExpr.getFunctionCall() != null) {
if (selectExpr.getFunctionCall().getOperator().equalsIgnoreCase(SqlKind.AS.toString())) {
matchIdentifierInAliasMap(selectExpr.getFunctionCall().getOperands().get(0), aliasKeys);
Function functionCall = selectExpr.getFunctionCall();
if (functionCall != null) {
if (functionCall.getOperator().equalsIgnoreCase(SqlKind.AS.toString())) {
matchIdentifierInAliasMap(functionCall.getOperands().get(0), aliasKeys);
} else {
for (Expression operand : selectExpr.getFunctionCall().getOperands()) {
matchIdentifierInAliasMap(operand, aliasKeys);
if (functionCall.getOperandsSize() > 0) {
for (Expression operand : functionCall.getOperands()) {
matchIdentifierInAliasMap(operand, aliasKeys);
}
}
}
}
Expand Down Expand Up @@ -169,16 +173,19 @@ private static void validateGroupByClause(PinotQuery pinotQuery)
}

private static boolean isAggregateExpression(Expression expression) {
if (expression.getFunctionCall() != null) {
String operator = expression.getFunctionCall().getOperator();
Function functionCall = expression.getFunctionCall();
if (functionCall != null) {
String operator = functionCall.getOperator();
try {
AggregationFunctionType.getAggregationFunctionType(operator);
return true;
} catch (IllegalArgumentException e) {
}
for (Expression operand : expression.getFunctionCall().getOperands()) {
if (isAggregateExpression(operand)) {
return true;
if (functionCall.getOperandsSize() > 0) {
for (Expression operand : functionCall.getOperands()) {
if (isAggregateExpression(operand)) {
return true;
}
}
}
}
Expand Down Expand Up @@ -291,7 +298,7 @@ private static PinotQuery compileCalciteSqlToPinotQuery(String sql) {
throw new RuntimeException(
"Unable to convert SqlNode: " + sqlNode + " to PinotQuery. Unknown node type: " + sqlNode.getKind());
}
queryReWrite(pinotQuery);
queryRewrite(pinotQuery);
return pinotQuery;
}

Expand All @@ -307,7 +314,7 @@ private static SqlParser getSqlParser(String sql) {
return SqlParser.create(sql, parserBuilder.build());
}

private static void queryReWrite(PinotQuery pinotQuery) {
private static void queryRewrite(PinotQuery pinotQuery) {
// Update Predicate Comparison
if (pinotQuery.isSetFilterExpression()) {
Expression filterExpression = pinotQuery.getFilterExpression();
Expand Down Expand Up @@ -359,10 +366,11 @@ private static Expression updateComparisonPredicate(Expression expression) {
comparisonFunction.getFunctionCall().setOperands(exprList);
return comparisonFunction;
default:
List<Expression> operands = functionCall.getOperands();
List<Expression> newOperands = new ArrayList<>();
for (int i = 0; i < operands.size(); i++) {
newOperands.add(updateComparisonPredicate(operands.get(i)));
int operandsSize = functionCall.getOperandsSize();
for (int i = 0; i < operandsSize; i++) {
Expression operand = functionCall.getOperands().get(i);
newOperands.add(updateComparisonPredicate(operand));
}
functionCall.setOperands(newOperands);
}
Expand Down Expand Up @@ -433,7 +441,7 @@ private static void applyAlias(Map<Identifier, Expression> aliasMap, Expression
expression.setType(aliasExpression.getType()).setIdentifier(aliasExpression.getIdentifier())
.setFunctionCall(aliasExpression.getFunctionCall()).setLiteral(aliasExpression.getLiteral());
}
if (expression.getFunctionCall() != null) {
if (expression.getFunctionCall() != null && expression.getFunctionCall().getOperandsSize() > 0) {
for (Expression operand : expression.getFunctionCall().getOperands()) {
applyAlias(aliasMap, operand);
}
Expand Down Expand Up @@ -590,8 +598,9 @@ private static Expression toExpression(SqlNode node) {
if (funcSqlNode.getOperator().getKind() == SqlKind.OTHER_FUNCTION) {
funcName = funcSqlNode.getOperator().getName();
}
if (funcName.equalsIgnoreCase(SqlKind.COUNT.toString()) && (funcSqlNode.getFunctionQuantifier() != null) && funcSqlNode
.getFunctionQuantifier().toValue().equalsIgnoreCase(AggregationFunctionType.DISTINCT.getName())) {
if (funcName.equalsIgnoreCase(SqlKind.COUNT.toString()) && (funcSqlNode.getFunctionQuantifier() != null)
&& funcSqlNode.getFunctionQuantifier().toValue()
.equalsIgnoreCase(AggregationFunctionType.DISTINCT.getName())) {
funcName = AggregationFunctionType.DISTINCTCOUNT.getName();
}
final Expression funcExpr = RequestUtils.getFunctionExpression(funcName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ public void testUpperCase() {
Assert.assertTrue(equalsWithStandardExpressionTree(TransformExpressionTree.compileToExpressionTree(expression)));
}

@Test
public void testNoArgFunction() {
String expression = "now()";
TransformExpressionTree expressionTree = TransformExpressionTree.compileToExpressionTree(expression);
Assert.assertEquals(expressionTree.isFunction(), true);
Assert.assertEquals(expressionTree.getValue(), "now");
Assert.assertEquals(expressionTree.getChildren().size(), 0);
}

private static boolean equalsWithStandardExpressionTree(TransformExpressionTree expressionTree) {
return expressionTree.hashCode() == STANDARD_EXPRESSION_TREE.hashCode() && expressionTree
.equals(STANDARD_EXPRESSION_TREE) && expressionTree.toString().equals(STANDARD_EXPRESSION);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1467,4 +1467,21 @@ public void testDistinctCountRewrite() {
pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(1).getIdentifier().getName(),
"distinct_bar");
}
}

@Test
public void testNoArgFunction() {
String query = "SELECT now() FROM foo ";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(), "now");

query = "SELECT a FROM foo where time_col > now()";
pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
Function greaterThan = pinotQuery.getFilterExpression().getFunctionCall();
Function minus = greaterThan.getOperands().get(0).getFunctionCall();
Assert.assertEquals(minus.getOperands().get(1).getFunctionCall().getOperator(), "now");

query = "SELECT sum(a), now() FROM foo group by now()";
pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
Assert.assertEquals(pinotQuery.getGroupByList().get(0).getFunctionCall().getOperator(), "now");
}
}