Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] check compare optmization more strict (backport #42936) #43343

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,10 @@ public static Type getCompatibleTypeForBinary(BinaryType type, Type type1, Type
baseType = type1.isDecimalOfAnyVersion() ? type1 : type2;
}
}

if (ConnectContext.get() != null && SessionVariableConstants.DOUBLE.equalsIgnoreCase(ConnectContext.get()
.getSessionVariable().getCboEqBaseType())) {
baseType = Type.DOUBLE;
}
if ((type1.isStringType() && type2.isExactNumericType()) ||
(type1.isExactNumericType() && type2.isStringType())) {
return baseType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ public ScalarOperator visitBinaryPredicate(BinaryPredicateOperator predicate,
}
}

Type compatibleType = TypeManager.getCompatibleTypeForBinary(predicate.getBinaryType(), type1, type2);
Type compatibleType =
TypeManager.getCompatibleTypeForBinary(predicate.getBinaryType(), type1, type2);

if (!type1.matchesType(compatibleType)) {
addCastChild(compatibleType, predicate, 0);
Expand All @@ -190,23 +191,36 @@ private Optional<BinaryPredicateOperator> optimizeConstantAndVariable(BinaryPred
Type typeConstant = constant.getType();
Type typeVariable = variable.getType();

if (typeVariable.isStringType() && typeConstant.isExactNumericType()) {
if (ConnectContext.get() == null || SessionVariableConstants.DECIMAL.equalsIgnoreCase(
ConnectContext.get().getSessionVariable().getCboEqBaseType())) {
// don't optimize when cbo_eq_base_type is decimal
return Optional.empty();
boolean checkStringCastToNumber = false;
if (typeVariable.isNumericType() && typeConstant.isStringType()) {
if (predicate.getBinaryType().isNotRangeComparison()) {
String baseType = ConnectContext.get() != null ?
ConnectContext.get().getSessionVariable().getCboEqBaseType() : SessionVariableConstants.VARCHAR;
checkStringCastToNumber = SessionVariableConstants.DECIMAL.equals(baseType) ||
SessionVariableConstants.DOUBLE.equals(baseType);
} else {
// range compare, base type must be double
checkStringCastToNumber = true;
}
}

Optional<ScalarOperator> op = Utils.tryCastConstant(constant, variable.getType());
if (op.isPresent()) {
predicate.getChildren().set(constantIndex, op.get());
return Optional.of(predicate);
} else if (variable.getType().isDateType() && !constant.getType().isDateType() &&
Type.canCastTo(constant.getType(), variable.getType())) {
// For like MySQL, convert to date type as much as possible
addCastChild(variable.getType(), predicate, constantIndex);
return Optional.of(predicate);
// strict check, only support white check
if ((typeVariable.isNumericType() && typeConstant.isNumericType()) ||
(typeVariable.isNumericType() && typeConstant.isBoolean()) ||
(typeVariable.isDateType() && typeConstant.isNumericType()) ||
(typeVariable.isDateType() && typeConstant.isStringType()) ||
(typeVariable.isBoolean() && typeConstant.isStringType()) ||
checkStringCastToNumber) {

Optional<ScalarOperator> op = Utils.tryCastConstant(constant, variable.getType());
if (op.isPresent()) {
predicate.getChildren().set(constantIndex, op.get());
return Optional.of(predicate);
} else if (variable.getType().isDateType() && Type.canCastTo(constant.getType(), variable.getType())) {
// For like MySQL, convert to date type as much as possible
addCastChild(variable.getType(), predicate, constantIndex);
return Optional.of(predicate);
}
}

return Optional.empty();
Expand Down
106 changes: 106 additions & 0 deletions fe/fe-core/src/test/java/com/starrocks/sql/plan/ExpressionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1743,4 +1743,110 @@ public void testCastConstantFn() throws Exception {
String plan = getFragmentPlan(sql);
assertContains(plan, " CAST(9: id_date AS DOUBLE) = 1998.0");
}

@Test
public void testCastStringNumber() throws Exception {
try {
// string number
connectContext.getSessionVariable().setCboEqBaseType("VARCHAR");
String sql = "select t1a = 1.2345 from test_all_type";
String plan = getVerboseExplain(sql);
assertContains(plan, "[1: t1a, VARCHAR, true] = '1.2345'");

sql = "select t1a < 1 from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([1: t1a, VARCHAR, true] as DOUBLE) < 1.0");

// number string
sql = "select t1f = '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "[6: t1f, DOUBLE, true] = 123.345");

sql = "select t1c >= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([3: t1c, INT, true] as DOUBLE) >= 123.345");

sql = "select t1e >= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([5: t1e, FLOAT, true] as DOUBLE) >= 123.345");

sql = "select t1f <= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "[6: t1f, DOUBLE, true] <= 123.345");

sql = "select t1e >= 'abc' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([5: t1e, FLOAT, true] as DOUBLE) >= cast('abc' as DOUBLE)");

sql = "select t1g = 'abc' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([7: t1g, BIGINT, true] as VARCHAR(1048576)) = 'abc'");

sql = "select id_bool = 'abc' from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([11: id_bool, BOOLEAN, true] as DOUBLE) = cast('abc' as DOUBLE)");

sql = "select id_bool = 'false' from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "[11: id_bool, BOOLEAN, true] = FALSE");

sql = "select t1g = true from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "[7: t1g, BIGINT, true] = 1");
} finally {
connectContext.getSessionVariable().setCboEqBaseType("VARCHAR");
}

