Skip to content

Commit

Permalink
fix: respect result format code from Bind msg (#238)
Browse files Browse the repository at this point in the history
* fix: respect result format code from Bind msg

The result format code given in a Bind message should always be
respected by PGAdapter. This was not done if PGAdapter received a
request for text format for a column with a data type that has binary
format as its default type.

* test: request data in binary and text format for all types
  • Loading branch information
olavloite committed Jul 1, 2022
1 parent a006dec commit 708fa42
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ protected void sendPayload() throws Exception {
this.resultSet.getType().getStructFields().get(column_index).getName().getBytes(UTF8));
// If it can be identified as a column of a table, the object ID of the table.
this.outputStream.writeByte(DEFAULT_FLAG);
// TODO: pass through Postgres types
// If it can be identified as a column of a table, the attribute number of the column
this.outputStream.writeInt(DEFAULT_FLAG);
// The object ID of the field's data type.
Expand All @@ -105,9 +104,7 @@ protected void sendPayload() throws Exception {
short format =
this.statement == null
? defaultFormat.getCode()
: this.statement.getResultFormatCode(column_index) == 0
? defaultFormat.getCode()
: this.statement.getResultFormatCode(column_index);
: this.statement.getResultFormatCode(column_index);
this.outputStream.writeShort(format);
}
}
Expand Down
18 changes: 15 additions & 3 deletions src/test/golang/pgadapter_pgx_tests/pgx.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func TestQueryWithParameter(connString string) *C.char {
}

