Skip to content

Commit

Permalink
ARROW-17638: [Go] Extend C Data API support for Union arrays and Reco…
Browse files Browse the repository at this point in the history
…rdReader interface (#14057)

Lead-authored-by: Matt Topol <zotthewizard@gmail.com>
Co-authored-by: Matthew Topol <zotthewizard@gmail.com>
Signed-off-by: Matt Topol <zotthewizard@gmail.com>
  • Loading branch information
zeroshade committed Sep 7, 2022
1 parent cbf0ec0 commit ff3aa3b
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 9 deletions.
175 changes: 167 additions & 8 deletions go/arrow/cdata/cdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,39 @@ func importSchema(schema *CArrowSchema) (ret arrow.Field, err error) {
st := childFields[0].Type.(*arrow.StructType)
dt = arrow.MapOf(st.Field(0).Type, st.Field(1).Type)
dt.(*arrow.MapType).KeysSorted = (schema.flags & C.ARROW_FLAG_MAP_KEYS_SORTED) != 0
case 'u': // union
var mode arrow.UnionMode
switch f[2] {
case 'd':
mode = arrow.DenseMode
case 's':
mode = arrow.SparseMode
default:
err = fmt.Errorf("%w: invalid union type", arrow.ErrInvalid)
return
}

codes := strings.Split(strings.Split(f, ":")[1], ",")
typeCodes := make([]arrow.UnionTypeCode, 0, len(codes))
for _, i := range codes {
v, e := strconv.ParseInt(i, 10, 8)
if e != nil {
err = fmt.Errorf("%w: invalid type code: %s", arrow.ErrInvalid, e)
return
}
if v < 0 {
err = fmt.Errorf("%w: negative type code in union: format string %s", arrow.ErrInvalid, f)
return
}
typeCodes = append(typeCodes, arrow.UnionTypeCode(v))
}

if len(childFields) != len(typeCodes) {
err = fmt.Errorf("%w: ArrowArray struct number of children incompatible with format string", arrow.ErrInvalid)
return
}

dt = arrow.UnionOf(mode, childFields, typeCodes)
}
}

Expand Down Expand Up @@ -311,6 +344,18 @@ func (imp *cimporter) doImportChildren() error {
if err := imp.children[0].importChild(imp, children[0]); err != nil {
return err
}
case arrow.DENSE_UNION:
dt := imp.dt.(*arrow.DenseUnionType)
for i, c := range children {
imp.children[i].dt = dt.Fields()[i].Type
imp.children[i].importChild(imp, c)
}
case arrow.SPARSE_UNION:
dt := imp.dt.(*arrow.SparseUnionType)
for i, c := range children {
imp.children[i].dt = dt.Fields()[i].Type
imp.children[i].importChild(imp, c)
}
}

return nil
Expand Down Expand Up @@ -407,6 +452,52 @@ func (imp *cimporter) doImport(src *CArrowArray) error {
}

imp.data = array.NewData(dt, int(imp.arr.length), []*memory.Buffer{nulls}, children, int(imp.arr.null_count), int(imp.arr.offset))
case *arrow.DenseUnionType:
if err := imp.checkNoNulls(); err != nil {
return err
}

bufs := []*memory.Buffer{nil, nil, nil}
if imp.arr.n_buffers == 3 {
// legacy format exported by older arrow c++ versions
bufs[1] = imp.importFixedSizeBuffer(1, 1)
bufs[2] = imp.importFixedSizeBuffer(2, int64(arrow.Int32SizeBytes))
} else {
if err := imp.checkNumBuffers(2); err != nil {
return err
}

bufs[1] = imp.importFixedSizeBuffer(0, 1)
bufs[2] = imp.importFixedSizeBuffer(1, int64(arrow.Int32SizeBytes))
}

children := make([]arrow.ArrayData, len(imp.children))
for i := range imp.children {
children[i] = imp.children[i].data
}
imp.data = array.NewData(dt, int(imp.arr.length), bufs, children, 0, int(imp.arr.offset))
case *arrow.SparseUnionType:
if err := imp.checkNoNulls(); err != nil {
return err
}

var buf *memory.Buffer
if imp.arr.n_buffers == 2 {
// legacy format exported by older Arrow C++ versions
buf = imp.importFixedSizeBuffer(1, 1)
} else {
if err := imp.checkNumBuffers(1); err != nil {
return err
}

buf = imp.importFixedSizeBuffer(0, 1)
}

