Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions fe/fe-core/src/main/cup/sql_parser.cup
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ parser code {:
public boolean isVerbose = false;
public String wild;
public Expr where;
public ArrayList<PlaceHolderExpr> placeholder_expr_list = Lists.newArrayList();

// List of expected tokens ids from current parsing state for generating syntax error message
private final List<Integer> expectedTokenIds = Lists.newArrayList();
Expand Down Expand Up @@ -1074,7 +1075,11 @@ stmt ::=
| switch_stmt:stmt
{: RESULT = stmt; :}
| query_stmt:query
{: RESULT = query; :}
{:
RESULT = query;
query.setPlaceHolders(parser.placeholder_expr_list);
parser.placeholder_expr_list.clear();
:}
| drop_stmt:stmt
{: RESULT = stmt; :}
| recover_stmt:stmt
Expand Down Expand Up @@ -5185,6 +5190,8 @@ prepare_stmt ::=
KW_PREPARE variable_name:name KW_FROM select_stmt:s
{:
RESULT = new PrepareStmt(s, name, false);
s.setPlaceHolders(parser.placeholder_expr_list);
parser.placeholder_expr_list.clear();
:}
;

Expand Down Expand Up @@ -6741,9 +6748,9 @@ literal ::=
| KW_NULL
{: RESULT = new NullLiteral(); :}
| PLACEHOLDER
{: RESULT = new PlaceHolderExpr(); :}
{: RESULT = new PlaceHolderExpr(); parser.placeholder_expr_list.add((PlaceHolderExpr) RESULT); :}
| MOD
{: RESULT = new PlaceHolderExpr(); :}
{: RESULT = new PlaceHolderExpr(); parser.placeholder_expr_list.add((PlaceHolderExpr) RESULT); :}
| UNMATCHED_STRING_LITERAL:l expr:e
{:
// we have an unmatched string literal.
Expand Down
87 changes: 28 additions & 59 deletions fe/fe-core/src/main/java/org/apache/doris/analysis/PrepareStmt.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.doris.analysis;

import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.common.UserException;
import org.apache.doris.qe.ConnectContext;
Expand All @@ -42,9 +41,6 @@ public class PrepareStmt extends StatementBase {
private static final Logger LOG = LogManager.getLogger(PrepareStmt.class);
private StatementBase inner;
private String stmtName;
// select * from tbl where a = ? and b = ?
// `?` is the placeholder
protected List<PlaceHolderExpr> placeholders = new ArrayList<>();

// Cached for better CPU performance, since serialize DescriptorTable and
// outputExprs are heavy work
Expand Down Expand Up @@ -89,10 +85,6 @@ public UUID getID() {
return id;
}

public List<PlaceHolderExpr> placeholders() {
return this.placeholders;
}

public boolean isBinaryProtocol() {
return binaryRowFormat;
}
Expand Down Expand Up @@ -139,59 +131,14 @@ public ByteString getSerializedOutputExprs() {
return serializedOutputExpr;
}

public int getParmCount() {
return placeholders.size();
}

public List<Expr> getSlotRefOfPlaceHolders() {
ArrayList<Expr> slots = new ArrayList<>();
if (inner instanceof SelectStmt) {
SelectStmt select = (SelectStmt) inner;
for (PlaceHolderExpr pexpr : placeholders) {
// Only point query support
for (Map.Entry<SlotRef, Expr> entry :
select.getPointQueryEQPredicates().entrySet()) {
// same instance
if (entry.getValue() == pexpr) {
slots.add(entry.getKey());
}
}
}
return slots;
}
return null;
}

public List<String> getColLabelsOfPlaceHolders() {
ArrayList<String> lables = new ArrayList<>();
if (inner instanceof SelectStmt) {
for (Expr slotExpr : getSlotRefOfPlaceHolders()) {
SlotRef slot = (SlotRef) slotExpr;
Column c = slot.getColumn();
if (c != null) {
lables.add(c.getName());
continue;
}
lables.add("");
}
return lables;
}
return null;
}

@Override
public void analyze(Analyzer analyzer) throws UserException {
if (!(inner instanceof SelectStmt)) {
throw new UserException("Only support prepare SelectStmt now");
}
// Use tmpAnalyzer since selectStmt will be reAnalyzed
Analyzer tmpAnalyzer = new Analyzer(context.getEnv(), context);
// collect placeholders from stmt exprs tree
SelectStmt selectStmt = (SelectStmt) inner;
// TODO(lhy) support more clauses
if (selectStmt.getWhereClause() != null) {
selectStmt.getWhereClause().collect(PlaceHolderExpr.class, placeholders);
}
inner.analyze(tmpAnalyzer);
if (!selectStmt.checkAndSetPointQuery()) {
throw new UserException("Only support prepare SelectStmt point query now");
Expand All @@ -217,17 +164,40 @@ public StatementBase getInnerStmt() {
return inner;
}

public int argsSize() {
return placeholders.size();
public List<PlaceHolderExpr> placeholders() {
return inner.getPlaceHolders();
}

public int getParmCount() {
return inner.getPlaceHolders().size();
}

public List<Expr> getPlaceHolderExprList() {
ArrayList<Expr> slots = new ArrayList<>();
for (PlaceHolderExpr pexpr : inner.getPlaceHolders()) {
slots.add(pexpr);
}
return slots;
}

public List<String> getColLabelsOfPlaceHolders() {
ArrayList<String> lables = new ArrayList<>();
for (int i = 0; i < inner.getPlaceHolders().size(); ++i) {
lables.add("lable " + i);
}
return lables;
}

public void asignValues(List<LiteralExpr> values) throws UserException {
if (values.size() != placeholders.size()) {
if (values.size() != inner.getPlaceHolders().size()) {
throw new UserException("Invalid arguments size "
+ values.size() + ", expected " + placeholders.size());
+ values.size() + ", expected " + inner.getPlaceHolders().size());
}
for (int i = 0; i < values.size(); ++i) {
placeholders.get(i).setLiteral(values.get(i));
inner.getPlaceHolders().get(i).setLiteral(values.get(i));
}
if (!values.isEmpty()) {
LOG.debug("assign values {}", values.get(0).toSql());
}
}

Expand All @@ -237,7 +207,6 @@ public void reset() {
serializedOutputExpr = null;
descTable = null;
this.id = UUID.randomUUID();
placeholders.clear();
inner.reset();
serializedFields.clear();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

Expand All @@ -57,6 +58,10 @@ public abstract class StatementBase implements ParseNode {

private boolean isPrepared = false;

// select * from tbl where a = ? and b = ?
// `?` is the placeholder
private ArrayList<PlaceHolderExpr> placeholders = new ArrayList<>();

protected StatementBase() { }

/**
Expand Down Expand Up @@ -101,6 +106,14 @@ public boolean isExplain() {
return this.explainOptions != null;
}

public void setPlaceHolders(ArrayList<PlaceHolderExpr> placeholders) {
this.placeholders = new ArrayList<PlaceHolderExpr>(placeholders);
}

public ArrayList<PlaceHolderExpr> getPlaceHolders() {
return this.placeholders;
}

public boolean isVerbose() {
return explainOptions != null && explainOptions.isVerbose();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ public boolean supportClientLocalFile() {
return (flags & Flag.CLIENT_LOCAL_FILES.getFlagBit()) != 0;
}

public boolean isDeprecatedEOF() {
return (flags & Flag.CLIENT_DEPRECATE_EOF.getFlagBit()) != 0;
}

@Override
public boolean equals(Object obj) {
if (obj == null || !(obj instanceof MysqlCapability)) {
Expand Down
11 changes: 11 additions & 0 deletions fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,21 @@ public class MysqlChannel {

protected volatile MysqlSerializer serializer;

// mysql flag CLIENT_DEPRECATE_EOF
private boolean clientDeprecatedEOF;

protected MysqlChannel() {
// For DummyMysqlChannel
}

public void setClientDeprecatedEOF() {
clientDeprecatedEOF = true;
}

public boolean clientDeprecatedEOF() {
return clientDeprecatedEOF;
}

public MysqlChannel(StreamConnection connection) {
Preconditions.checkNotNull(connection);
this.sequenceId = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ public static boolean negotiate(ConnectContext context) throws IOException {
// receive response failed.
return false;
}
if (capability.isDeprecatedEOF()) {
context.getMysqlChannel().setClientDeprecatedEOF();
}
MysqlAuthPacket authPacket = new MysqlAuthPacket();
if (!authPacket.readFrom(handshakeResponse)) {
ErrorReport.report(ErrorCode.ERR_NOT_SUPPORTED_AUTH_MODE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ private void dispatch() throws IOException {
LOG.warn("Unknown command(" + code + ")");
return;
}
LOG.debug("handle command {}", command);
ctx.setCommand(command);
ctx.setStartTime();

Expand Down
28 changes: 22 additions & 6 deletions fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -1922,8 +1922,6 @@ private void handlePrepareStmt() throws Exception {
if (prepareStmt.isBinaryProtocol()) {
sendStmtPrepareOK();
}
// context.getState().setEof();
context.getState().setOk();
}


Expand Down Expand Up @@ -1965,6 +1963,10 @@ private void sendMetaData(ResultSetMetaData metaData) throws IOException {
context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer());
}

private List<PrimitiveType> exprToStringType(List<Expr> exprs) {
return exprs.stream().map(e -> PrimitiveType.STRING).collect(Collectors.toList());
}

private void sendStmtPrepareOK() throws IOException {
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_prepare.html#sect_protocol_com_stmt_prepare_response
serializer.reset();
Expand All @@ -1979,13 +1981,27 @@ private void sendStmtPrepareOK() throws IOException {
int numParams = prepareStmt.getColLabelsOfPlaceHolders().size();
serializer.writeInt2(numParams);
// reserved_1
// serializer.writeInt1(0);
serializer.writeInt1(0);
context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer());
if (numParams > 0) {
sendFields(prepareStmt.getColLabelsOfPlaceHolders(),
exprToType(prepareStmt.getSlotRefOfPlaceHolders()));
// send field one by one
// TODO use real type instead of string, for JDBC client it's ok
// but for other client, type should be correct
List<PrimitiveType> types = exprToStringType(prepareStmt.getPlaceHolderExprList());
List<String> colNames = prepareStmt.getColLabelsOfPlaceHolders();
LOG.debug("sendFields {}, {}", colNames, types);
for (int i = 0; i < colNames.size(); ++i) {
serializer.reset();
serializer.writeField(colNames.get(i), Type.fromPrimitiveType(types.get(i)));
context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer());
}
}
// send EOF if nessessary
if (!context.getMysqlChannel().clientDeprecatedEOF()) {
context.getState().setEof();
} else {
context.getState().setOk();
}
context.getState().setOk();
}

private void sendFields(List<String> colNames, List<Type> types) throws IOException {
Expand Down