Skip to content

Commit

Permalink
perf: skip analyzeQuery for queries (#80)
Browse files Browse the repository at this point in the history
* perf: skip analyzeQuery for queries

Skip the analyzeQuery round trip for DescribePortal messages for
queries. This is safe, as:
1. The query will not make any changes to the database, so even if we
   never receive an Execute message, we have not made any changes.
2. ExecuteQuery returns the information we need for Describe.

Note that for DML/DDL statements the DescribePortal message is handled
internally by PGAdapter as it knows that the answer is NoDataResponse.

* chore: fix comment formatting

* fix: remove local file reference
  • Loading branch information
olavloite committed Apr 1, 2022
1 parent 4c24ef9 commit 98e430a
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

package com.google.cloud.spanner.pgadapter.statements;

import com.google.cloud.spanner.ReadContext.QueryAnalyzeMode;
import com.google.cloud.spanner.ResultSet;
import com.google.cloud.spanner.SpannerException;
import com.google.cloud.spanner.Statement;
import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
Expand Down Expand Up @@ -82,24 +80,18 @@ public void setResultFormatCodes(List<Short> resultFormatCodes) {

@Override
public DescribeMetadata describe() {
// TODO: Consider replacing this with an execute call, so we don't take two round-trips to the
// backend just to first describe and then execute a query.
try (ResultSet resultSet = connection.analyzeQuery(this.statement, QueryAnalyzeMode.PLAN)) {
// TODO: Remove ResultSet.next() call once this is supported in the client library.
// See https://github.com/googleapis/java-spanner/pull/1691
resultSet.next();
return new DescribePortalMetadata(resultSet);
try {
// Pre-emptively execute the statement, even though it is only asked to be described. This is
// a lot more efficient than taking two round trips to the server, and getting a
// DescribePortal message without a following Execute message is extremely rare, as that would
// only happen if the client is ill-behaved, or if the client crashes between the
// DescribePortal and Execute.
this.statementResult = connection.executeQuery(this.statement);
this.hasMoreData = this.statementResult.next();
return new DescribePortalMetadata(statementResult);
} catch (SpannerException e) {
/* Generally this error will occur when a non-SELECT portal statement is described in Spanner,
however, it could occur when a statement is incorrectly formatted. Though we could catch
this early if we could parse the type of statement, it is a significant burden on the
proxy. As such, we send the user a descriptive message to help them understand the issue
in case they misuse the method.
*/
logger.log(Level.SEVERE, e, e::toString);
throw new IllegalStateException(
"Something went wrong in Describing this statement."
+ "Note that non-SELECT result types in Spanner cannot be described.");
logger.log(Level.SEVERE, e, e::getMessage);
throw e;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,15 @@ public void setParameterDataTypes(List<Integer> parameterDataTypes) {
@Override
public void execute() {
this.executed = true;
try {
StatementResult result = connection.execute(this.statement);
this.updateResultCount(result);
} catch (SpannerException e) {
handleExecutionException(e);
// If the portal has already been described, the statement has already been executed, and we
// don't need to do that once more.
if (this.statementResult == null) {
try {
StatementResult result = connection.execute(this.statement);
this.updateResultCount(result);
} catch (SpannerException e) {
handleExecutionException(e);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,12 @@ public void testQuery() throws SQLException {
}
}

// The statement is sent twice to the mock server:
// 1. The first time it is sent with the PLAN mode enabled.
// 2. The second time it is sent in normal execute mode.
// TODO: Consider skipping the PLAN step and always execute the query already when we receive a
// DESCRIBE portal message.
assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
ExecuteSqlRequest planRequest = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0);
assertEquals(QueryMode.PLAN, planRequest.getQueryMode());
// The statement is only sent once to the mock server. The DescribePortal message will trigger
// the execution of the query, and the result from that execution will be used for the Execute
// message.
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
ExecuteSqlRequest executeRequest =
mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(1);
mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).get(0);
assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode());

for (ExecuteSqlRequest request : mockSpanner.getRequestsOfType(ExecuteSqlRequest.class)) {
Expand Down Expand Up @@ -163,38 +159,32 @@ public void testQueryWithParameters() throws SQLException {
}

List<ExecuteSqlRequest> requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class);
assertEquals(2, requests.size());

ExecuteSqlRequest planRequest = requests.get(0);
assertEquals(QueryMode.PLAN, planRequest.getQueryMode());
assertEquals(pgSql, planRequest.getSql());
assertEquals(1, requests.size());

ExecuteSqlRequest executeRequest = requests.get(1);
ExecuteSqlRequest executeRequest = requests.get(0);
assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode());
assertEquals(pgSql, executeRequest.getSql());

for (ExecuteSqlRequest request : requests) {
Map<String, Value> params = request.getParams().getFieldsMap();
Map<String, Type> types = request.getParamTypesMap();

assertEquals(TypeCode.INT64, types.get("p1").getCode());
assertEquals("1", params.get("p1").getStringValue());
assertEquals(TypeCode.BOOL, types.get("p2").getCode());
assertTrue(params.get("p2").getBoolValue());
assertEquals(TypeCode.BYTES, types.get("p3").getCode());
assertEquals(
Base64.getEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)),
params.get("p3").getStringValue());
assertEquals(TypeCode.FLOAT64, types.get("p4").getCode());
assertEquals(3.14d, params.get("p4").getNumberValue(), 0.0d);
assertEquals(TypeCode.NUMERIC, types.get("p5").getCode());
assertEquals(TypeAnnotationCode.PG_NUMERIC, types.get("p5").getTypeAnnotation());
assertEquals("6.626", params.get("p5").getStringValue());
assertEquals(TypeCode.TIMESTAMP, types.get("p6").getCode());
assertEquals("2022-02-16T13:18:02.123457000Z", params.get("p6").getStringValue());
assertEquals(TypeCode.STRING, types.get("p7").getCode());
assertEquals("test", params.get("p7").getStringValue());
}
Map<String, Value> params = executeRequest.getParams().getFieldsMap();
Map<String, Type> types = executeRequest.getParamTypesMap();

assertEquals(TypeCode.INT64, types.get("p1").getCode());
assertEquals("1", params.get("p1").getStringValue());
assertEquals(TypeCode.BOOL, types.get("p2").getCode());
assertTrue(params.get("p2").getBoolValue());
assertEquals(TypeCode.BYTES, types.get("p3").getCode());
assertEquals(
Base64.getEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)),
params.get("p3").getStringValue());
assertEquals(TypeCode.FLOAT64, types.get("p4").getCode());
assertEquals(3.14d, params.get("p4").getNumberValue(), 0.0d);
assertEquals(TypeCode.NUMERIC, types.get("p5").getCode());
assertEquals(TypeAnnotationCode.PG_NUMERIC, types.get("p5").getTypeAnnotation());
assertEquals("6.626", params.get("p5").getStringValue());
assertEquals(TypeCode.TIMESTAMP, types.get("p6").getCode());
assertEquals("2022-02-16T13:18:02.123457000Z", params.get("p6").getStringValue());
assertEquals(TypeCode.STRING, types.get("p7").getCode());
assertEquals("test", params.get("p7").getStringValue());
}

