Skip to content
Permalink
Browse files
IGNITE-16268 JDBC. Validation of statement type should be done before…
… its execution (#613)
  • Loading branch information
vladErmakov07 committed Feb 24, 2022
1 parent f1a932f commit 3288848aa09076a8eceb1d69173f22ca251367e2
Showing 18 changed files with 386 additions and 61 deletions.
@@ -17,29 +17,54 @@

package org.apache.ignite.client.proto.query;

import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
* JDBC statement type.
*/
public enum JdbcStatementType {
/** Any statement type. */
ANY_STATEMENT_TYPE,
ANY_STATEMENT_TYPE((byte) 0),

/** Select statement type. */
SELECT_STATEMENT_TYPE,
SELECT_STATEMENT_TYPE((byte) 1),

/** DML / DDL statement type. */
UPDATE_STMT_TYPE;
UPDATE_STATEMENT_TYPE((byte) 2);

private static final Map<Byte, JdbcStatementType> STATEMENT_TYPE_IDX;

/** Enumerated values. */
private static final JdbcStatementType[] VALS = values();
static {
STATEMENT_TYPE_IDX = Arrays.stream(values()).collect(
Collectors.toMap(JdbcStatementType::getId, Function.identity()));
}

/**
* Efficiently gets enumerated value from its ordinal.
* Gets statement type value by its id.
*
* @param ord Ordinal value.
* @return Enumerated value or {@code null} if ordinal out of range.
* @param id The id.
* @return JdbcStatementType value.
* @throws IllegalArgumentException If statement is not found.
*/
public static JdbcStatementType fromOrdinal(int ord) {
return ord >= 0 && ord < VALS.length ? VALS[ord] : null;
public static JdbcStatementType getStatement(byte id) {
JdbcStatementType value = STATEMENT_TYPE_IDX.get(id);

Objects.requireNonNull(value, String.format("Unknown jdbcStatementType %s", id));

return value;
}

private final byte id;

JdbcStatementType(byte id) {
this.id = id;
}

public byte getId() {
return id;
}
}
@@ -17,7 +17,9 @@

package org.apache.ignite.client.proto.query.event;

import java.util.Objects;
import org.apache.ignite.client.proto.query.ClientMessage;
import org.apache.ignite.client.proto.query.JdbcStatementType;
import org.apache.ignite.internal.client.proto.ClientMessagePacker;
import org.apache.ignite.internal.client.proto.ClientMessageUnpacker;
import org.apache.ignite.internal.tostring.S;
@@ -26,6 +28,9 @@
* JDBC query execute request.
*/
public class QueryExecuteRequest implements ClientMessage {
/** Expected statement type. */
private JdbcStatementType stmtType;

/** Schema name. */
private String schemaName;

@@ -50,14 +55,18 @@ public QueryExecuteRequest() {
/**
* Constructor.
*
* @param stmtType Expected statement type.
* @param schemaName Cache name.
* @param pageSize Fetch size.
* @param maxRows Max rows.
* @param sqlQry SQL query.
* @param args Arguments list.
*/
public QueryExecuteRequest(String schemaName, int pageSize, int maxRows, String sqlQry, Object[] args) {
public QueryExecuteRequest(JdbcStatementType stmtType,
String schemaName, int pageSize, int maxRows, String sqlQry, Object[] args) {
Objects.requireNonNull(stmtType);

this.stmtType = stmtType;
this.schemaName = schemaName == null || schemaName.isEmpty() ? null : schemaName;
this.pageSize = pageSize;
this.maxRows = maxRows;
@@ -66,7 +75,7 @@ public QueryExecuteRequest(String schemaName, int pageSize, int maxRows, String
}

/**
* Get the page size.
* Returns the page size.
*
* @return Page size.
*/
@@ -75,7 +84,7 @@ public int pageSize() {
}

/**
* Get the max rows.
* Returns the max rows.
*
* @return Max rows.
*/
@@ -84,7 +93,7 @@ public int maxRows() {
}

/**
* Get the sql query.
* Returns the sql query.
*
* @return Sql query.
*/
@@ -93,7 +102,7 @@ public String sqlQuery() {
}

/**
* Get the arguments.
* Returns the arguments.
*
* @return Sql query arguments.
*/
@@ -102,17 +111,27 @@ public Object[] arguments() {
}

/**
* Get the schema name.
* Returns the schema name.
*
* @return Schema name.
*/
public String schemaName() {
return schemaName;
}

/**
* Returns the expected statement type.
*
* @return Statement type.
*/
public JdbcStatementType getStmtType() {
return stmtType;
}

/** {@inheritDoc} */
@Override
public void writeBinary(ClientMessagePacker packer) {
packer.packByte(stmtType.getId());
packer.packString(schemaName);
packer.packInt(pageSize);
packer.packInt(maxRows);
@@ -124,6 +143,7 @@ public void writeBinary(ClientMessagePacker packer) {
/** {@inheritDoc} */
@Override
public void readBinary(ClientMessageUnpacker unpacker) {
stmtType = JdbcStatementType.getStatement(unpacker.unpackByte());
schemaName = unpacker.unpackString();
pageSize = unpacker.unpackInt();
maxRows = unpacker.unpackInt();
@@ -34,6 +34,7 @@
import org.apache.ignite.client.handler.requests.sql.JdbcMetadataCatalog;
import org.apache.ignite.client.handler.requests.sql.JdbcQueryCursor;
import org.apache.ignite.client.proto.query.JdbcQueryEventHandler;
import org.apache.ignite.client.proto.query.JdbcStatementType;
import org.apache.ignite.client.proto.query.event.BatchExecuteRequest;
import org.apache.ignite.client.proto.query.event.BatchExecuteResult;
import org.apache.ignite.client.proto.query.event.BatchPreparedStmntRequest;
@@ -55,10 +56,15 @@
import org.apache.ignite.client.proto.query.event.QueryFetchResult;
import org.apache.ignite.client.proto.query.event.QuerySingleResult;
import org.apache.ignite.client.proto.query.event.Response;
import org.apache.ignite.internal.sql.engine.QueryContext;
import org.apache.ignite.internal.sql.engine.QueryProcessor;
import org.apache.ignite.internal.sql.engine.QueryValidator;
import org.apache.ignite.internal.sql.engine.ResultFieldMetadata;
import org.apache.ignite.internal.sql.engine.ResultSetMetadata;
import org.apache.ignite.internal.sql.engine.SqlCursor;
import org.apache.ignite.internal.sql.engine.exec.QueryValidationException;
import org.apache.ignite.internal.sql.engine.prepare.QueryPlan;
import org.apache.ignite.internal.sql.engine.prepare.QueryPlan.Type;
import org.apache.ignite.internal.sql.engine.util.Commons;
import org.apache.ignite.internal.util.Cursor;

@@ -99,7 +105,9 @@ public CompletableFuture<QueryExecuteResult> queryAsync(QueryExecuteRequest req)

List<SqlCursor<List<?>>> cursors;
try {
List<SqlCursor<List<?>>> queryCursors = processor.query(req.schemaName(), req.sqlQuery(),
QueryContext context = createQueryContext(req.getStmtType());

List<SqlCursor<List<?>>> queryCursors = processor.query(context, req.schemaName(), req.sqlQuery(),
req.arguments() == null ? OBJECT_EMPTY_ARRAY : req.arguments());

cursors = queryCursors.stream()
@@ -134,6 +142,27 @@ public CompletableFuture<QueryExecuteResult> queryAsync(QueryExecuteRequest req)
return CompletableFuture.completedFuture(new QueryExecuteResult(results));
}

private QueryContext createQueryContext(JdbcStatementType stmtType) {
if (stmtType == JdbcStatementType.ANY_STATEMENT_TYPE) {
return QueryContext.of();
}

QueryValidator validator = (QueryPlan plan) -> {
if (plan.type() == Type.QUERY || plan.type() == Type.EXPLAIN) {
if (stmtType == JdbcStatementType.SELECT_STATEMENT_TYPE) {
return;
}
throw new QueryValidationException("Given statement type does not match that declared by JDBC driver.");
}
if (stmtType == JdbcStatementType.UPDATE_STATEMENT_TYPE) {
return;
}
throw new QueryValidationException("Given statement type does not match that declared by JDBC driver.");
};

return QueryContext.of(validator);
}

/** {@inheritDoc} */
@Override
public CompletableFuture<QueryFetchResult> fetchAsync(QueryFetchRequest req) {
@@ -172,9 +201,11 @@ public CompletableFuture<BatchExecuteResult> batchAsync(BatchExecuteRequest req)

IntList res = new IntArrayList(queries.size());

QueryContext context = createQueryContext(JdbcStatementType.UPDATE_STATEMENT_TYPE);

for (String query : queries) {
try {
executeAndCollectUpdateCount(req.schemaName(), query, OBJECT_EMPTY_ARRAY, res);
executeAndCollectUpdateCount(context, req.schemaName(), query, OBJECT_EMPTY_ARRAY, res);
} catch (Exception e) {
return handleBatchException(e, query, res);
}
@@ -188,9 +219,11 @@ public CompletableFuture<BatchExecuteResult> batchAsync(BatchExecuteRequest req)
public CompletableFuture<BatchExecuteResult> batchPrepStatementAsync(BatchPreparedStmntRequest req) {
IntList res = new IntArrayList(req.getArgs().size());

QueryContext context = createQueryContext(JdbcStatementType.UPDATE_STATEMENT_TYPE);

try {
for (Object[] arg : req.getArgs()) {
executeAndCollectUpdateCount(req.schemaName(), req.getQuery(), arg, res);
executeAndCollectUpdateCount(context, req.schemaName(), req.getQuery(), arg, res);
}
} catch (Exception e) {
return handleBatchException(e, req.getQuery(), res);
@@ -199,8 +232,8 @@ public CompletableFuture<BatchExecuteResult> batchPrepStatementAsync(BatchPrepar
return CompletableFuture.completedFuture(new BatchExecuteResult(res.toIntArray()));
}

private void executeAndCollectUpdateCount(String schema, String sql, Object[] arg, IntList res) {
List<SqlCursor<List<?>>> cursors = processor.query(schema, sql, arg);
private void executeAndCollectUpdateCount(QueryContext context, String schema, String sql, Object[] arg, IntList res) {
List<SqlCursor<List<?>>> cursors = processor.query(context, schema, sql, arg);
for (SqlCursor<List<?>> cursor : cursors) {
long updatedRows = (long) cursor.next().get(0);
res.add((int) updatedRows);
@@ -46,6 +46,7 @@
import java.util.List;
import java.util.Objects;
import org.apache.ignite.client.proto.query.IgniteQueryErrorCode;
import org.apache.ignite.client.proto.query.JdbcStatementType;
import org.apache.ignite.client.proto.query.SqlStateCode;
import org.apache.ignite.client.proto.query.event.BatchExecuteResult;
import org.apache.ignite.client.proto.query.event.BatchPreparedStmntRequest;
@@ -81,7 +82,7 @@ public class JdbcPreparedStatement extends JdbcStatement implements PreparedStat
/** {@inheritDoc} */
@Override
public ResultSet executeQuery() throws SQLException {
executeWithArguments();
executeWithArguments(JdbcStatementType.SELECT_STATEMENT_TYPE);

ResultSet rs = getResultSet();

@@ -132,7 +133,7 @@ public int[] executeBatch() throws SQLException {
/** {@inheritDoc} */
@Override
public int executeUpdate() throws SQLException {
executeWithArguments();
executeWithArguments(JdbcStatementType.UPDATE_STATEMENT_TYPE);

int res = getUpdateCount();

@@ -174,7 +175,7 @@ public int executeUpdate(String sql, int[] colNames) throws SQLException {
/** {@inheritDoc} */
@Override
public boolean execute() throws SQLException {
executeWithArguments();
executeWithArguments(JdbcStatementType.ANY_STATEMENT_TYPE);

return isQuery();
}
@@ -649,10 +650,11 @@ private void setArgument(int paramIdx, Object val) throws SQLException {
/**
* Execute query with arguments and nullify them afterwards.
*
* @param JdbcStatementType Expected statement type.
* @throws SQLException If failed.
*/
private void executeWithArguments() throws SQLException {
execute0(sql, currentArgs);
private void executeWithArguments(JdbcStatementType statementType) throws SQLException {
execute0(statementType, sql, currentArgs);

currentArgs = null;
}
@@ -33,6 +33,7 @@
import java.util.List;
import java.util.Objects;
import org.apache.ignite.client.proto.query.IgniteQueryErrorCode;
import org.apache.ignite.client.proto.query.JdbcStatementType;
import org.apache.ignite.client.proto.query.SqlStateCode;
import org.apache.ignite.client.proto.query.event.BatchExecuteRequest;
import org.apache.ignite.client.proto.query.event.BatchExecuteResult;
@@ -100,7 +101,7 @@ public class JdbcStatement implements Statement {
/** {@inheritDoc} */
@Override
public ResultSet executeQuery(String sql) throws SQLException {
execute0(Objects.requireNonNull(sql), null);
execute0(JdbcStatementType.SELECT_STATEMENT_TYPE, Objects.requireNonNull(sql), null);

ResultSet rs = getResultSet();

@@ -118,7 +119,7 @@ public ResultSet executeQuery(String sql) throws SQLException {
* @param args Query parameters.
* @throws SQLException Onj error.
*/
protected void execute0(String sql, List<Object> args) throws SQLException {
protected void execute0(JdbcStatementType stmtType, String sql, List<Object> args) throws SQLException {
ensureNotClosed();

closeResults();
@@ -127,7 +128,7 @@ protected void execute0(String sql, List<Object> args) throws SQLException {
throw new SQLException("SQL query is empty.");
}

QueryExecuteRequest req = new QueryExecuteRequest(schema, pageSize, maxRows, sql,
QueryExecuteRequest req = new QueryExecuteRequest(stmtType, schema, pageSize, maxRows, sql,
args == null ? ArrayUtils.OBJECT_EMPTY_ARRAY : args.toArray());

QueryExecuteResult res = conn.handler().queryAsync(req).join();
@@ -156,7 +157,7 @@ protected void execute0(String sql, List<Object> args) throws SQLException {
/** {@inheritDoc} */
@Override
public int executeUpdate(String sql) throws SQLException {
execute0(Objects.requireNonNull(sql), null);
execute0(JdbcStatementType.UPDATE_STATEMENT_TYPE, Objects.requireNonNull(sql), null);

int res = getUpdateCount();

@@ -320,7 +321,7 @@ public void setCursorName(String name) throws SQLException {
public boolean execute(String sql) throws SQLException {
ensureNotClosed();

execute0(Objects.requireNonNull(sql), null);
execute0(JdbcStatementType.ANY_STATEMENT_TYPE, Objects.requireNonNull(sql), null);

return isQuery();
}
@@ -19,6 +19,7 @@

import java.util.Collections;
import java.util.List;
import org.apache.ignite.internal.sql.engine.QueryContext;
import org.apache.ignite.internal.sql.engine.QueryProcessor;
import org.apache.ignite.internal.sql.engine.SqlCursor;

@@ -31,6 +32,11 @@ public List<SqlCursor<List<?>>> query(String schemaName, String qry, Object... p
return Collections.singletonList(new FakeCursor());
}

@Override
public List<SqlCursor<List<?>>> query(QueryContext context, String schemaName, String qry, Object... params) {
return Collections.singletonList(new FakeCursor());
}

@Override
public void start() {

0 comments on commit 3288848

Please sign in to comment.