Skip to content

Commit

Permalink
[Enhancement] check compare optmization more strict (#42936)
Browse files Browse the repository at this point in the history
Signed-off-by: Seaven <seaven_7@qq.com>
  • Loading branch information
Seaven committed Mar 28, 2024
1 parent 079ee1a commit a04b659
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,10 @@ public static Type getCompatibleTypeForBinary(BinaryType type, Type type1, Type
baseType = type1.isDecimalOfAnyVersion() ? type1 : type2;
}
}

if (ConnectContext.get() != null && SessionVariableConstants.DOUBLE.equalsIgnoreCase(ConnectContext.get()
.getSessionVariable().getCboEqBaseType())) {
baseType = Type.DOUBLE;
}
if ((type1.isStringType() && type2.isExactNumericType()) ||
(type1.isExactNumericType() && type2.isStringType())) {
return baseType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ public ScalarOperator visitBinaryPredicate(BinaryPredicateOperator predicate,
}
}

Type compatibleType = TypeManager.getCompatibleTypeForBinary(predicate.getBinaryType(), type1, type2);
Type compatibleType =
TypeManager.getCompatibleTypeForBinary(predicate.getBinaryType(), type1, type2);

if (!type1.matchesType(compatibleType)) {
addCastChild(compatibleType, predicate, 0);
Expand All @@ -190,23 +191,36 @@ private Optional<BinaryPredicateOperator> optimizeConstantAndVariable(BinaryPred
Type typeConstant = constant.getType();
Type typeVariable = variable.getType();

if (typeVariable.isStringType() && typeConstant.isExactNumericType()) {
if (ConnectContext.get() == null || SessionVariableConstants.DECIMAL.equalsIgnoreCase(
ConnectContext.get().getSessionVariable().getCboEqBaseType())) {
// don't optimize when cbo_eq_base_type is decimal
return Optional.empty();
boolean checkStringCastToNumber = false;
if (typeVariable.isNumericType() && typeConstant.isStringType()) {
if (predicate.getBinaryType().isNotRangeComparison()) {
String baseType = ConnectContext.get() != null ?
ConnectContext.get().getSessionVariable().getCboEqBaseType() : SessionVariableConstants.VARCHAR;
checkStringCastToNumber = SessionVariableConstants.DECIMAL.equals(baseType) ||
SessionVariableConstants.DOUBLE.equals(baseType);
} else {
// range compare, base type must be double
checkStringCastToNumber = true;
}
}

Optional<ScalarOperator> op = Utils.tryCastConstant(constant, variable.getType());
if (op.isPresent()) {
predicate.getChildren().set(constantIndex, op.get());
return Optional.of(predicate);
} else if (variable.getType().isDateType() && !constant.getType().isDateType() &&
Type.canCastTo(constant.getType(), variable.getType())) {
// For like MySQL, convert to date type as much as possible
addCastChild(variable.getType(), predicate, constantIndex);
return Optional.of(predicate);
// strict check, only support white check
if ((typeVariable.isNumericType() && typeConstant.isNumericType()) ||
(typeVariable.isNumericType() && typeConstant.isBoolean()) ||
(typeVariable.isDateType() && typeConstant.isNumericType()) ||
(typeVariable.isDateType() && typeConstant.isStringType()) ||
(typeVariable.isBoolean() && typeConstant.isStringType()) ||
checkStringCastToNumber) {

Optional<ScalarOperator> op = Utils.tryCastConstant(constant, variable.getType());
if (op.isPresent()) {
predicate.getChildren().set(constantIndex, op.get());
return Optional.of(predicate);
} else if (variable.getType().isDateType() && Type.canCastTo(constant.getType(), variable.getType())) {
// For like MySQL, convert to date type as much as possible
addCastChild(variable.getType(), predicate, constantIndex);
return Optional.of(predicate);
}
}

return Optional.empty();
Expand Down
106 changes: 106 additions & 0 deletions fe/fe-core/src/test/java/com/starrocks/sql/plan/ExpressionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1747,4 +1747,110 @@ public void testCastConstantFn() throws Exception {
String plan = getFragmentPlan(sql);
assertContains(plan, " CAST(9: id_date AS DOUBLE) = 1998.0");
}

@Test
public void testCastStringNumber() throws Exception {
try {
// string number
connectContext.getSessionVariable().setCboEqBaseType("VARCHAR");
String sql = "select t1a = 1.2345 from test_all_type";
String plan = getVerboseExplain(sql);
assertContains(plan, "[1: t1a, VARCHAR, true] = '1.2345'");

sql = "select t1a < 1 from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([1: t1a, VARCHAR, true] as DOUBLE) < 1.0");

// number string
sql = "select t1f = '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "[6: t1f, DOUBLE, true] = 123.345");

sql = "select t1c >= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([3: t1c, INT, true] as DOUBLE) >= 123.345");

sql = "select t1e >= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([5: t1e, FLOAT, true] as DOUBLE) >= 123.345");

sql = "select t1f <= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "[6: t1f, DOUBLE, true] <= 123.345");

sql = "select t1e >= 'abc' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([5: t1e, FLOAT, true] as DOUBLE) >= cast('abc' as DOUBLE)");

sql = "select t1g = 'abc' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([7: t1g, BIGINT, true] as VARCHAR(1048576)) = 'abc'");

sql = "select id_bool = 'abc' from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([11: id_bool, BOOLEAN, true] as DOUBLE) = cast('abc' as DOUBLE)");

sql = "select id_bool = 'false' from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "[11: id_bool, BOOLEAN, true] = FALSE");

sql = "select t1g = true from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "[7: t1g, BIGINT, true] = 1");
} finally {
connectContext.getSessionVariable().setCboEqBaseType("VARCHAR");
}

