Skip to content

Commit

Permalink
[Enhancement] Control implicit cast optimization by cbo_eq_type (Star…
Browse files Browse the repository at this point in the history
…Rocks#40619)

Signed-off-by: Seaven <seaven_7@qq.com>
  • Loading branch information
Seaven committed Feb 27, 2024
1 parent 41f0d7d commit 3594985
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 24 deletions.
Expand Up @@ -1081,6 +1081,10 @@ public String getCboEqBaseType() {
return cboEqBaseType;
}

public void setCboEqBaseType(String cboEqBaseType) {
this.cboEqBaseType = cboEqBaseType;
}

@VarAttr(name = FOLLOWER_QUERY_FORWARD_MODE, flag = VariableMgr.INVISIBLE | VariableMgr.DISABLE_FORWARD_TO_LEADER)
private String followerForwardMode = "";

Expand Down
Expand Up @@ -9,6 +9,7 @@
import com.starrocks.catalog.Function;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.Type;
import com.starrocks.qe.ConnectContext;
import com.starrocks.sql.common.TypeManager;
import com.starrocks.sql.optimizer.Utils;
import com.starrocks.sql.optimizer.operator.scalar.BetweenPredicateOperator;
Expand Down Expand Up @@ -133,27 +134,12 @@ public ScalarOperator visitBinaryPredicate(BinaryPredicateOperator predicate,
}

// we will try cast const operator to variable operator
if (rightChild.isVariable() && leftChild.isConstantRef()) {
Optional<ScalarOperator> op = Utils.tryCastConstant(leftChild, type2);
if (op.isPresent()) {
predicate.getChildren().set(0, op.get());
return predicate;
} else if (rightChild.getType().isDateType() && !leftChild.getType().isDateType() &&
Type.canCastTo(leftChild.getType(), rightChild.getType())) {
// For like MySQL, convert to date type as much as possible
addCastChild(rightChild.getType(), predicate, 0);
return predicate;
}
} else if (leftChild.isVariable() && rightChild.isConstantRef()) {
Optional<ScalarOperator> op = Utils.tryCastConstant(rightChild, type1);
if (op.isPresent()) {
predicate.getChildren().set(1, op.get());
return predicate;
} else if (leftChild.getType().isDateType() && !rightChild.getType().isDateType() &&
Type.canCastTo(rightChild.getType(), leftChild.getType())) {
// For like MySQL, convert to date type as much as possible
addCastChild(leftChild.getType(), predicate, 1);
return predicate;
if (rightChild.isVariable() != leftChild.isVariable()) {
int constant = leftChild.isVariable() ? 1 : 0;
int variable = 1 - constant;
Optional<BinaryPredicateOperator> optional = optimizeConstantAndVariable(predicate, constant, variable);
if (optional.isPresent()) {
return optional.get();
}
}

Expand All @@ -169,6 +155,35 @@ public ScalarOperator visitBinaryPredicate(BinaryPredicateOperator predicate,
return predicate;
}

private Optional<BinaryPredicateOperator> optimizeConstantAndVariable(BinaryPredicateOperator predicate,
int constantIndex, int variableIndex) {
ScalarOperator constant = predicate.getChild(constantIndex);
ScalarOperator variable = predicate.getChild(variableIndex);
Type typeConstant = constant.getType();
Type typeVariable = variable.getType();

if (typeVariable.isStringType() && typeConstant.isExactNumericType()) {
if (ConnectContext.get() == null || "decimal".equalsIgnoreCase(
ConnectContext.get().getSessionVariable().getCboEqBaseType())) {
// don't optimize when cbo_eq_base_type is decimal
return Optional.empty();
}
}

Optional<ScalarOperator> op = Utils.tryCastConstant(constant, variable.getType());
if (op.isPresent()) {
predicate.getChildren().set(constantIndex, op.get());
return Optional.of(predicate);
} else if (variable.getType().isDateType() && !constant.getType().isDateType() &&
Type.canCastTo(constant.getType(), variable.getType())) {
// For like MySQL, convert to date type as much as possible
addCastChild(variable.getType(), predicate, constantIndex);
return Optional.of(predicate);
}

return Optional.empty();
}

@Override
public ScalarOperator visitCompoundPredicate(CompoundPredicateOperator predicate,
ScalarOperatorRewriteContext context) {
Expand Down
Expand Up @@ -82,10 +82,10 @@ public void testRewrite2() {
assertEquals(root, result);
assertEquals(OperatorType.BINARY, result.getChild(0).getOpType());
assertEquals(OperatorType.VARIABLE, result.getChild(0).getChild(0).getOpType());
assertEquals(OperatorType.CONSTANT, result.getChild(0).getChild(1).getOpType());
assertEquals(OperatorType.CALL, result.getChild(0).getChild(1).getOpType());

assertEquals(Type.VARCHAR, result.getChild(0).getChild(0).getType());
assertEquals(Type.VARCHAR, result.getChild(0).getChild(1).getType());
assertEquals(Type.VARCHAR.getPrimitiveType(), result.getChild(0).getChild(0).getType().getPrimitiveType());
assertEquals(Type.VARCHAR.getPrimitiveType(), result.getChild(0).getChild(1).getType().getPrimitiveType());

assertEquals(OperatorType.COMPOUND, result.getChild(1).getOpType());
assertEquals(OperatorType.BINARY, result.getChild(1).getChild(0).getOpType());
Expand Down
Expand Up @@ -1403,4 +1403,25 @@ public void testDecimalV2Cast1() throws Exception {
plan = getVerboseExplain(sql);
assertContains(plan, "3 <-> length[(cast(2480.0 as VARCHAR)); args: VARCHAR; result: INT;");
}

@Test
public void testCastStringDouble() throws Exception {
try {
connectContext.getSessionVariable().setCboEqBaseType("VARCHAR");
String sql = "select t1a = 1 from test_all_type";
String plan = getVerboseExplain(sql);
assertContains(plan, "11 <-> [1: t1a, VARCHAR, true] = '1'");
} finally {
connectContext.getSessionVariable().setCboEqBaseType("VARCHAR");
}

try {
connectContext.getSessionVariable().setCboEqBaseType("DECIMAL");
String sql = "select t1a = 1 from test_all_type";
String plan = getVerboseExplain(sql);
assertContains(plan, "cast([1: t1a, VARCHAR, true] as DECIMAL128(38,9)) = 1");
} finally {
connectContext.getSessionVariable().setCboEqBaseType("VARCHAR");
}
}
}

0 comments on commit 3594985

Please sign in to comment.