Skip to content

Commit

Permalink
[Format][C][Java][Python] Simplify execute/query interface (#69)
Browse files Browse the repository at this point in the history
* [Format][C][Java][Python] Simplify execute/query interface

Fixes #61.

* Update vendored nanoarrow

* [C] Split Execute
  • Loading branch information
lidavidm authored Aug 26, 2022
1 parent 26235d9 commit d8c5821
Show file tree
Hide file tree
Showing 28 changed files with 2,075 additions and 2,022 deletions.
284 changes: 152 additions & 132 deletions adbc.h

Large diffs are not rendered by default.

106 changes: 59 additions & 47 deletions c/driver_manager/adbc_driver_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

#include "adbc_driver_manager.h"
#include <adbc.h>

#include <algorithm>
#include <cstring>
Expand Down Expand Up @@ -69,7 +70,7 @@ AdbcStatusCode ConnectionCommit(struct AdbcConnection*, struct AdbcError* error)

AdbcStatusCode ConnectionGetObjects(struct AdbcConnection*, int, const char*, const char*,
const char*, const char**, const char*,
struct AdbcStatement*, struct AdbcError* error) {
struct ArrowArrayStream*, struct AdbcError* error) {
return ADBC_STATUS_NOT_IMPLEMENTED;
}

Expand All @@ -79,7 +80,15 @@ AdbcStatusCode ConnectionGetTableSchema(struct AdbcConnection*, const char*, con
return ADBC_STATUS_NOT_IMPLEMENTED;
}

AdbcStatusCode ConnectionGetTableTypes(struct AdbcConnection*, struct AdbcStatement*,
AdbcStatusCode ConnectionGetTableTypes(struct AdbcConnection*, struct ArrowArrayStream*,
struct AdbcError* error) {
return ADBC_STATUS_NOT_IMPLEMENTED;
}

AdbcStatusCode ConnectionReadPartition(struct AdbcConnection* connection,
const uint8_t* serialized_partition,
size_t serialized_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
return ADBC_STATUS_NOT_IMPLEMENTED;
}
Expand All @@ -98,7 +107,11 @@ AdbcStatusCode StatementBind(struct AdbcStatement*, struct ArrowArray*,
return ADBC_STATUS_NOT_IMPLEMENTED;
}

AdbcStatusCode StatementExecute(struct AdbcStatement*, struct AdbcError* error) {
AdbcStatusCode StatementExecutePartitions(struct AdbcStatement* statement,
struct ArrowSchema* schema,
struct AdbcPartitions* partitions,
int64_t* rows_affected,
struct AdbcError* error) {
return ADBC_STATUS_NOT_IMPLEMENTED;
}

Expand All @@ -108,16 +121,6 @@ AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement* statement,
return ADBC_STATUS_NOT_IMPLEMENTED;
}

AdbcStatusCode StatementGetPartitionDesc(struct AdbcStatement*, uint8_t*,
struct AdbcError*) {
return ADBC_STATUS_NOT_IMPLEMENTED;
}

AdbcStatusCode StatementGetPartitionDescSize(struct AdbcStatement*, size_t*,
struct AdbcError*) {
return ADBC_STATUS_NOT_IMPLEMENTED;
}

AdbcStatusCode StatementPrepare(struct AdbcStatement*, struct AdbcError* error) {
return ADBC_STATUS_NOT_IMPLEMENTED;
}
Expand Down Expand Up @@ -328,27 +331,27 @@ AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection,

AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection,
uint32_t* info_codes, size_t info_codes_length,
struct AdbcStatement* statement,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (!connection->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
return connection->private_driver->ConnectionGetInfo(
connection, info_codes, info_codes_length, statement, error);
return connection->private_driver->ConnectionGetInfo(connection, info_codes,
info_codes_length, out, error);
}

AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth,
const char* catalog, const char* db_schema,
const char* table_name, const char** table_types,
const char* column_name,
struct AdbcStatement* statement,
struct ArrowArrayStream* stream,
struct AdbcError* error) {
if (!connection->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
return connection->private_driver->ConnectionGetObjects(
connection, depth, catalog, db_schema, table_name, table_types, column_name,
statement, error);
connection, depth, catalog, db_schema, table_name, table_types, column_name, stream,
error);
}

AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
Expand All @@ -364,13 +367,12 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
}

AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
struct AdbcStatement* statement,
struct ArrowArrayStream* stream,
struct AdbcError* error) {
if (!connection->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
return connection->private_driver->ConnectionGetTableTypes(connection, statement,
error);
return connection->private_driver->ConnectionGetTableTypes(connection, stream, error);
}

AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
Expand Down Expand Up @@ -404,6 +406,18 @@ AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection,
return ADBC_STATUS_OK;
}

AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection,
const uint8_t* serialized_partition,
size_t serialized_length,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (!connection->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
return connection->private_driver->ConnectionReadPartition(
connection, serialized_partition, serialized_length, out, error);
}

AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection,
struct AdbcError* error) {
if (!connection->private_driver) {
Expand Down Expand Up @@ -454,32 +468,38 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
return statement->private_driver->StatementBindStream(statement, stream, error);
}

AdbcStatusCode AdbcStatementExecute(struct AdbcStatement* statement,
struct AdbcError* error) {
// XXX: cpplint gets confused here if declared as 'struct ArrowSchema* schema'
AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement,
ArrowSchema* schema,
struct AdbcPartitions* partitions,
int64_t* rows_affected,
struct AdbcError* error) {
if (!statement->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
return statement->private_driver->StatementExecute(statement, error);
return statement->private_driver->StatementExecutePartitions(
statement, schema, partitions, rows_affected, error);
}

AdbcStatusCode AdbcStatementGetPartitionDesc(struct AdbcStatement* statement,
uint8_t* partition_desc,
struct AdbcError* error) {
AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
struct ArrowArrayStream* out,
int64_t* rows_affected,
struct AdbcError* error) {
if (!statement->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
return statement->private_driver->StatementGetPartitionDesc(statement, partition_desc,
error);
return statement->private_driver->StatementExecuteQuery(statement, out, rows_affected,
error);
}

AdbcStatusCode AdbcStatementGetPartitionDescSize(struct AdbcStatement* statement,
size_t* length,
struct AdbcError* error) {
AdbcStatusCode AdbcStatementExecuteUpdate(struct AdbcStatement* statement,
int64_t* rows_affected,
struct AdbcError* error) {
if (!statement->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
return statement->private_driver->StatementGetPartitionDescSize(statement, length,
error);
return statement->private_driver->StatementExecuteUpdate(statement, rows_affected,
error);
}

AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
Expand All @@ -491,15 +511,6 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
return statement->private_driver->StatementGetParameterSchema(statement, schema, error);
}

AdbcStatusCode AdbcStatementGetStream(struct AdbcStatement* statement,
struct ArrowArrayStream* out,
struct AdbcError* error) {
if (!statement->private_driver) {
return ADBC_STATUS_INVALID_STATE;
}
return statement->private_driver->StatementGetStream(statement, out, error);
}

AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection,
struct AdbcStatement* statement,
struct AdbcError* error) {
Expand Down Expand Up @@ -715,16 +726,17 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint,
FILL_DEFAULT(driver, ConnectionGetObjects);
FILL_DEFAULT(driver, ConnectionGetTableSchema);
FILL_DEFAULT(driver, ConnectionGetTableTypes);
FILL_DEFAULT(driver, ConnectionReadPartition);
FILL_DEFAULT(driver, ConnectionRollback);
FILL_DEFAULT(driver, ConnectionSetOption);

FILL_DEFAULT(driver, StatementExecutePartitions);
CHECK_REQUIRED(driver, StatementExecuteQuery);
CHECK_REQUIRED(driver, StatementExecuteUpdate);
CHECK_REQUIRED(driver, StatementNew);
CHECK_REQUIRED(driver, StatementRelease);
FILL_DEFAULT(driver, StatementBind);
FILL_DEFAULT(driver, StatementExecute);
FILL_DEFAULT(driver, StatementGetParameterSchema);
FILL_DEFAULT(driver, StatementGetPartitionDesc);
FILL_DEFAULT(driver, StatementGetPartitionDescSize);
FILL_DEFAULT(driver, StatementPrepare);
FILL_DEFAULT(driver, StatementSetOption);
FILL_DEFAULT(driver, StatementSetSqlQuery);
Expand Down
26 changes: 10 additions & 16 deletions c/driver_manager/adbc_driver_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,13 @@ TEST_F(DriverManager, MetadataGetInfo) {
})),
});

