Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-17638: [Go] Extend C Data API support for Union arrays and RecordReader interface #14057

Merged
merged 3 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 167 additions & 8 deletions go/arrow/cdata/cdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,39 @@ func importSchema(schema *CArrowSchema) (ret arrow.Field, err error) {
st := childFields[0].Type.(*arrow.StructType)
dt = arrow.MapOf(st.Field(0).Type, st.Field(1).Type)
dt.(*arrow.MapType).KeysSorted = (schema.flags & C.ARROW_FLAG_MAP_KEYS_SORTED) != 0
case 'u': // union
var mode arrow.UnionMode
switch f[2] {
case 'd':
mode = arrow.DenseMode
case 's':
mode = arrow.SparseMode
default:
err = fmt.Errorf("%w: invalid union type", arrow.ErrInvalid)
return
}

codes := strings.Split(strings.Split(f, ":")[1], ",")
typeCodes := make([]arrow.UnionTypeCode, 0, len(codes))
for _, i := range codes {
v, e := strconv.ParseInt(i, 10, 8)
if e != nil {
err = fmt.Errorf("%w: invalid type code: %s", arrow.ErrInvalid, e)
return
}
if v < 0 {
err = fmt.Errorf("%w: negative type code in union: format string %s", arrow.ErrInvalid, f)
return
}
typeCodes = append(typeCodes, arrow.UnionTypeCode(v))
}

if len(childFields) != len(typeCodes) {
err = fmt.Errorf("%w: ArrowArray struct number of children incompatible with format string", arrow.ErrInvalid)
return
}

dt = arrow.UnionOf(mode, childFields, typeCodes)
}
}

Expand Down Expand Up @@ -311,6 +344,18 @@ func (imp *cimporter) doImportChildren() error {
if err := imp.children[0].importChild(imp, children[0]); err != nil {
return err
}
case arrow.DENSE_UNION:
dt := imp.dt.(*arrow.DenseUnionType)
for i, c := range children {
imp.children[i].dt = dt.Fields()[i].Type
imp.children[i].importChild(imp, c)
}
case arrow.SPARSE_UNION:
dt := imp.dt.(*arrow.SparseUnionType)
for i, c := range children {
imp.children[i].dt = dt.Fields()[i].Type
imp.children[i].importChild(imp, c)
}
}

return nil
Expand Down Expand Up @@ -407,6 +452,52 @@ func (imp *cimporter) doImport(src *CArrowArray) error {
}

imp.data = array.NewData(dt, int(imp.arr.length), []*memory.Buffer{nulls}, children, int(imp.arr.null_count), int(imp.arr.offset))
case *arrow.DenseUnionType:
if err := imp.checkNoNulls(); err != nil {
return err
}

bufs := []*memory.Buffer{nil, nil, nil}
if imp.arr.n_buffers == 3 {
// legacy format exported by older arrow c++ versions
bufs[1] = imp.importFixedSizeBuffer(1, 1)
bufs[2] = imp.importFixedSizeBuffer(2, int64(arrow.Int32SizeBytes))
} else {
if err := imp.checkNumBuffers(2); err != nil {
return err
}

bufs[1] = imp.importFixedSizeBuffer(0, 1)
bufs[2] = imp.importFixedSizeBuffer(1, int64(arrow.Int32SizeBytes))
}

children := make([]arrow.ArrayData, len(imp.children))
for i := range imp.children {
children[i] = imp.children[i].data
}
imp.data = array.NewData(dt, int(imp.arr.length), bufs, children, 0, int(imp.arr.offset))
case *arrow.SparseUnionType:
if err := imp.checkNoNulls(); err != nil {
return err
}

var buf *memory.Buffer
if imp.arr.n_buffers == 2 {
// legacy format exported by older Arrow C++ versions
buf = imp.importFixedSizeBuffer(1, 1)
} else {
if err := imp.checkNumBuffers(1); err != nil {
return err
}

buf = imp.importFixedSizeBuffer(0, 1)
}