//export TestQueryAllDataTypes
func TestQueryAllDataTypes(connString string) *C.char {
func TestQueryAllDataTypes(connString string, oid, format int16) *C.char {
ctx := context.Background()
conn, err := pgx.Connect(ctx, connString)
if err != nil {
Expand All @@ -120,7 +120,19 @@ func TestQueryAllDataTypes(connString string) *C.char {
var dateValue time.Time
var varcharValue string

row := conn.QueryRow(ctx, "SELECT * FROM all_types WHERE col_bigint=1")
var row pgx.Row
if oid != 0 {
formats := make(pgx.QueryResultFormatsByOID)
for _, o := range []uint32{
pgtype.Int8OID, pgtype.BoolOID, pgtype.ByteaOID, pgtype.Float8OID, pgtype.Int4OID,
pgtype.NumericOID, pgtype.TimestamptzOID, pgtype.DateOID, pgtype.VarcharOID} {
formats[o] = conn.ConnInfo().ResultFormatCodeForOID(o)
}
formats[uint32(oid)] = format
row = conn.QueryRow(ctx, "SELECT * FROM all_types WHERE col_bigint=1", formats)
} else {
row = conn.QueryRow(ctx, "SELECT * FROM all_types WHERE col_bigint=1")
}
err = row.Scan(
&bigintValue,
&boolValue,
Expand Down Expand Up @@ -161,7 +173,7 @@ func TestQueryAllDataTypes(connString string) *C.char {
}
// Encoding the timestamp values as a parameter will truncate it to microsecond precision.
wantTimestamptzValue, _ := time.Parse(time.RFC3339Nano, "2022-02-16T13:18:02.123456+00:00")
if strings.Contains(connString, "prefer_simple_protocol=true") {
if strings.Contains(connString, "prefer_simple_protocol=true") || (oid == pgtype.TimestamptzOID && format == 0) {
// Simple protocol writes the timestamp as a string and preserves nanosecond precision.
wantTimestamptzValue, _ = time.Parse(time.RFC3339Nano, "2022-02-16T13:18:02.123456789+00:00")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ public void testSelectHelloWorld() {

@Test
public void testQueryAllDataTypes() {
assertNull(pgxTest.TestQueryAllDataTypes(createConnString()));
assertNull(pgxTest.TestQueryAllDataTypes(createConnString(), 0, 0));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;
import org.postgresql.core.Oid;

/**
* Tests PGAdapter using the native Go pgx driver. The Go code can be found in
Expand Down Expand Up @@ -214,20 +215,41 @@ public void testQueryAllDataTypes() {
String sql = "SELECT * FROM all_types WHERE col_bigint=1";
mockSpanner.putStatementResult(StatementResult.query(Statement.of(sql), ALL_TYPES_RESULTSET));

String res = pgxTest.TestQueryAllDataTypes(createConnString());

assertNull(res);
List<ExecuteSqlRequest> requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class);
// pgx by default always uses prepared statements. As this statement does not contain any
// parameters, we don't need to describe the parameter types, so it is 'only' sent twice to the
// backend.
assertEquals(2, requests.size());
ExecuteSqlRequest describeRequest = requests.get(0);
assertEquals(sql, describeRequest.getSql());
assertEquals(QueryMode.PLAN, describeRequest.getQueryMode());
ExecuteSqlRequest executeRequest = requests.get(1);
assertEquals(sql, executeRequest.getSql());
assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode());
// Request the data of each column once in both text and binary format to ensure that we support
// each format for all data types, *AND* that PGAdapter actually uses the format that the client
// requests.
for (int oid :
new int[] {
Oid.INT8,
Oid.BOOL,
Oid.BYTEA,
Oid.FLOAT8,
Oid.INT4,
Oid.NUMERIC,
Oid.DATE,
Oid.TIMESTAMPTZ,
Oid.VARCHAR
}) {
for (int format : new int[] {0, 1}) {
String res = pgxTest.TestQueryAllDataTypes(createConnString(), oid, format);

assertNull(res);
List<ExecuteSqlRequest> requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class);
// pgx by default always uses prepared statements. As this statement does not contain any
// parameters, we don't need to describe the parameter types, so it is 'only' sent twice to
// the
// backend.
assertEquals(2, requests.size());
ExecuteSqlRequest describeRequest = requests.get(0);
assertEquals(sql, describeRequest.getSql());
assertEquals(QueryMode.PLAN, describeRequest.getQueryMode());
ExecuteSqlRequest executeRequest = requests.get(1);
assertEquals(sql, executeRequest.getSql());
assertEquals(QueryMode.NORMAL, executeRequest.getQueryMode());

mockSpanner.clearRequests();
}
}
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ public void testQueryAllDataTypes() {
String sql = "SELECT * FROM all_types WHERE col_bigint=1";
mockSpanner.putStatementResult(StatementResult.query(Statement.of(sql), ALL_TYPES_RESULTSET));

String res = pgxTest.TestQueryAllDataTypes(createConnString());
String res = pgxTest.TestQueryAllDataTypes(createConnString(), 0, 0);

assertNull(res);
List<ExecuteSqlRequest> requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public interface PgxTest extends Library {

String TestQueryWithParameter(GoString connString);

String TestQueryAllDataTypes(GoString connString);
String TestQueryAllDataTypes(GoString connString, int oid, int format);

String TestInsertAllDataTypes(GoString connString);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,79 @@ public void SendPayloadNullStatementTest() throws Exception {
assertEquals(0, outputReader.readShort());
}

@Test
public void SendPayloadStatementWithBinarayAndTextOptionTest() throws Exception {
Type rowType =
Type.struct(
StructField.of("numeric-text", Type.pgNumeric()),
StructField.of("numeric-binary", Type.pgNumeric()));
when(metadata.getColumnCount()).thenReturn(rowType.getStructFields().size());
when(metadata.getType()).thenReturn(rowType);
when(metadata.getColumnType(Mockito.anyInt())).thenReturn(Type.pgNumeric());
when(statement.getResultFormatCode(0)).thenReturn((short) 0);
when(statement.getResultFormatCode(1)).thenReturn((short) 1);
JSONParser parser = new JSONParser();
JSONObject commandMetadata = (JSONObject) parser.parse(EMPTY_COMMAND_JSON);
OptionsMetadata options =
new OptionsMetadata(
"jdbc:cloudspanner:/projects/test-project/instances/test-instance/databases/test-database",
8888,
TextFormat.POSTGRESQL,
true,
false,
false,
false,
commandMetadata);
QueryMode mode = QueryMode.EXTENDED;
RowDescriptionResponse response =
new RowDescriptionResponse(output, statement, metadata, options, mode);
response.sendPayload();
DataInputStream outputReader =
new DataInputStream(new ByteArrayInputStream(buffer.toByteArray()));

// column count
assertEquals(2, outputReader.readShort());
// column name
int numOfBytes = "numeric-text".getBytes(UTF8).length;
byte[] bytes = new byte[numOfBytes];
assertEquals(numOfBytes, outputReader.read(bytes, 0, numOfBytes));
assertEquals("numeric-text", new String(bytes, UTF8));
// null terminator
assertEquals(DEFAULT_FLAG, outputReader.readByte());
// table oid
assertEquals(DEFAULT_FLAG, outputReader.readInt());
// column index
assertEquals(DEFAULT_FLAG, outputReader.readShort());
// type oid
assertEquals(Oid.NUMERIC, outputReader.readInt());
// type size
assertEquals(-1, outputReader.readShort());
// type modifier
assertEquals(DEFAULT_FLAG, outputReader.readInt());
// format code
assertEquals(0, outputReader.readShort());

// column name
numOfBytes = "numeric-binary".getBytes(UTF8).length;
bytes = new byte[numOfBytes];
assertEquals(numOfBytes, outputReader.read(bytes, 0, numOfBytes));
assertEquals("numeric-binary", new String(bytes, UTF8));
// null terminator
assertEquals(DEFAULT_FLAG, outputReader.readByte());
// table oid
assertEquals(DEFAULT_FLAG, outputReader.readInt());
// column index
assertEquals(DEFAULT_FLAG, outputReader.readShort());
// type oid
assertEquals(Oid.NUMERIC, outputReader.readInt());
// type size
assertEquals(-1, outputReader.readShort());
// type modifier
assertEquals(DEFAULT_FLAG, outputReader.readInt());
// format code
assertEquals(1, outputReader.readShort());
}

@Test
public void SendPayloadStatementWithBinaryOptionTest() throws Exception {
int COLUMN_COUNT = 1;
Expand All @@ -168,7 +241,7 @@ public void SendPayloadStatementWithBinaryOptionTest() throws Exception {
when(metadata.getColumnCount()).thenReturn(COLUMN_COUNT);
when(metadata.getType()).thenReturn(rowType);
when(metadata.getColumnType(Mockito.anyInt())).thenReturn(Type.int64());
when(statement.getResultFormatCode(Mockito.anyInt())).thenReturn((short) 0);
when(statement.getResultFormatCode(Mockito.anyInt())).thenReturn((short) 1);
JSONParser parser = new JSONParser();
JSONObject commandMetadata = (JSONObject) parser.parse(EMPTY_COMMAND_JSON);
OptionsMetadata options =
Expand Down

0 comments on commit 708fa42

Please sign in to comment.