diff --git a/go/arrow/cdata/cdata.go b/go/arrow/cdata/cdata.go index 9e1f0b2076dbc..a2b583f268ef2 100644 --- a/go/arrow/cdata/cdata.go +++ b/go/arrow/cdata/cdata.go @@ -243,6 +243,39 @@ func importSchema(schema *CArrowSchema) (ret arrow.Field, err error) { st := childFields[0].Type.(*arrow.StructType) dt = arrow.MapOf(st.Field(0).Type, st.Field(1).Type) dt.(*arrow.MapType).KeysSorted = (schema.flags & C.ARROW_FLAG_MAP_KEYS_SORTED) != 0 + case 'u': // union + var mode arrow.UnionMode + switch f[2] { + case 'd': + mode = arrow.DenseMode + case 's': + mode = arrow.SparseMode + default: + err = fmt.Errorf("%w: invalid union type", arrow.ErrInvalid) + return + } + + codes := strings.Split(strings.Split(f, ":")[1], ",") + typeCodes := make([]arrow.UnionTypeCode, 0, len(codes)) + for _, i := range codes { + v, e := strconv.ParseInt(i, 10, 8) + if e != nil { + err = fmt.Errorf("%w: invalid type code: %s", arrow.ErrInvalid, e) + return + } + if v < 0 { + err = fmt.Errorf("%w: negative type code in union: format string %s", arrow.ErrInvalid, f) + return + } + typeCodes = append(typeCodes, arrow.UnionTypeCode(v)) + } + + if len(childFields) != len(typeCodes) { + err = fmt.Errorf("%w: ArrowArray struct number of children incompatible with format string", arrow.ErrInvalid) + return + } + + dt = arrow.UnionOf(mode, childFields, typeCodes) } } @@ -311,6 +344,18 @@ func (imp *cimporter) doImportChildren() error { if err := imp.children[0].importChild(imp, children[0]); err != nil { return err } + case arrow.DENSE_UNION: + dt := imp.dt.(*arrow.DenseUnionType) + for i, c := range children { + imp.children[i].dt = dt.Fields()[i].Type + imp.children[i].importChild(imp, c) + } + case arrow.SPARSE_UNION: + dt := imp.dt.(*arrow.SparseUnionType) + for i, c := range children { + imp.children[i].dt = dt.Fields()[i].Type + imp.children[i].importChild(imp, c) + } } return nil @@ -407,6 +452,52 @@ func (imp *cimporter) doImport(src *CArrowArray) error { } imp.data = array.NewData(dt, int(imp.arr.length), []*memory.Buffer{nulls}, children, int(imp.arr.null_count), int(imp.arr.offset)) + case *arrow.DenseUnionType: + if err := imp.checkNoNulls(); err != nil { + return err + } + + bufs := []*memory.Buffer{nil, nil, nil} + if imp.arr.n_buffers == 3 { + // legacy format exported by older arrow c++ versions + bufs[1] = imp.importFixedSizeBuffer(1, 1) + bufs[2] = imp.importFixedSizeBuffer(2, int64(arrow.Int32SizeBytes)) + } else { + if err := imp.checkNumBuffers(2); err != nil { + return err + } + + bufs[1] = imp.importFixedSizeBuffer(0, 1) + bufs[2] = imp.importFixedSizeBuffer(1, int64(arrow.Int32SizeBytes)) + } + + children := make([]arrow.ArrayData, len(imp.children)) + for i := range imp.children { + children[i] = imp.children[i].data + } + imp.data = array.NewData(dt, int(imp.arr.length), bufs, children, 0, int(imp.arr.offset)) + case *arrow.SparseUnionType: + if err := imp.checkNoNulls(); err != nil { + return err + } + + var buf *memory.Buffer + if imp.arr.n_buffers == 2 { + // legacy format exported by older Arrow C++ versions + buf = imp.importFixedSizeBuffer(1, 1) + } else { + if err := imp.checkNumBuffers(1); err != nil { + return err + } + + buf = imp.importFixedSizeBuffer(0, 1) + } + + children := make([]arrow.ArrayData, len(imp.children)) + for i := range imp.children { + children[i] = imp.children[i].data + } + imp.data = array.NewData(dt, int(imp.arr.length), []*memory.Buffer{nil, buf}, children, 0, int(imp.arr.offset)) default: return fmt.Errorf("unimplemented type %s", dt) } @@ -494,6 +585,13 @@ func (imp *cimporter) importFixedSizePrimitive() error { func (imp *cimporter) checkNoChildren() error { return imp.checkNumChildren(0) } +func (imp *cimporter) checkNoNulls() error { + if imp.arr.null_count != 0 { + return fmt.Errorf("%w: unexpected non-zero null count for imported type %s", arrow.ErrInvalid, imp.dt) + } + return nil +} + func (imp *cimporter) checkNumChildren(n int64) error { if int64(imp.arr.n_children) != n { return fmt.Errorf("expected %d children, for imported type %s, ArrowArray has %d", n, imp.dt, imp.arr.n_children) @@ -558,6 +656,9 @@ func initReader(rdr *nativeCRecordBatchReader, stream *CArrowArrayStream) { rdr.stream = C.get_stream() C.ArrowArrayStreamMove(stream, rdr.stream) runtime.SetFinalizer(rdr, func(r *nativeCRecordBatchReader) { + if r.cur != nil { + r.cur.Release() + } C.ArrowArrayStreamRelease(r.stream) C.free(unsafe.Pointer(r.stream)) }) @@ -567,40 +668,98 @@ func initReader(rdr *nativeCRecordBatchReader, stream *CArrowArrayStream) { type nativeCRecordBatchReader struct { stream *CArrowArrayStream schema *arrow.Schema + + cur arrow.Record + err error } -func (n *nativeCRecordBatchReader) getError(errno int) error { - return fmt.Errorf("%w: %s", syscall.Errno(errno), C.GoString(C.stream_get_last_error(n.stream))) +// No need to implement retain and release here as we used runtime.SetFinalizer when constructing +// the reader to free up the ArrowArrayStream memory when the garbage collector cleans it up. +func (n *nativeCRecordBatchReader) Retain() {} +func (n *nativeCRecordBatchReader) Release() {} + +func (n *nativeCRecordBatchReader) Record() arrow.Record { return n.cur } + +func (n *nativeCRecordBatchReader) Next() bool { + err := n.next() + switch { + case err == nil: + return true + case err == io.EOF: + return false + } + n.err = err + return false } -func (n *nativeCRecordBatchReader) Read() (arrow.Record, error) { +func (n *nativeCRecordBatchReader) next() error { if n.schema == nil { var sc CArrowSchema errno := C.stream_get_schema(n.stream, &sc) if errno != 0 { - return nil, n.getError(int(errno)) + return n.getError(int(errno)) } defer C.ArrowSchemaRelease(&sc) s, err := ImportCArrowSchema((*CArrowSchema)(&sc)) if err != nil { - return nil, err + return err } n.schema = s } + if n.cur != nil { + n.cur.Release() + n.cur = nil + } + arr := C.get_arr() defer C.free(unsafe.Pointer(arr)) errno := C.stream_get_next(n.stream, arr) if errno != 0 { - return nil, n.getError(int(errno)) + return n.getError(int(errno)) } if C.ArrowArrayIsReleased(arr) == 1 { - return nil, io.EOF + return io.EOF + } + + rec, err := ImportCRecordBatchWithSchema(arr, n.schema) + if err != nil { + return err } - return ImportCRecordBatchWithSchema(arr, n.schema) + n.cur = rec + return nil +} + +func (n *nativeCRecordBatchReader) Schema() *arrow.Schema { + if n.schema == nil { + var sc CArrowSchema + errno := C.stream_get_schema(n.stream, &sc) + if errno != 0 { + panic(n.getError(int(errno))) + } + defer C.ArrowSchemaRelease(&sc) + s, err := ImportCArrowSchema((*CArrowSchema)(&sc)) + if err != nil { + panic(err) + } + + n.schema = s + } + return n.schema +} + +func (n *nativeCRecordBatchReader) getError(errno int) error { + return fmt.Errorf("%w: %s", syscall.Errno(errno), C.GoString(C.stream_get_last_error(n.stream))) +} + +func (n *nativeCRecordBatchReader) Read() (arrow.Record, error) { + if err := n.next(); err != nil { + return nil, err + } + return n.cur, nil } func releaseArr(arr *CArrowArray) { diff --git a/go/arrow/cdata/cdata_test.go b/go/arrow/cdata/cdata_test.go index 03c01181c13ef..0b73a08d6b0c8 100644 --- a/go/arrow/cdata/cdata_test.go +++ b/go/arrow/cdata/cdata_test.go @@ -646,7 +646,6 @@ func TestRecordReaderStream(t *testing.T) { } assert.NoError(t, err) } - defer rec.Release() assert.EqualValues(t, 2, rec.NumCols()) assert.Equal(t, "a", rec.ColumnName(0))