try {
connectContext.getSessionVariable().setCboEqBaseType("DECIMAL");
// string number
String sql = "select t1a = 1.2345 from test_all_type";
String plan = getVerboseExplain(sql);
assertContains(plan, "cast([1: t1a, VARCHAR, true] as DECIMAL32(5,4)) = 1.2345");

sql = "select t1a < 1 from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([1: t1a, VARCHAR, true] as DOUBLE) < 1.0");

// number string
sql = "select t1f = '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "[6: t1f, DOUBLE, true] = 123.345");

sql = "select t1c >= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([3: t1c, INT, true] as DOUBLE) >= 123.345");

sql = "select t1e >= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([5: t1e, FLOAT, true] as DOUBLE) >= 123.345");

sql = "select t1f <= '123.345' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "[6: t1f, DOUBLE, true] <= 123.345");

sql = "select t1e >= 'abc' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([5: t1e, FLOAT, true] as DOUBLE) >= cast('abc' as DOUBLE)");

sql = "select t1g = 'abc' from test_all_type";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([7: t1g, BIGINT, true] as DECIMAL128(38,9)) " +
"= cast('abc' as DECIMAL128(38,9))");

sql = "select id_bool = 'abc' from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "cast([11: id_bool, BOOLEAN, true] as DOUBLE) = cast('abc' as DOUBLE)");

sql = "select id_bool = 'false' from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "[11: id_bool, BOOLEAN, true] = FALSE");

sql = "select t1g = true from test_bool";
plan = getVerboseExplain(sql);
assertContains(plan, "[7: t1g, BIGINT, true] = 1");
} finally {
connectContext.getSessionVariable().setCboEqBaseType("VARCHAR");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ public RuleSet getRuleSet() {
Assert.assertTrue(replayPair.second, replayPair.second.contains(" 13:NESTLOOP JOIN\n" +
" | join op: INNER JOIN\n" +
" | colocate: false, reason: \n" +
" | other join predicates: CASE WHEN CAST(6: v3 AS BOOLEAN) THEN CAST(11: v2 AS VARCHAR) " +
"WHEN CAST(3: v3 AS BOOLEAN) THEN '123' ELSE CAST(12: v3 AS VARCHAR) END > '1', " +
"(2: v2 = CAST(8: v2 AS VARCHAR(1048576))) OR (3: v3 = 8: v2)\n"));
" | other join predicates: CAST(CASE WHEN CAST(6: v3 AS BOOLEAN) THEN CAST(11: v2 AS VARCHAR) " +
"WHEN CAST(3: v3 AS BOOLEAN) THEN '123' ELSE CAST(12: v3 AS VARCHAR) END AS DOUBLE) > " +
"1.0, (2: v2 = CAST(8: v2 AS VARCHAR(1048576))) OR (3: v3 = 8: v2)\n"));
connectContext.getSessionVariable().setEnableLocalShuffleAgg(true);
}

