Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.google.common.base.Preconditions;
import java.util.List;
import org.apache.commons.lang3.EnumUtils;
import org.apache.pinot.common.function.TransformFunctionType;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.ExpressionType;
import org.apache.pinot.common.request.Function;
Expand All @@ -46,8 +47,10 @@ public PinotQuery rewrite(PinotQuery pinotQuery) {

/**
* This method converts an expression to what Pinot could evaluate.
* 1. For comparison expression, left operand could be any expression, but right operand only
* supports literal. E.g. 'WHERE a > b' will be converted to 'WHERE a - b > 0'
* 1. For comparison expressions, the right operand should be a literal. If the left operand is a
* literal and the right is not, they are swapped. If the right operand is still non-literal
* (column-to-column comparison), the predicate is rewritten using a comparison transform
* function: e.g. 'WHERE a = b' becomes 'WHERE equals(a, b) = true'.
* 2. Updates boolean predicates (literals and scalar functions) that are missing an EQUALS filter.
* E.g. 1: 'WHERE a' will be updated to 'WHERE a = true'
* E.g. 2: "WHERE startsWith(col, 'str')" will be updated to "WHERE startsWith(col, 'str') = true"
Expand Down Expand Up @@ -115,11 +118,12 @@ private static Expression updateFunctionExpression(Expression expression) {
break;
}

// Handle predicate like 'a > b' -> 'a - b > 0'
if (!secondOperand.isSetLiteral()) {
Expression minusExpression = RequestUtils.getFunctionExpression("minus", firstOperand, secondOperand);
operands.set(0, minusExpression);
operands.set(1, RequestUtils.getLiteralExpression(0));
Expression comparisonExpression = RequestUtils.getFunctionExpression(
getComparisonFunctionName(filterKind), firstOperand, secondOperand);
function.setOperator(FilterKind.EQUALS.name());
operands.set(0, comparisonExpression);
operands.set(1, RequestUtils.getLiteralExpression(true));
break;
}
break;
Expand Down Expand Up @@ -204,6 +208,25 @@ private static Expression convertPredicateToEqualsBooleanExpression(Expression e
RequestUtils.getLiteralExpression(true));
}

private static String getComparisonFunctionName(FilterKind filterKind) {
switch (filterKind) {
case EQUALS:
return TransformFunctionType.EQUALS.getName();
case NOT_EQUALS:
return TransformFunctionType.NOT_EQUALS.getName();
case GREATER_THAN:
return TransformFunctionType.GREATER_THAN.getName();
case GREATER_THAN_OR_EQUAL:
return TransformFunctionType.GREATER_THAN_OR_EQUAL.getName();
case LESS_THAN:
return TransformFunctionType.LESS_THAN.getName();
case LESS_THAN_OR_EQUAL:
return TransformFunctionType.LESS_THAN_OR_EQUAL.getName();
default:
throw new IllegalStateException("Unsupported comparison operator: " + filterKind);
}
}

/**
* The purpose of this method is to convert expression "0 < columnA" to "columnA > 0".
* The conversion would be:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,13 +492,13 @@ public void testFilterClauses() {
public void testFilterClausesWithRightExpression() {
PinotQuery pinotQuery = compileToPinotQuery("select * from vegetables where a > b");
Function func = pinotQuery.getFilterExpression().getFunctionCall();
Assert.assertEquals(func.getOperator(), FilterKind.GREATER_THAN.name());
Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperator(), "minus");
Assert.assertEquals(func.getOperator(), FilterKind.EQUALS.name());
Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperator(), "greater_than");
Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperands().get(0).getIdentifier().getName(),
"a");
Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperands().get(1).getIdentifier().getName(),
"b");
Assert.assertEquals(func.getOperands().get(1).getLiteral().getIntValue(), 0);
Assert.assertTrue(func.getOperands().get(1).getLiteral().getBoolValue());
pinotQuery = compileToPinotQuery("select * from vegetables where 0 < a-b");
func = pinotQuery.getFilterExpression().getFunctionCall();
Assert.assertEquals(func.getOperator(), FilterKind.GREATER_THAN.name());
Expand All @@ -511,8 +511,8 @@ public void testFilterClausesWithRightExpression() {

pinotQuery = compileToPinotQuery("select * from vegetables where b < 100 + c");
func = pinotQuery.getFilterExpression().getFunctionCall();
Assert.assertEquals(func.getOperator(), FilterKind.LESS_THAN.name());
Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperator(), "minus");
Assert.assertEquals(func.getOperator(), FilterKind.EQUALS.name());
Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperator(), "less_than");
Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperands().get(0).getIdentifier().getName(),
"b");
Assert.assertEquals(
Expand All @@ -523,7 +523,7 @@ public void testFilterClausesWithRightExpression() {
Assert.assertEquals(
func.getOperands().get(0).getFunctionCall().getOperands().get(1).getFunctionCall().getOperands().get(1)
.getIdentifier().getName(), "c");
Assert.assertEquals(func.getOperands().get(1).getLiteral().getIntValue(), 0);
Assert.assertTrue(func.getOperands().get(1).getLiteral().getBoolValue());
pinotQuery = compileToPinotQuery("select * from vegetables where b -(100+c)< 0");
func = pinotQuery.getFilterExpression().getFunctionCall();
Assert.assertEquals(func.getOperator(), FilterKind.LESS_THAN.name());
Expand All @@ -542,8 +542,8 @@ public void testFilterClausesWithRightExpression() {

pinotQuery = compileToPinotQuery("select * from vegetables where foo1(bar1(a-b)) <= foo2(bar2(c+d))");
func = pinotQuery.getFilterExpression().getFunctionCall();
Assert.assertEquals(func.getOperator(), FilterKind.LESS_THAN_OR_EQUAL.name());
Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperator(), "minus");
Assert.assertEquals(func.getOperator(), FilterKind.EQUALS.name());
Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperator(), "less_than_or_equal");
Assert.assertEquals(
func.getOperands().get(0).getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), "foo1");
Assert.assertEquals(
Expand Down Expand Up @@ -576,7 +576,7 @@ public void testFilterClausesWithRightExpression() {
func.getOperands().get(0).getFunctionCall().getOperands().get(1).getFunctionCall().getOperands().get(0)
.getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(1).getIdentifier().getName(),
"d");
Assert.assertEquals(func.getOperands().get(1).getLiteral().getIntValue(), 0);
Assert.assertTrue(func.getOperands().get(1).getLiteral().getBoolValue());
pinotQuery = compileToPinotQuery("select * from vegetables where foo1(bar1(a-b)) - foo2(bar2(c+d)) <= 0");
func = pinotQuery.getFilterExpression().getFunctionCall();
Assert.assertEquals(func.getOperator(), FilterKind.LESS_THAN_OR_EQUAL.name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,56 +125,78 @@ public void testFilterPredicateLiteralIdentifierSwap() {

@Test
public void testFilterPredicateColumnComparisonRewrite() {
// Filters like 'col1 = col2' should be rewritten to 'col1 - col2 = 0'

PinotQuery pinotQuery =
CalciteSqlParser.compileToPinotQueryWithoutRewrites("SELECT * FROM mytable WHERE col1 = col2 AND col3 < col4;");
assertEquals(pinotQuery.getFilterExpression().getFunctionCall().getOperator(), "AND");
assertEquals(pinotQuery.getFilterExpression().getFunctionCall().getOperands().size(), 2);
assertEquals(
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(),
"EQUALS");
assertEquals(
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(0)
.getIdentifier().getName(), "col1");
// col1 = col2 should be rewritten to equals(col1, col2) = true
PinotQuery equalsQuery =
CalciteSqlParser.compileToPinotQueryWithoutRewrites("SELECT * FROM mytable WHERE col1 = col2");
PinotQuery rewrittenEquals = _predicateComparisonRewriter.rewrite(equalsQuery);
assertEquals(rewrittenEquals.getFilterExpression().getFunctionCall().getOperator(), "EQUALS");
assertEquals(
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(1)
.getIdentifier().getName(), "col2");

PinotQuery rewrittenQuery = _predicateComparisonRewriter.rewrite(pinotQuery);
assertEquals(rewrittenQuery.getFilterExpression().getFunctionCall().getOperator(), "AND");
assertEquals(rewrittenQuery.getFilterExpression().getFunctionCall().getOperands().size(), 2);
rewrittenEquals.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(),
"equals");
assertEquals(
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(),
"EQUALS");
assertEquals(
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(0)
.getFunctionCall().getOperator(), "minus");
rewrittenEquals.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperands()
.get(0).getIdentifier().getName(), "col1");
assertEquals(
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(0)
.getFunctionCall().getOperands().get(0).getIdentifier().getName(), "col1");
assertEquals(
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(0)
.getFunctionCall().getOperands().get(1).getIdentifier().getName(), "col2");
assertEquals(
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getFunctionCall().getOperands().get(1)
.getLiteral().getIntValue(), 0);
assertEquals(
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getFunctionCall().getOperator(),
"LESS_THAN");
assertEquals(
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getFunctionCall().getOperands().get(0)
.getFunctionCall().getOperator(), "minus");
assertEquals(
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getFunctionCall().getOperands().get(0)
.getFunctionCall().getOperands().get(0).getIdentifier().getName(), "col3");
assertEquals(
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getFunctionCall().getOperands().get(0)
.getFunctionCall().getOperands().get(1).getIdentifier().getName(), "col4");
assertEquals(
pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getFunctionCall().getOperands().get(1)
.getLiteral().getIntValue(), 0);

rewrittenEquals.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperands()
.get(1).getIdentifier().getName(), "col2");
assertTrue(
rewrittenEquals.getFilterExpression().getFunctionCall().getOperands().get(1).getLiteral().getBoolValue());

// col3 < col4 should be rewritten to less_than(col3, col4) = true
PinotQuery lessThanQuery =
CalciteSqlParser.compileToPinotQueryWithoutRewrites("SELECT * FROM mytable WHERE col3 < col4");
PinotQuery rewrittenLt = _predicateComparisonRewriter.rewrite(lessThanQuery);
assertEquals(rewrittenLt.getFilterExpression().getFunctionCall().getOperator(), "EQUALS");
assertEquals(
rewrittenLt.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(),
"less_than");

// col1 != col2 should be rewritten to not_equals(col1, col2) = true
PinotQuery notEqualsQuery =
CalciteSqlParser.compileToPinotQueryWithoutRewrites("SELECT * FROM mytable WHERE col1 != col2");
PinotQuery rewrittenNeq = _predicateComparisonRewriter.rewrite(notEqualsQuery);
assertEquals(rewrittenNeq.getFilterExpression().getFunctionCall().getOperator(), "EQUALS");
assertEquals(
rewrittenNeq.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(),
"not_equals");

// Function on LHS with column on RHS
PinotQuery functionRhsQuery = CalciteSqlParser.compileToPinotQueryWithoutRewrites(
"SELECT * FROM mytable WHERE json_extract_scalar(col1, '$.f', 'STRING', 'null') = col2");
PinotQuery rewrittenFunc = _predicateComparisonRewriter.rewrite(functionRhsQuery);
assertEquals(rewrittenFunc.getFilterExpression().getFunctionCall().getOperator(), "EQUALS");
assertEquals(
rewrittenFunc.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(),
"equals");

// col5 >= col6 should be rewritten to greater_than_or_equal(col5, col6) = true
PinotQuery gteQuery =
CalciteSqlParser.compileToPinotQueryWithoutRewrites("SELECT * FROM mytable WHERE col5 >= col6");
PinotQuery rewrittenGte = _predicateComparisonRewriter.rewrite(gteQuery);
assertEquals(rewrittenGte.getFilterExpression().getFunctionCall().getOperator(), "EQUALS");
assertEquals(
rewrittenGte.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(),
"greater_than_or_equal");

// col5 <= col6 should be rewritten to less_than_or_equal(col5, col6) = true
PinotQuery lteQuery =
CalciteSqlParser.compileToPinotQueryWithoutRewrites("SELECT * FROM mytable WHERE col5 <= col6");
PinotQuery rewrittenLte = _predicateComparisonRewriter.rewrite(lteQuery);
assertEquals(rewrittenLte.getFilterExpression().getFunctionCall().getOperator(), "EQUALS");
assertEquals(
rewrittenLte.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(),
"less_than_or_equal");

// col7 > col8 should be rewritten to greater_than(col7, col8) = true
PinotQuery gtQuery =
CalciteSqlParser.compileToPinotQueryWithoutRewrites("SELECT * FROM mytable WHERE col7 > col8");
PinotQuery rewrittenGt = _predicateComparisonRewriter.rewrite(gtQuery);
assertEquals(rewrittenGt.getFilterExpression().getFunctionCall().getOperator(), "EQUALS");
assertEquals(
rewrittenGt.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(),
"greater_than");

// BETWEEN with non-literal bounds should still throw (not a comparison operator)
PinotQuery betweenQuery =
CalciteSqlParser.compileToPinotQueryWithoutRewrites("SELECT * FROM mytable WHERE col1 BETWEEN col2 AND col3");
assertThrows(SqlCompilationException.class, () -> _predicateComparisonRewriter.rewrite(betweenQuery));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import org.apache.pinot.common.response.broker.BrokerResponseNative;
import org.apache.pinot.spi.config.table.FieldConfig;
import org.apache.pinot.spi.config.table.TableConfig;
import org.apache.pinot.spi.config.table.TableType;
Expand All @@ -30,6 +31,8 @@
import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
import org.testng.annotations.Test;

import static org.testng.Assert.assertTrue;


public class JsonExtractScalarTest extends BaseJsonQueryTest {

Expand Down Expand Up @@ -191,6 +194,61 @@ public void testNullAsDefaultValueWithNullHandlingDisabled() {
);
}

@Test(dataProvider = "allJsonColumns")
public void testJsonExtractScalarComparedToColumn(String column) {
// Before the fix, column-to-column comparisons were rewritten as minus(a, b) = 0,
// which forced numeric coercion and threw NumberFormatException on string values.

// json.name.last ("duck", "mouse", ...) never equals stringColumn ("daffy duck", "mickey mouse", ...)
// so this should return 0 rows — but the key assertion is that it doesn't throw.
BrokerResponseNative noMatchResponse = getBrokerResponseForOptimizedQuery(
"SELECT intColumn FROM testTable "
+ "WHERE json_extract_scalar(" + column + ", '$.name.last', 'STRING', '') = stringColumn "
+ "LIMIT 10",
SCHEMA);
assertTrue(noMatchResponse.getExceptions() == null || noMatchResponse.getExceptions().isEmpty(),
"Query should not throw for string column-to-column comparison, got: " + noMatchResponse.getExceptions());
assertTrue(noMatchResponse.getResultTable().getRows().isEmpty());

// Same expression on both sides — every row matches, so LIMIT 3 should return 3 rows.
checkResult(
"SELECT intColumn FROM testTable "
+ "WHERE json_extract_scalar(" + column + ", '$.name.last', 'STRING', '') "
+ "= json_extract_scalar(" + column + ", '$.name.last', 'STRING', '') "
+ "LIMIT 3",
new Object[][]{{1}, {2}, {3}});

// Also verify != works (all rows should satisfy since last name != full name)
BrokerResponseNative neqResponse = getBrokerResponseForOptimizedQuery(
"SELECT intColumn FROM testTable "
+ "WHERE json_extract_scalar(" + column + ", '$.name.last', 'STRING', '') != stringColumn "
+ "ORDER BY intColumn LIMIT 3",
SCHEMA);
assertTrue(neqResponse.getExceptions() == null || neqResponse.getExceptions().isEmpty(),
"Query should not throw for != column comparison, got: " + neqResponse.getExceptions());
assertTrue(neqResponse.getResultTable().getRows().size() == 3);

// Numeric: json.id (101, 111, 121, ...) never equals intColumn (1, 2, 3, ...)
BrokerResponseNative numericResponse = getBrokerResponseForOptimizedQuery(
"SELECT intColumn FROM testTable "
+ "WHERE json_extract_scalar(" + column + ", '$.id', 'INT', '0') = intColumn "
+ "LIMIT 10",
SCHEMA);
assertTrue(numericResponse.getExceptions() == null || numericResponse.getExceptions().isEmpty(),
"Query should not throw for numeric column-to-column comparison, got: " + numericResponse.getExceptions());
assertTrue(numericResponse.getResultTable().getRows().isEmpty());

// Numeric > comparison: intColumn (1-19) is always < json.id (101+)
BrokerResponseNative gtResponse = getBrokerResponseForOptimizedQuery(
"SELECT intColumn FROM testTable "
+ "WHERE json_extract_scalar(" + column + ", '$.id', 'INT', '0') > intColumn "
+ "ORDER BY intColumn LIMIT 3",
SCHEMA);
assertTrue(gtResponse.getExceptions() == null || gtResponse.getExceptions().isEmpty(),
"Query should not throw for > column comparison, got: " + gtResponse.getExceptions());
assertTrue(gtResponse.getResultTable().getRows().size() == 3);
}

@Test
public void testNullAsDefaultValueWithNullHandlingEnabled() {
checkResult(
Expand Down