Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] check compare optmization more strict (backport #42936) #43345

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ public static Type getCompatibleTypeForBinary(boolean isNotRangeComparison, Type
baseType = type1.isDecimalOfAnyVersion() ? type1 : type2;
}
}

if (ConnectContext.get() != null && SessionVariableConstants.DOUBLE.equalsIgnoreCase(ConnectContext.get()
.getSessionVariable().getCboEqBaseType())) {
baseType = Type.DOUBLE;
}
if ((type1.isStringType() && type2.isExactNumericType()) ||
(type1.isExactNumericType() && type2.isStringType())) {
return baseType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,13 @@ public ScalarOperator visitBinaryPredicate(BinaryPredicateOperator predicate,
}
}

<<<<<<< HEAD
Type compatibleType = TypeManager.getCompatibleTypeForBinary(
predicate.getBinaryType().isNotRangeComparison(), type1, type2);
=======
Type compatibleType =
TypeManager.getCompatibleTypeForBinary(predicate.getBinaryType(), type1, type2);
>>>>>>> a04b659a9a ([Enhancement] check compare optmization more strict (#42936))

if (!type1.matchesType(compatibleType)) {
addCastChild(compatibleType, predicate, 0);
Expand All @@ -163,23 +168,44 @@ private Optional<BinaryPredicateOperator> optimizeConstantAndVariable(BinaryPred
Type typeConstant = constant.getType();
Type typeVariable = variable.getType();

<<<<<<< HEAD
if (typeVariable.isStringType() && typeConstant.isExactNumericType()) {
if (ConnectContext.get() == null || "decimal".equalsIgnoreCase(
ConnectContext.get().getSessionVariable().getCboEqBaseType())) {
// don't optimize when cbo_eq_base_type is decimal
return Optional.empty();
=======
boolean checkStringCastToNumber = false;
if (typeVariable.isNumericType() && typeConstant.isStringType()) {
if (predicate.getBinaryType().isNotRangeComparison()) {
String baseType = ConnectContext.get() != null ?
ConnectContext.get().getSessionVariable().getCboEqBaseType() : SessionVariableConstants.VARCHAR;
checkStringCastToNumber = SessionVariableConstants.DECIMAL.equals(baseType) ||
SessionVariableConstants.DOUBLE.equals(baseType);
} else {
// range compare, base type must be double
checkStringCastToNumber = true;
>>>>>>> a04b659a9a ([Enhancement] check compare optmization more strict (#42936))
}
}

Optional<ScalarOperator> op = Utils.tryCastConstant(constant, variable.getType());
if (op.isPresent()) {
predicate.getChildren().set(constantIndex, op.get());
return Optional.of(predicate);
} else if (variable.getType().isDateType() && !constant.getType().isDateType() &&
Type.canCastTo(constant.getType(), variable.getType())) {
// For like MySQL, convert to date type as much as possible
addCastChild(variable.getType(), predicate, constantIndex);
return Optional.of(predicate);
// strict check, only support white check
if ((typeVariable.isNumericType() && typeConstant.isNumericType()) ||
(typeVariable.isNumericType() && typeConstant.isBoolean()) ||
(typeVariable.isDateType() && typeConstant.isNumericType()) ||
(typeVariable.isDateType() && typeConstant.isStringType()) ||
(typeVariable.isBoolean() && typeConstant.isStringType()) ||
checkStringCastToNumber) {

Optional<ScalarOperator> op = Utils.tryCastConstant(constant, variable.getType());
if (op.isPresent()) {
predicate.getChildren().set(constantIndex, op.get());
return Optional.of(predicate);
} else if (variable.getType().isDateType() && Type.canCastTo(constant.getType(), variable.getType())) {
// For like MySQL, convert to date type as much as possible
addCastChild(variable.getType(), predicate, constantIndex);
return Optional.of(predicate);
}
}

return Optional.empty();
Expand Down
106 changes: 106 additions & 0 deletions fe/fe-core/src/test/java/com/starrocks/sql/plan/ExpressionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1431,4 +1431,110 @@ public void testCastConstantFn() throws Exception {
String plan = getFragmentPlan(sql);
assertContains(plan, " CAST(9: id_date AS DOUBLE) = 1998.0");
}

@Test
public void testCastStringNumber() throws Exception {
try {
// string number
connectContext.getSessionVariable().setCboEqBaseType("VARCHAR");
String sql = "select t1a = 1.2345 from test_all_type";
String plan = getVerboseExplain(sql);
assertContains(plan, "[1: t1a, VARCHAR, true] = '1.2345'");

sql = "select t1a < 1 from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([1: t1a, VARCHAR, true] as DOUBLE) < 1.0");

// number string
sql = "select t1f = '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "[6: t1f, DOUBLE, true] = 123.345");

sql = "select t1c >= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([3: t1c, INT, true] as DOUBLE) >= 123.345");

sql = "select t1e >= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([5: t1e, FLOAT, true] as DOUBLE) >= 123.345");

sql = "select t1f <= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "[6: t1f, DOUBLE, true] <= 123.345");

sql = "select t1e >= 'abc' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([5: t1e, FLOAT, true] as DOUBLE) >= cast('abc' as DOUBLE)");

sql = "select t1g = 'abc' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([7: t1g, BIGINT, true] as VARCHAR(1048576)) = 'abc'");

sql = "select id_bool = 'abc' from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([11: id_bool, BOOLEAN, true] as DOUBLE) = cast('abc' as DOUBLE)");

sql = "select id_bool = 'false' from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "[11: id_bool, BOOLEAN, true] = FALSE");

sql = "select t1g = true from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "[7: t1g, BIGINT, true] = 1");
} finally {
connectContext.getSessionVariable().setCboEqBaseType("VARCHAR");
}

try {
connectContext.getSessionVariable().setCboEqBaseType("DECIMAL");
// string number
String sql = "select t1a = 1.2345 from test_all_type";
String plan = getVerboseExplain(sql);
assertContains(plan, "cast([1: t1a, VARCHAR, true] as DECIMAL32(5,4)) = 1.2345");

sql = "select t1a < 1 from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([1: t1a, VARCHAR, true] as DOUBLE) < 1.0");

// number string
sql = "select t1f = '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "[6: t1f, DOUBLE, true] = 123.345");

sql = "select t1c >= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([3: t1c, INT, true] as DOUBLE) >= 123.345");

sql = "select t1e >= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([5: t1e, FLOAT, true] as DOUBLE) >= 123.345");

sql = "select t1f <= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "[6: t1f, DOUBLE, true] <= 123.345");

sql = "select t1e >= 'abc' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([5: t1e, FLOAT, true] as DOUBLE) >= cast('abc' as DOUBLE)");

sql = "select t1g = 'abc' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([7: t1g, BIGINT, true] as DECIMAL128(38,9)) " +
"= cast('abc' as DECIMAL128(38,9))");