Expand Down
14 changes: 7 additions & 7 deletions fe/fe-core/src/test/resources/sql/subquery/in-subquery.sql
Original file line number Diff line number Diff line change
Expand Up @@ -709,18 +709,18 @@ LEFT SEMI JOIN (join-predicate [9: add = 7: expr AND 10: add = 5: v5] post-join-
[sql]
select 1 from customer where (C_NATIONKEY, C_NAME) IN (select P_NAME, P_RETAILPRICE from part)
[result]
RIGHT SEMI JOIN (join-predicate [11: P_NAME = 22: cast AND 17: P_RETAILPRICE = 23: cast] post-join-predicate [null])
EXCHANGE SHUFFLE[11, 17]
SCAN (columns[17: P_RETAILPRICE, 11: P_NAME] predicate[null])
EXCHANGE SHUFFLE[22, 23]
SCAN (columns[2: C_NAME, 4: C_NATIONKEY] predicate[null])
LEFT SEMI JOIN (join-predicate [22: cast = 23: cast AND 24: cast = 17: P_RETAILPRICE] post-join-predicate [null])
SCAN (columns[2: C_NAME, 4: C_NATIONKEY] predicate[null])
EXCHANGE BROADCAST
SCAN (columns[17: P_RETAILPRICE, 11: P_NAME] predicate[cast(11: P_NAME as double) IS NOT NULL])
[end]

[sql]
select 1 from customer where (C_NATIONKEY, C_NAME) IN (select "aa", 123.45)
[result]
LEFT SEMI JOIN (join-predicate [15: cast = 11: expr AND 2: C_NAME = 16: cast] post-join-predicate [null])
LEFT SEMI JOIN (join-predicate [15: cast = 16: cast AND 17: cast = 18: cast] post-join-predicate [null])
SCAN (columns[2: C_NAME, 4: C_NATIONKEY] predicate[null])
EXCHANGE BROADCAST
VALUES (null)
PREDICATE cast(aa as double) IS NOT NULL
VALUES (null)
[end]
42 changes: 19 additions & 23 deletions fe/fe-core/src/test/resources/sql/subquery/scalar-subquery.sql
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ PLAN FRAGMENT 1
partitions=1/1
rollup: t0
tabletRatio=3/3
tabletList=10008,10010,10012
tabletList=10005,10007,10009
cardinality=1
avgRowSize=2.0

Expand Down Expand Up @@ -76,7 +76,7 @@ PLAN FRAGMENT 3
partitions=1/1
rollup: t3
tabletRatio=3/3
tabletList=10035,10037,10039
tabletList=10032,10034,10036
cardinality=1
avgRowSize=1.0
[end]
Expand Down Expand Up @@ -156,7 +156,7 @@ PLAN FRAGMENT 2
partitions=1/1
rollup: t3
tabletRatio=3/3
tabletList=10035,10037,10039
tabletList=10032,10034,10036
cardinality=1
avgRowSize=2.0

Expand All @@ -175,7 +175,7 @@ PLAN FRAGMENT 3
partitions=1/1
rollup: t0
tabletRatio=3/3
tabletList=10008,10010,10012
tabletList=10005,10007,10009
cardinality=1
avgRowSize=3.0
[end]
Expand Down Expand Up @@ -272,7 +272,7 @@ PLAN FRAGMENT 2
partitions=1/1
rollup: t3
tabletRatio=3/3
tabletList=10035,10037,10039
tabletList=10032,10034,10036
cardinality=1
avgRowSize=5.0

Expand All @@ -297,7 +297,7 @@ PLAN FRAGMENT 3
partitions=1/1
rollup: t0
tabletRatio=3/3
tabletList=10008,10010,10012
tabletList=10005,10007,10009
cardinality=1
avgRowSize=4.0
[end]
Expand Down Expand Up @@ -497,6 +497,7 @@ INNER JOIN (join-predicate [2: v2 = 4: v4] post-join-predicate [null])
SCAN (columns[4: v4, 5: v5] predicate[null])
[end]

/* test PushDownApplyAggFilterRule */
/* test PushDownApplyAggFilterRule */

[sql]
Expand Down Expand Up @@ -605,13 +606,13 @@ INNER JOIN (join-predicate [1: v1 = 23: cast] post-join-predicate [null])
AGGREGATE ([GLOBAL] aggregate [{20: min=min(20: min)}] group by [[]] having [null]
EXCHANGE GATHER
AGGREGATE ([LOCAL] aggregate [{20: min=min(6: t1c)}] group by [[]] having [null]
INNER JOIN (join-predicate [4: t1a = 22: cast AND 7: t1d = 18: max] post-join-predicate [null])
SCAN (columns[4: t1a, 6: t1c, 7: t1d] predicate[4: t1a IS NOT NULL AND 7: t1d IS NOT NULL])
EXCHANGE SHUFFLE[22]
INNER JOIN (join-predicate [24: cast = 22: cast AND 7: t1d = 18: max] post-join-predicate [null])
SCAN (columns[4: t1a, 6: t1c, 7: t1d] predicate[cast(4: t1a as double) IS NOT NULL AND 7: t1d IS NOT NULL])
EXCHANGE BROADCAST
AGGREGATE ([GLOBAL] aggregate [{18: max=max(18: max)}] group by [[22: cast]] having [18: max IS NOT NULL]
EXCHANGE SHUFFLE[22]
AGGREGATE ([LOCAL] aggregate [{18: max=max(17: expr)}] group by [[22: cast]] having [null]
SCAN (columns[14: v4, 15: v5] predicate[cast(14: v4 as varchar(1048576)) IS NOT NULL AND 14: v4 = 2])
SCAN (columns[14: v4, 15: v5] predicate[cast(14: v4 as double) IS NOT NULL AND 14: v4 = 2])
[end]

[sql]
Expand All @@ -624,13 +625,13 @@ CROSS JOIN (join-predicate [null] post-join-predicate [null])
AGGREGATE ([GLOBAL] aggregate [{20: min=min(20: min)}] group by [[]] having [null]
EXCHANGE GATHER
AGGREGATE ([LOCAL] aggregate [{20: min=min(6: t1c)}] group by [[]] having [null]
INNER JOIN (join-predicate [4: t1a = 22: cast AND 7: t1d = 18: max] post-join-predicate [null])
SCAN (columns[4: t1a, 6: t1c, 7: t1d] predicate[4: t1a IS NOT NULL AND 7: t1d IS NOT NULL])
EXCHANGE SHUFFLE[22]
INNER JOIN (join-predicate [23: cast = 22: cast AND 7: t1d = 18: max] post-join-predicate [null])
SCAN (columns[4: t1a, 6: t1c, 7: t1d] predicate[cast(4: t1a as double) IS NOT NULL AND 7: t1d IS NOT NULL])
EXCHANGE BROADCAST
AGGREGATE ([GLOBAL] aggregate [{18: max=max(18: max)}] group by [[22: cast]] having [18: max IS NOT NULL]
EXCHANGE SHUFFLE[22]
AGGREGATE ([LOCAL] aggregate [{18: max=max(17: expr)}] group by [[22: cast]] having [null]
SCAN (columns[14: v4, 15: v5] predicate[cast(14: v4 as varchar(1048576)) IS NOT NULL AND 14: v4 = 2])
SCAN (columns[14: v4, 15: v5] predicate[cast(14: v4 as double) IS NOT NULL AND 14: v4 = 2])
[end]

[sql]
Expand Down Expand Up @@ -702,11 +703,6 @@ LEFT OUTER JOIN (join-predicate [add(add(1: v1, 4: v4), 9: v9) = if(23: expr, 1,
SCAN (columns[12: t1c, 13: t1d] predicate[null])
[end]

/* test ScalarApply2JoinRule */
/* test ScalarApply2JoinRule */
/* test ScalarApply2JoinRule */
/* test ScalarApply2JoinRule */

[sql]
select * from t0 where 1 = (select v5 + 1 from t1 where t0.v2 = t1.v4);
[result]
Expand Down Expand Up @@ -846,12 +842,12 @@ INNER JOIN (join-predicate [1: v1 = 24: cast] post-join-predicate [null])
ASSERT LE 1
EXCHANGE GATHER
PREDICATE 7: t1d = 18: expr
RIGHT OUTER JOIN (join-predicate [20: cast = 4: t1a] post-join-predicate [null])
RIGHT OUTER JOIN (join-predicate [20: cast = 25: cast] post-join-predicate [null])
AGGREGATE ([GLOBAL] aggregate [{21: countRows=count(21: countRows), 22: anyValue=any_value(22: anyValue)}] group by [[20: cast]] having [null]
EXCHANGE SHUFFLE[20]
AGGREGATE ([LOCAL] aggregate [{21: countRows=count(1), 22: anyValue=any_value(add(14: v4, 15: v5))}] group by [[20: cast]] having [null]
SCAN (columns[14: v4, 15: v5] predicate[14: v4 = 2])
EXCHANGE SHUFFLE[4]
EXCHANGE SHUFFLE[25]
SCAN (columns[4: t1a, 6: t1c, 7: t1d] predicate[null])
[end]

Expand All @@ -864,12 +860,12 @@ CROSS JOIN (join-predicate [null] post-join-predicate [null])
ASSERT LE 1
EXCHANGE GATHER
PREDICATE 7: t1d = 18: expr
RIGHT OUTER JOIN (join-predicate [20: cast = 4: t1a] post-join-predicate [null])
RIGHT OUTER JOIN (join-predicate [20: cast = 24: cast] post-join-predicate [null])
AGGREGATE ([GLOBAL] aggregate [{21: countRows=count(21: countRows), 22: anyValue=any_value(22: anyValue)}] group by [[20: cast]] having [null]
EXCHANGE SHUFFLE[20]
AGGREGATE ([LOCAL] aggregate [{21: countRows=count(1), 22: anyValue=any_value(add(14: v4, 15: v5))}] group by [[20: cast]] having [null]
SCAN (columns[14: v4, 15: v5] predicate[14: v4 = 2])
EXCHANGE SHUFFLE[4]
EXCHANGE SHUFFLE[24]
SCAN (columns[4: t1a, 6: t1c, 7: t1d] predicate[null])
[end]

Expand Down
13 changes: 11 additions & 2 deletions test/sql/test_multi_ops/R/test_depends_ops
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,16 @@ group by
-- !result
select l.c1, sum(l.c0) ,sum(r.c0) from aggregated_table l join aggregated_table r on l.c0 <= r.c0 and r.c0 < 10 group by 1 order by 2, 3 limit 10;
-- result:
4095 1.0 1.0
4087 9.0 9.0
4095 9.0 45.0
4088 16.0 17.0
4094 16.0 44.0
4089 21.0 24.0
4093 21.0 42.0
4090 24.0 30.0
4092 24.0 39.0
4091 25.0 35.0
4016 80.0 9.0
-- !result
select l.c1, l.c0 from aggregated_table l except select r.c1, r.c0 from aggregated_table r;
-- result:
Expand All @@ -152,4 +161,4 @@ select l.c1, l.c0 from aggregated_table l intersect select r.c1, r.c0 from aggre
7 4089
8 4088
9 4087
-- !result
-- !result

0 comments on commit a04b659

Please sign in to comment.