try {
connectContext.getSessionVariable().setCboEqBaseType("DECIMAL");
// string number
String sql = "select t1a = 1.2345 from test_all_type";
String plan = getVerboseExplain(sql);
assertContains(plan, "cast([1: t1a, VARCHAR, true] as DECIMAL32(5,4)) = 1.2345");

sql = "select t1a < 1 from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([1: t1a, VARCHAR, true] as DOUBLE) < 1.0");

// number string
sql = "select t1f = '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "[6: t1f, DOUBLE, true] = 123.345");

sql = "select t1c >= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([3: t1c, INT, true] as DOUBLE) >= 123.345");

sql = "select t1e >= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([5: t1e, FLOAT, true] as DOUBLE) >= 123.345");

sql = "select t1f <= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "[6: t1f, DOUBLE, true] <= 123.345");

sql = "select t1e >= 'abc' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([5: t1e, FLOAT, true] as DOUBLE) >= cast('abc' as DOUBLE)");

sql = "select t1g = 'abc' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([7: t1g, BIGINT, true] as DECIMAL128(38,9)) " +
"= cast('abc' as DECIMAL128(38,9))");

sql = "select id_bool = 'abc' from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([11: id_bool, BOOLEAN, true] as DOUBLE) = cast('abc' as DOUBLE)");

sql = "select id_bool = 'false' from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "[11: id_bool, BOOLEAN, true] = FALSE");

sql = "select t1g = true from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "[7: t1g, BIGINT, true] = 1");
} finally {
connectContext.getSessionVariable().setCboEqBaseType("VARCHAR");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,9 @@ public RuleSet getRuleSet() {
Assert.assertTrue(replayPair.second, replayPair.second.contains(" 13:NESTLOOP JOIN\n" +
" | join op: INNER JOIN\n" +
" | colocate: false, reason: \n" +
" | other join predicates: CASE WHEN CAST(6: v3 AS BOOLEAN) THEN CAST(11: v2 AS VARCHAR) " +
"WHEN CAST(3: v3 AS BOOLEAN) THEN '123' ELSE CAST(12: v3 AS VARCHAR) END > '1', " +
"(2: v2 = CAST(8: v2 AS VARCHAR(1048576))) OR (3: v3 = 8: v2)\n"));
" | other join predicates: CAST(CASE WHEN CAST(6: v3 AS BOOLEAN) THEN CAST(11: v2 AS VARCHAR) " +
"WHEN CAST(3: v3 AS BOOLEAN) THEN '123' ELSE CAST(12: v3 AS VARCHAR) END AS DOUBLE) > " +
"1.0, (2: v2 = CAST(8: v2 AS VARCHAR(1048576))) OR (3: v3 = 8: v2)\n"));
connectContext.getSessionVariable().setEnableLocalShuffleAgg(true);
}

Expand Down
14 changes: 7 additions & 7 deletions fe/fe-core/src/test/resources/sql/subquery/in-subquery.sql
Original file line number Diff line number Diff line change
Expand Up @@ -709,18 +709,18 @@ LEFT SEMI JOIN (join-predicate [9: add = 7: expr AND 10: add = 5: v5] post-join-
[sql]
select 1 from customer where (C_NATIONKEY, C_NAME) IN (select P_NAME, P_RETAILPRICE from part)
[result]
RIGHT SEMI JOIN (join-predicate [11: P_NAME = 22: cast AND 17: P_RETAILPRICE = 23: cast] post-join-predicate [null])
EXCHANGE SHUFFLE[11, 17]
SCAN (columns[17: P_RETAILPRICE, 11: P_NAME] predicate[null])
EXCHANGE SHUFFLE[22, 23]
SCAN (columns[2: C_NAME, 4: C_NATIONKEY] predicate[null])
LEFT SEMI JOIN (join-predicate [22: cast = 23: cast AND 24: cast = 17: P_RETAILPRICE] post-join-predicate [null])
SCAN (columns[2: C_NAME, 4: C_NATIONKEY] predicate[null])
EXCHANGE BROADCAST
SCAN (columns[17: P_RETAILPRICE, 11: P_NAME] predicate[cast(11: P_NAME as double) IS NOT NULL])
[end]

[sql]
select 1 from customer where (C_NATIONKEY, C_NAME) IN (select "aa", 123.45)
[result]
LEFT SEMI JOIN (join-predicate [15: cast = 11: expr AND 2: C_NAME = 16: cast] post-join-predicate [null])
LEFT SEMI JOIN (join-predicate [15: cast = 16: cast AND 17: cast = 18: cast] post-join-predicate [null])
SCAN (columns[2: C_NAME, 4: C_NATIONKEY] predicate[null])
EXCHANGE BROADCAST
VALUES (null)
PREDICATE cast(aa as double) IS NOT NULL
VALUES (null)
[end]