sql = "select id_bool = 'abc' from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([11: id_bool, BOOLEAN, true] as DOUBLE) = cast('abc' as DOUBLE)");

sql = "select id_bool = 'false' from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "[11: id_bool, BOOLEAN, true] = FALSE");

sql = "select t1g = true from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "[7: t1g, BIGINT, true] = 1");
} finally {
connectContext.getSessionVariable().setCboEqBaseType("VARCHAR");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,9 @@ public RuleSet getRuleSet() {
Assert.assertTrue(replayPair.second, replayPair.second.contains(" 14:NESTLOOP JOIN\n" +
" | join op: INNER JOIN\n" +
" | colocate: false, reason: \n" +
" | other join predicates: CASE WHEN CAST(6: v3 AS BOOLEAN) THEN CAST(11: v2 AS VARCHAR) " +
"WHEN CAST(3: v3 AS BOOLEAN) THEN '123' ELSE CAST(12: v3 AS VARCHAR) END > '1', " +
"(2: v2 = CAST(8: v2 AS VARCHAR(1048576))) OR (3: v3 = 8: v2)\n"));
" | other join predicates: CAST(CASE WHEN CAST(6: v3 AS BOOLEAN) THEN CAST(11: v2 AS VARCHAR) " +
"WHEN CAST(3: v3 AS BOOLEAN) THEN '123' ELSE CAST(12: v3 AS VARCHAR) END AS DOUBLE) > " +
"1.0, (2: v2 = CAST(8: v2 AS VARCHAR(1048576))) OR (3: v3 = 8: v2)\n"));
connectContext.getSessionVariable().setEnableLocalShuffleAgg(true);
}