AdbcStatement statement;
std::memset(&statement, 0, sizeof(statement));
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
struct ArrowArrayStream stream;
ADBC_ASSERT_OK_WITH_ERROR(
error, AdbcConnectionGetInfo(&connection, nullptr, 0, &statement, &error));
error, AdbcConnectionGetInfo(&connection, nullptr, 0, &stream, &error));

std::shared_ptr<arrow::Schema> schema;
arrow::RecordBatchVector batches;
ReadStatement(&statement, &schema, &batches);
ReadStream(&stream, &schema, &batches);
ASSERT_SCHEMA_EQ(*schema, *kInfoSchema);
ASSERT_EQ(1, batches.size());

Expand All @@ -141,12 +139,10 @@ TEST_F(DriverManager, MetadataGetInfo) {
ADBC_INFO_VENDOR_NAME,
ADBC_INFO_VENDOR_VERSION,
};
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
ADBC_ASSERT_OK_WITH_ERROR(
error,
AdbcConnectionGetInfo(&connection, info.data(), info.size(), &statement, &error));
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionGetInfo(&connection, info.data(),
info.size(), &stream, &error));
batches.clear();
ReadStatement(&statement, &schema, &batches);
ReadStream(&stream, &schema, &batches);
ASSERT_SCHEMA_EQ(*schema, *kInfoSchema);
ASSERT_EQ(1, batches.size());
ASSERT_EQ(4, batches[0]->num_rows());
Expand All @@ -159,7 +155,6 @@ TEST_F(DriverManager, SqlExecute) {
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
ADBC_ASSERT_OK_WITH_ERROR(error,
AdbcStatementSetSqlQuery(&statement, query.c_str(), &error));
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));

std::shared_ptr<arrow::Schema> schema;
arrow::RecordBatchVector batches;
Expand Down Expand Up @@ -193,8 +188,6 @@ TEST_F(DriverManager, SqlPrepare) {
AdbcStatementSetSqlQuery(&statement, query.c_str(), &error));
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementPrepare(&statement, &error));

ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));

std::shared_ptr<arrow::Schema> schema;
arrow::RecordBatchVector batches;
ASSERT_NO_FATAL_FAILURE(ReadStatement(&statement, &schema, &batches));
Expand Down Expand Up @@ -227,7 +220,6 @@ TEST_F(DriverManager, SqlPrepareMultipleParams) {

ADBC_ASSERT_OK_WITH_ERROR(
error, AdbcStatementBind(&statement, &export_params, &export_schema, &error));
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));

std::shared_ptr<arrow::Schema> schema;
arrow::RecordBatchVector batches;
Expand Down Expand Up @@ -263,7 +255,8 @@ TEST_F(DriverManager, BulkIngestStream) {
"bulk_insert", &error));
ADBC_ASSERT_OK_WITH_ERROR(
error, AdbcStatementBindStream(&statement, &export_stream, &error));
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
ADBC_ASSERT_OK_WITH_ERROR(error,
AdbcStatementExecuteUpdate(&statement, nullptr, &error));
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementRelease(&statement, &error));
}

Expand All @@ -273,7 +266,8 @@ TEST_F(DriverManager, BulkIngestStream) {
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
ADBC_ASSERT_OK_WITH_ERROR(
error, AdbcStatementSetSqlQuery(&statement, "SELECT * FROM bulk_insert", &error));
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
ADBC_ASSERT_OK_WITH_ERROR(error,
AdbcStatementExecuteUpdate(&statement, nullptr, &error));

std::shared_ptr<arrow::Schema> schema;
arrow::RecordBatchVector batches;
Expand Down
Loading

0 comments on commit d8c5821

Please sign in to comment.