Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ public enum TableFrom {
// map placeholder id to comparison slot, which will used to replace conjuncts
// directly
private final Map<PlaceholderId, SlotReference> idToComparisonSlot = new TreeMap<>();
// map placeholder id to slot for IN predicate options, used to replace IN predicate
// conjuncts in short circuit plan for prepared statement
private final Map<PlaceholderId, SlotReference> idToInPredicateSlot = new TreeMap<>();

// collect all hash join conditions to compute node connectivity in join graph
private final List<Expression> joinFilters = new ArrayList<>();
Expand Down Expand Up @@ -628,6 +631,10 @@ public Map<PlaceholderId, SlotReference> getIdToComparisonSlot() {
return idToComparisonSlot;
}

public Map<PlaceholderId, SlotReference> getIdToInPredicateSlot() {
return idToInPredicateSlot;
}

public Map<CTEId, List<Pair<Multimap<Slot, Slot>, Group>>> getCteIdToConsumerGroup() {
return cteIdToConsumerGroup;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -893,12 +893,37 @@ public Expression visitWhenClause(WhenClause whenClause, ExpressionRewriteContex

@Override
public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) {
// Register placeholder ids to slot BEFORE children are visited (i.e., before Placeholders
// are replaced with actual literal values by visitPlaceholder).
// Used to replace expressions in ShortCircuit plan for prepared statement.
registerInPredicatePlaceholderToSlot(inPredicate, context);
List<Expression> rewrittenChildren = inPredicate.children().stream()
.map(e -> e.accept(this, context)).collect(Collectors.toList());
InPredicate newInPredicate = inPredicate.withChildren(rewrittenChildren);
return TypeCoercionUtils.processInPredicate(newInPredicate);
}

// Register prepared statement placeholder ids to related slot in IN predicate.
// Each placeholder in the IN list is mapped to the compare slot for short circuit plan replacement.
// Must be called BEFORE children are recursively visited (before Placeholder→Literal substitution).
private void registerInPredicatePlaceholderToSlot(InPredicate inPredicate,
ExpressionRewriteContext context) {
if (context == null) {
return;
}
if (ConnectContext.get() != null
&& ConnectContext.get().getCommand() == MysqlCommand.COM_STMT_EXECUTE
&& inPredicate.getCompareExpr() instanceof SlotReference) {
SlotReference slot = (SlotReference) inPredicate.getCompareExpr();
for (Expression option : inPredicate.getOptions()) {
if (option instanceof Placeholder) {
PlaceholderId id = ((Placeholder) option).getPlaceholderId();
context.cascadesContext.getStatementContext().getIdToInPredicateSlot().put(id, slot);
}
}
}
}

@Override
public Expression visitBetween(Between between, ExpressionRewriteContext context) {
Expression compareExpr = between.getCompareExpr().accept(this, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.Placeholder;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
Expand All @@ -39,7 +42,7 @@

/**
* short circuit query optimization
* pattern : select xxx from tbl where key = ?
* pattern : select xxx from tbl where key = ? or key IN (?, ?, ...)
*/
public class LogicalResultSinkToShortCircuitPointQuery implements RewriteRuleFactory {

Expand All @@ -50,14 +53,35 @@ private Expression removeCast(Expression expression) {
return expression;
}

private boolean filterMatchShortCircuitCondition(LogicalFilter<LogicalOlapScan> filter) {
return filter.getConjuncts().stream().allMatch(
// all conjuncts match with pattern `key = ?`
expression -> (expression instanceof EqualTo)
&& (removeCast(expression.child(0)).isKeyColumnFromTable()
// Check if an expression in the filter is a valid short-circuit condition.
// Supports: key = literal/placeholder, or key IN (literal/placeholder, ...)
private boolean isValidShortCircuitExpression(Expression expression) {
if (expression instanceof EqualTo) {
// key = literal or key = placeholder
return (removeCast(expression.child(0)).isKeyColumnFromTable()
|| (expression.child(0) instanceof SlotReference
&& ((SlotReference) expression.child(0)).getName().equals(Column.DELETE_SIGN)))
&& expression.child(1).isLiteral());
&& (expression.child(1).isLiteral() || expression.child(1) instanceof Placeholder);
} else if (expression instanceof InPredicate) {
// key IN (literal/placeholder, ...)
InPredicate inPredicate = (InPredicate) expression;
Expression compareExpr = removeCast(inPredicate.getCompareExpr());
if (!compareExpr.isKeyColumnFromTable()) {
return false;
}
// All options must be literals or placeholders
for (Expression option : inPredicate.getOptions()) {
if (!(option instanceof Literal) && !(option instanceof Placeholder)) {
return false;
}
}
return !inPredicate.getOptions().isEmpty();
}
return false;
}

private boolean filterMatchShortCircuitCondition(LogicalFilter<LogicalOlapScan> filter) {
return filter.getConjuncts().stream().allMatch(this::isValidShortCircuitExpression);
}

private boolean scanMatchShortCircuitCondition(LogicalOlapScan olapScan) {
Expand All @@ -75,12 +99,21 @@ private boolean scanMatchShortCircuitCondition(LogicalOlapScan olapScan) {
// set short circuit flag and return the original plan
private Plan shortCircuit(Plan root, OlapTable olapTable,
Set<Expression> conjuncts, StatementContext statementContext) {
// All key columns in conjuncts
// All key columns covered by conjuncts (EqualTo or InPredicate)
Set<String> colNames = Sets.newHashSet();
for (Expression expr : conjuncts) {
SlotReference slot = ((SlotReference) removeCast((expr.child(0))));
if (slot.isKeyColumnFromTable()) {
colNames.add(slot.getName());
if (expr instanceof EqualTo) {
SlotReference slot = (SlotReference) removeCast(expr.child(0));
if (slot.isKeyColumnFromTable()) {
colNames.add(slot.getName());
}
} else if (expr instanceof InPredicate) {
InPredicate inPredicate = (InPredicate) expr;
Expression compareExpr = removeCast(inPredicate.getCompareExpr());
if (compareExpr instanceof SlotReference
&& ((SlotReference) compareExpr).isKeyColumnFromTable()) {
colNames.add(((SlotReference) compareExpr).getName());
}
}
}
// set short circuit flag and modify nothing to the plan
Expand All @@ -99,7 +132,6 @@ public List<Rule> buildRules() {
).when(this::filterMatchShortCircuitCondition)))
.thenApply(ctx -> {
return shortCircuit(ctx.root, ctx.root.child().child().child().getTable(),

ctx.root.child().child().getConjuncts(), ctx.statementContext);
})),
RuleType.SHOR_CIRCUIT_POINT_QUERY.build(
Expand Down
132 changes: 98 additions & 34 deletions fe/fe-core/src/main/java/org/apache/doris/qe/PointQueryExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.doris.analysis.BinaryPredicate;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.ExprToSqlVisitor;
import org.apache.doris.analysis.InPredicate;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.ToSqlParams;
Expand Down Expand Up @@ -145,47 +146,74 @@ public static void directExecuteShortCircuitQuery(StmtExecutor executor,
StatementContext statementContext) throws Exception {
Preconditions.checkNotNull(preparedStmtCtx.shortCircuitQueryContext);
ShortCircuitQueryContext shortCircuitQueryContext = preparedStmtCtx.shortCircuitQueryContext.get();
// update conjuncts
// update conjuncts for equality predicates: colName -> LiteralExpr
Map<String, Expr> colNameToConjunct = Maps.newHashMap();
for (Entry<PlaceholderId, SlotReference> entry : statementContext.getIdToComparisonSlot().entrySet()) {
String colName = entry.getValue().getOriginalColumn().get().getName();
Expr conjunctVal = ((Literal) statementContext.getIdToPlaceholderRealExpr()
.get(entry.getKey())).toLegacyLiteral();
colNameToConjunct.put(colName, conjunctVal);
}
if (colNameToConjunct.size() != preparedStmtCtx.command.placeholderCount()) {
// update conjuncts for IN predicate: colName -> List<LiteralExpr>
Map<String, List<Expr>> colNameToInValues = Maps.newHashMap();
for (Entry<PlaceholderId, SlotReference> entry : statementContext.getIdToInPredicateSlot().entrySet()) {
String colName = entry.getValue().getOriginalColumn().get().getName();
Expr inVal = ((Literal) statementContext.getIdToPlaceholderRealExpr()
.get(entry.getKey())).toLegacyLiteral();
colNameToInValues.computeIfAbsent(colName, k -> new ArrayList<>()).add(inVal);
}
int totalPlaceholderCount = preparedStmtCtx.command.placeholderCount();
int resolvedCount = colNameToConjunct.size() + colNameToInValues.values().stream()
.mapToInt(List::size).sum();
if (resolvedCount != totalPlaceholderCount) {
throw new AnalysisException("Mismatched conjuncts values size with prepared"
+ "statement parameters size, expected "
+ preparedStmtCtx.command.placeholderCount()
+ ", but meet " + colNameToConjunct.size());
+ totalPlaceholderCount
+ ", but meet " + resolvedCount);
}
updateScanNodeConjuncts(shortCircuitQueryContext.scanNode, colNameToConjunct);
updateScanNodeConjuncts(shortCircuitQueryContext.scanNode, colNameToConjunct, colNameToInValues);
// short circuit plan and execution
executor.executeAndSendResult(false, false,
shortCircuitQueryContext.analzyedQuery, executor.getContext()
.getMysqlChannel(), null, null);
}

private static void updateScanNodeConjuncts(OlapScanNode scanNode,
Map<String, Expr> colNameToConjunct) {
Map<String, Expr> colNameToConjunct, Map<String, List<Expr>> colNameToInValues) {
for (Expr conjunct : scanNode.getConjuncts()) {
BinaryPredicate binaryPredicate = (BinaryPredicate) conjunct;
SlotRef slot = null;
int updateChildIdx = 0;
if (binaryPredicate.getChild(0) instanceof LiteralExpr) {
slot = (SlotRef) binaryPredicate.getChildWithoutCast(1);
} else if (binaryPredicate.getChild(1) instanceof LiteralExpr) {
slot = (SlotRef) binaryPredicate.getChildWithoutCast(0);
updateChildIdx = 1;
} else {
Preconditions.checkState(false, "Should contains literal in "
+ binaryPredicate.accept(ExprToSqlVisitor.INSTANCE, ToSqlParams.WITH_TABLE));
}
// not a placeholder to replace
if (!colNameToConjunct.containsKey(slot.getColumnName())) {
continue;
if (conjunct instanceof BinaryPredicate) {
BinaryPredicate binaryPredicate = (BinaryPredicate) conjunct;
SlotRef slot = null;
int updateChildIdx = 0;
if (binaryPredicate.getChild(0) instanceof LiteralExpr) {
slot = (SlotRef) binaryPredicate.getChildWithoutCast(1);
} else if (binaryPredicate.getChild(1) instanceof LiteralExpr) {
slot = (SlotRef) binaryPredicate.getChildWithoutCast(0);
updateChildIdx = 1;
} else {
Preconditions.checkState(false, "Should contains literal in "
+ binaryPredicate.accept(ExprToSqlVisitor.INSTANCE, ToSqlParams.WITH_TABLE));
}
// not a placeholder to replace
if (!colNameToConjunct.containsKey(slot.getColumnName())) {
continue;
}
binaryPredicate.setChild(updateChildIdx, colNameToConjunct.get(slot.getColumnName()));
} else if (conjunct instanceof InPredicate) {
InPredicate inPredicate = (InPredicate) conjunct;
SlotRef slot = inPredicate.getChild(0).unwrapSlotRef();
if (slot == null || !colNameToInValues.containsKey(slot.getColumnName())) {
continue;
}
List<Expr> newValues = colNameToInValues.get(slot.getColumnName());
// Replace all list children (children[1..n]) with the new literal values
while (inPredicate.getChildren().size() > 1) {
inPredicate.getChildren().remove(inPredicate.getChildren().size() - 1);
}
for (Expr val : newValues) {
inPredicate.addChild(val);
}
}
binaryPredicate.setChild(updateChildIdx, colNameToConjunct.get(slot.getColumnName()));
}
}

Expand All @@ -195,21 +223,57 @@ public void setTimeout(long timeoutMs) {

void addKeyTuples(
InternalService.PTabletKeyLookupRequest.Builder requestBuilder) {
// TODO handle IN predicates
Map<String, Expr> columnExpr = Maps.newHashMap();
KeyTuple.Builder kBuilder = KeyTuple.newBuilder();
// Separate equality predicates from IN predicates
Map<String, Expr> equalityColumnExpr = Maps.newHashMap();
Map<String, List<Expr>> inColumnExprs = Maps.newHashMap();

for (Expr expr : shortCircuitQueryContext.scanNode.getConjuncts()) {
BinaryPredicate predicate = (BinaryPredicate) expr;
Expr left = predicate.getChild(0);
Expr right = predicate.getChild(1);
SlotRef columnSlot = left.unwrapSlotRef();
columnExpr.put(columnSlot.getColumnName(), right);
if (expr instanceof BinaryPredicate) {
BinaryPredicate predicate = (BinaryPredicate) expr;
Expr left = predicate.getChild(0);
Expr right = predicate.getChild(1);
SlotRef columnSlot = left.unwrapSlotRef();
equalityColumnExpr.put(columnSlot.getColumnName(), right);
} else if (expr instanceof InPredicate) {
InPredicate inPredicate = (InPredicate) expr;
SlotRef slot = inPredicate.getChild(0).unwrapSlotRef();
if (slot != null) {
List<Expr> values = new ArrayList<>();
for (int i = 1; i < inPredicate.getChildren().size(); i++) {
values.add(inPredicate.getChild(i));
}
inColumnExprs.put(slot.getColumnName(), values);
}
}
}
// add key tuple in keys order
for (Column column : shortCircuitQueryContext.scanNode.getOlapTable().getBaseSchemaKeyColumns()) {
kBuilder.addKeyColumnRep(columnExpr.get(column.getName()).getStringValue());

List<Column> keyColumns = shortCircuitQueryContext.scanNode.getOlapTable().getBaseSchemaKeyColumns();

if (inColumnExprs.isEmpty()) {
// Pure equality case: generate one KeyTuple
KeyTuple.Builder kBuilder = KeyTuple.newBuilder();
for (Column column : keyColumns) {
kBuilder.addKeyColumnRep(equalityColumnExpr.get(column.getName()).getStringValue());
}
requestBuilder.addKeyTuples(kBuilder);
} else {
// IN predicate case: generate one KeyTuple per combination
// Find the IN column and its values, combine with equality columns
// Note: currently only supports single IN column per query (all keys must be covered)
String inColName = inColumnExprs.keySet().iterator().next();
List<Expr> inValues = inColumnExprs.get(inColName);
for (Expr inVal : inValues) {
KeyTuple.Builder kBuilder = KeyTuple.newBuilder();
for (Column column : keyColumns) {
if (column.getName().equals(inColName)) {
kBuilder.addKeyColumnRep(inVal.getStringValue());
} else {
kBuilder.addKeyColumnRep(equalityColumnExpr.get(column.getName()).getStringValue());
}
}
requestBuilder.addKeyTuples(kBuilder);
}
}
requestBuilder.addKeyTuples(kBuilder);
}

@Override
Expand Down
Loading
Loading