Skip to content

Commit

Permalink
feat(go/adbc/drivermgr): Implement Remaining CGO Wrapper Methods that…
Browse files Browse the repository at this point in the history
… are Supported by SQLite Driver (#1304)

# What?
Implementations for the following methods in the CGO wrapper for
`adbc_driver_manager`:
- `GetTableSchema`
- `GetTableTypes`
- `Commit`
- `Rollback`
- `GetParameterSchema`
- `BindStream`

# Why?
Functionality exists in C++ driver manager but not yet accessible via Go
driver interface.

# Notes
Three methods in the wrapper remain unimplemented: `ExecutePartitions`,
`ReadPartition`, and `SetSubstraitPlan`. These methods are not currently
supported by the SQLite driver, which is the primary test target for
these changes. It is still possible to implement them in the drivermgr
wrapper without support in specific drivers, but it does make it more
difficult to verify correct behavior. The effort to add those methods
will likely involve some additional work to ensure we are able to test
their behaviors, so they are being left out of this current round of
implementations.

Closes part of: #1291
  • Loading branch information
joellubi committed Nov 21, 2023
1 parent cac2d83 commit cc0c4e9
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 11 deletions.
92 changes: 86 additions & 6 deletions go/adbc/drivermgr/wrapper.go
Expand Up @@ -32,6 +32,10 @@ package drivermgr
// return (struct ArrowArray*)malloc(sizeof(struct ArrowArray));
// }
//
// struct ArrowArrayStream* allocArrStream() {
// return (struct ArrowArrayStream*)malloc(sizeof(struct ArrowArrayStream));
// }
//
import "C"
import (
"context"
Expand Down Expand Up @@ -186,6 +190,15 @@ func getRdr(out *C.struct_ArrowArrayStream) (array.RecordReader, error) {
return rdr.(array.RecordReader), nil
}

func getSchema(out *C.struct_ArrowSchema) (*arrow.Schema, error) {
// Maybe: ImportCArrowSchema should perform this check?
if out.format == nil {
return nil, nil
}

return cdata.ImportCArrowSchema((*cdata.CArrowSchema)(unsafe.Pointer(out)))
}

type cnxn struct {
conn *C.struct_AdbcConnection
}
Expand Down Expand Up @@ -255,19 +268,68 @@ func (c *cnxn) GetObjects(_ context.Context, depth adbc.ObjectDepth, catalog, db
}

func (c *cnxn) GetTableSchema(_ context.Context, catalog, dbSchema *string, tableName string) (*arrow.Schema, error) {
return nil, &adbc.Error{Code: adbc.StatusNotImplemented}
var (
schema C.struct_ArrowSchema
err C.struct_AdbcError
catalog_ *C.char
dbSchema_ *C.char
tableName_ *C.char
)

if catalog != nil {
catalog_ = C.CString(*catalog)
defer C.free(unsafe.Pointer(catalog_))
}

if dbSchema != nil {
dbSchema_ = C.CString(*dbSchema)
defer C.free(unsafe.Pointer(dbSchema_))
}

tableName_ = C.CString(tableName)
defer C.free(unsafe.Pointer(tableName_))

if code := adbc.Status(C.AdbcConnectionGetTableSchema(c.conn, catalog_, dbSchema_, tableName_, &schema, &err)); code != adbc.StatusOK {
return nil, toAdbcError(code, &err)
}

return getSchema(&schema)
}

func (c *cnxn) GetTableTypes(context.Context) (array.RecordReader, error) {
return nil, &adbc.Error{Code: adbc.StatusNotImplemented}
var (
out C.struct_ArrowArrayStream
err C.struct_AdbcError
)

if code := adbc.Status(C.AdbcConnectionGetTableTypes(c.conn, &out, &err)); code != adbc.StatusOK {
return nil, toAdbcError(code, &err)
}
return getRdr(&out)
}

func (c *cnxn) Commit(context.Context) error {
return &adbc.Error{Code: adbc.StatusNotImplemented}
var (
err C.struct_AdbcError
)

if code := adbc.Status(C.AdbcConnectionCommit(c.conn, &err)); code != adbc.StatusOK {
return toAdbcError(code, &err)
}

return nil
}

func (c *cnxn) Rollback(context.Context) error {
return &adbc.Error{Code: adbc.StatusNotImplemented}
var (
err C.struct_AdbcError
)

if code := adbc.Status(C.AdbcConnectionRollback(c.conn, &err)); code != adbc.StatusOK {
return toAdbcError(code, &err)
}

return nil
}

func (c *cnxn) NewStatement() (adbc.Statement, error) {
Expand Down Expand Up @@ -405,11 +467,29 @@ func (s *stmt) Bind(_ context.Context, values arrow.Record) error {
}

func (s *stmt) BindStream(_ context.Context, stream array.RecordReader) error {
return &adbc.Error{Code: adbc.StatusNotImplemented}
var (
arrStream = C.allocArrStream()
cdArrStream = (*cdata.CArrowArrayStream)(unsafe.Pointer(arrStream))
err C.struct_AdbcError
)
cdata.ExportRecordReader(stream, cdArrStream)
if code := adbc.Status(C.AdbcStatementBindStream(s.st, arrStream, &err)); code != adbc.StatusOK {
return toAdbcError(code, &err)
}
return nil
}

func (s *stmt) GetParameterSchema() (*arrow.Schema, error) {
return nil, &adbc.Error{Code: adbc.StatusNotImplemented}
var (
schema C.struct_ArrowSchema
err C.struct_AdbcError
)

if code := adbc.Status(C.AdbcStatementGetParameterSchema(s.st, &schema, &err)); code != adbc.StatusOK {
return nil, toAdbcError(code, &err)
}

return getSchema(&schema)
}

func (s *stmt) ExecutePartitions(context.Context) (*arrow.Schema, adbc.Partitions, int64, error) {
Expand Down
163 changes: 158 additions & 5 deletions go/adbc/drivermgr/wrapper_sqlite_test.go
Expand Up @@ -53,20 +53,25 @@ func (dm *DriverMgrSuite) SetupSuite() {
})
dm.NoError(err)

db, err := dm.db.Open(dm.ctx)
cnxn, err := dm.db.Open(dm.ctx)
dm.NoError(err)
defer db.Close()
defer cnxn.Close()

stmt, err := db.NewStatement()
stmt, err := cnxn.NewStatement()
dm.NoError(err)
defer stmt.Close()

err = stmt.SetSqlQuery("CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)")
dm.NoError(err)
dm.NoError(stmt.SetSqlQuery("CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)"))

nrows, err := stmt.ExecuteUpdate(dm.ctx)
dm.NoError(err)
dm.Equal(int64(0), nrows)

dm.NoError(stmt.SetSqlQuery("INSERT INTO test_table (id, name) VALUES (1, 'test')"))

nrows, err = stmt.ExecuteUpdate(dm.ctx)
dm.NoError(err)
dm.Equal(int64(1), nrows)
}

func (dm *DriverMgrSuite) SetupTest() {
Expand Down Expand Up @@ -334,6 +339,83 @@ func (dm *DriverMgrSuite) TestGetObjectsTableType() {
dm.False(rdr.Next())
}

func (dm *DriverMgrSuite) TestGetTableSchema() {
schema, err := dm.conn.GetTableSchema(dm.ctx, nil, nil, "test_table")
dm.NoError(err)

expSchema := arrow.NewSchema(
[]arrow.Field{
{Name: "id", Type: arrow.PrimitiveTypes.Int64, Nullable: true},
{Name: "name", Type: arrow.BinaryTypes.String, Nullable: true},
}, nil)
dm.True(expSchema.Equal(schema))
}

func (dm *DriverMgrSuite) TestGetTableSchemaInvalidTable() {
_, err := dm.conn.GetTableSchema(dm.ctx, nil, nil, "unknown_table")
dm.Error(err)
}

func (dm *DriverMgrSuite) TestGetTableSchemaCatalog() {
catalog := "does_not_exist"
schema, err := dm.conn.GetTableSchema(dm.ctx, &catalog, nil, "test_table")
dm.NoError(err)
dm.Nil(schema)
}

func (dm *DriverMgrSuite) TestGetTableSchemaDBSchema() {
dbSchema := "does_not_exist"
schema, err := dm.conn.GetTableSchema(dm.ctx, nil, &dbSchema, "test_table")
dm.NoError(err)
dm.Nil(schema)
}

func (dm *DriverMgrSuite) TestGetTableTypes() {
rdr, err := dm.conn.GetTableTypes(dm.ctx)
dm.NoError(err)
defer rdr.Release()

expSchema := adbc.TableTypesSchema
dm.True(expSchema.Equal(rdr.Schema()))
dm.True(rdr.Next())

rec := rdr.Record()
dm.Equal(int64(2), rec.NumRows())

expTableTypes := []string{"table", "view"}
dm.Contains(expTableTypes, rec.Column(0).ValueStr(0))
dm.Contains(expTableTypes, rec.Column(0).ValueStr(1))
dm.False(rdr.Next())
}

func (dm *DriverMgrSuite) TestCommit() {
err := dm.conn.Commit(dm.ctx)
dm.Error(err)
dm.ErrorContains(err, "No active transaction, cannot commit")
}

func (dm *DriverMgrSuite) TestCommitAutocommitDisabled() {
cnxnopt, ok := dm.conn.(adbc.PostInitOptions)
dm.True(ok)

dm.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueDisabled))
dm.NoError(dm.conn.Commit(dm.ctx))
}

func (dm *DriverMgrSuite) TestRollback() {
err := dm.conn.Rollback(dm.ctx)
dm.Error(err)
dm.ErrorContains(err, "No active transaction, cannot rollback")
}

func (dm *DriverMgrSuite) TestRollbackAutocommitDisabled() {
cnxnopt, ok := dm.conn.(adbc.PostInitOptions)
dm.True(ok)

dm.NoError(cnxnopt.SetOption(adbc.OptionKeyAutoCommit, adbc.OptionValueDisabled))
dm.NoError(dm.conn.Rollback(dm.ctx))
}

func (dm *DriverMgrSuite) TestSqlExecute() {
query := "SELECT 1"
st, err := dm.conn.NewStatement()
Expand Down Expand Up @@ -429,6 +511,77 @@ func (dm *DriverMgrSuite) TestSqlPrepareMultipleParams() {
dm.False(rdr.Next())
}

func (dm *DriverMgrSuite) TestGetParameterSchema() {
query := "SELECT ?1, ?2"
st, err := dm.conn.NewStatement()
dm.Require().NoError(err)
dm.Require().NoError(st.SetSqlQuery(query))
defer st.Close()

expSchema := arrow.NewSchema([]arrow.Field{
{Name: "?1", Type: arrow.Null, Nullable: true},
{Name: "?2", Type: arrow.Null, Nullable: true},
}, nil)

schema, err := st.GetParameterSchema()
dm.NoError(err)

dm.True(expSchema.Equal(schema))
}

func (dm *DriverMgrSuite) TestBindStream() {
query := "SELECT ?1, ?2"
st, err := dm.conn.NewStatement()
dm.Require().NoError(err)
dm.Require().NoError(st.SetSqlQuery(query))
defer st.Close()

schema := arrow.NewSchema([]arrow.Field{
{Name: "1", Type: arrow.PrimitiveTypes.Int64, Nullable: true},
{Name: "2", Type: arrow.BinaryTypes.String, Nullable: true},
}, nil)

bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
defer bldr.Release()

bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{1, 2, 3}, nil)
bldr.Field(1).(*array.StringBuilder).AppendValues([]string{"one", "two", "three"}, nil)

rec1 := bldr.NewRecord()
defer rec1.Release()

bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{4, 5, 6}, nil)
bldr.Field(1).(*array.StringBuilder).AppendValues([]string{"four", "five", "six"}, nil)

rec2 := bldr.NewRecord()
defer rec2.Release()

recsIn := []arrow.Record{rec1, rec2}
rdrIn, err := array.NewRecordReader(schema, recsIn)
dm.NoError(err)

dm.NoError(st.BindStream(dm.ctx, rdrIn))

rdrOut, _, err := st.ExecuteQuery(dm.ctx)
dm.NoError(err)
defer rdrOut.Release()

recsOut := make([]arrow.Record, 0)
for rdrOut.Next() {
rec := rdrOut.Record()
rec.Retain()
defer rec.Release()
recsOut = append(recsOut, rec)
}

tableIn := array.NewTableFromRecords(schema, recsIn)
defer tableIn.Release()
tableOut := array.NewTableFromRecords(schema, recsOut)
defer tableOut.Release()

dm.Truef(array.TableEqual(tableIn, tableOut), "expected: %s\ngot: %s", tableIn, tableOut)
}

func TestDriverMgr(t *testing.T) {
suite.Run(t, new(DriverMgrSuite))
}
Expand Down

0 comments on commit cc0c4e9

Please sign in to comment.