@Test
Expand Down Expand Up @@ -283,6 +273,7 @@ public void testNullValues() throws SQLException {
assertFalse(resultSet.next());
}
}
assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
}

@Test
Expand Down Expand Up @@ -376,6 +367,7 @@ public void testTwoQueries() throws SQLException {
assertEquals(-1, statement.getUpdateCount());
}
}
assertEquals(2, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -310,18 +311,16 @@ public void testPreparedStatementDescribeDoesNotThrowException() throws Exceptio
}

@Test
public void testPortalStatement() throws Exception {
public void testPortalStatement() {
String sqlStatement = "SELECT * FROM users WHERE age > $1 AND age < $2 AND name = $3";
when(connection.analyzeQuery(Statement.of(sqlStatement), QueryAnalyzeMode.PLAN))
.thenReturn(resultSet);
when(connection.executeQuery(Statement.of(sqlStatement))).thenReturn(resultSet);

IntermediatePortalStatement intermediateStatement =
new IntermediatePortalStatement(options, parse(sqlStatement), connection);

intermediateStatement.describe();

Mockito.verify(connection, Mockito.times(1))
.analyzeQuery(Statement.of(sqlStatement), QueryAnalyzeMode.PLAN);
Mockito.verify(connection, Mockito.times(1)).executeQuery(Statement.of(sqlStatement));

assertEquals(intermediateStatement.getParameterFormatCode(0), 0);
assertEquals(intermediateStatement.getParameterFormatCode(1), 0);
Expand Down Expand Up @@ -349,18 +348,20 @@ public void testPortalStatement() throws Exception {
assertEquals(intermediateStatement.getResultFormatCode(2), 0);
}

@Test(expected = IllegalStateException.class)
public void testPortalStatementDescribePropagatesFailure() throws Exception {
@Test
public void testPortalStatementDescribePropagatesFailure() {
String sqlStatement = "SELECT * FROM users WHERE age > $1 AND age < $2 AND name = $3";

IntermediatePortalStatement intermediateStatement =
new IntermediatePortalStatement(options, parse(sqlStatement), connection);

when(connection.analyzeQuery(Statement.of(sqlStatement), QueryAnalyzeMode.PLAN))
when(connection.executeQuery(Statement.of(sqlStatement)))
.thenThrow(
SpannerExceptionFactory.newSpannerException(ErrorCode.INVALID_ARGUMENT, "test error"));

intermediateStatement.describe();
SpannerException exception =
assertThrows(SpannerException.class, intermediateStatement::describe);
assertEquals(ErrorCode.INVALID_ARGUMENT, exception.getErrorCode());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,14 @@ public void testHelloWorld() {

assertNull(res);
List<ExecuteSqlRequest> requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class);
// pgx sends the query three times:
// pgx sends the query two times:
// 1. DESCRIBE statement
// 2. DESCRIBE portal
// 3. EXECUTE portal
assertEquals(3, requests.size());
// 2. DESCRIBE/EXECUTE portal
assertEquals(2, requests.size());
int index = 0;
for (ExecuteSqlRequest request : requests) {
assertEquals(sql, request.getSql());
if (index < 2) {
if (index < 1) {
assertEquals(QueryMode.PLAN, request.getQueryMode());
} else {
assertEquals(QueryMode.NORMAL, request.getQueryMode());
Expand All @@ -167,15 +166,14 @@ public void testSelect1() {

assertNull(res);
List<ExecuteSqlRequest> requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class);
// pgx sends the query three times:
// pgx sends the query two times:
// 1. DESCRIBE statement
// 2. DESCRIBE portal
// 3. EXECUTE portal
assertEquals(3, requests.size());
// 2. DESCRIBE/EXECUTE portal
assertEquals(2, requests.size());
int index = 0;
for (ExecuteSqlRequest request : requests) {
assertEquals(sql, request.getSql());
if (index < 2) {
if (index < 1) {
assertEquals(QueryMode.PLAN, request.getQueryMode());
} else {
assertEquals(QueryMode.NORMAL, request.getQueryMode());
Expand Down Expand Up @@ -221,15 +219,14 @@ public void testQueryWithParameter() {

assertNull(res);
List<ExecuteSqlRequest> requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class);
// pgx sends the query three times:
// pgx sends the query two times:
// 1. DESCRIBE statement
// 2. DESCRIBE portal
// 3. EXECUTE portal
assertEquals(3, requests.size());
// 2. DESCRIBE/EXECUTE portal
assertEquals(2, requests.size());
int index = 0;
for (ExecuteSqlRequest request : requests) {
assertEquals(sql, request.getSql());
if (index < 2) {
if (index < 1) {
assertEquals(QueryMode.PLAN, request.getQueryMode());
} else {
assertEquals(QueryMode.NORMAL, request.getQueryMode());
Expand All @@ -247,15 +244,14 @@ public void testQueryAllDataTypes() {

assertNull(res);
List<ExecuteSqlRequest> requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class);
// pgx sends the query three times:
// pgx sends the query two times:
// 1. DESCRIBE statement
// 2. DESCRIBE portal
// 3. EXECUTE portal
assertEquals(3, requests.size());
// 2. EXECUTE portal
assertEquals(2, requests.size());
int index = 0;
for (ExecuteSqlRequest request : requests) {
assertEquals(sql, request.getSql());
if (index < 2) {
if (index < 1) {
assertEquals(QueryMode.PLAN, request.getQueryMode());
} else {
assertEquals(QueryMode.NORMAL, request.getQueryMode());
Expand Down

0 comments on commit 98e430a

Please sign in to comment.