children := make([]arrow.ArrayData, len(imp.children))
for i := range imp.children {
children[i] = imp.children[i].data
}
imp.data = array.NewData(dt, int(imp.arr.length), []*memory.Buffer{nil, buf}, children, 0, int(imp.arr.offset))
default:
return fmt.Errorf("unimplemented type %s", dt)
}
Expand Down Expand Up @@ -494,6 +585,13 @@ func (imp *cimporter) importFixedSizePrimitive() error {

func (imp *cimporter) checkNoChildren() error { return imp.checkNumChildren(0) }

func (imp *cimporter) checkNoNulls() error {
if imp.arr.null_count != 0 {
return fmt.Errorf("%w: unexpected non-zero null count for imported type %s", arrow.ErrInvalid, imp.dt)
}
return nil
}

func (imp *cimporter) checkNumChildren(n int64) error {
if int64(imp.arr.n_children) != n {
return fmt.Errorf("expected %d children, for imported type %s, ArrowArray has %d", n, imp.dt, imp.arr.n_children)
Expand Down Expand Up @@ -558,6 +656,9 @@ func initReader(rdr *nativeCRecordBatchReader, stream *CArrowArrayStream) {
rdr.stream = C.get_stream()
C.ArrowArrayStreamMove(stream, rdr.stream)
runtime.SetFinalizer(rdr, func(r *nativeCRecordBatchReader) {
if r.cur != nil {
r.cur.Release()
}
C.ArrowArrayStreamRelease(r.stream)
C.free(unsafe.Pointer(r.stream))
})
Expand All @@ -567,40 +668,98 @@ func initReader(rdr *nativeCRecordBatchReader, stream *CArrowArrayStream) {
type nativeCRecordBatchReader struct {
stream *CArrowArrayStream
schema *arrow.Schema

cur arrow.Record
err error
}

func (n *nativeCRecordBatchReader) getError(errno int) error {
return fmt.Errorf("%w: %s", syscall.Errno(errno), C.GoString(C.stream_get_last_error(n.stream)))
// No need to implement retain and release here as we used runtime.SetFinalizer when constructing
// the reader to free up the ArrowArrayStream memory when the garbage collector cleans it up.
func (n *nativeCRecordBatchReader) Retain() {}
func (n *nativeCRecordBatchReader) Release() {}

func (n *nativeCRecordBatchReader) Record() arrow.Record { return n.cur }

func (n *nativeCRecordBatchReader) Next() bool {
err := n.next()
switch {
case err == nil:
return true
case err == io.EOF:
return false
}
n.err = err
return false
}

func (n *nativeCRecordBatchReader) Read() (arrow.Record, error) {
func (n *nativeCRecordBatchReader) next() error {
if n.schema == nil {
var sc CArrowSchema
errno := C.stream_get_schema(n.stream, &sc)
if errno != 0 {
return nil, n.getError(int(errno))
return n.getError(int(errno))
}
defer C.ArrowSchemaRelease(&sc)
s, err := ImportCArrowSchema((*CArrowSchema)(&sc))
if err != nil {
return nil, err
return err
}

n.schema = s
}

if n.cur != nil {
n.cur.Release()
n.cur = nil
}

arr := C.get_arr()
defer C.free(unsafe.Pointer(arr))
errno := C.stream_get_next(n.stream, arr)
if errno != 0 {
return nil, n.getError(int(errno))
return n.getError(int(errno))
}

if C.ArrowArrayIsReleased(arr) == 1 {
return nil, io.EOF
return io.EOF
}

rec, err := ImportCRecordBatchWithSchema(arr, n.schema)
if err != nil {
return err
}

return ImportCRecordBatchWithSchema(arr, n.schema)
n.cur = rec
return nil
}

func (n *nativeCRecordBatchReader) Schema() *arrow.Schema {
if n.schema == nil {
var sc CArrowSchema
errno := C.stream_get_schema(n.stream, &sc)
if errno != 0 {
panic(n.getError(int(errno)))
}
defer C.ArrowSchemaRelease(&sc)
s, err := ImportCArrowSchema((*CArrowSchema)(&sc))
if err != nil {
panic(err)
}

n.schema = s
}
return n.schema
}

func (n *nativeCRecordBatchReader) getError(errno int) error {
return fmt.Errorf("%w: %s", syscall.Errno(errno), C.GoString(C.stream_get_last_error(n.stream)))
}

func (n *nativeCRecordBatchReader) Read() (arrow.Record, error) {
if err := n.next(); err != nil {
return nil, err
}
return n.cur, nil
}

func releaseArr(arr *CArrowArray) {
Expand Down
1 change: 0 additions & 1 deletion go/arrow/cdata/cdata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,6 @@ func TestRecordReaderStream(t *testing.T) {
}
assert.NoError(t, err)
}
defer rec.Release()

assert.EqualValues(t, 2, rec.NumCols())
assert.Equal(t, "a", rec.ColumnName(0))
Expand Down

0 comments on commit ff3aa3b

Please sign in to comment.