Expand Down
85 changes: 85 additions & 0 deletions fe/fe-core/src/test/resources/sql/subquery/in-subquery.sql
Original file line number Diff line number Diff line change
Expand Up @@ -884,4 +884,89 @@ NULL AWARE LEFT ANTI JOIN (join-predicate [19: cast = 20: abs AND add(cast(9: t1
SCAN (columns[1: v1, 3: v3] predicate[null])
EXCHANGE BROADCAST
SCAN (columns[7: t1a, 9: t1c, 10: t1d] predicate[null])
<<<<<<< HEAD
=======
[end]

[sql]
select t0.v1 from t0 where (v1, v2) IN (select t1.v4, t1.v5 from t1)
[result]
LEFT SEMI JOIN (join-predicate [1: v1 = 4: v4 AND 2: v2 = 5: v5] post-join-predicate [null])
SCAN (columns[1: v1, 2: v2] predicate[null])
EXCHANGE SHUFFLE[4]
SCAN (columns[4: v4, 5: v5] predicate[4: v4 IS NOT NULL AND 5: v5 IS NOT NULL])
[end]

[sql]
select t0.v1 from t0 where (v1, v2) NOT IN (select t1.v4, t1.v5 from t1)
[result]
NULL AWARE LEFT ANTI JOIN (join-predicate [1: v1 = 4: v4 AND 2: v2 = 5: v5] post-join-predicate [null])
SCAN (columns[1: v1, 2: v2] predicate[null])
EXCHANGE BROADCAST
SCAN (columns[4: v4, 5: v5] predicate[null])
[end]

[sql]
select * from t0 where (v1, v2) IN (select t1.v4, t1.v5 from t1 WHERE t1.v6 = t0.v3 AND t1.v5 > 10)
[result]
LEFT SEMI JOIN (join-predicate [1: v1 = 4: v4 AND 2: v2 = 5: v5 AND 3: v3 = 6: v6] post-join-predicate [null])
SCAN (columns[1: v1, 2: v2, 3: v3] predicate[2: v2 > 10])
EXCHANGE SHUFFLE[4]
SCAN (columns[4: v4, 5: v5, 6: v6] predicate[5: v5 > 10])
[end]

[sql]
select * from t0 where (v1, v2) NOT IN (select t1.v4, t1.v5 from t1 WHERE t1.v6 = t0.v3 AND t1.v5 > 10)
[result]
NULL AWARE LEFT ANTI JOIN (join-predicate [1: v1 = 4: v4 AND 2: v2 = 5: v5 AND 6: v6 = 3: v3] post-join-predicate [null])
SCAN (columns[1: v1, 2: v2, 3: v3] predicate[null])
EXCHANGE BROADCAST
SCAN (columns[4: v4, 5: v5, 6: v6] predicate[5: v5 > 10])
[end]

[sql]
select * from t0 where (v1, v2) IN (select t1.v4, t1.v5 from t1 WHERE t1.v6 = t0.v3 AND t1.v5 > 10) AND v3 > 10
[result]
LEFT SEMI JOIN (join-predicate [1: v1 = 4: v4 AND 2: v2 = 5: v5 AND 3: v3 = 6: v6] post-join-predicate [null])
SCAN (columns[1: v1, 2: v2, 3: v3] predicate[2: v2 > 10 AND 3: v3 > 10])
EXCHANGE SHUFFLE[4]
SCAN (columns[4: v4, 5: v5, 6: v6] predicate[5: v5 > 10 AND 6: v6 > 10])
[end]

[sql]
select * from test_all_type where (t1e, t1f) IN (select v4, v5 from t1)
[result]
LEFT SEMI JOIN (join-predicate [15: cast = 16: cast AND 6: t1f = 17: cast] post-join-predicate [null])
SCAN (columns[1: t1a, 2: t1b, 3: t1c, 4: t1d, 5: t1e, 6: t1f, 7: t1g, 8: id_datetime, 9: id_date, 10: id_decimal] predicate[null])
EXCHANGE BROADCAST
SCAN (columns[11: v4, 12: v5] predicate[cast(11: v4 as double) IS NOT NULL AND cast(12: v5 as double) IS NOT NULL])
[end]

[sql]
select t0.v1 from t0 where (v1 + 10, v2 + v2) IN (select t1.v4 + t1.v5, t1.v5 from t1)
[result]
LEFT SEMI JOIN (join-predicate [9: add = 7: expr AND 10: add = 5: v5] post-join-predicate [null])
SCAN (columns[1: v1, 2: v2] predicate[null])
EXCHANGE BROADCAST
SCAN (columns[4: v4, 5: v5] predicate[add(4: v4, 5: v5) IS NOT NULL AND 5: v5 IS NOT NULL])
[end]

[sql]
select 1 from customer where (C_NATIONKEY, C_NAME) IN (select P_NAME, P_RETAILPRICE from part)
[result]
LEFT SEMI JOIN (join-predicate [22: cast = 23: cast AND 24: cast = 17: P_RETAILPRICE] post-join-predicate [null])
SCAN (columns[2: C_NAME, 4: C_NATIONKEY] predicate[null])
EXCHANGE BROADCAST
SCAN (columns[17: P_RETAILPRICE, 11: P_NAME] predicate[cast(11: P_NAME as double) IS NOT NULL])
[end]

[sql]
select 1 from customer where (C_NATIONKEY, C_NAME) IN (select "aa", 123.45)
[result]
LEFT SEMI JOIN (join-predicate [15: cast = 16: cast AND 17: cast = 18: cast] post-join-predicate [null])
SCAN (columns[2: C_NAME, 4: C_NATIONKEY] predicate[null])
EXCHANGE BROADCAST
PREDICATE cast(aa as double) IS NOT NULL
VALUES (null)
>>>>>>> a04b659a9a ([Enhancement] check compare optmization more strict (#42936))
[end]