children := make([]arrow.ArrayData, len(imp.children))
for i := range imp.children {
children[i] = imp.children[i].data
}
imp.data = array.NewData(dt, int(imp.arr.length), []*memory.Buffer{nil, buf}, children, 0, int(imp.arr.offset))
default:
return fmt.Errorf("unimplemented type %s", dt)
}
Expand Down Expand Up @@ -494,6 +585,13 @@ func (imp *cimporter) importFixedSizePrimitive() error {

func (imp *cimporter) checkNoChildren() error { return imp.checkNumChildren(0) }

func (imp *cimporter) checkNoNulls() error {
if imp.arr.null_count != 0 {
return fmt.Errorf("%w: unexpected non-zero null count for imported type %s", arrow.ErrInvalid, imp.dt)
}
return nil
}

func (imp *cimporter) checkNumChildren(n int64) error {
if int64(imp.arr.n_children) != n {
return fmt.Errorf("expected %d children, for imported type %s, ArrowArray has %d", n, imp.dt, imp.arr.n_children)
Expand Down Expand Up @@ -558,6 +656,9 @@ func initReader(rdr *nativeCRecordBatchReader, stream *CArrowArrayStream) {
rdr.stream = C.get_stream()
C.ArrowArrayStreamMove(stream, rdr.stream)
runtime.SetFinalizer(rdr, func(r *nativeCRecordBatchReader) {
if r.cur != nil {
r.cur.Release()
}
C.ArrowArrayStreamRelease(r.stream)
C.free(unsafe.Pointer(r.stream))
})
Expand All @@ -567,40 +668,98 @@ func initReader(rdr *nativeCRecordBatchReader, stream *CArrowArrayStream) {
type nativeCRecordBatchReader struct {
stream *CArrowArrayStream
schema *arrow.Schema

cur arrow.Record
err error
}

func (n *nativeCRecordBatchReader) getError(errno int) error {
return fmt.Errorf("%w: %s", syscall.Errno(errno), C.GoString(C.stream_get_last_error(n.stream)))
// No need to implement retain and release here as we used runtime.SetFinalizer when constructing
// the reader to free up the ArrowArrayStream memory when the garbage collector cleans it up.
func (n *nativeCRecordBatchReader) Retain() {}
func (n *nativeCRecordBatchReader) Release() {}

func (n *nativeCRecordBatchReader) Record() arrow.Record { return n.cur }

func (n *nativeCRecordBatchReader) Next() bool {
err := n.next()
switch {
case err == nil:
return true
case err == io.EOF:
return false
}
n.err = err
return false
}

func (n *nativeCRecordBatchReader) Read() (arrow.Record, error) {
func (n *nativeCRecordBatchReader) next() error {
if n.schema == nil {
var sc CArrowSchema
errno := C.stream_get_schema(n.stream, &sc)
if errno != 0 {
return nil, n.getError(int(errno))
return n.getError(int(errno))
}
defer C.ArrowSchemaRelease(&sc)
s, err := ImportCArrowSchema((*CArrowSchema)(&sc))
if err != nil {
return nil, err
return err
}

n.schema = s
}

if n.cur != nil {
n.cur.Release()
n.cur = nil
}

arr := C.get_arr()
defer C.free(unsafe.Pointer(arr))
errno := C.stream_get_next(n.stream, arr)
if errno != 0 {
return nil, n.getError(int(errno))
return n.getError(int(errno))
}

if C.ArrowArrayIsReleased(arr) == 1 {
return nil, io.EOF
return io.EOF
}

rec, err := ImportCRecordBatchWithSchema(arr, n.schema)
if err != nil {
return err
}

return ImportCRecordBatchWithSchema(arr, n.schema)
n.cur = rec
return nil
}

func (n *nativeCRecordBatchReader) Schema() *arrow.Schema {
if n.schema == nil {
var sc CArrowSchema
errno := C.stream_get_schema(n.stream, &sc)
if errno != 0 {
panic(n.getError(int(errno)))
}
defer C.ArrowSchemaRelease(&sc)
s, err := ImportCArrowSchema((*CArrowSchema)(&sc))
if err != nil {
panic(err)
}

n.schema = s
}
return n.schema
}

func (n *nativeCRecordBatchReader) getError(errno int) error {
return fmt.Errorf("%w: %s", syscall.Errno(errno), C.GoString(C.stream_get_last_error(n.stream)))
}

func (n *nativeCRecordBatchReader) Read() (arrow.Record, error) {
if err := n.next(); err != nil {
return nil, err
}
return n.cur, nil
}

func releaseArr(arr *CArrowArray) {
Expand Down
1 change: 0 additions & 1 deletion go/arrow/cdata/cdata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,6 @@ func TestRecordReaderStream(t *testing.T) {
}
assert.NoError(t, err)
}
defer rec.Release()

assert.EqualValues(t, 2, rec.NumCols())
assert.Equal(t, "a", rec.ColumnName(0))
Expand Down