diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriter.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriter.java index 5b9fff5659f5..208955e55f17 100644 --- a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriter.java +++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriter.java @@ -21,6 +21,7 @@ import com.google.common.base.Preconditions; import java.util.List; import org.apache.commons.lang3.EnumUtils; +import org.apache.pinot.common.function.TransformFunctionType; import org.apache.pinot.common.request.Expression; import org.apache.pinot.common.request.ExpressionType; import org.apache.pinot.common.request.Function; @@ -46,8 +47,10 @@ public PinotQuery rewrite(PinotQuery pinotQuery) { /** * This method converts an expression to what Pinot could evaluate. - * 1. For comparison expression, left operand could be any expression, but right operand only - * supports literal. E.g. 'WHERE a > b' will be converted to 'WHERE a - b > 0' + * 1. For comparison expressions, the right operand should be a literal. If the left operand is a + * literal and the right is not, they are swapped. If the right operand is still non-literal + * (column-to-column comparison), the predicate is rewritten using a comparison transform + * function: e.g. 'WHERE a = b' becomes 'WHERE equals(a, b) = true'. * 2. Updates boolean predicates (literals and scalar functions) that are missing an EQUALS filter. * E.g. 1: 'WHERE a' will be updated to 'WHERE a = true' * E.g. 2: "WHERE startsWith(col, 'str')" will be updated to "WHERE startsWith(col, 'str') = true" @@ -115,11 +118,12 @@ private static Expression updateFunctionExpression(Expression expression) { break; } - // Handle predicate like 'a > b' -> 'a - b > 0' if (!secondOperand.isSetLiteral()) { - Expression minusExpression = RequestUtils.getFunctionExpression("minus", firstOperand, secondOperand); - operands.set(0, minusExpression); - operands.set(1, RequestUtils.getLiteralExpression(0)); + Expression comparisonExpression = RequestUtils.getFunctionExpression( + getComparisonFunctionName(filterKind), firstOperand, secondOperand); + function.setOperator(FilterKind.EQUALS.name()); + operands.set(0, comparisonExpression); + operands.set(1, RequestUtils.getLiteralExpression(true)); break; } break; @@ -204,6 +208,25 @@ private static Expression convertPredicateToEqualsBooleanExpression(Expression e RequestUtils.getLiteralExpression(true)); } + private static String getComparisonFunctionName(FilterKind filterKind) { + switch (filterKind) { + case EQUALS: + return TransformFunctionType.EQUALS.getName(); + case NOT_EQUALS: + return TransformFunctionType.NOT_EQUALS.getName(); + case GREATER_THAN: + return TransformFunctionType.GREATER_THAN.getName(); + case GREATER_THAN_OR_EQUAL: + return TransformFunctionType.GREATER_THAN_OR_EQUAL.getName(); + case LESS_THAN: + return TransformFunctionType.LESS_THAN.getName(); + case LESS_THAN_OR_EQUAL: + return TransformFunctionType.LESS_THAN_OR_EQUAL.getName(); + default: + throw new IllegalStateException("Unsupported comparison operator: " + filterKind); + } + } + /** * The purpose of this method is to convert expression "0 < columnA" to "columnA > 0". * The conversion would be: diff --git a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java index 532b28ab0bd5..7e34bdabed62 100644 --- a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java +++ b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java @@ -492,13 +492,13 @@ public void testFilterClauses() { public void testFilterClausesWithRightExpression() { PinotQuery pinotQuery = compileToPinotQuery("select * from vegetables where a > b"); Function func = pinotQuery.getFilterExpression().getFunctionCall(); - Assert.assertEquals(func.getOperator(), FilterKind.GREATER_THAN.name()); - Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperator(), "minus"); + Assert.assertEquals(func.getOperator(), FilterKind.EQUALS.name()); + Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperator(), "greater_than"); Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperands().get(0).getIdentifier().getName(), "a"); Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperands().get(1).getIdentifier().getName(), "b"); - Assert.assertEquals(func.getOperands().get(1).getLiteral().getIntValue(), 0); + Assert.assertTrue(func.getOperands().get(1).getLiteral().getBoolValue()); pinotQuery = compileToPinotQuery("select * from vegetables where 0 < a-b"); func = pinotQuery.getFilterExpression().getFunctionCall(); Assert.assertEquals(func.getOperator(), FilterKind.GREATER_THAN.name()); @@ -511,8 +511,8 @@ public void testFilterClausesWithRightExpression() { pinotQuery = compileToPinotQuery("select * from vegetables where b < 100 + c"); func = pinotQuery.getFilterExpression().getFunctionCall(); - Assert.assertEquals(func.getOperator(), FilterKind.LESS_THAN.name()); - Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperator(), "minus"); + Assert.assertEquals(func.getOperator(), FilterKind.EQUALS.name()); + Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperator(), "less_than"); Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperands().get(0).getIdentifier().getName(), "b"); Assert.assertEquals( @@ -523,7 +523,7 @@ public void testFilterClausesWithRightExpression() { Assert.assertEquals( func.getOperands().get(0).getFunctionCall().getOperands().get(1).getFunctionCall().getOperands().get(1) .getIdentifier().getName(), "c"); - Assert.assertEquals(func.getOperands().get(1).getLiteral().getIntValue(), 0); + Assert.assertTrue(func.getOperands().get(1).getLiteral().getBoolValue()); pinotQuery = compileToPinotQuery("select * from vegetables where b -(100+c)< 0"); func = pinotQuery.getFilterExpression().getFunctionCall(); Assert.assertEquals(func.getOperator(), FilterKind.LESS_THAN.name()); @@ -542,8 +542,8 @@ public void testFilterClausesWithRightExpression() { pinotQuery = compileToPinotQuery("select * from vegetables where foo1(bar1(a-b)) <= foo2(bar2(c+d))"); func = pinotQuery.getFilterExpression().getFunctionCall(); - Assert.assertEquals(func.getOperator(), FilterKind.LESS_THAN_OR_EQUAL.name()); - Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperator(), "minus"); + Assert.assertEquals(func.getOperator(), FilterKind.EQUALS.name()); + Assert.assertEquals(func.getOperands().get(0).getFunctionCall().getOperator(), "less_than_or_equal"); Assert.assertEquals( func.getOperands().get(0).getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), "foo1"); Assert.assertEquals( @@ -576,7 +576,7 @@ public void testFilterClausesWithRightExpression() { func.getOperands().get(0).getFunctionCall().getOperands().get(1).getFunctionCall().getOperands().get(0) .getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(1).getIdentifier().getName(), "d"); - Assert.assertEquals(func.getOperands().get(1).getLiteral().getIntValue(), 0); + Assert.assertTrue(func.getOperands().get(1).getLiteral().getBoolValue()); pinotQuery = compileToPinotQuery("select * from vegetables where foo1(bar1(a-b)) - foo2(bar2(c+d)) <= 0"); func = pinotQuery.getFilterExpression().getFunctionCall(); Assert.assertEquals(func.getOperator(), FilterKind.LESS_THAN_OR_EQUAL.name()); diff --git a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriterTest.java b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriterTest.java index 4a11a79d8760..63cff46ad1a2 100644 --- a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriterTest.java +++ b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriterTest.java @@ -125,56 +125,78 @@ public void testFilterPredicateLiteralIdentifierSwap() { @Test public void testFilterPredicateColumnComparisonRewrite() { - // Filters like 'col1 = col2' should be rewritten to 'col1 - col2 = 0' - - PinotQuery pinotQuery = - CalciteSqlParser.compileToPinotQueryWithoutRewrites("SELECT * FROM mytable WHERE col1 = col2 AND col3 < col4;"); - assertEquals(pinotQuery.getFilterExpression().getFunctionCall().getOperator(), "AND"); - assertEquals(pinotQuery.getFilterExpression().getFunctionCall().getOperands().size(), 2); - assertEquals( - pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), - "EQUALS"); - assertEquals( - pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(0) - .getIdentifier().getName(), "col1"); + // col1 = col2 should be rewritten to equals(col1, col2) = true + PinotQuery equalsQuery = + CalciteSqlParser.compileToPinotQueryWithoutRewrites("SELECT * FROM mytable WHERE col1 = col2"); + PinotQuery rewrittenEquals = _predicateComparisonRewriter.rewrite(equalsQuery); + assertEquals(rewrittenEquals.getFilterExpression().getFunctionCall().getOperator(), "EQUALS"); assertEquals( - pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(1) - .getIdentifier().getName(), "col2"); - - PinotQuery rewrittenQuery = _predicateComparisonRewriter.rewrite(pinotQuery); - assertEquals(rewrittenQuery.getFilterExpression().getFunctionCall().getOperator(), "AND"); - assertEquals(rewrittenQuery.getFilterExpression().getFunctionCall().getOperands().size(), 2); + rewrittenEquals.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), + "equals"); assertEquals( - pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), - "EQUALS"); - assertEquals( - pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(0) - .getFunctionCall().getOperator(), "minus"); + rewrittenEquals.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperands() + .get(0).getIdentifier().getName(), "col1"); assertEquals( - pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(0) - .getFunctionCall().getOperands().get(0).getIdentifier().getName(), "col1"); - assertEquals( - pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(0) - .getFunctionCall().getOperands().get(1).getIdentifier().getName(), "col2"); - assertEquals( - pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getFunctionCall().getOperands().get(1) - .getLiteral().getIntValue(), 0); - assertEquals( - pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getFunctionCall().getOperator(), - "LESS_THAN"); - assertEquals( - pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getFunctionCall().getOperands().get(0) - .getFunctionCall().getOperator(), "minus"); - assertEquals( - pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getFunctionCall().getOperands().get(0) - .getFunctionCall().getOperands().get(0).getIdentifier().getName(), "col3"); - assertEquals( - pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getFunctionCall().getOperands().get(0) - .getFunctionCall().getOperands().get(1).getIdentifier().getName(), "col4"); - assertEquals( - pinotQuery.getFilterExpression().getFunctionCall().getOperands().get(1).getFunctionCall().getOperands().get(1) - .getLiteral().getIntValue(), 0); - + rewrittenEquals.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperands() + .get(1).getIdentifier().getName(), "col2"); + assertTrue( + rewrittenEquals.getFilterExpression().getFunctionCall().getOperands().get(1).getLiteral().getBoolValue()); + + // col3 < col4 should be rewritten to less_than(col3, col4) = true + PinotQuery lessThanQuery = + CalciteSqlParser.compileToPinotQueryWithoutRewrites("SELECT * FROM mytable WHERE col3 < col4"); + PinotQuery rewrittenLt = _predicateComparisonRewriter.rewrite(lessThanQuery); + assertEquals(rewrittenLt.getFilterExpression().getFunctionCall().getOperator(), "EQUALS"); + assertEquals( + rewrittenLt.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), + "less_than"); + + // col1 != col2 should be rewritten to not_equals(col1, col2) = true + PinotQuery notEqualsQuery = + CalciteSqlParser.compileToPinotQueryWithoutRewrites("SELECT * FROM mytable WHERE col1 != col2"); + PinotQuery rewrittenNeq = _predicateComparisonRewriter.rewrite(notEqualsQuery); + assertEquals(rewrittenNeq.getFilterExpression().getFunctionCall().getOperator(), "EQUALS"); + assertEquals( + rewrittenNeq.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), + "not_equals"); + + // Function on LHS with column on RHS + PinotQuery functionRhsQuery = CalciteSqlParser.compileToPinotQueryWithoutRewrites( + "SELECT * FROM mytable WHERE json_extract_scalar(col1, '$.f', 'STRING', 'null') = col2"); + PinotQuery rewrittenFunc = _predicateComparisonRewriter.rewrite(functionRhsQuery); + assertEquals(rewrittenFunc.getFilterExpression().getFunctionCall().getOperator(), "EQUALS"); + assertEquals( + rewrittenFunc.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), + "equals"); + + // col5 >= col6 should be rewritten to greater_than_or_equal(col5, col6) = true + PinotQuery gteQuery = + CalciteSqlParser.compileToPinotQueryWithoutRewrites("SELECT * FROM mytable WHERE col5 >= col6"); + PinotQuery rewrittenGte = _predicateComparisonRewriter.rewrite(gteQuery); + assertEquals(rewrittenGte.getFilterExpression().getFunctionCall().getOperator(), "EQUALS"); + assertEquals( + rewrittenGte.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), + "greater_than_or_equal"); + + // col5 <= col6 should be rewritten to less_than_or_equal(col5, col6) = true + PinotQuery lteQuery = + CalciteSqlParser.compileToPinotQueryWithoutRewrites("SELECT * FROM mytable WHERE col5 <= col6"); + PinotQuery rewrittenLte = _predicateComparisonRewriter.rewrite(lteQuery); + assertEquals(rewrittenLte.getFilterExpression().getFunctionCall().getOperator(), "EQUALS"); + assertEquals( + rewrittenLte.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), + "less_than_or_equal"); + + // col7 > col8 should be rewritten to greater_than(col7, col8) = true + PinotQuery gtQuery = + CalciteSqlParser.compileToPinotQueryWithoutRewrites("SELECT * FROM mytable WHERE col7 > col8"); + PinotQuery rewrittenGt = _predicateComparisonRewriter.rewrite(gtQuery); + assertEquals(rewrittenGt.getFilterExpression().getFunctionCall().getOperator(), "EQUALS"); + assertEquals( + rewrittenGt.getFilterExpression().getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), + "greater_than"); + + // BETWEEN with non-literal bounds should still throw (not a comparison operator) PinotQuery betweenQuery = CalciteSqlParser.compileToPinotQueryWithoutRewrites("SELECT * FROM mytable WHERE col1 BETWEEN col2 AND col3"); assertThrows(SqlCompilationException.class, () -> _predicateComparisonRewriter.rewrite(betweenQuery)); diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/JsonExtractScalarTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/JsonExtractScalarTest.java index 5c7049369a01..a91236ef616a 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/JsonExtractScalarTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/JsonExtractScalarTest.java @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import org.apache.pinot.common.response.broker.BrokerResponseNative; import org.apache.pinot.spi.config.table.FieldConfig; import org.apache.pinot.spi.config.table.TableConfig; import org.apache.pinot.spi.config.table.TableType; @@ -30,6 +31,8 @@ import org.apache.pinot.spi.utils.builder.TableConfigBuilder; import org.testng.annotations.Test; +import static org.testng.Assert.assertTrue; + public class JsonExtractScalarTest extends BaseJsonQueryTest { @@ -191,6 +194,61 @@ public void testNullAsDefaultValueWithNullHandlingDisabled() { ); } + @Test(dataProvider = "allJsonColumns") + public void testJsonExtractScalarComparedToColumn(String column) { + // Before the fix, column-to-column comparisons were rewritten as minus(a, b) = 0, + // which forced numeric coercion and threw NumberFormatException on string values. + + // json.name.last ("duck", "mouse", ...) never equals stringColumn ("daffy duck", "mickey mouse", ...) + // so this should return 0 rows — but the key assertion is that it doesn't throw. + BrokerResponseNative noMatchResponse = getBrokerResponseForOptimizedQuery( + "SELECT intColumn FROM testTable " + + "WHERE json_extract_scalar(" + column + ", '$.name.last', 'STRING', '') = stringColumn " + + "LIMIT 10", + SCHEMA); + assertTrue(noMatchResponse.getExceptions() == null || noMatchResponse.getExceptions().isEmpty(), + "Query should not throw for string column-to-column comparison, got: " + noMatchResponse.getExceptions()); + assertTrue(noMatchResponse.getResultTable().getRows().isEmpty()); + + // Same expression on both sides — every row matches, so LIMIT 3 should return 3 rows. + checkResult( + "SELECT intColumn FROM testTable " + + "WHERE json_extract_scalar(" + column + ", '$.name.last', 'STRING', '') " + + "= json_extract_scalar(" + column + ", '$.name.last', 'STRING', '') " + + "LIMIT 3", + new Object[][]{{1}, {2}, {3}}); + + // Also verify != works (all rows should satisfy since last name != full name) + BrokerResponseNative neqResponse = getBrokerResponseForOptimizedQuery( + "SELECT intColumn FROM testTable " + + "WHERE json_extract_scalar(" + column + ", '$.name.last', 'STRING', '') != stringColumn " + + "ORDER BY intColumn LIMIT 3", + SCHEMA); + assertTrue(neqResponse.getExceptions() == null || neqResponse.getExceptions().isEmpty(), + "Query should not throw for != column comparison, got: " + neqResponse.getExceptions()); + assertTrue(neqResponse.getResultTable().getRows().size() == 3); + + // Numeric: json.id (101, 111, 121, ...) never equals intColumn (1, 2, 3, ...) + BrokerResponseNative numericResponse = getBrokerResponseForOptimizedQuery( + "SELECT intColumn FROM testTable " + + "WHERE json_extract_scalar(" + column + ", '$.id', 'INT', '0') = intColumn " + + "LIMIT 10", + SCHEMA); + assertTrue(numericResponse.getExceptions() == null || numericResponse.getExceptions().isEmpty(), + "Query should not throw for numeric column-to-column comparison, got: " + numericResponse.getExceptions()); + assertTrue(numericResponse.getResultTable().getRows().isEmpty()); + + // Numeric > comparison: intColumn (1-19) is always < json.id (101+) + BrokerResponseNative gtResponse = getBrokerResponseForOptimizedQuery( + "SELECT intColumn FROM testTable " + + "WHERE json_extract_scalar(" + column + ", '$.id', 'INT', '0') > intColumn " + + "ORDER BY intColumn LIMIT 3", + SCHEMA); + assertTrue(gtResponse.getExceptions() == null || gtResponse.getExceptions().isEmpty(), + "Query should not throw for > column comparison, got: " + gtResponse.getExceptions()); + assertTrue(gtResponse.getResultTable().getRows().size() == 3); + } + @Test public void testNullAsDefaultValueWithNullHandlingEnabled() { checkResult(