diff --git a/arrow/array/arreflect/doc.go b/arrow/array/arreflect/doc.go new file mode 100644 index 00000000..7730810b --- /dev/null +++ b/arrow/array/arreflect/doc.go @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package arreflect provides utilities for converting between +// Apache Arrow arrays and Go structs using reflection. +// +// The primary entry points are the generic functions [At], [ToSlice], +// [FromSlice], [RecordToSlice], and [RecordFromSlice], which convert +// between Arrow arrays/records and Go slices of structs. +// +// Schema inference is available via [InferSchema] and [InferType]. +// +// Arrow struct tags control field mapping: +// +// type MyRow struct { +// Name string `arrow:"name"` +// Score float64 `arrow:"score"` +// Skip string `arrow:"-"` +// Enc string `arrow:"enc,dict"` +// T32 time.Time `arrow:"t32,time32"` +// } +// +// Temporal type overrides for time.Time fields: +// +// arrow:"field,date32" — use Date32 instead of Timestamp +// arrow:"field,date64" — use Date64 instead of Timestamp +// arrow:"field,time32" — use Time32(ms) instead of Timestamp +// arrow:"field,time64" — use Time64(ns) instead of Timestamp +// +// Additional tag options: +// +// arrow:"field,view" — use STRING_VIEW/BINARY_VIEW for string/bytes fields, or LIST_VIEW for slice fields +// arrow:"field,ree" — run-end encoding at top-level only (struct fields not supported) +// arrow:"field,decimal(precision,scale)" — override decimal precision and scale (e.g., arrow:",decimal(18,2)") +package arreflect diff --git a/arrow/array/arreflect/example_test.go b/arrow/array/arreflect/example_test.go new file mode 100644 index 00000000..8a3a6c64 --- /dev/null +++ b/arrow/array/arreflect/example_test.go @@ -0,0 +1,401 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect_test + +import ( + "fmt" + "reflect" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/array/arreflect" + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/arrow-go/v18/arrow/memory" +) + +func ExampleFromSlice() { + mem := memory.NewGoAllocator() + + arr, err := arreflect.FromSlice([]int32{10, 20, 30}, mem) + if err != nil { + panic(err) + } + defer arr.Release() + + fmt.Println("Type:", arr.DataType()) + fmt.Println("Len:", arr.Len()) + for i := 0; i < arr.Len(); i++ { + fmt.Println(arr.(*array.Int32).Value(i)) + } + // Output: + // Type: int32 + // Len: 3 + // 10 + // 20 + // 30 +} + +func ExampleFromSlice_structSlice() { + mem := memory.NewGoAllocator() + + type Row struct { + Name string `arrow:"name"` + Score float64 `arrow:"score"` + } + + arr, err := arreflect.FromSlice([]Row{ + {"alice", 9.5}, + {"bob", 7.0}, + }, mem) + if err != nil { + panic(err) + } + defer arr.Release() + + sa := arr.(*array.Struct) + fmt.Println("Type:", sa.DataType()) + fmt.Println("Names:", sa.Field(0)) + fmt.Println("Scores:", sa.Field(1)) + // Output: + // Type: struct + // Names: ["alice" "bob"] + // Scores: [9.5 7] +} + +func ExampleFromSlice_withDecimal() { + mem := memory.NewGoAllocator() + + vals := []decimal128.Num{ + decimal128.FromI64(12345), + decimal128.FromI64(-67890), + } + arr, err := arreflect.FromSlice(vals, mem, arreflect.WithDecimal(10, 2)) + if err != nil { + panic(err) + } + defer arr.Release() + + fmt.Println("Type:", arr.DataType()) + fmt.Println("Len:", arr.Len()) + // Output: + // Type: decimal(10, 2) + // Len: 2 +} + +func ExampleToSlice() { + mem := memory.NewGoAllocator() + + b := array.NewFloat64Builder(mem) + defer b.Release() + b.Append(1.1) + b.Append(2.2) + b.Append(3.3) + arr := b.NewArray() + defer arr.Release() + + vals, err := arreflect.ToSlice[float64](arr) + if err != nil { + panic(err) + } + fmt.Println(vals) + // Output: + // [1.1 2.2 3.3] +} + +type Measurement struct { + Sensor string `arrow:"sensor"` + Value float64 `arrow:"value"` +} + +func ExampleRecordFromSlice() { + mem := memory.NewGoAllocator() + + rows := []Measurement{ + {"temp-1", 23.5}, + {"temp-2", 19.8}, + } + rec, err := arreflect.RecordFromSlice(rows, mem) + if err != nil { + panic(err) + } + defer rec.Release() + + fmt.Println("Schema:", rec.Schema()) + fmt.Println("Rows:", rec.NumRows()) + fmt.Println("Col 0:", rec.Column(0)) + fmt.Println("Col 1:", rec.Column(1)) + // Output: + // Schema: schema: + // fields: 2 + // - sensor: type=utf8 + // - value: type=float64 + // Rows: 2 + // Col 0: ["temp-1" "temp-2"] + // Col 1: [23.5 19.8] +} + +func ExampleRecordToSlice() { + mem := memory.NewGoAllocator() + + rows := []Measurement{ + {"temp-1", 23.5}, + {"temp-2", 19.8}, + } + rec, err := arreflect.RecordFromSlice(rows, mem) + if err != nil { + panic(err) + } + defer rec.Release() + + got, err := arreflect.RecordToSlice[Measurement](rec) + if err != nil { + panic(err) + } + for _, m := range got { + fmt.Printf("%s: %.1f\n", m.Sensor, m.Value) + } + // Output: + // temp-1: 23.5 + // temp-2: 19.8 +} + +func ExampleAt() { + mem := memory.NewGoAllocator() + + b := array.NewStringBuilder(mem) + defer b.Release() + b.Append("alpha") + b.Append("beta") + b.Append("gamma") + arr := b.NewArray() + defer arr.Release() + + val, err := arreflect.At[string](arr, 1) + if err != nil { + panic(err) + } + fmt.Println(val) + // Output: + // beta +} + +func ExampleInferSchema() { + type Event struct { + ID int64 `arrow:"id"` + Name string `arrow:"name"` + Score float64 `arrow:"score"` + Comment *string `arrow:"comment"` + } + + schema, err := arreflect.InferSchema[Event]() + if err != nil { + panic(err) + } + fmt.Println(schema) + // Output: + // schema: + // fields: 4 + // - id: type=int64 + // - name: type=utf8 + // - score: type=float64 + // - comment: type=utf8, nullable +} + +func ExampleFromSlice_withDict() { + mem := memory.NewGoAllocator() + + arr, err := arreflect.FromSlice( + []string{"red", "green", "red", "blue", "green"}, + mem, + arreflect.WithDict(), + ) + if err != nil { + panic(err) + } + defer arr.Release() + + fmt.Println("Type:", arr.DataType()) + dict := arr.(*array.Dictionary) + fmt.Println("Indices:", dict.Indices()) + fmt.Println("Dictionary:", dict.Dictionary()) + // Output: + // Type: dictionary + // Indices: [0 1 0 2 1] + // Dictionary: ["red" "green" "blue"] +} + +func ExampleToAnySlice() { + mem := memory.NewGoAllocator() + + st := arrow.StructOf( + arrow.Field{Name: "city", Type: arrow.BinaryTypes.String}, + arrow.Field{Name: "pop", Type: arrow.PrimitiveTypes.Int64}, + ) + sb := array.NewStructBuilder(mem, st) + defer sb.Release() + + sb.Append(true) + sb.FieldBuilder(0).(*array.StringBuilder).Append("Tokyo") + sb.FieldBuilder(1).(*array.Int64Builder).Append(14000000) + + sb.Append(true) + sb.FieldBuilder(0).(*array.StringBuilder).Append("Paris") + sb.FieldBuilder(1).(*array.Int64Builder).Append(2200000) + + arr := sb.NewArray() + defer arr.Release() + + rows, err := arreflect.ToAnySlice(arr) + if err != nil { + panic(err) + } + for _, row := range rows { + fmt.Println(row) + } + // Output: + // {Tokyo 14000000} + // {Paris 2200000} +} + +func ExampleToAnySlice_nullableFields() { + mem := memory.NewGoAllocator() + + st := arrow.StructOf( + arrow.Field{Name: "name", Type: arrow.BinaryTypes.String, Nullable: false}, + arrow.Field{Name: "score", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + ) + sb := array.NewStructBuilder(mem, st) + defer sb.Release() + + sb.Append(true) + sb.FieldBuilder(0).(*array.StringBuilder).Append("alice") + sb.FieldBuilder(1).(*array.Float64Builder).Append(9.5) + + sb.Append(true) + sb.FieldBuilder(0).(*array.StringBuilder).Append("bob") + sb.FieldBuilder(1).(*array.Float64Builder).AppendNull() + + arr := sb.NewArray() + defer arr.Release() + + rows, err := arreflect.ToAnySlice(arr) + if err != nil { + panic(err) + } + for _, row := range rows { + v := reflect.ValueOf(row) + var name string + var scoreField reflect.Value + for i := 0; i < v.NumField(); i++ { + switch v.Type().Field(i).Tag.Get("arrow") { + case "name": + name = v.Field(i).String() + case "score": + scoreField = v.Field(i) + } + } + if scoreField.IsNil() { + fmt.Printf("%s: \n", name) + } else { + fmt.Printf("%s: %.1f\n", name, scoreField.Elem().Float()) + } + } + // Output: + // alice: 9.5 + // bob: +} + +func ExampleWithLarge() { + mem := memory.NewGoAllocator() + + arr, err := arreflect.FromSlice([]string{"hello", "world"}, mem, arreflect.WithLarge()) + if err != nil { + panic(err) + } + defer arr.Release() + + fmt.Println("Type:", arr.DataType()) + fmt.Println("Len:", arr.Len()) + // Output: + // Type: large_utf8 + // Len: 2 +} + +func ExampleFromSlice_largeStruct() { + type Event struct { + Name string `arrow:"name,large"` + Code int32 `arrow:"code"` + } + + schema, err := arreflect.InferSchema[Event]() + if err != nil { + panic(err) + } + fmt.Println("Schema:", schema) + + mem := memory.NewGoAllocator() + arr, err := arreflect.FromSlice([]Event{{"click", 1}, {"view", 2}}, mem) + if err != nil { + panic(err) + } + defer arr.Release() + + sa := arr.(*array.Struct) + fmt.Println("Name type:", sa.Field(0).DataType()) + fmt.Println("Code type:", sa.Field(1).DataType()) + // Output: + // Schema: schema: + // fields: 2 + // - name: type=large_utf8 + // - code: type=int32 + // Name type: large_utf8 + // Code type: int32 +} + +func ExampleWithView() { + mem := memory.NewGoAllocator() + + arr, err := arreflect.FromSlice([]string{"hello", "world"}, mem, arreflect.WithView()) + if err != nil { + panic(err) + } + defer arr.Release() + + fmt.Println("Type:", arr.DataType()) + fmt.Println("Len:", arr.Len()) + // Output: + // Type: string_view + // Len: 2 +} + +func ExampleFromSlice_viewStruct() { + type Event struct { + Name string `arrow:"name,view"` + Code int32 `arrow:"code"` + } + + schema, err := arreflect.InferSchema[Event]() + if err != nil { + panic(err) + } + fmt.Println("Schema:", schema) + // Output: + // Schema: schema: + // fields: 2 + // - name: type=string_view + // - code: type=int32 +} diff --git a/arrow/array/arreflect/reflect.go b/arrow/array/arreflect/reflect.go new file mode 100644 index 00000000..248350b5 --- /dev/null +++ b/arrow/array/arreflect/reflect.go @@ -0,0 +1,605 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect + +import ( + "errors" + "fmt" + "reflect" + "sort" + "strconv" + "strings" + "sync" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/memory" +) + +var ( + ErrUnsupportedType = errors.New("arreflect: unsupported type") + ErrTypeMismatch = errors.New("arreflect: type mismatch") +) + +type tagOpts struct { + Name string + Skip bool + Dict bool + View bool + REE bool + Large bool + DecimalPrecision int32 + DecimalScale int32 + HasDecimalOpts bool + Temporal string // "timestamp" (default), "date32", "date64", "time32", "time64" + ParseErr string // diagnostic set when decimal(p,s) tag fails to parse; surfaced by validateOptions +} + +type fieldMeta struct { + Name string + Index []int + Type reflect.Type + Nullable bool + Opts tagOpts +} + +func parseTag(tag string) tagOpts { + if tag == "-" { + return tagOpts{Skip: true} + } + + name, rest, _ := strings.Cut(tag, ",") + opts := tagOpts{Name: name} + + if rest == "" { + return opts + } + + parseOptions(&opts, rest) + return opts +} + +func splitTagTokens(rest string) []string { + var tokens []string + depth := 0 + start := 0 + for i := 0; i < len(rest); i++ { + switch rest[i] { + case '(': + depth++ + case ')': + depth-- + case ',': + if depth == 0 { + tokens = append(tokens, strings.TrimSpace(rest[start:i])) + start = i + 1 + } + } + } + if start < len(rest) { + tokens = append(tokens, strings.TrimSpace(rest[start:])) + } + return tokens +} + +func parseOptions(opts *tagOpts, rest string) { + for _, token := range splitTagTokens(rest) { + if strings.HasPrefix(token, "decimal(") && strings.HasSuffix(token, ")") { + parseDecimalOpt(opts, token) + continue + } + switch token { + case "dict": + opts.Dict = true + case "view": + opts.View = true + case "ree": + opts.REE = true + case "large": + opts.Large = true + case "date32", "date64", "time32", "time64", "timestamp": + opts.Temporal = token + default: + opts.ParseErr = fmt.Sprintf("unknown option %q", token) + } + } +} + +func parseDecimalOpt(opts *tagOpts, token string) { + inner := strings.TrimPrefix(token, "decimal(") + inner = strings.TrimSuffix(inner, ")") + parts := strings.SplitN(inner, ",", 2) + if len(parts) != 2 { + opts.ParseErr = fmt.Sprintf("invalid decimal tag %q: expected decimal(precision,scale)", token) + return + } + p, errP := strconv.ParseInt(strings.TrimSpace(parts[0]), 10, 32) + if errP != nil { + opts.ParseErr = fmt.Sprintf("invalid decimal tag %q: precision %q is not an integer", token, strings.TrimSpace(parts[0])) + return + } + s, errS := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 32) + if errS != nil { + opts.ParseErr = fmt.Sprintf("invalid decimal tag %q: scale %q is not an integer", token, strings.TrimSpace(parts[1])) + return + } + opts.HasDecimalOpts = true + opts.DecimalPrecision = int32(p) + opts.DecimalScale = int32(s) +} + +type bfsEntry struct { + t reflect.Type + index []int + depth int +} + +type candidate struct { + meta fieldMeta + depth int + tagged bool + order int +} + +type resolvedField struct { + meta fieldMeta + order int +} + +func collectFieldCandidates(t reflect.Type) map[string][]candidate { + nameMap := make(map[string][]candidate) + orderCounter := 0 + + queue := []bfsEntry{{t: t, index: nil, depth: 0}} + visited := make(map[reflect.Type]bool) + + for len(queue) > 0 { + entry := queue[0] + queue = queue[1:] + + st := entry.t + for st.Kind() == reflect.Ptr { + st = st.Elem() + } + if st.Kind() != reflect.Struct { + continue + } + + if visited[st] { + continue + } + if entry.depth > 0 { + visited[st] = true + } + + for i := 0; i < st.NumField(); i++ { + sf := st.Field(i) + + fullIndex := make([]int, len(entry.index)+1) + copy(fullIndex, entry.index) + fullIndex[len(entry.index)] = i + + if !sf.IsExported() && !sf.Anonymous { + continue + } + + tagVal, hasTag := sf.Tag.Lookup("arrow") + var opts tagOpts + if hasTag { + opts = parseTag(tagVal) + } + + if opts.Skip { + continue + } + + arrowName := opts.Name + if arrowName == "" { + arrowName = sf.Name + } + + if sf.Anonymous && !hasTag { + ft := sf.Type + for ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + if ft.Kind() == reflect.Struct { + queue = append(queue, bfsEntry{ + t: ft, + index: fullIndex, + depth: entry.depth + 1, + }) + continue + } + } + + nullable := sf.Type.Kind() == reflect.Ptr + tagged := hasTag && opts.Name != "" + + meta := fieldMeta{ + Name: arrowName, + Index: fullIndex, + Type: sf.Type, + Nullable: nullable, + Opts: opts, + } + + existingCands := nameMap[arrowName] + order := orderCounter + if len(existingCands) > 0 { + order = existingCands[0].order + } else { + orderCounter++ + } + + nameMap[arrowName] = append(existingCands, candidate{ + meta: meta, + depth: entry.depth, + tagged: tagged, + order: order, + }) + } + } + + return nameMap +} + +func resolveFieldCandidates(nameMap map[string][]candidate) []fieldMeta { + resolved := make([]resolvedField, 0, len(nameMap)) + for _, candidates := range nameMap { + minDepth := candidates[0].depth + for _, c := range candidates[1:] { + if c.depth < minDepth { + minDepth = c.depth + } + } + + var atMin []candidate + for _, c := range candidates { + if c.depth == minDepth { + atMin = append(atMin, c) + } + } + + var winner *candidate + if len(atMin) == 1 { + winner = &atMin[0] + } else { + var tagged []candidate + for _, c := range atMin { + if c.tagged { + tagged = append(tagged, c) + } + } + if len(tagged) == 1 { + winner = &tagged[0] + } + } + + if winner != nil { + resolved = append(resolved, resolvedField{meta: winner.meta, order: winner.order}) + } + } + + sort.Slice(resolved, func(i, j int) bool { + return resolved[i].order < resolved[j].order + }) + + result := make([]fieldMeta, len(resolved)) + for i, r := range resolved { + result[i] = r.meta + } + return result +} + +func getStructFields(t reflect.Type) []fieldMeta { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + if t.Kind() != reflect.Struct { + return nil + } + + return resolveFieldCandidates(collectFieldCandidates(t)) +} + +var structFieldCache sync.Map + +func cachedStructFields(t reflect.Type) []fieldMeta { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + if v, ok := structFieldCache.Load(t); ok { + return v.([]fieldMeta) + } + + fields := getStructFields(t) + v, _ := structFieldCache.LoadOrStore(t, fields) + return v.([]fieldMeta) +} + +func fieldByIndexSafe(v reflect.Value, index []int) (reflect.Value, bool) { + for _, idx := range index { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return reflect.Value{}, false + } + v = v.Elem() + } + v = v.Field(idx) + } + return v, true +} + +func At[T any](arr arrow.Array, i int) (T, error) { + var result T + v := reflect.ValueOf(&result).Elem() + if err := setValue(v, arr, i); err != nil { + var zero T + return zero, err + } + return result, nil +} + +func ToSlice[T any](arr arrow.Array) ([]T, error) { + n := arr.Len() + result := make([]T, n) + for i := 0; i < n; i++ { + v := reflect.ValueOf(&result[i]).Elem() + if err := setValue(v, arr, i); err != nil { + return nil, fmt.Errorf("index %d: %w", i, err) + } + } + return result, nil +} + +// Option configures encoding behavior for [FromSlice] and [RecordFromSlice]. +type Option func(*tagOpts) + +// WithDict requests dictionary encoding for the top-level array. +func WithDict() Option { return func(o *tagOpts) { o.Dict = true } } + +// WithView requests view-type encoding (STRING_VIEW, BINARY_VIEW, LIST_VIEW) +// for the top-level array and recursively for nested types. +func WithView() Option { return func(o *tagOpts) { o.View = true } } + +// WithREE requests run-end encoding for the top-level array. +func WithREE() Option { return func(o *tagOpts) { o.REE = true } } + +// WithDecimal sets the precision and scale for decimal types. +func WithDecimal(precision, scale int32) Option { + return func(o *tagOpts) { + o.DecimalPrecision = precision + o.DecimalScale = scale + o.HasDecimalOpts = true + } +} + +// WithTemporal overrides the Arrow temporal encoding for time.Time slices. +// Valid values: "date32", "date64", "time32", "time64", "timestamp" (default). +// Equivalent to tagging a struct field with arrow:",date32" etc. +// Invalid values cause FromSlice to return an error. +func WithTemporal(temporal string) Option { + return func(o *tagOpts) { o.Temporal = temporal } +} + +// WithLarge requests Large type variants (LARGE_STRING, LARGE_BINARY, LARGE_LIST, +// LARGE_LIST_VIEW) for the top-level array and recursively for nested types. +func WithLarge() Option { return func(o *tagOpts) { o.Large = true } } + +func validateTemporalOpt(temporal string) error { + switch temporal { + case "", "timestamp", "date32", "date64", "time32", "time64": + return nil + default: + return fmt.Errorf("arreflect: invalid WithTemporal value %q; valid values are date32, date64, time32, time64, timestamp: %w", temporal, ErrUnsupportedType) + } +} + +func validateOptions(opts tagOpts) error { + if opts.ParseErr != "" { + return fmt.Errorf("arreflect: %s: %w", opts.ParseErr, ErrUnsupportedType) + } + n := 0 + if opts.Dict { + n++ + } + if opts.REE { + n++ + } + if opts.View { + n++ + } + if n > 1 { + return fmt.Errorf("arreflect: conflicting options: only one of WithDict, WithREE, WithView may be specified: %w", ErrUnsupportedType) + } + return nil +} + +func buildEmptyTyped(goType reflect.Type, opts tagOpts, mem memory.Allocator) (arrow.Array, error) { + dt, err := inferArrowType(goType) + if err != nil { + return nil, err + } + derefType := goType + for derefType.Kind() == reflect.Ptr { + derefType = derefType.Elem() + } + dt = applyDecimalOpts(dt, derefType, opts) + dt = applyTemporalOpts(dt, derefType, opts) + if opts.Large { + if !hasLargeableType(dt) { + return nil, fmt.Errorf("arreflect: large option has no effect on type %s: %w", dt, ErrUnsupportedType) + } + dt = applyLargeOpts(dt) + } + if opts.View { + if derefType.Kind() == reflect.Slice && derefType != typeOfByteSlice { + // slice-of-slices: build a LIST_VIEW or LARGE_LIST_VIEW + innerElem := derefType.Elem() + for innerElem.Kind() == reflect.Ptr { + innerElem = innerElem.Elem() + } + innerDT, err := inferArrowType(innerElem) + if err != nil { + return nil, err + } + innerDT = applyViewOpts(innerDT) + if opts.Large { + dt = arrow.LargeListViewOf(innerDT) + } else { + dt = arrow.ListViewOf(innerDT) + } + } else { + // primitive/string/binary: apply view recursively + if !hasViewableType(dt) { + return nil, fmt.Errorf("arreflect: view option has no effect on type %s: %w", dt, ErrUnsupportedType) + } + dt = applyViewOpts(dt) + } + } + if opts.Dict { + if err := validateDictValueType(dt); err != nil { + return nil, err + } + dt = &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, ValueType: dt} + } else if opts.REE { + dt = arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int32, dt) + } + b := array.NewBuilder(mem, dt) + defer b.Release() + return b.NewArray(), nil +} + +func FromSlice[T any](vals []T, mem memory.Allocator, opts ...Option) (arrow.Array, error) { + if mem == nil { + mem = memory.DefaultAllocator + } + var tOpts tagOpts + for _, o := range opts { + o(&tOpts) + } + if err := validateOptions(tOpts); err != nil { + return nil, err + } + if err := validateTemporalOpt(tOpts.Temporal); err != nil { + return nil, err + } + if tOpts.Temporal != "" { + goType := reflect.TypeFor[T]() + deref := goType + for deref.Kind() == reflect.Ptr { + deref = deref.Elem() + } + if deref != typeOfTime { + return nil, fmt.Errorf("arreflect: WithTemporal requires a time.Time element type, got %s: %w", deref, ErrUnsupportedType) + } + } + if len(vals) == 0 { + return buildEmptyTyped(reflect.TypeFor[T](), tOpts, mem) + } + sv := reflect.ValueOf(vals) + return buildArray(sv, tOpts, mem) +} + +func RecordToSlice[T any](rec arrow.RecordBatch) ([]T, error) { + sa := array.RecordToStructArray(rec) + defer sa.Release() + return ToSlice[T](sa) +} + +func RecordFromSlice[T any](vals []T, mem memory.Allocator, opts ...Option) (arrow.RecordBatch, error) { + arr, err := FromSlice[T](vals, mem, opts...) + if err != nil { + return nil, err + } + defer arr.Release() + sa, ok := arr.(*array.Struct) + if !ok { + return nil, fmt.Errorf("arreflect: RecordFromSlice requires a struct type T, got %v", reflect.TypeFor[T]()) + } + return array.RecordFromStructArray(sa, nil), nil +} + +// RecordAt converts the row at index i of an Arrow Record to a Go value of type T. +// T must be a struct type whose fields correspond to the record's columns. +func RecordAt[T any](rec arrow.RecordBatch, i int) (T, error) { + sa := array.RecordToStructArray(rec) + defer sa.Release() + return At[T](sa, i) +} + +// RecordAtAny converts the row at index i of an Arrow Record to a Go value, +// inferring the Go type from the record's schema at runtime via [InferGoType]. +// Equivalent to AtAny on the struct array underlying the record. +func RecordAtAny(rec arrow.RecordBatch, i int) (any, error) { + sa := array.RecordToStructArray(rec) + defer sa.Release() + return AtAny(sa, i) +} + +// RecordToAnySlice converts all rows of an Arrow Record to Go values, +// inferring the Go type at runtime via [InferGoType]. +// Equivalent to ToAnySlice on the struct array underlying the record. +func RecordToAnySlice(rec arrow.RecordBatch) ([]any, error) { + sa := array.RecordToStructArray(rec) + defer sa.Release() + return ToAnySlice(sa) +} + +// AtAny converts a single element at index i of an Arrow array to a Go value, +// inferring the Go type from the Arrow DataType at runtime via [InferGoType]. +// Useful when the column type is not known at compile time. +// Null elements are returned as the Go zero value of the inferred type; use +// arr.IsNull(i) to distinguish a null element from a genuine zero. +// For typed access when T is known, prefer [At]. +func AtAny(arr arrow.Array, i int) (any, error) { + goType, err := InferGoType(arr.DataType()) + if err != nil { + return nil, err + } + result := reflect.New(goType).Elem() + if err := setValue(result, arr, i); err != nil { + return nil, err + } + return result.Interface(), nil +} + +// ToAnySlice converts all elements of an Arrow array to Go values, +// inferring the Go type from the Arrow DataType at runtime via [InferGoType]. +// All elements share the same inferred Go type. Null elements are returned as +// the Go zero value of the inferred type; use arr.IsNull(i) to distinguish +// a null element from a genuine zero value. +// For typed access when T is known, prefer [ToSlice]. +func ToAnySlice(arr arrow.Array) ([]any, error) { + goType, err := InferGoType(arr.DataType()) + if err != nil { + return nil, err + } + n := arr.Len() + result := make([]any, n) + for i := 0; i < n; i++ { + v := reflect.New(goType).Elem() + if err := setValue(v, arr, i); err != nil { + return nil, fmt.Errorf("index %d: %w", i, err) + } + result[i] = v.Interface() + } + return result, nil +} diff --git a/arrow/array/arreflect/reflect_arrow_to_go.go b/arrow/array/arreflect/reflect_arrow_to_go.go new file mode 100644 index 00000000..c5b1288b --- /dev/null +++ b/arrow/array/arreflect/reflect_arrow_to_go.go @@ -0,0 +1,441 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect + +import ( + "fmt" + "reflect" + "strings" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" +) + +func assertArray[T any](arr arrow.Array) (*T, error) { + a, ok := any(arr).(*T) + if !ok { + var zero T + return nil, fmt.Errorf("expected *%T, got %T: %w", zero, arr, ErrTypeMismatch) + } + return a, nil +} + +func isIntKind(k reflect.Kind) bool { + return k == reflect.Int || k == reflect.Int8 || k == reflect.Int16 || + k == reflect.Int32 || k == reflect.Int64 +} + +func isUintKind(k reflect.Kind) bool { + return k == reflect.Uint || k == reflect.Uint8 || k == reflect.Uint16 || + k == reflect.Uint32 || k == reflect.Uint64 || k == reflect.Uintptr +} + +func isFloatKind(k reflect.Kind) bool { return k == reflect.Float32 || k == reflect.Float64 } + +func setValue(v reflect.Value, arr arrow.Array, i int) error { + if arr.IsNull(i) { + v.Set(reflect.Zero(v.Type())) + return nil + } + for v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + + switch arr.DataType().ID() { + case arrow.BOOL: + a, err := assertArray[array.Boolean](arr) + if err != nil { + return err + } + if v.Kind() != reflect.Bool { + return fmt.Errorf("cannot set bool into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetBool(a.Value(i)) + + case arrow.INT8, arrow.INT16, arrow.INT32, arrow.INT64, + arrow.UINT8, arrow.UINT16, arrow.UINT32, arrow.UINT64, + arrow.FLOAT32, arrow.FLOAT64: + return setPrimitiveValue(v, arr, i) + + case arrow.STRING, arrow.LARGE_STRING, arrow.STRING_VIEW: + type stringer interface{ Value(int) string } + a, ok := arr.(stringer) + if !ok { + return fmt.Errorf("expected string array, got %T: %w", arr, ErrTypeMismatch) + } + if v.Kind() != reflect.String { + return fmt.Errorf("cannot set string into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetString(strings.Clone(a.Value(i))) + + case arrow.BINARY, arrow.LARGE_BINARY, arrow.BINARY_VIEW: + type byter interface{ Value(int) []byte } + a, ok := arr.(byter) + if !ok { + return fmt.Errorf("expected binary array, got %T: %w", arr, ErrTypeMismatch) + } + if v.Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 { + return fmt.Errorf("cannot set []byte into %s: %w", v.Type(), ErrTypeMismatch) + } + src := a.Value(i) + dst := make([]byte, len(src)) + copy(dst, src) + v.SetBytes(dst) + + case arrow.TIMESTAMP, arrow.DATE32, arrow.DATE64, + arrow.TIME32, arrow.TIME64, arrow.DURATION: + return setTemporalValue(v, arr, i) + + case arrow.DECIMAL128, arrow.DECIMAL256, arrow.DECIMAL32, arrow.DECIMAL64: + return setDecimalValue(v, arr, i) + + case arrow.STRUCT: + a, err := assertArray[array.Struct](arr) + if err != nil { + return err + } + return setStructValue(v, a, i) + + case arrow.LIST, arrow.LARGE_LIST, arrow.LIST_VIEW, arrow.LARGE_LIST_VIEW: + a, ok := arr.(array.ListLike) + if !ok { + return fmt.Errorf("expected ListLike, got %T: %w", arr, ErrTypeMismatch) + } + return setListValue(v, a, i) + + case arrow.MAP: + a, err := assertArray[array.Map](arr) + if err != nil { + return err + } + return setMapValue(v, a, i) + + case arrow.FIXED_SIZE_LIST: + a, err := assertArray[array.FixedSizeList](arr) + if err != nil { + return err + } + return setFixedSizeListValue(v, a, i) + + case arrow.DICTIONARY: + a, err := assertArray[array.Dictionary](arr) + if err != nil { + return err + } + return setDictionaryValue(v, a, i) + + case arrow.RUN_END_ENCODED: + a, err := assertArray[array.RunEndEncoded](arr) + if err != nil { + return err + } + return setRunEndEncodedValue(v, a, i) + + default: + return fmt.Errorf("unsupported Arrow type %v for reflection: %w", arr.DataType(), ErrUnsupportedType) + } + return nil +} + +func setPrimitiveValue(v reflect.Value, arr arrow.Array, i int) error { + switch arr.DataType().ID() { + case arrow.INT8: + if !isIntKind(v.Kind()) { + return fmt.Errorf("cannot set int8 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetInt(int64(arr.(*array.Int8).Value(i))) + case arrow.INT16: + if !isIntKind(v.Kind()) { + return fmt.Errorf("cannot set int16 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetInt(int64(arr.(*array.Int16).Value(i))) + case arrow.INT32: + if !isIntKind(v.Kind()) { + return fmt.Errorf("cannot set int32 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetInt(int64(arr.(*array.Int32).Value(i))) + case arrow.INT64: + if !isIntKind(v.Kind()) { + return fmt.Errorf("cannot set int64 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetInt(arr.(*array.Int64).Value(i)) + case arrow.UINT8: + if !isUintKind(v.Kind()) { + return fmt.Errorf("cannot set uint8 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetUint(uint64(arr.(*array.Uint8).Value(i))) + case arrow.UINT16: + if !isUintKind(v.Kind()) { + return fmt.Errorf("cannot set uint16 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetUint(uint64(arr.(*array.Uint16).Value(i))) + case arrow.UINT32: + if !isUintKind(v.Kind()) { + return fmt.Errorf("cannot set uint32 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetUint(uint64(arr.(*array.Uint32).Value(i))) + case arrow.UINT64: + if !isUintKind(v.Kind()) { + return fmt.Errorf("cannot set uint64 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetUint(arr.(*array.Uint64).Value(i)) + case arrow.FLOAT32: + if !isFloatKind(v.Kind()) { + return fmt.Errorf("cannot set float32 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetFloat(float64(arr.(*array.Float32).Value(i))) + case arrow.FLOAT64: + if !isFloatKind(v.Kind()) { + return fmt.Errorf("cannot set float64 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetFloat(arr.(*array.Float64).Value(i)) + default: + return fmt.Errorf("unsupported primitive type %v: %w", arr.DataType(), ErrUnsupportedType) + } + return nil +} + +func setTime(v reflect.Value, t time.Time) error { + if v.Type() != typeOfTime { + return fmt.Errorf("cannot set time.Time into %s: %w", v.Type(), ErrTypeMismatch) + } + v.Set(reflect.ValueOf(t)) + return nil +} + +func setTemporalValue(v reflect.Value, arr arrow.Array, i int) error { + switch arr.DataType().ID() { + case arrow.TIMESTAMP: + a, err := assertArray[array.Timestamp](arr) + if err != nil { + return err + } + unit := arr.DataType().(*arrow.TimestampType).Unit + return setTime(v, a.Value(i).ToTime(unit)) + + case arrow.DATE32: + a, err := assertArray[array.Date32](arr) + if err != nil { + return err + } + return setTime(v, a.Value(i).ToTime()) + + case arrow.DATE64: + a, err := assertArray[array.Date64](arr) + if err != nil { + return err + } + return setTime(v, a.Value(i).ToTime()) + + case arrow.TIME32: + a, err := assertArray[array.Time32](arr) + if err != nil { + return err + } + unit := arr.DataType().(*arrow.Time32Type).Unit + return setTime(v, a.Value(i).ToTime(unit)) + + case arrow.TIME64: + a, err := assertArray[array.Time64](arr) + if err != nil { + return err + } + unit := arr.DataType().(*arrow.Time64Type).Unit + return setTime(v, a.Value(i).ToTime(unit)) + + case arrow.DURATION: + a, err := assertArray[array.Duration](arr) + if err != nil { + return err + } + if v.Type() != typeOfDuration { + return fmt.Errorf("cannot set time.Duration into %s: %w", v.Type(), ErrTypeMismatch) + } + unit := arr.DataType().(*arrow.DurationType).Unit + dur := time.Duration(a.Value(i)) * unit.Multiplier() + v.Set(reflect.ValueOf(dur)) + + default: + return fmt.Errorf("unsupported temporal type %v: %w", arr.DataType(), ErrUnsupportedType) + } + return nil +} + +func setDecimalValue(v reflect.Value, arr arrow.Array, i int) error { + switch arr.DataType().ID() { + case arrow.DECIMAL128: + a, err := assertArray[array.Decimal128](arr) + if err != nil { + return err + } + if v.Type() != typeOfDec128 { + return fmt.Errorf("cannot set decimal128.Num into %s: %w", v.Type(), ErrTypeMismatch) + } + num := a.Value(i) + v.Set(reflect.ValueOf(num)) + + case arrow.DECIMAL256: + a, err := assertArray[array.Decimal256](arr) + if err != nil { + return err + } + if v.Type() != typeOfDec256 { + return fmt.Errorf("cannot set decimal256.Num into %s: %w", v.Type(), ErrTypeMismatch) + } + num := a.Value(i) + v.Set(reflect.ValueOf(num)) + + case arrow.DECIMAL32: + a, err := assertArray[array.Decimal32](arr) + if err != nil { + return err + } + if v.Type() != typeOfDec32 { + return fmt.Errorf("cannot set decimal.Decimal32 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.Set(reflect.ValueOf(a.Value(i))) + + case arrow.DECIMAL64: + a, err := assertArray[array.Decimal64](arr) + if err != nil { + return err + } + if v.Type() != typeOfDec64 { + return fmt.Errorf("cannot set decimal.Decimal64 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.Set(reflect.ValueOf(a.Value(i))) + + default: + return fmt.Errorf("unsupported decimal type %v: %w", arr.DataType(), ErrUnsupportedType) + } + return nil +} + +func setStructValue(v reflect.Value, sa *array.Struct, i int) error { + if v.Kind() != reflect.Struct { + return fmt.Errorf("cannot set struct into %s: %w", v.Type(), ErrTypeMismatch) + } + + fields := cachedStructFields(v.Type()) + st := sa.DataType().(*arrow.StructType) + + for _, fm := range fields { + arrowIdx, found := st.FieldIdx(fm.Name) + if !found { + continue + } + fv, ok := fieldByIndexSafe(v, fm.Index) + if !ok { + // embedded pointer is nil; leave the field at its zero value + continue + } + if err := setValue(fv, sa.Field(arrowIdx), i); err != nil { + return fmt.Errorf("arreflect: field %q: %w", fm.Name, err) + } + } + return nil +} + +func setListValue(v reflect.Value, arr array.ListLike, i int) error { + if v.Kind() != reflect.Slice { + return fmt.Errorf("cannot set list into %s: %w", v.Type(), ErrTypeMismatch) + } + + start, end := arr.ValueOffsets(i) + child := arr.ListValues() + length := int(end - start) + + result := reflect.MakeSlice(v.Type(), length, length) + for j := 0; j < length; j++ { + if err := setValue(result.Index(j), child, int(start)+j); err != nil { + return fmt.Errorf("arreflect: list element %d: %w", j, err) + } + } + v.Set(result) + return nil +} + +func setMapValue(v reflect.Value, arr *array.Map, i int) error { + if v.Kind() != reflect.Map { + return fmt.Errorf("cannot set map into %s: %w", v.Type(), ErrTypeMismatch) + } + + start, end := arr.ValueOffsets(i) + keys := arr.Keys() + items := arr.Items() + keyType := v.Type().Key() + elemType := v.Type().Elem() + + result := reflect.MakeMap(v.Type()) + for j := int(start); j < int(end); j++ { + keyVal := reflect.New(keyType).Elem() + if err := setValue(keyVal, keys, j); err != nil { + return fmt.Errorf("arreflect: map key %d: %w", j-int(start), err) + } + elemVal := reflect.New(elemType).Elem() + if err := setValue(elemVal, items, j); err != nil { + return fmt.Errorf("arreflect: map value %d: %w", j-int(start), err) + } + result.SetMapIndex(keyVal, elemVal) + } + v.Set(result) + return nil +} + +func fillFixedSizeList(dst reflect.Value, child arrow.Array, start, n int) error { + for k := 0; k < n; k++ { + if err := setValue(dst.Index(k), child, start+k); err != nil { + return fmt.Errorf("arreflect: fixed-size list element %d: %w", k, err) + } + } + return nil +} + +func setFixedSizeListValue(v reflect.Value, arr *array.FixedSizeList, i int) error { + n := int(arr.DataType().(*arrow.FixedSizeListType).Len()) + child := arr.ListValues() + start, _ := arr.ValueOffsets(i) + + switch v.Kind() { + case reflect.Array: + if v.Len() != n { + return fmt.Errorf("fixed-size list length %d does not match Go array length %d: %w", n, v.Len(), ErrTypeMismatch) + } + return fillFixedSizeList(v, child, int(start), n) + case reflect.Slice: + result := reflect.MakeSlice(v.Type(), n, n) + if err := fillFixedSizeList(result, child, int(start), n); err != nil { + return err + } + v.Set(result) + default: + return fmt.Errorf("cannot set fixed-size list into %s: %w", v.Type(), ErrTypeMismatch) + } + return nil +} + +func setDictionaryValue(v reflect.Value, arr *array.Dictionary, i int) error { + return setValue(v, arr.Dictionary(), arr.GetValueIndex(i)) +} + +func setRunEndEncodedValue(v reflect.Value, arr *array.RunEndEncoded, i int) error { + return setValue(v, arr.Values(), arr.GetPhysicalIndex(i)) +} diff --git a/arrow/array/arreflect/reflect_arrow_to_go_test.go b/arrow/array/arreflect/reflect_arrow_to_go_test.go new file mode 100644 index 00000000..1a17b33a --- /dev/null +++ b/arrow/array/arreflect/reflect_arrow_to_go_test.go @@ -0,0 +1,955 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect + +import ( + "reflect" + "testing" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/arrow-go/v18/arrow/decimal256" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setValueAt[T any](t *testing.T, arr arrow.Array, i int) T { + t.Helper() + var got T + setValueInto(t, &got, arr, i) + return got +} + +func TestSetValue(t *testing.T) { + mem := checkedMem(t) + + t.Run("bool", func(t *testing.T) { + b := array.NewBooleanBuilder(mem) + defer b.Release() + b.Append(true) + b.AppendNull() + arr := b.NewBooleanArray() + defer arr.Release() + + got := setValueAt[bool](t, arr, 0) + assert.True(t, got, "expected true, got false") + + got = true + setValueInto(t, &got, arr, 1) + assert.False(t, got, "expected false (null → zero), got true") + }) + + t.Run("string", func(t *testing.T) { + b := array.NewStringBuilder(mem) + defer b.Release() + b.Append("hello") + arr := b.NewStringArray() + defer arr.Release() + + got := setValueAt[string](t, arr, 0) + assert.Equal(t, "hello", got) + }) + + t.Run("binary", func(t *testing.T) { + b := array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary) + defer b.Release() + b.Append([]byte("data")) + arr := b.NewBinaryArray() + defer arr.Release() + + got := setValueAt[[]byte](t, arr, 0) + assert.Equal(t, "data", string(got)) + }) + + t.Run("unsupported type error", func(t *testing.T) { + b := array.NewBooleanBuilder(mem) + defer b.Release() + b.Append(true) + arr := b.NewBooleanArray() + defer arr.Release() + + var got int32 + err := setValue(reflect.ValueOf(&got).Elem(), arr, 0) + assert.Error(t, err, "expected error for bool→int32 mismatch") + }) + + t.Run("pointer allocation", func(t *testing.T) { + b := array.NewStringBuilder(mem) + defer b.Release() + b.Append("ptr") + b.AppendNull() + arr := b.NewStringArray() + defer arr.Release() + + got := setValueAt[*string](t, arr, 0) + if assert.NotNil(t, got) { + assert.Equal(t, "ptr", *got) + } + + got = new(string) + setValueInto(t, &got, arr, 1) + assert.Nil(t, got, "expected nil for null, got %v", got) + }) +} + +func TestSetPrimitiveValue(t *testing.T) { + mem := checkedMem(t) + + t.Run("int32", func(t *testing.T) { + b := array.NewInt32Builder(mem) + defer b.Release() + b.Append(42) + b.AppendNull() + arr := b.NewInt32Array() + defer arr.Release() + + got := setValueAt[int32](t, arr, 0) + assert.Equal(t, int32(42), got) + + got = 99 + setValueInto(t, &got, arr, 1) + assert.Equal(t, int32(0), got, "expected 0 for null, got %d", got) + }) + + t.Run("int64", func(t *testing.T) { + b := array.NewInt64Builder(mem) + defer b.Release() + b.Append(int64(1 << 40)) + arr := b.NewInt64Array() + defer arr.Release() + + var got int64 + require.NoError(t, setPrimitiveValue(reflect.ValueOf(&got).Elem(), arr, 0)) + assert.Equal(t, int64(1<<40), got) + }) + + t.Run("uint8", func(t *testing.T) { + b := array.NewUint8Builder(mem) + defer b.Release() + b.Append(255) + arr := b.NewUint8Array() + defer arr.Release() + + var got uint8 + require.NoError(t, setPrimitiveValue(reflect.ValueOf(&got).Elem(), arr, 0)) + assert.Equal(t, uint8(255), got) + }) + + t.Run("float64", func(t *testing.T) { + b := array.NewFloat64Builder(mem) + defer b.Release() + b.Append(3.14) + arr := b.NewFloat64Array() + defer arr.Release() + + var got float64 + require.NoError(t, setPrimitiveValue(reflect.ValueOf(&got).Elem(), arr, 0)) + assert.Equal(t, 3.14, got) + }) + + t.Run("type mismatch returns error", func(t *testing.T) { + b := array.NewInt32Builder(mem) + defer b.Release() + b.Append(10) + arr := b.NewInt32Array() + defer arr.Release() + + var got float64 + err := setPrimitiveValue(reflect.ValueOf(&got).Elem(), arr, 0) + assert.Error(t, err, "expected error for int32→float64 mismatch") + }) + + t.Run("int8", func(t *testing.T) { + b := array.NewInt8Builder(mem) + defer b.Release() + b.Append(-42) + arr := b.NewArray().(*array.Int8) + defer arr.Release() + + var got int8 + require.NoError(t, setPrimitiveValue(reflect.ValueOf(&got).Elem(), arr, 0)) + assert.Equal(t, int8(-42), got) + + var bad float32 + assert.ErrorIs(t, setPrimitiveValue(reflect.ValueOf(&bad).Elem(), arr, 0), ErrTypeMismatch) + }) + + t.Run("int16", func(t *testing.T) { + b := array.NewInt16Builder(mem) + defer b.Release() + b.Append(-1234) + arr := b.NewArray().(*array.Int16) + defer arr.Release() + + var got int16 + require.NoError(t, setPrimitiveValue(reflect.ValueOf(&got).Elem(), arr, 0)) + assert.Equal(t, int16(-1234), got) + + var bad float32 + assert.ErrorIs(t, setPrimitiveValue(reflect.ValueOf(&bad).Elem(), arr, 0), ErrTypeMismatch) + }) + + t.Run("int64 mismatch", func(t *testing.T) { + b := array.NewInt64Builder(mem) + defer b.Release() + b.Append(1) + arr := b.NewArray().(*array.Int64) + defer arr.Release() + + var bad string + assert.ErrorIs(t, setPrimitiveValue(reflect.ValueOf(&bad).Elem(), arr, 0), ErrTypeMismatch) + }) + + t.Run("uint16", func(t *testing.T) { + b := array.NewUint16Builder(mem) + defer b.Release() + b.Append(65535) + arr := b.NewArray().(*array.Uint16) + defer arr.Release() + + var got uint16 + require.NoError(t, setPrimitiveValue(reflect.ValueOf(&got).Elem(), arr, 0)) + assert.Equal(t, uint16(65535), got) + + var bad int32 + assert.ErrorIs(t, setPrimitiveValue(reflect.ValueOf(&bad).Elem(), arr, 0), ErrTypeMismatch) + }) + + t.Run("uint32", func(t *testing.T) { + b := array.NewUint32Builder(mem) + defer b.Release() + b.Append(4_000_000_000) + arr := b.NewArray().(*array.Uint32) + defer arr.Release() + + var got uint32 + require.NoError(t, setPrimitiveValue(reflect.ValueOf(&got).Elem(), arr, 0)) + assert.Equal(t, uint32(4_000_000_000), got) + + var bad int32 + assert.ErrorIs(t, setPrimitiveValue(reflect.ValueOf(&bad).Elem(), arr, 0), ErrTypeMismatch) + }) + + t.Run("uint64", func(t *testing.T) { + b := array.NewUint64Builder(mem) + defer b.Release() + b.Append(1 << 63) + arr := b.NewArray().(*array.Uint64) + defer arr.Release() + + var got uint64 + require.NoError(t, setPrimitiveValue(reflect.ValueOf(&got).Elem(), arr, 0)) + assert.Equal(t, uint64(1<<63), got) + + var bad float64 + assert.ErrorIs(t, setPrimitiveValue(reflect.ValueOf(&bad).Elem(), arr, 0), ErrTypeMismatch) + }) + + t.Run("uint8 mismatch", func(t *testing.T) { + b := array.NewUint8Builder(mem) + defer b.Release() + b.Append(1) + arr := b.NewArray().(*array.Uint8) + defer arr.Release() + + var bad int8 + assert.ErrorIs(t, setPrimitiveValue(reflect.ValueOf(&bad).Elem(), arr, 0), ErrTypeMismatch) + }) + + t.Run("float32", func(t *testing.T) { + b := array.NewFloat32Builder(mem) + defer b.Release() + b.Append(2.5) + arr := b.NewArray().(*array.Float32) + defer arr.Release() + + var got float32 + require.NoError(t, setPrimitiveValue(reflect.ValueOf(&got).Elem(), arr, 0)) + assert.Equal(t, float32(2.5), got) + + var bad int32 + assert.ErrorIs(t, setPrimitiveValue(reflect.ValueOf(&bad).Elem(), arr, 0), ErrTypeMismatch) + }) + + t.Run("float64 mismatch", func(t *testing.T) { + b := array.NewFloat64Builder(mem) + defer b.Release() + b.Append(1.0) + arr := b.NewArray().(*array.Float64) + defer arr.Release() + + var bad int32 + assert.ErrorIs(t, setPrimitiveValue(reflect.ValueOf(&bad).Elem(), arr, 0), ErrTypeMismatch) + }) + + t.Run("unsupported primitive type returns error", func(t *testing.T) { + b := array.NewBooleanBuilder(mem) + defer b.Release() + b.Append(true) + arr := b.NewArray().(*array.Boolean) + defer arr.Release() + + var got bool + err := setPrimitiveValue(reflect.ValueOf(&got).Elem(), arr, 0) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) +} + +func TestSetTemporalValue(t *testing.T) { + mem := checkedMem(t) + + t.Run("timestamp", func(t *testing.T) { + dt := &arrow.TimestampType{Unit: arrow.Second} + b := array.NewTimestampBuilder(mem, dt) + defer b.Release() + now := time.Unix(1700000000, 0).UTC() + b.Append(arrow.Timestamp(now.Unix())) + arr := b.NewArray().(*array.Timestamp) + defer arr.Release() + + got := setValueAt[time.Time](t, arr, 0) + assert.True(t, got.Equal(now), "expected %v, got %v", now, got) + }) + + t.Run("date32", func(t *testing.T) { + b := array.NewDate32Builder(mem) + defer b.Release() + b.Append(arrow.Date32(19000)) + arr := b.NewArray().(*array.Date32) + defer arr.Release() + + got := setValueAt[time.Time](t, arr, 0) + expected := arrow.Date32(19000).ToTime() + assert.True(t, got.Equal(expected), "expected %v, got %v", expected, got) + }) + + t.Run("duration", func(t *testing.T) { + dt := &arrow.DurationType{Unit: arrow.Second} + b := array.NewDurationBuilder(mem, dt) + defer b.Release() + b.Append(arrow.Duration(5)) + arr := b.NewArray().(*array.Duration) + defer arr.Release() + + got := setValueAt[time.Duration](t, arr, 0) + expected := 5 * time.Second + assert.Equal(t, expected, got) + }) + + t.Run("null temporal", func(t *testing.T) { + dt := &arrow.TimestampType{Unit: arrow.Second} + b := array.NewTimestampBuilder(mem, dt) + defer b.Release() + b.AppendNull() + arr := b.NewArray().(*array.Timestamp) + defer arr.Release() + + got := setValueAt[*time.Time](t, arr, 0) + assert.Nil(t, got, "expected nil for null timestamp pointer") + }) + + t.Run("time32", func(t *testing.T) { + dt := &arrow.Time32Type{Unit: arrow.Millisecond} + b := array.NewTime32Builder(mem, dt) + defer b.Release() + // 10h30m0s500ms = (10*3600 + 30*60)*1000 + 500 = 37800500 ms + b.Append(arrow.Time32(37800500)) + arr := b.NewArray() + defer arr.Release() + + var got time.Time + v := reflect.ValueOf(&got).Elem() + require.NoError(t, setValue(v, arr, 0)) + assert.True(t, got.Hour() == 10 && got.Minute() == 30 && got.Second() == 0 && got.Nanosecond()/1_000_000 == 500, + "time32: got %v, want 10:30:00.500", got) + }) + + t.Run("time64", func(t *testing.T) { + dt := &arrow.Time64Type{Unit: arrow.Nanosecond} + b := array.NewTime64Builder(mem, dt) + defer b.Release() + // 10h30m0s123456789ns + nanos := int64(10*3600+30*60)*1_000_000_000 + 123456789 + b.Append(arrow.Time64(nanos)) + arr := b.NewArray() + defer arr.Release() + + var got time.Time + v := reflect.ValueOf(&got).Elem() + require.NoError(t, setValue(v, arr, 0)) + assert.True(t, got.Hour() == 10 && got.Minute() == 30 && got.Second() == 0 && got.Nanosecond() == 123456789, + "time64: got %v, want 10:30:00.123456789", got) + }) + + t.Run("date64", func(t *testing.T) { + b := array.NewDate64Builder(mem) + defer b.Release() + ms := int64(1705276800000) + b.Append(arrow.Date64(ms)) + arr := b.NewArray().(*array.Date64) + defer arr.Release() + + got := setValueAt[time.Time](t, arr, 0) + expected := arrow.Date64(ms).ToTime() + assert.True(t, got.Equal(expected), "date64: expected %v, got %v", expected, got) + }) + + t.Run("type mismatch into non-time returns error", func(t *testing.T) { + b := array.NewTimestampBuilder(mem, &arrow.TimestampType{Unit: arrow.Second}) + defer b.Release() + b.Append(arrow.Timestamp(0)) + arr := b.NewArray().(*array.Timestamp) + defer arr.Release() + + var bad int64 + err := setTemporalValue(reflect.ValueOf(&bad).Elem(), arr, 0) + assert.ErrorIs(t, err, ErrTypeMismatch) + }) + + t.Run("duration into non-duration returns error", func(t *testing.T) { + dt := &arrow.DurationType{Unit: arrow.Second} + b := array.NewDurationBuilder(mem, dt) + defer b.Release() + b.Append(arrow.Duration(1)) + arr := b.NewArray().(*array.Duration) + defer arr.Release() + + var bad int64 + err := setTemporalValue(reflect.ValueOf(&bad).Elem(), arr, 0) + assert.ErrorIs(t, err, ErrTypeMismatch) + }) + + t.Run("unsupported temporal type returns error", func(t *testing.T) { + b := array.NewInt32Builder(mem) + defer b.Release() + b.Append(1) + arr := b.NewArray().(*array.Int32) + defer arr.Release() + + var got time.Time + err := setTemporalValue(reflect.ValueOf(&got).Elem(), arr, 0) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) +} + +func TestSetDecimalValue(t *testing.T) { + mem := checkedMem(t) + + t.Run("decimal128", func(t *testing.T) { + dt := &arrow.Decimal128Type{Precision: 10, Scale: 2} + b := array.NewDecimal128Builder(mem, dt) + defer b.Release() + num := decimal128.New(0, 12345) + b.Append(num) + b.AppendNull() + arr := b.NewDecimal128Array() + defer arr.Release() + + got := setValueAt[decimal128.Num](t, arr, 0) + assert.Equal(t, num, got) + + gotPtr := setValueAt[*decimal128.Num](t, arr, 1) + assert.Nil(t, gotPtr, "expected nil for null decimal128") + }) + + t.Run("decimal256", func(t *testing.T) { + dt := &arrow.Decimal256Type{Precision: 20, Scale: 4} + b := array.NewDecimal256Builder(mem, dt) + defer b.Release() + num := decimal256.New(0, 0, 0, 9876) + b.Append(num) + arr := b.NewDecimal256Array() + defer arr.Release() + + got := setValueAt[decimal256.Num](t, arr, 0) + assert.Equal(t, num, got) + }) + + t.Run("decimal32", func(t *testing.T) { + dt := &arrow.Decimal32Type{Precision: 9, Scale: 2} + b := array.NewDecimal32Builder(mem, dt) + defer b.Release() + num := decimal.Decimal32(12345) + b.Append(num) + b.AppendNull() + arr := b.NewArray().(*array.Decimal32) + defer arr.Release() + + got := setValueAt[decimal.Decimal32](t, arr, 0) + assert.Equal(t, num, got) + + gotPtr := setValueAt[*decimal.Decimal32](t, arr, 1) + assert.Nil(t, gotPtr, "expected nil for null decimal32") + }) + + t.Run("decimal64", func(t *testing.T) { + dt := &arrow.Decimal64Type{Precision: 18, Scale: 3} + b := array.NewDecimal64Builder(mem, dt) + defer b.Release() + num := decimal.Decimal64(987654321) + b.Append(num) + arr := b.NewArray().(*array.Decimal64) + defer arr.Release() + + got := setValueAt[decimal.Decimal64](t, arr, 0) + assert.Equal(t, num, got) + }) + + t.Run("type mismatch into wrong decimal kind returns error", func(t *testing.T) { + b128 := array.NewDecimal128Builder(mem, &arrow.Decimal128Type{Precision: 10, Scale: 2}) + defer b128.Release() + b128.Append(decimal128.New(0, 1)) + arr128 := b128.NewDecimal128Array() + defer arr128.Release() + + var got256 decimal256.Num + assert.ErrorIs(t, setDecimalValue(reflect.ValueOf(&got256).Elem(), arr128, 0), ErrTypeMismatch) + + var got32 decimal.Decimal32 + assert.ErrorIs(t, setDecimalValue(reflect.ValueOf(&got32).Elem(), arr128, 0), ErrTypeMismatch) + + b256 := array.NewDecimal256Builder(mem, &arrow.Decimal256Type{Precision: 20, Scale: 4}) + defer b256.Release() + b256.Append(decimal256.New(0, 0, 0, 1)) + arr256 := b256.NewDecimal256Array() + defer arr256.Release() + + var got128 decimal128.Num + assert.ErrorIs(t, setDecimalValue(reflect.ValueOf(&got128).Elem(), arr256, 0), ErrTypeMismatch) + + b32 := array.NewDecimal32Builder(mem, &arrow.Decimal32Type{Precision: 9, Scale: 2}) + defer b32.Release() + b32.Append(decimal.Decimal32(1)) + arr32 := b32.NewArray().(*array.Decimal32) + defer arr32.Release() + + var got64 decimal.Decimal64 + assert.ErrorIs(t, setDecimalValue(reflect.ValueOf(&got64).Elem(), arr32, 0), ErrTypeMismatch) + + b64 := array.NewDecimal64Builder(mem, &arrow.Decimal64Type{Precision: 18, Scale: 3}) + defer b64.Release() + b64.Append(decimal.Decimal64(1)) + arr64 := b64.NewArray().(*array.Decimal64) + defer arr64.Release() + + var badF float64 + assert.ErrorIs(t, setDecimalValue(reflect.ValueOf(&badF).Elem(), arr64, 0), ErrTypeMismatch) + }) + + t.Run("unsupported decimal type returns error", func(t *testing.T) { + b := array.NewInt32Builder(mem) + defer b.Release() + b.Append(1) + arr := b.NewArray().(*array.Int32) + defer arr.Release() + + var got decimal128.Num + err := setDecimalValue(reflect.ValueOf(&got).Elem(), arr, 0) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) +} + +func TestAssertArrayTypeMismatch(t *testing.T) { + mem := checkedMem(t) + b := array.NewInt32Builder(mem) + defer b.Release() + b.Append(1) + arr := b.NewInt32Array() + defer arr.Release() + + _, err := assertArray[array.Float64](arr) + require.Error(t, err) + assert.ErrorIs(t, err, ErrTypeMismatch) +} + +func TestSetStructValue(t *testing.T) { + mem := checkedMem(t) + + t.Run("basic struct", func(t *testing.T) { + nameArr := makeStringArray(t, mem, "Alice", "Bob") + ageArr := makeInt32Array(t, mem, 30, 25) + sa := makeStructArray(t, []arrow.Array{nameArr, ageArr}, []string{"Name", "Age"}) + + type Person struct { + Name string + Age int32 + } + + var got Person + setValueInto(t, &got, sa, 0) + assert.Equal(t, "Alice", got.Name) + assert.Equal(t, int32(30), got.Age) + + setValueInto(t, &got, sa, 1) + assert.Equal(t, "Bob", got.Name) + assert.Equal(t, int32(25), got.Age) + }) + + t.Run("arrow tag mapping", func(t *testing.T) { + nameArr := makeStringArray(t, mem, "Charlie") + sa := makeStructArray(t, []arrow.Array{nameArr}, []string{"full_name"}) + + type TaggedPerson struct { + FullName string `arrow:"full_name"` + } + + var got TaggedPerson + setValueInto(t, &got, sa, 0) + assert.Equal(t, "Charlie", got.FullName) + }) + + t.Run("missing arrow field leaves go field zero", func(t *testing.T) { + nameArr := makeStringArray(t, mem, "Dave") + sa := makeStructArray(t, []arrow.Array{nameArr}, []string{"Name"}) + + type PersonWithExtra struct { + Name string + Email string + } + + var got PersonWithExtra + setValueInto(t, &got, sa, 0) + assert.Equal(t, "Dave", got.Name) + assert.Equal(t, "", got.Email) + }) + + t.Run("nil embedded pointer leaves promoted fields zero", func(t *testing.T) { + // Regression: reflect.Value.FieldByIndex panics on nil embedded pointer; + // the walker must stop and leave promoted fields at their zero value. + nameArr := makeStringArray(t, mem, "Alice") + cityArr := makeStringArray(t, mem, "NYC") + sa := makeStructArray(t, []arrow.Array{nameArr, cityArr}, []string{"Name", "City"}) + + type Inner struct { + City string + } + type Outer struct { + Name string + *Inner + } + + var got Outer + setValueInto(t, &got, sa, 0) + assert.Equal(t, "Alice", got.Name) + assert.Nil(t, got.Inner, "nil embedded pointer should remain nil; promoted City left at zero value") + }) +} + +func TestSetValueClonesStringAndBytes(t *testing.T) { + // Regression: String.Value / Binary.Value return views into the array's + // backing buffer. setValue must copy so Go values outlive the array. + mem := checkedMem(t) + + t.Run("string", func(t *testing.T) { + sb := array.NewStringBuilder(mem) + sb.Append("hello world") + arr := sb.NewStringArray() + sb.Release() + + var got string + setValueInto(t, &got, arr, 0) + assert.Equal(t, "hello world", got) + arr.Release() + assert.Equal(t, "hello world", got, "string must survive Arrow array release") + }) + + t.Run("bytes", func(t *testing.T) { + bb := array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary) + bb.Append([]byte{0x01, 0x02, 0x03, 0x04}) + arr := bb.NewBinaryArray() + bb.Release() + + var got []byte + setValueInto(t, &got, arr, 0) + assert.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, got) + arr.Release() + assert.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, got, "[]byte must survive Arrow array release") + }) +} + +func TestSetListValue(t *testing.T) { + mem := checkedMem(t) + + t.Run("list of int32", func(t *testing.T) { + lb := array.NewListBuilder(mem, arrow.PrimitiveTypes.Int32) + defer lb.Release() + + vb := lb.ValueBuilder().(*array.Int32Builder) + lb.Append(true) + vb.AppendValues([]int32{1, 2, 3}, nil) + lb.Append(true) + vb.AppendValues([]int32{4, 5}, nil) + lb.AppendNull() + + arr := lb.NewListArray() + defer arr.Release() + + got := setValueAt[[]int32](t, arr, 0) + assert.Equal(t, []int32{1, 2, 3}, got) + + setValueInto(t, &got, arr, 1) + assert.Equal(t, []int32{4, 5}, got) + + setValueInto(t, &got, arr, 2) + assert.Nil(t, got, "expected nil slice for null list, got %v", got) + }) + + t.Run("nested list of lists", func(t *testing.T) { + inner := array.NewListBuilder(mem, arrow.PrimitiveTypes.Int32) + defer inner.Release() + outer := array.NewListBuilder(mem, arrow.ListOf(arrow.PrimitiveTypes.Int32)) + defer outer.Release() + + innerVB := inner.ValueBuilder().(*array.Int32Builder) + + inner.Append(true) + innerVB.AppendValues([]int32{1, 2}, nil) + inner.Append(true) + innerVB.AppendValues([]int32{3}, nil) + innerArr := inner.NewListArray() + defer innerArr.Release() + + outerVB := outer.ValueBuilder().(*array.ListBuilder) + outerInnerVB := outerVB.ValueBuilder().(*array.Int32Builder) + outer.Append(true) + outerVB.Append(true) + outerInnerVB.AppendValues([]int32{10, 20}, nil) + outerVB.Append(true) + outerInnerVB.AppendValues([]int32{30}, nil) + + outerArr := outer.NewListArray() + defer outerArr.Release() + + var got [][]int32 + setValueInto(t, &got, outerArr, 0) + require.Len(t, got, 2, "expected 2 inner slices, got %d", len(got)) + assert.Equal(t, []int32{10, 20}, got[0]) + assert.Equal(t, []int32{30}, got[1]) + }) + + t.Run("large list view of int32", func(t *testing.T) { + lvb := array.NewLargeListViewBuilder(mem, arrow.PrimitiveTypes.Int32) + defer lvb.Release() + vb := lvb.ValueBuilder().(*array.Int32Builder) + + lvb.AppendWithSize(true, 2) + vb.AppendValues([]int32{1, 2}, nil) + lvb.AppendWithSize(true, 1) + vb.AppendValues([]int32{3}, nil) + + arr := lvb.NewLargeListViewArray() + defer arr.Release() + + got := setValueAt[[]int32](t, arr, 0) + assert.Equal(t, []int32{1, 2}, got, "row 0: expected [1,2], got %v", got) + + setValueInto(t, &got, arr, 1) + assert.Equal(t, []int32{3}, got, "row 1: expected [3], got %v", got) + }) +} + +func TestSetMapValue(t *testing.T) { + mem := checkedMem(t) + + t.Run("map string to int32", func(t *testing.T) { + mb := array.NewMapBuilder(mem, arrow.BinaryTypes.String, arrow.PrimitiveTypes.Int32, false) + defer mb.Release() + + kb := mb.KeyBuilder().(*array.StringBuilder) + ib := mb.ItemBuilder().(*array.Int32Builder) + + mb.Append(true) + kb.Append("a") + ib.Append(1) + kb.Append("b") + ib.Append(2) + + mb.Append(true) + kb.Append("x") + ib.Append(10) + + mb.AppendNull() + + arr := mb.NewMapArray() + defer arr.Release() + + got := setValueAt[map[string]int32](t, arr, 0) + assert.Equal(t, int32(1), got["a"]) + assert.Equal(t, int32(2), got["b"]) + + setValueInto(t, &got, arr, 1) + assert.Equal(t, int32(10), got["x"]) + + setValueInto(t, &got, arr, 2) + assert.Nil(t, got, "expected nil map for null, got %v", got) + }) +} + +func TestSetFixedSizeListValue(t *testing.T) { + mem := checkedMem(t) + + t.Run("go array", func(t *testing.T) { + b := array.NewFixedSizeListBuilder(mem, 3, arrow.PrimitiveTypes.Int32) + defer b.Release() + vb := b.ValueBuilder().(*array.Int32Builder) + + b.Append(true) + vb.AppendValues([]int32{10, 20, 30}, nil) + b.Append(true) + vb.AppendValues([]int32{40, 50, 60}, nil) + b.AppendNull() + + arr := b.NewArray().(*array.FixedSizeList) + defer arr.Release() + + got := setValueAt[[3]int32](t, arr, 0) + assert.Equal(t, [3]int32{10, 20, 30}, got) + + setValueInto(t, &got, arr, 1) + assert.Equal(t, [3]int32{40, 50, 60}, got) + + got = [3]int32{1, 2, 3} + setValueInto(t, &got, arr, 2) + assert.Equal(t, [3]int32{}, got, "expected zero array for null, got %v", got) + }) + + t.Run("go slice", func(t *testing.T) { + b := array.NewFixedSizeListBuilder(mem, 2, arrow.PrimitiveTypes.Int32) + defer b.Release() + vb := b.ValueBuilder().(*array.Int32Builder) + + b.Append(true) + vb.AppendValues([]int32{7, 8}, nil) + + arr := b.NewArray().(*array.FixedSizeList) + defer arr.Release() + + got := setValueAt[[]int32](t, arr, 0) + assert.Equal(t, []int32{7, 8}, got) + }) + + t.Run("size mismatch returns error", func(t *testing.T) { + b := array.NewFixedSizeListBuilder(mem, 3, arrow.PrimitiveTypes.Int32) + defer b.Release() + vb := b.ValueBuilder().(*array.Int32Builder) + b.Append(true) + vb.AppendValues([]int32{1, 2, 3}, nil) + + arr := b.NewArray().(*array.FixedSizeList) + defer arr.Release() + + var got [2]int32 + err := setValue(reflect.ValueOf(&got).Elem(), arr, 0) + assert.Error(t, err, "expected error for size mismatch") + }) + + t.Run("child element type mismatch errors", func(t *testing.T) { + b := array.NewFixedSizeListBuilder(mem, 3, arrow.PrimitiveTypes.Int32) + defer b.Release() + vb := b.ValueBuilder().(*array.Int32Builder) + b.Append(true) + vb.AppendValues([]int32{1, 2, 3}, nil) + + arr := b.NewArray().(*array.FixedSizeList) + defer arr.Release() + + var got [3]string + err := setValue(reflect.ValueOf(&got).Elem(), arr, 0) + require.Error(t, err) + assert.ErrorIs(t, err, ErrTypeMismatch) + }) + + t.Run("child element type mismatch on slice errors", func(t *testing.T) { + b := array.NewFixedSizeListBuilder(mem, 2, arrow.PrimitiveTypes.Int32) + defer b.Release() + vb := b.ValueBuilder().(*array.Int32Builder) + b.Append(true) + vb.AppendValues([]int32{10, 20}, nil) + + arr := b.NewArray().(*array.FixedSizeList) + defer arr.Release() + + var got []string + err := setValue(reflect.ValueOf(&got).Elem(), arr, 0) + require.Error(t, err) + assert.ErrorIs(t, err, ErrTypeMismatch) + }) +} + +func TestSetDictionaryValue(t *testing.T) { + mem := checkedMem(t) + + t.Run("dictionary int8 to string", func(t *testing.T) { + dt := &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int8, ValueType: arrow.BinaryTypes.String} + bldr := array.NewDictionaryBuilder(mem, dt) + defer bldr.Release() + db := bldr.(*array.BinaryDictionaryBuilder) + + db.AppendString("foo") + db.AppendString("bar") + db.AppendString("foo") + db.AppendNull() + + arr := bldr.NewDictionaryArray() + defer arr.Release() + + got := setValueAt[string](t, arr, 0) + assert.Equal(t, "foo", got) + + setValueInto(t, &got, arr, 1) + assert.Equal(t, "bar", got) + + setValueInto(t, &got, arr, 2) + assert.Equal(t, "foo", got) + + gotPtr := setValueAt[*string](t, arr, 3) + assert.Nil(t, gotPtr, "expected nil for null dictionary entry") + }) +} + +func TestSetRunEndEncodedValue(t *testing.T) { + mem := checkedMem(t) + + t.Run("ree int32 to string", func(t *testing.T) { + b := array.NewRunEndEncodedBuilder(mem, arrow.PrimitiveTypes.Int32, arrow.BinaryTypes.String) + defer b.Release() + vb := b.ValueBuilder().(*array.StringBuilder) + + b.Append(3) + vb.Append("aaa") + b.Append(2) + vb.Append("bbb") + + arr := b.NewRunEndEncodedArray() + defer arr.Release() + + got := setValueAt[string](t, arr, 0) + assert.Equal(t, "aaa", got, "expected aaa at logical 0, got %q", got) + + setValueInto(t, &got, arr, 2) + assert.Equal(t, "aaa", got, "expected aaa at logical 2, got %q", got) + + setValueInto(t, &got, arr, 3) + assert.Equal(t, "bbb", got, "expected bbb at logical 3, got %q", got) + + setValueInto(t, &got, arr, 4) + assert.Equal(t, "bbb", got, "expected bbb at logical 4, got %q", got) + }) +} diff --git a/arrow/array/arreflect/reflect_go_to_arrow.go b/arrow/array/arreflect/reflect_go_to_arrow.go new file mode 100644 index 00000000..e2acfabc --- /dev/null +++ b/arrow/array/arreflect/reflect_go_to_arrow.go @@ -0,0 +1,846 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect + +import ( + "fmt" + "reflect" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/arrow-go/v18/arrow/decimal256" + "github.com/apache/arrow-go/v18/arrow/memory" +) + +func buildArray(vals reflect.Value, opts tagOpts, mem memory.Allocator) (arrow.Array, error) { + if vals.Kind() != reflect.Slice { + return nil, fmt.Errorf("arreflect: expected slice, got %v", vals.Kind()) + } + + elemType := vals.Type().Elem() + for elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + + if opts.Large || opts.View { + dt, err := inferArrowType(elemType) + if err != nil { + return nil, err + } + if opts.Large && !hasLargeableType(dt) { + return nil, fmt.Errorf("arreflect: large option has no effect on type %s: %w", dt, ErrUnsupportedType) + } + if opts.View && !hasViewableType(dt) { + return nil, fmt.Errorf("arreflect: view option has no effect on type %s: %w", dt, ErrUnsupportedType) + } + } + + if opts.Dict { + return buildDictionaryArray(vals, opts, mem) + } + if opts.REE { + return buildRunEndEncodedArray(vals, opts, mem) + } + if opts.View { + if elemType.Kind() != reflect.Slice || elemType == typeOfByteSlice { + return buildPrimitiveArray(vals, opts, mem) + } + return buildListViewArray(vals, opts, mem) + } + + switch elemType { + case typeOfDec32, typeOfDec64, typeOfDec128, typeOfDec256: + return buildDecimalArray(vals, opts, mem) + } + + switch elemType.Kind() { + case reflect.Slice: + if elemType == typeOfByteSlice { + return buildPrimitiveArray(vals, opts, mem) + } + return buildListArray(vals, opts, mem) + + case reflect.Array: + return buildFixedSizeListArray(vals, opts, mem) + + case reflect.Map: + return buildMapArray(vals, opts, mem) + + case reflect.Struct: + switch elemType { + case typeOfTime: + return buildTemporalArray(vals, opts, mem) + default: + return buildStructArray(vals, opts, mem) + } + + default: + return buildPrimitiveArray(vals, opts, mem) + } +} + +func buildPrimitiveArray(vals reflect.Value, opts tagOpts, mem memory.Allocator) (arrow.Array, error) { + elemType, isPtr := derefSliceElem(vals) + + dt, err := inferArrowType(elemType) + if err != nil { + return nil, err + } + if opts.Large { + dt = applyLargeOpts(dt) + } + if opts.View { + dt = applyViewOpts(dt) + } + + b := array.NewBuilder(mem, dt) + defer b.Release() + b.Reserve(vals.Len()) + + if err := iterSlice(vals, isPtr, b.AppendNull, func(v reflect.Value) error { + return appendValue(b, v) + }); err != nil { + return nil, err + } + return b.NewArray(), nil +} + +func timeOfDayNanos(t time.Time) int64 { + t = t.UTC() + midnight := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC) + return t.Sub(midnight).Nanoseconds() +} + +func asTime(v reflect.Value) (time.Time, error) { + t, ok := reflect.TypeAssert[time.Time](v) + if !ok { + return time.Time{}, fmt.Errorf("expected time.Time, got %s: %w", v.Type(), ErrTypeMismatch) + } + return t, nil +} + +func asDuration(v reflect.Value) (time.Duration, error) { + d, ok := reflect.TypeAssert[time.Duration](v) + if !ok { + return 0, fmt.Errorf("expected time.Duration, got %s: %w", v.Type(), ErrTypeMismatch) + } + return d, nil +} + +func derefSliceElem(vals reflect.Value) (elemType reflect.Type, isPtr bool) { + elemType = vals.Type().Elem() + isPtr = elemType.Kind() == reflect.Ptr + for elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + return +} + +func iterSlice(vals reflect.Value, isPtr bool, appendNull func(), appendVal func(reflect.Value) error) error { + for i := 0; i < vals.Len(); i++ { + v := vals.Index(i) + if isPtr { + wasNull := false + for v.Kind() == reflect.Ptr { + if v.IsNil() { + appendNull() + wasNull = true + break + } + v = v.Elem() + } + if wasNull { + continue + } + } + if err := appendVal(v); err != nil { + return err + } + } + return nil +} + +func inferListElemDT(vals reflect.Value) (elemDT arrow.DataType, err error) { + outerSliceType, _ := derefSliceElem(vals) + innerElemType := outerSliceType.Elem() + for innerElemType.Kind() == reflect.Ptr { + innerElemType = innerElemType.Elem() + } + elemDT, err = inferArrowType(innerElemType) + return +} + +func temporalBuilder(opts tagOpts, mem memory.Allocator) array.Builder { + switch opts.Temporal { + case "date32": + return array.NewDate32Builder(mem) + case "date64": + return array.NewDate64Builder(mem) + case "time32": + return array.NewTime32Builder(mem, &arrow.Time32Type{Unit: arrow.Millisecond}) + case "time64": + return array.NewTime64Builder(mem, &arrow.Time64Type{Unit: arrow.Nanosecond}) + default: + return array.NewTimestampBuilder(mem, &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: "UTC"}) + } +} + +func buildTemporalArray(vals reflect.Value, opts tagOpts, mem memory.Allocator) (arrow.Array, error) { + elemType, isPtr := derefSliceElem(vals) + if elemType != typeOfTime { + return nil, fmt.Errorf("unsupported temporal type %v: %w", elemType, ErrUnsupportedType) + } + b := temporalBuilder(opts, mem) + defer b.Release() + b.Reserve(vals.Len()) + if err := iterSlice(vals, isPtr, b.AppendNull, func(v reflect.Value) error { + return appendTemporalValue(b, v) + }); err != nil { + return nil, err + } + return b.NewArray(), nil +} + +func decimalPrecisionScale(opts tagOpts, defaultPrec int32) (precision, scale int32) { + if opts.HasDecimalOpts { + return opts.DecimalPrecision, opts.DecimalScale + } + return defaultPrec, 0 +} + +func buildDecimalArray(vals reflect.Value, opts tagOpts, mem memory.Allocator) (arrow.Array, error) { + elemType, isPtr := derefSliceElem(vals) + + var b array.Builder + switch elemType { + case typeOfDec128: + p, s := decimalPrecisionScale(opts, dec128DefaultPrecision) + b = array.NewDecimal128Builder(mem, &arrow.Decimal128Type{Precision: p, Scale: s}) + case typeOfDec256: + p, s := decimalPrecisionScale(opts, dec256DefaultPrecision) + b = array.NewDecimal256Builder(mem, &arrow.Decimal256Type{Precision: p, Scale: s}) + case typeOfDec32: + p, s := decimalPrecisionScale(opts, dec32DefaultPrecision) + b = array.NewDecimal32Builder(mem, &arrow.Decimal32Type{Precision: p, Scale: s}) + case typeOfDec64: + p, s := decimalPrecisionScale(opts, dec64DefaultPrecision) + b = array.NewDecimal64Builder(mem, &arrow.Decimal64Type{Precision: p, Scale: s}) + default: + return nil, fmt.Errorf("unsupported decimal type %v: %w", elemType, ErrUnsupportedType) + } + defer b.Release() + b.Reserve(vals.Len()) + if err := iterSlice(vals, isPtr, b.AppendNull, func(v reflect.Value) error { + return appendDecimalValue(b, v) + }); err != nil { + return nil, err + } + return b.NewArray(), nil +} + +func appendStructFields(sb *array.StructBuilder, v reflect.Value, fields []fieldMeta) error { + sb.Append(true) + for fi, fm := range fields { + fv, ok := fieldByIndexSafe(v, fm.Index) + if !ok { + sb.FieldBuilder(fi).AppendNull() + continue + } + if err := appendValue(sb.FieldBuilder(fi), fv); err != nil { + return fmt.Errorf("struct field %q: %w", fm.Name, err) + } + } + return nil +} + +func buildStructArray(vals reflect.Value, opts tagOpts, mem memory.Allocator) (arrow.Array, error) { + elemType, isPtr := derefSliceElem(vals) + + st, err := inferStructType(elemType) + if err != nil { + return nil, err + } + if opts.Large { + // applyLargeOpts is idempotent, so per-field "large" tags already applied + // by inferStructType are safe to walk again here. + st = applyLargeOpts(st).(*arrow.StructType) + } + if opts.View { + st = applyViewOpts(st).(*arrow.StructType) + } + + fields := cachedStructFields(elemType) + sb := array.NewStructBuilder(mem, st) + defer sb.Release() + sb.Reserve(vals.Len()) + + if err := iterSlice(vals, isPtr, sb.AppendNull, func(v reflect.Value) error { + return appendStructFields(sb, v, fields) + }); err != nil { + return nil, err + } + + return sb.NewArray(), nil +} + +func appendTemporalValue(b array.Builder, v reflect.Value) error { + switch tb := b.(type) { + case *array.TimestampBuilder: + unit := tb.Type().(*arrow.TimestampType).Unit + t, err := asTime(v) + if err != nil { + return err + } + tb.Append(arrow.Timestamp(t.UnixNano() / int64(unit.Multiplier()))) + case *array.Date32Builder: + t, err := asTime(v) + if err != nil { + return err + } + tb.Append(arrow.Date32FromTime(t)) + case *array.Date64Builder: + t, err := asTime(v) + if err != nil { + return err + } + tb.Append(arrow.Date64FromTime(t)) + case *array.Time32Builder: + unit := tb.Type().(*arrow.Time32Type).Unit + t, err := asTime(v) + if err != nil { + return err + } + tb.Append(arrow.Time32(timeOfDayNanos(t) / int64(unit.Multiplier()))) + case *array.Time64Builder: + unit := tb.Type().(*arrow.Time64Type).Unit + t, err := asTime(v) + if err != nil { + return err + } + tb.Append(arrow.Time64(timeOfDayNanos(t) / int64(unit.Multiplier()))) + case *array.DurationBuilder: + unit := tb.Type().(*arrow.DurationType).Unit + d, err := asDuration(v) + if err != nil { + return err + } + tb.Append(arrow.Duration(d.Nanoseconds() / int64(unit.Multiplier()))) + default: + return fmt.Errorf("unexpected temporal builder %T: %w", b, ErrUnsupportedType) + } + return nil +} + +func appendDecimalValue(b array.Builder, v reflect.Value) error { + switch tb := b.(type) { + case *array.Decimal128Builder: + n, ok := reflect.TypeAssert[decimal128.Num](v) + if !ok { + return fmt.Errorf("expected decimal128.Num, got %s: %w", v.Type(), ErrTypeMismatch) + } + tb.Append(n) + case *array.Decimal256Builder: + n, ok := reflect.TypeAssert[decimal256.Num](v) + if !ok { + return fmt.Errorf("expected decimal256.Num, got %s: %w", v.Type(), ErrTypeMismatch) + } + tb.Append(n) + case *array.Decimal32Builder: + tb.Append(decimal.Decimal32(v.Int())) + case *array.Decimal64Builder: + tb.Append(decimal.Decimal64(v.Int())) + default: + return fmt.Errorf("unexpected decimal builder %T: %w", b, ErrUnsupportedType) + } + return nil +} + +func appendValue(b array.Builder, v reflect.Value) error { + for v.Kind() == reflect.Ptr { + if v.IsNil() { + b.AppendNull() + return nil + } + v = v.Elem() + } + + switch tb := b.(type) { + case *array.Int8Builder: + tb.Append(int8(v.Int())) + case *array.Int16Builder: + tb.Append(int16(v.Int())) + case *array.Int32Builder: + tb.Append(int32(v.Int())) + case *array.Int64Builder: + tb.Append(int64(v.Int())) + case *array.Uint8Builder: + tb.Append(uint8(v.Uint())) + case *array.Uint16Builder: + tb.Append(uint16(v.Uint())) + case *array.Uint32Builder: + tb.Append(uint32(v.Uint())) + case *array.Uint64Builder: + tb.Append(uint64(v.Uint())) + case *array.Float32Builder: + tb.Append(float32(v.Float())) + case *array.Float64Builder: + tb.Append(float64(v.Float())) + case *array.BooleanBuilder: + tb.Append(v.Bool()) + case array.StringLikeBuilder: + tb.Append(v.String()) + case array.BinaryLikeBuilder: + if v.IsNil() { + tb.AppendNull() + } else { + tb.Append(v.Bytes()) + } + case *array.TimestampBuilder, *array.Date32Builder, *array.Date64Builder, + *array.Time32Builder, *array.Time64Builder, *array.DurationBuilder: + return appendTemporalValue(b, v) + case *array.Decimal128Builder, *array.Decimal256Builder, *array.Decimal32Builder, *array.Decimal64Builder: + return appendDecimalValue(b, v) + case *array.ListBuilder, *array.LargeListBuilder, *array.ListViewBuilder, *array.LargeListViewBuilder: + return appendListElement(b, v) + case *array.FixedSizeListBuilder: + if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { + return fmt.Errorf("arreflect: cannot set fixed-size list from %s: %w", v.Type(), ErrTypeMismatch) + } + if v.Kind() == reflect.Slice && v.IsNil() { + tb.AppendNull() + return nil + } + expectedLen := int(tb.Type().(*arrow.FixedSizeListType).Len()) + if v.Len() != expectedLen { + return fmt.Errorf("arreflect: fixed-size list length mismatch: got %d, want %d", v.Len(), expectedLen) + } + tb.Append(true) + vb := tb.ValueBuilder() + for i := 0; i < v.Len(); i++ { + if err := appendValue(vb, v.Index(i)); err != nil { + return err + } + } + case *array.MapBuilder: + if v.IsNil() { + tb.AppendNull() + } else { + tb.Append(true) + kb := tb.KeyBuilder() + ib := tb.ItemBuilder() + for _, key := range v.MapKeys() { + if err := appendValue(kb, key); err != nil { + return err + } + if err := appendValue(ib, v.MapIndex(key)); err != nil { + return err + } + } + } + case *array.StructBuilder: + fields := cachedStructFields(v.Type()) + return appendStructFields(tb, v, fields) + default: + if db, ok := b.(array.DictionaryBuilder); ok { + return appendToDictBuilder(db, v) + } + return fmt.Errorf("unsupported builder type %T: %w", b, ErrUnsupportedType) + } + return nil +} + +func appendToDictBuilder(db array.DictionaryBuilder, v reflect.Value) error { + switch bdb := db.(type) { + case *array.BinaryDictionaryBuilder: + switch v.Kind() { + case reflect.String: + return bdb.AppendString(v.String()) + case reflect.Slice: + if v.IsNil() { + bdb.AppendNull() + return nil + } + return bdb.Append(v.Bytes()) + default: + return fmt.Errorf("unsupported value kind %v for BinaryDictionaryBuilder: %w", v.Kind(), ErrUnsupportedType) + } + case *array.Int8DictionaryBuilder: + return bdb.Append(int8(v.Int())) + case *array.Int16DictionaryBuilder: + return bdb.Append(int16(v.Int())) + case *array.Int32DictionaryBuilder: + return bdb.Append(int32(v.Int())) + case *array.Int64DictionaryBuilder: + return bdb.Append(int64(v.Int())) + case *array.Uint8DictionaryBuilder: + return bdb.Append(uint8(v.Uint())) + case *array.Uint16DictionaryBuilder: + return bdb.Append(uint16(v.Uint())) + case *array.Uint32DictionaryBuilder: + return bdb.Append(uint32(v.Uint())) + case *array.Uint64DictionaryBuilder: + return bdb.Append(uint64(v.Uint())) + case *array.Float32DictionaryBuilder: + return bdb.Append(float32(v.Float())) + case *array.Float64DictionaryBuilder: + return bdb.Append(float64(v.Float())) + } + return fmt.Errorf("unsupported builder type %T: %w", db, ErrUnsupportedType) +} + +type listBuilderLike interface { + array.Builder + ValueBuilder() array.Builder +} + +func appendListElement(b array.Builder, v reflect.Value) error { + if v.Kind() == reflect.Slice && v.IsNil() { + b.AppendNull() + return nil + } + + var vb array.Builder + switch lb := b.(type) { + case *array.ListBuilder: + lb.Append(true) + vb = lb.ValueBuilder() + case *array.LargeListBuilder: + lb.Append(true) + vb = lb.ValueBuilder() + case *array.ListViewBuilder: + lb.AppendWithSize(true, v.Len()) + vb = lb.ValueBuilder() + case *array.LargeListViewBuilder: + lb.AppendWithSize(true, v.Len()) + vb = lb.ValueBuilder() + default: + return fmt.Errorf("unexpected list builder type %T: %w", b, ErrUnsupportedType) + } + for i := 0; i < v.Len(); i++ { + if err := appendValue(vb, v.Index(i)); err != nil { + return err + } + } + return nil +} + +func buildListLikeArray(vals reflect.Value, mem memory.Allocator, opts tagOpts, isView bool) (arrow.Array, error) { + elemDT, err := inferListElemDT(vals) + if err != nil { + return nil, err + } + if opts.Large { + elemDT = applyLargeOpts(elemDT) + } + if opts.View { + elemDT = applyViewOpts(elemDT) + } + + label := "list element" + if isView { + label = "list-view element" + } + + var bldr listBuilderLike + var beginRow func(int) + switch { + case isView && opts.Large: + b := array.NewLargeListViewBuilder(mem, elemDT) + bldr = b + beginRow = func(n int) { b.AppendWithSize(true, n) } + case isView: + b := array.NewListViewBuilder(mem, elemDT) + bldr = b + beginRow = func(n int) { b.AppendWithSize(true, n) } + case opts.Large: + b := array.NewLargeListBuilder(mem, elemDT) + bldr = b + beginRow = func(_ int) { b.Append(true) } + default: + b := array.NewListBuilder(mem, elemDT) + bldr = b + beginRow = func(_ int) { b.Append(true) } + } + defer bldr.Release() + + vb := bldr.ValueBuilder() + for i := 0; i < vals.Len(); i++ { + outer := vals.Index(i) + for outer.Kind() == reflect.Ptr { + if outer.IsNil() { + bldr.AppendNull() + break + } + outer = outer.Elem() + } + if outer.Kind() == reflect.Ptr { + continue + } + if outer.Kind() == reflect.Slice && outer.IsNil() { + bldr.AppendNull() + continue + } + if outer.Kind() != reflect.Slice && outer.Kind() != reflect.Array { + return nil, fmt.Errorf("arreflect: %s [%d]: expected slice, got %s: %w", label, i, outer.Type(), ErrTypeMismatch) + } + beginRow(outer.Len()) + for j := 0; j < outer.Len(); j++ { + if err := appendValue(vb, outer.Index(j)); err != nil { + return nil, fmt.Errorf("%s [%d][%d]: %w", label, i, j, err) + } + } + } + return bldr.NewArray(), nil +} + +func buildListArray(vals reflect.Value, opts tagOpts, mem memory.Allocator) (arrow.Array, error) { + return buildListLikeArray(vals, mem, opts, false) +} + +func buildListViewArray(vals reflect.Value, opts tagOpts, mem memory.Allocator) (arrow.Array, error) { + return buildListLikeArray(vals, mem, opts, true) +} + +func buildMapArray(vals reflect.Value, opts tagOpts, mem memory.Allocator) (arrow.Array, error) { + mapType, isPtr := derefSliceElem(vals) + + keyType := mapType.Key() + valType := mapType.Elem() + + for keyType.Kind() == reflect.Ptr { + keyType = keyType.Elem() + } + for valType.Kind() == reflect.Ptr { + valType = valType.Elem() + } + + keyDT, err := inferArrowType(keyType) + if err != nil { + return nil, fmt.Errorf("map key type: %w", err) + } + valDT, err := inferArrowType(valType) + if err != nil { + return nil, fmt.Errorf("map value type: %w", err) + } + if opts.Large { + keyDT = applyLargeOpts(keyDT) + valDT = applyLargeOpts(valDT) + } + if opts.View { + keyDT = applyViewOpts(keyDT) + valDT = applyViewOpts(valDT) + } + + mb := array.NewMapBuilder(mem, keyDT, valDT, false) + defer mb.Release() + + kb := mb.KeyBuilder() + ib := mb.ItemBuilder() + + if err := iterSlice(vals, isPtr, mb.AppendNull, func(m reflect.Value) error { + if m.IsNil() { + mb.AppendNull() + return nil + } + mb.Append(true) + for _, key := range m.MapKeys() { + if err := appendValue(kb, key); err != nil { + return fmt.Errorf("map key: %w", err) + } + if err := appendValue(ib, m.MapIndex(key)); err != nil { + return fmt.Errorf("map value: %w", err) + } + } + return nil + }); err != nil { + return nil, err + } + + return mb.NewArray(), nil +} + +func buildFixedSizeListArray(vals reflect.Value, opts tagOpts, mem memory.Allocator) (arrow.Array, error) { + elemType, isPtr := derefSliceElem(vals) + + if elemType.Kind() != reflect.Array { + return nil, fmt.Errorf("arreflect: expected array element, got %v", elemType.Kind()) + } + + n := int32(elemType.Len()) + innerElemType := elemType.Elem() + for innerElemType.Kind() == reflect.Ptr { + innerElemType = innerElemType.Elem() + } + + innerDT, err := inferArrowType(innerElemType) + if err != nil { + return nil, err + } + if opts.Large { + innerDT = applyLargeOpts(innerDT) + } + if opts.View { + innerDT = applyViewOpts(innerDT) + } + + fb := array.NewFixedSizeListBuilder(mem, n, innerDT) + defer fb.Release() + + vb := fb.ValueBuilder() + + idx := 0 + appendNullIdx := func() { fb.AppendNull(); idx++ } + if err := iterSlice(vals, isPtr, appendNullIdx, func(elem reflect.Value) error { + fb.Append(true) + for j := 0; j < int(n); j++ { + if err := appendValue(vb, elem.Index(j)); err != nil { + return fmt.Errorf("fixed-size list element [%d][%d]: %w", idx, j, err) + } + } + idx++ + return nil + }); err != nil { + return nil, err + } + + return fb.NewArray(), nil +} + +func validateDictValueType(dt arrow.DataType) error { + switch dt.ID() { + case arrow.INT8, arrow.INT16, arrow.INT32, arrow.INT64, + arrow.UINT8, arrow.UINT16, arrow.UINT32, arrow.UINT64, + arrow.FLOAT32, arrow.FLOAT64, + arrow.STRING, arrow.BINARY: + return nil + default: + return fmt.Errorf("arreflect: dictionary encoding not supported for %s: %w", dt, ErrUnsupportedType) + } +} + +func buildDictionaryArray(vals reflect.Value, _ tagOpts, mem memory.Allocator) (arrow.Array, error) { + elemType, isPtr := derefSliceElem(vals) + + valDT, err := inferArrowType(elemType) + if err != nil { + return nil, err + } + // large is intentionally NOT applied here: Dictionary is + // unimplemented in the Arrow library (NewDictionaryBuilder panics). + + if err := validateDictValueType(valDT); err != nil { + return nil, err + } + + dt := &arrow.DictionaryType{ + IndexType: arrow.PrimitiveTypes.Int32, + ValueType: valDT, + } + db := array.NewDictionaryBuilder(mem, dt) + defer db.Release() + + if err := iterSlice(vals, isPtr, db.AppendNull, func(elem reflect.Value) error { + return appendToDictBuilder(db, elem) + }); err != nil { + return nil, err + } + return db.NewArray(), nil +} + +func buildRunEndEncodedArray(vals reflect.Value, opts tagOpts, mem memory.Allocator) (arrow.Array, error) { + valOpts := opts + valOpts.REE = false + valOpts.View = false + if vals.Len() == 0 { + runEndsArr, err := buildPrimitiveArray(reflect.MakeSlice(reflect.TypeOf([]int32{}), 0, 0), tagOpts{}, mem) + if err != nil { + return nil, err + } + defer runEndsArr.Release() + valuesArr, err := buildArray(reflect.MakeSlice(vals.Type(), 0, 0), valOpts, mem) + if err != nil { + return nil, err + } + defer valuesArr.Release() + return array.NewRunEndEncodedArray(runEndsArr, valuesArr, 0, 0), nil + } + + type run struct { + end int32 + val reflect.Value + } + + // For comparable element types use reflect.Value.Equal (fast, avoids boxing). + // For non-comparable types (e.g. slices, maps) fall back to reflect.DeepEqual, + // which handles structural equality but cannot compress runs of function values. + elemType := vals.Type().Elem() + for elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + comparable := elemType.Comparable() + + equal := func(a, b reflect.Value) bool { + if comparable { + da, db := a, b + for da.Kind() == reflect.Ptr { + if da.IsNil() || db.IsNil() { + return da.IsNil() && db.IsNil() + } + da, db = da.Elem(), db.Elem() + } + return da.Equal(db) + } + return reflect.DeepEqual(a.Interface(), b.Interface()) + } + + var runs []run + current := vals.Index(0) + for i := 1; i < vals.Len(); i++ { + next := vals.Index(i) + if !equal(current, next) { + runs = append(runs, run{end: int32(i), val: current}) + current = next + } + } + runs = append(runs, run{end: int32(vals.Len()), val: current}) + + runEnds := make([]int32, len(runs)) + for i, r := range runs { + runEnds[i] = r.end + } + runEndsSlice := reflect.ValueOf(runEnds) + runEndsArr, err := buildPrimitiveArray(runEndsSlice, tagOpts{}, mem) + if err != nil { + return nil, fmt.Errorf("run-end encoded run ends: %w", err) + } + defer runEndsArr.Release() + + runValues := reflect.MakeSlice(vals.Type(), len(runs), len(runs)) + for i, r := range runs { + runValues.Index(i).Set(r.val) + } + valuesArr, err := buildArray(runValues, valOpts, mem) + if err != nil { + return nil, fmt.Errorf("run-end encoded values: %w", err) + } + defer valuesArr.Release() + + return array.NewRunEndEncodedArray(runEndsArr, valuesArr, vals.Len(), 0), nil +} diff --git a/arrow/array/arreflect/reflect_go_to_arrow_test.go b/arrow/array/arreflect/reflect_go_to_arrow_test.go new file mode 100644 index 00000000..7f098874 --- /dev/null +++ b/arrow/array/arreflect/reflect_go_to_arrow_test.go @@ -0,0 +1,1208 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect + +import ( + "reflect" + "testing" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/arrow-go/v18/arrow/decimal256" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildPrimitiveArray(t *testing.T) { + mem := checkedMem(t) + + t.Run("int32", func(t *testing.T) { + vals := []int32{1, 2, 3, 4, 5} + arr := mustBuildDefault(t, vals, mem) + assert.Equal(t, 5, arr.Len()) + assert.Equal(t, arrow.INT32, arr.DataType().ID()) + typed := arr.(*array.Int32) + for i, want := range vals { + assert.Equal(t, want, typed.Value(i), "[%d] value mismatch", i) + } + }) + + t.Run("multi_level_pointer_int32", func(t *testing.T) { + v := int32(42) + pv := &v + var nilPv *int32 + vals := []**int32{&pv, &nilPv, &pv} + arr := mustBuildDefault(t, vals, mem) + assertMultiLevelPtrNullPattern(t, arr) + assert.Equal(t, int32(42), arr.(*array.Int32).Value(0)) + }) + + t.Run("pointer_with_null", func(t *testing.T) { + v1, v3 := int32(10), int32(30) + vals := []*int32{&v1, nil, &v3} + arr := mustBuildDefault(t, vals, mem) + assert.True(t, arr.IsNull(1), "expected index 1 to be null") + typed := arr.(*array.Int32) + assert.Equal(t, int32(10), typed.Value(0)) + assert.Equal(t, int32(30), typed.Value(2)) + }) + + t.Run("bool", func(t *testing.T) { + vals := []bool{true, false, true} + arr := mustBuildDefault(t, vals, mem) + assert.Equal(t, arrow.BOOL, arr.DataType().ID()) + typed := arr.(*array.Boolean) + assert.True(t, typed.Value(0), "expected Value(0) to be true") + assert.False(t, typed.Value(1), "expected Value(1) to be false") + assert.True(t, typed.Value(2), "expected Value(2) to be true") + }) + + t.Run("binary", func(t *testing.T) { + vals := [][]byte{{1, 2, 3}, {4, 5}, {6}} + arr := mustBuildDefault(t, vals, mem) + assert.Equal(t, arrow.BINARY, arr.DataType().ID()) + }) + + t.Run("numeric_types", func(t *testing.T) { + cases := []struct { + vals any + id arrow.Type + }{ + {[]int8{1, -2, 3}, arrow.INT8}, + {[]int16{100, -200}, arrow.INT16}, + {[]int64{1000, -2000}, arrow.INT64}, + {[]uint8{1, 2, 3}, arrow.UINT8}, + {[]uint16{1, 2}, arrow.UINT16}, + {[]uint32{1, 2}, arrow.UINT32}, + {[]uint64{1, 2}, arrow.UINT64}, + {[]float32{1.0, 2.0}, arrow.FLOAT32}, + {[]float64{1.1, 2.2}, arrow.FLOAT64}, + {[]int{1, -2, 3}, arrow.INT64}, + {[]uint{1, 2, 3}, arrow.UINT64}, + } + for _, tc := range cases { + arr, err := buildArray(reflect.ValueOf(tc.vals), tagOpts{}, mem) + require.NoError(t, err, "type %v", tc.id) + assert.Equal(t, tc.id, arr.DataType().ID(), "expected %v, got %v", tc.id, arr.DataType()) + arr.Release() + } + }) +} + +func TestBuildTemporalArray(t *testing.T) { + mem := checkedMem(t) + + t.Run("time_time", func(t *testing.T) { + now := time.Now().UTC() + vals := []time.Time{now, now.Add(time.Hour)} + arr := mustBuildDefault(t, vals, mem) + assert.Equal(t, arrow.TIMESTAMP, arr.DataType().ID()) + typed := arr.(*array.Timestamp) + for i, want := range vals { + assert.Equal(t, arrow.Timestamp(want.UnixNano()), typed.Value(i), "[%d] timestamp mismatch", i) + } + }) + + t.Run("time_duration", func(t *testing.T) { + vals := []time.Duration{time.Second, time.Minute, time.Hour} + arr := mustBuildDefault(t, vals, mem) + assert.Equal(t, arrow.DURATION, arr.DataType().ID()) + typed := arr.(*array.Duration) + for i, want := range vals { + assert.Equal(t, arrow.Duration(want.Nanoseconds()), typed.Value(i), "[%d] duration mismatch", i) + } + }) +} + +func TestBuildDecimalArray(t *testing.T) { + mem := checkedMem(t) + + t.Run("decimal128", func(t *testing.T) { + vals := []decimal128.Num{ + decimal128.New(0, 100), + decimal128.New(0, 200), + decimal128.New(0, 300), + } + arr := mustBuildDefault(t, vals, mem) + assert.Equal(t, arrow.DECIMAL128, arr.DataType().ID()) + typed := arr.(*array.Decimal128) + for i, want := range vals { + assert.Equal(t, want, typed.Value(i), "[%d] decimal128 mismatch", i) + } + }) + + t.Run("decimal256", func(t *testing.T) { + vals := []decimal256.Num{ + decimal256.New(0, 0, 0, 100), + decimal256.New(0, 0, 0, 200), + } + arr := mustBuildDefault(t, vals, mem) + assert.Equal(t, arrow.DECIMAL256, arr.DataType().ID()) + typed := arr.(*array.Decimal256) + for i, want := range vals { + assert.Equal(t, want, typed.Value(i), "[%d] decimal256 mismatch", i) + } + }) + + t.Run("decimal128_custom_opts", func(t *testing.T) { + vals := []decimal128.Num{decimal128.New(0, 12345)} + opts := tagOpts{HasDecimalOpts: true, DecimalPrecision: 10, DecimalScale: 3} + arr := mustBuildArray(t, vals, opts, mem) + dt := arr.DataType().(*arrow.Decimal128Type) + assert.Equal(t, int32(10), dt.Precision, "expected p=10, got p=%d", dt.Precision) + assert.Equal(t, int32(3), dt.Scale, "expected s=3, got s=%d", dt.Scale) + }) + + t.Run("decimal32", func(t *testing.T) { + vals := []decimal.Decimal32{100, 200, 300} + arr := mustBuildDefault(t, vals, mem) + assert.Equal(t, arrow.DECIMAL32, arr.DataType().ID()) + typed := arr.(*array.Decimal32) + for i, want := range vals { + assert.Equal(t, want, typed.Value(i), "[%d] decimal32 mismatch", i) + } + }) + + t.Run("decimal64", func(t *testing.T) { + vals := []decimal.Decimal64{1000, 2000} + arr := mustBuildDefault(t, vals, mem) + assert.Equal(t, arrow.DECIMAL64, arr.DataType().ID()) + typed := arr.(*array.Decimal64) + for i, want := range vals { + assert.Equal(t, want, typed.Value(i), "[%d] decimal64 mismatch", i) + } + }) + + t.Run("decimal32_custom_opts", func(t *testing.T) { + vals := []decimal.Decimal32{12345} + opts := tagOpts{HasDecimalOpts: true, DecimalPrecision: 9, DecimalScale: 2} + arr := mustBuildArray(t, vals, opts, mem) + dt := arr.DataType().(*arrow.Decimal32Type) + assert.Equal(t, int32(9), dt.Precision, "expected p=9, got p=%d", dt.Precision) + assert.Equal(t, int32(2), dt.Scale, "expected s=2, got s=%d", dt.Scale) + }) +} + +type buildSimpleStruct struct { + X int32 + Y string +} + +type buildNestedStruct struct { + A int32 + B buildSimpleStruct +} + +type buildNullableStruct struct { + X *int32 + Y *string +} + +func TestBuildStructArray(t *testing.T) { + mem := checkedMem(t) + + t.Run("simple", func(t *testing.T) { + vals := []buildSimpleStruct{ + {X: 1, Y: "one"}, + {X: 2, Y: "two"}, + {X: 3, Y: "three"}, + } + arr := mustBuildDefault(t, vals, mem) + require.Equal(t, arrow.STRUCT, arr.DataType().ID(), "expected STRUCT, got %v", arr.DataType()) + typed := arr.(*array.Struct) + assert.Equal(t, 3, typed.Len()) + xArr := typed.Field(0).(*array.Int32) + yArr := typed.Field(1).(*array.String) + for i, want := range vals { + assert.Equal(t, want.X, xArr.Value(i), "[%d] X mismatch", i) + assert.Equal(t, want.Y, yArr.Value(i), "[%d] Y mismatch", i) + } + }) + + t.Run("pointer_null_row", func(t *testing.T) { + v1 := buildSimpleStruct{X: 42, Y: "answer"} + vals := []*buildSimpleStruct{&v1, nil} + arr := mustBuildDefault(t, vals, mem) + assert.Equal(t, 2, arr.Len()) + assert.True(t, arr.IsNull(1), "expected index 1 to be null") + }) + + t.Run("nullable_fields", func(t *testing.T) { + x1 := int32(10) + y1 := "hello" + vals := []buildNullableStruct{ + {X: &x1, Y: &y1}, + {X: nil, Y: nil}, + } + arr := mustBuildDefault(t, vals, mem) + typed := arr.(*array.Struct) + assert.True(t, typed.Field(0).IsNull(1), "expected X[1] to be null") + assert.True(t, typed.Field(1).IsNull(1), "expected Y[1] to be null") + }) + + t.Run("nested_struct", func(t *testing.T) { + vals := []buildNestedStruct{ + {A: 1, B: buildSimpleStruct{X: 10, Y: "inner1"}}, + {A: 2, B: buildSimpleStruct{X: 20, Y: "inner2"}}, + } + arr := mustBuildDefault(t, vals, mem) + require.Equal(t, arrow.STRUCT, arr.DataType().ID(), "expected STRUCT, got %v", arr.DataType()) + typed := arr.(*array.Struct) + aArr := typed.Field(0).(*array.Int32) + assert.Equal(t, int32(1), aArr.Value(0)) + assert.Equal(t, int32(2), aArr.Value(1)) + bArr := typed.Field(1).(*array.Struct) + bxArr := bArr.Field(0).(*array.Int32) + assert.Equal(t, int32(10), bxArr.Value(0)) + assert.Equal(t, int32(20), bxArr.Value(1)) + }) + + t.Run("multi_level_pointer_struct", func(t *testing.T) { + type S struct { + X int32 + } + s := S{X: 99} + ps := &s + var nilPs *S + vals := []**S{&ps, &nilPs, &ps} + arr := mustBuildDefault(t, vals, mem) + assertMultiLevelPtrNullPattern(t, arr) + sa := arr.(*array.Struct) + xArr := sa.Field(0).(*array.Int32) + assert.Equal(t, int32(99), xArr.Value(0)) + assert.Equal(t, int32(99), xArr.Value(2)) + }) + + t.Run("nil_embedded_pointer_promoted_field", func(t *testing.T) { + // Regression: reflect.Value.FieldByIndex panics when traversing a nil + // embedded pointer; promoted fields must become null instead. + type Inner struct { + City string + Zip int32 + } + type Outer struct { + Name string + *Inner + } + vals := []Outer{ + {Name: "Alice", Inner: &Inner{City: "NYC", Zip: 10001}}, + {Name: "Bob", Inner: nil}, + {Name: "Carol", Inner: &Inner{City: "LA", Zip: 90001}}, + } + arr := mustBuildDefault(t, vals, mem) + require.Equal(t, arrow.STRUCT, arr.DataType().ID()) + sa := arr.(*array.Struct) + require.Equal(t, 3, sa.Len()) + require.Equal(t, 3, sa.NumField(), "expected 3 promoted fields (Name, City, Zip)") + + nameArr := sa.Field(0).(*array.String) + cityArr := sa.Field(1).(*array.String) + zipArr := sa.Field(2).(*array.Int32) + + assert.Equal(t, "Alice", nameArr.Value(0)) + assert.False(t, cityArr.IsNull(0)) + assert.Equal(t, "NYC", cityArr.Value(0)) + assert.False(t, zipArr.IsNull(0)) + assert.Equal(t, int32(10001), zipArr.Value(0)) + + assert.Equal(t, "Bob", nameArr.Value(1)) + assert.True(t, cityArr.IsNull(1), "City should be null when *Inner is nil") + assert.True(t, zipArr.IsNull(1), "Zip should be null when *Inner is nil") + + assert.Equal(t, "Carol", nameArr.Value(2)) + assert.False(t, cityArr.IsNull(2)) + assert.Equal(t, "LA", cityArr.Value(2)) + assert.False(t, zipArr.IsNull(2)) + assert.Equal(t, int32(90001), zipArr.Value(2)) + }) +} + +func TestBuildListArray(t *testing.T) { + mem := checkedMem(t) + + t.Run("int32_lists", func(t *testing.T) { + vals := [][]int32{{1, 2, 3}, {4, 5}, {6}} + arr := mustBuildDefault(t, vals, mem) + require.Equal(t, arrow.LIST, arr.DataType().ID(), "expected LIST, got %v", arr.DataType()) + typed := arr.(*array.List) + assert.Equal(t, 3, typed.Len()) + assert.Equal(t, 6, typed.ListValues().(*array.Int32).Len(), "expected 6 total values") + }) + + t.Run("null_inner", func(t *testing.T) { + vals := [][]int32{{1, 2}, nil, {3}} + arr := mustBuildDefault(t, vals, mem) + assert.True(t, arr.IsNull(1), "expected index 1 to be null") + }) + + t.Run("nil_pointer_list_element", func(t *testing.T) { + a := []int32{1, 2} + vals := []*[]int32{&a, nil, &a} + arr := mustBuildDefault(t, vals, mem) + assert.Equal(t, arrow.LIST, arr.DataType().ID()) + assertMultiLevelPtrNullPattern(t, arr) + }) + + t.Run("multi_level_pointer_list", func(t *testing.T) { + a := []int32{1, 2} + pa := &a + var nilPa *[]int32 + vals := []**[]int32{&pa, &nilPa, &pa} + arr := mustBuildDefault(t, vals, mem) + assertMultiLevelPtrNullPattern(t, arr) + }) + + t.Run("string_lists", func(t *testing.T) { + vals := [][]string{{"a", "b"}, {"c"}} + arr := mustBuildDefault(t, vals, mem) + require.Equal(t, arrow.LIST, arr.DataType().ID(), "expected LIST, got %v", arr.DataType()) + }) + + t.Run("nested", func(t *testing.T) { + vals := [][][]int32{{{1, 2}, {3}}, {{4, 5, 6}}} + arr := mustBuildDefault(t, vals, mem) + require.Equal(t, arrow.LIST, arr.DataType().ID(), "expected outer LIST, got %v", arr.DataType()) + outer := arr.(*array.List) + assert.Equal(t, 2, outer.Len(), "expected 2 outer rows, got %d", outer.Len()) + require.Equal(t, arrow.LIST, outer.ListValues().DataType().ID(), "expected inner LIST, got %v", outer.ListValues().DataType()) + }) +} + +func TestBuildMapArray(t *testing.T) { + mem := checkedMem(t) + + t.Run("string_int32", func(t *testing.T) { + vals := []map[string]int32{ + {"a": 1, "b": 2}, + {"c": 3}, + } + arr := mustBuildDefault(t, vals, mem) + require.Equal(t, arrow.MAP, arr.DataType().ID(), "expected MAP, got %v", arr.DataType()) + assert.Equal(t, 2, arr.(*array.Map).Len()) + }) + + t.Run("null_map", func(t *testing.T) { + vals := []map[string]int32{{"a": 1}, nil} + arr := mustBuildDefault(t, vals, mem) + assert.True(t, arr.IsNull(1), "expected index 1 to be null") + }) + + t.Run("entry_count", func(t *testing.T) { + vals := []map[string]int32{{"x": 10, "y": 20, "z": 30}} + arr := mustBuildDefault(t, vals, mem) + kvArr := arr.(*array.Map).ListValues().(*array.Struct) + assert.Equal(t, 3, kvArr.Len(), "expected 3 key-value pairs, got %d", kvArr.Len()) + }) + + t.Run("multi_level_pointer_map", func(t *testing.T) { + m := map[string]int32{"x": 1} + pm := &m + var nilPm *map[string]int32 + vals := []**map[string]int32{&pm, &nilPm, &pm} + arr := mustBuildDefault(t, vals, mem) + assertMultiLevelPtrNullPattern(t, arr) + }) +} + +func TestBuildFixedSizeListArray(t *testing.T) { + mem := checkedMem(t) + + t.Run("int32_n3", func(t *testing.T) { + vals := [][3]int32{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}} + arr := mustBuildDefault(t, vals, mem) + require.Equal(t, arrow.FIXED_SIZE_LIST, arr.DataType().ID(), "expected FIXED_SIZE_LIST, got %v", arr.DataType()) + typed := arr.(*array.FixedSizeList) + assert.Equal(t, 3, typed.Len()) + assert.Equal(t, int32(3), typed.DataType().(*arrow.FixedSizeListType).Len(), "expected fixed size 3") + values := typed.ListValues().(*array.Int32) + assert.Equal(t, 9, values.Len()) + assert.Equal(t, int32(1), values.Value(0)) + assert.Equal(t, int32(4), values.Value(3)) + assert.Equal(t, int32(7), values.Value(6)) + }) + + t.Run("float64_n2", func(t *testing.T) { + vals := [][2]float64{{1.0, 2.0}, {3.0, 4.0}} + arr := mustBuildDefault(t, vals, mem) + require.Equal(t, arrow.FIXED_SIZE_LIST, arr.DataType().ID(), "expected FIXED_SIZE_LIST, got %v", arr.DataType()) + assert.Equal(t, int32(2), arr.DataType().(*arrow.FixedSizeListType).Len(), "expected fixed size 2") + }) + + t.Run("nil_slice_appends_null", func(t *testing.T) { + bldr := array.NewFixedSizeListBuilder(mem, int32(3), arrow.PrimitiveTypes.Int32) + defer bldr.Release() + + var nilSlice []int32 + err := appendValue(bldr, reflect.ValueOf(&nilSlice).Elem()) + require.NoError(t, err) + + bldr.Append(true) + vb := bldr.ValueBuilder().(*array.Int32Builder) + vb.AppendValues([]int32{1, 2, 3}, nil) + + arr := bldr.NewArray() + defer arr.Release() + assert.True(t, arr.IsNull(0), "nil slice should be null") + assert.False(t, arr.IsNull(1), "non-nil should not be null") + }) + + t.Run("multi_level_pointer_fixed_size_list", func(t *testing.T) { + a := [3]int32{1, 2, 3} + pa := &a + var nilPa *[3]int32 + vals := []**[3]int32{&pa, &nilPa, &pa} + arr := mustBuildDefault(t, vals, mem) + assertMultiLevelPtrNullPattern(t, arr) + }) +} + +func TestBuildDictionaryArray(t *testing.T) { + mem := checkedMem(t) + + t.Run("string_dict", func(t *testing.T) { + vals := []string{"apple", "banana", "apple", "cherry", "banana", "apple"} + arr := mustBuildArray(t, vals, tagOpts{Dict: true}, mem) + require.Equal(t, arrow.DICTIONARY, arr.DataType().ID(), "expected DICTIONARY, got %v", arr.DataType()) + typed := arr.(*array.Dictionary) + assert.Equal(t, 6, typed.Len()) + assert.Equal(t, 3, typed.Dictionary().Len(), "expected 3 unique, got %d", typed.Dictionary().Len()) + }) + + t.Run("int32_dict", func(t *testing.T) { + vals := []int32{1, 2, 1, 3, 2, 1} + arr := mustBuildArray(t, vals, tagOpts{Dict: true}, mem) + require.Equal(t, arrow.DICTIONARY, arr.DataType().ID(), "expected DICTIONARY, got %v", arr.DataType()) + typed := arr.(*array.Dictionary) + assert.Equal(t, 6, typed.Len()) + assert.Equal(t, 3, typed.Dictionary().Len(), "expected 3 unique, got %d", typed.Dictionary().Len()) + }) + + t.Run("index_type_is_int32", func(t *testing.T) { + vals := []string{"x", "y", "z"} + arr := mustBuildArray(t, vals, tagOpts{Dict: true}, mem) + dt := arr.DataType().(*arrow.DictionaryType) + assert.Equal(t, arrow.INT32, dt.IndexType.ID(), "expected INT32 index, got %v", dt.IndexType) + }) + + t.Run("bool_dict_returns_error", func(t *testing.T) { + _, err := buildArray(reflect.ValueOf([]bool{true, false}), tagOpts{Dict: true}, mem) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("pointer_string_with_nil", func(t *testing.T) { + s := "hello" + vals := []*string{&s, nil, &s} + arr := mustBuildArray(t, vals, tagOpts{Dict: true}, mem) + typed := arr.(*array.Dictionary) + assert.Equal(t, arrow.DICTIONARY, arr.DataType().ID()) + assert.Equal(t, 3, arr.Len()) + assert.False(t, arr.IsNull(0)) + assert.True(t, arr.IsNull(1)) + assert.False(t, arr.IsNull(2)) + assert.Equal(t, 1, typed.Dictionary().Len(), "expected 1 unique value") + }) + + t.Run("multi_level_pointer_string", func(t *testing.T) { + s := "world" + ps := &s + var nilPs *string + vals := []**string{&ps, &nilPs, &ps} + arr := mustBuildArray(t, vals, tagOpts{Dict: true}, mem) + typed := arr.(*array.Dictionary) + assert.Equal(t, arrow.DICTIONARY, arr.DataType().ID()) + assertMultiLevelPtrNullPattern(t, arr) + assert.Equal(t, 1, typed.Dictionary().Len(), "expected 1 unique value") + }) +} + +func TestBuildRunEndEncodedArray(t *testing.T) { + mem := checkedMem(t) + + t.Run("int32_runs", func(t *testing.T) { + vals := []int32{1, 1, 1, 2, 2, 3} + arr := mustBuildArray(t, vals, tagOpts{REE: true}, mem) + require.Equal(t, arrow.RUN_END_ENCODED, arr.DataType().ID(), "expected RUN_END_ENCODED, got %v", arr.DataType()) + ree := arr.(*array.RunEndEncoded) + assert.Equal(t, 6, ree.Len()) + runEnds := ree.RunEndsArr().(*array.Int32) + assert.Equal(t, 3, runEnds.Len(), "expected 3 runs, got %d", runEnds.Len()) + assert.Equal(t, int32(3), runEnds.Value(0)) + assert.Equal(t, int32(5), runEnds.Value(1)) + assert.Equal(t, int32(6), runEnds.Value(2)) + values := ree.Values().(*array.Int32) + assert.Equal(t, 3, values.Len(), "expected 3 values, got %d", values.Len()) + assert.Equal(t, int32(1), values.Value(0)) + assert.Equal(t, int32(2), values.Value(1)) + assert.Equal(t, int32(3), values.Value(2)) + }) + + t.Run("string_runs", func(t *testing.T) { + vals := []string{"a", "a", "b", "b", "b", "c"} + arr := mustBuildArray(t, vals, tagOpts{REE: true}, mem) + require.Equal(t, arrow.RUN_END_ENCODED, arr.DataType().ID(), "expected RUN_END_ENCODED, got %v", arr.DataType()) + ree := arr.(*array.RunEndEncoded) + assert.Equal(t, 6, ree.Len()) + assert.Equal(t, 3, ree.RunEndsArr().Len(), "expected 3 runs, got %d", ree.RunEndsArr().Len()) + }) + + t.Run("single_run", func(t *testing.T) { + vals := []int32{42, 42, 42} + arr := mustBuildArray(t, vals, tagOpts{REE: true}, mem) + ree := arr.(*array.RunEndEncoded) + assert.Equal(t, 3, ree.Len()) + runEnds := ree.RunEndsArr().(*array.Int32) + assert.Equal(t, 1, runEnds.Len()) + assert.Equal(t, int32(3), runEnds.Value(0)) + }) + + t.Run("all_distinct", func(t *testing.T) { + vals := []int32{1, 2, 3, 4, 5} + arr := mustBuildArray(t, vals, tagOpts{REE: true}, mem) + ree := arr.(*array.RunEndEncoded) + assert.Equal(t, 5, ree.Len()) + assert.Equal(t, 5, ree.RunEndsArr().Len(), "expected 5 runs for all-distinct, got %d", ree.RunEndsArr().Len()) + }) + + t.Run("pointer_value_equality", func(t *testing.T) { + x1 := "x" + x2 := "x" + y := "y" + vals := []*string{&x1, &x2, &y} + arr := mustBuildArray(t, vals, tagOpts{REE: true}, mem) + ree := arr.(*array.RunEndEncoded) + assert.Equal(t, 2, ree.RunEndsArr().Len(), "expected 2 runs (x+x coalesced, y), got %d", ree.RunEndsArr().Len()) + }) + + t.Run("ree_with_temporal_date32", func(t *testing.T) { + t1 := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + t2 := time.Date(2024, 6, 15, 0, 0, 0, 0, time.UTC) + vals := []time.Time{t1, t1, t2} + arr := mustBuildArray(t, vals, tagOpts{REE: true, Temporal: "date32"}, mem) + ree := arr.(*array.RunEndEncoded) + assert.Equal(t, 3, ree.Len()) + assert.Equal(t, arrow.DATE32, ree.Values().DataType().ID()) + }) +} + +func TestBuildViewArray(t *testing.T) { + mem := checkedMem(t) + + t.Run("string→STRING_VIEW", func(t *testing.T) { + arr := mustBuildArray(t, []string{"a", "b"}, tagOpts{View: true}, mem) + assert.Equal(t, arrow.STRING_VIEW, arr.DataType().ID()) + sv := arr.(*array.StringView) + assert.Equal(t, "a", sv.Value(0)) + assert.Equal(t, "b", sv.Value(1)) + }) + + t.Run("[]byte→BINARY_VIEW", func(t *testing.T) { + arr := mustBuildArray(t, [][]byte{{1, 2}, {3}}, tagOpts{View: true}, mem) + assert.Equal(t, arrow.BINARY_VIEW, arr.DataType().ID()) + }) + + t.Run("int32_view", func(t *testing.T) { + vals := [][]int32{{1, 2, 3}, {4, 5}} + arr := mustBuildArray(t, vals, tagOpts{View: true}, mem) + assert.Equal(t, arrow.LIST_VIEW, arr.DataType().ID()) + typed := arr.(*array.ListView) + assert.Equal(t, 2, typed.Len()) + }) + + t.Run("nil_outer_listview", func(t *testing.T) { + var nilSlice [][]int32 + arr := mustBuildArray(t, nilSlice, tagOpts{View: true}, mem) + assert.Equal(t, 0, arr.Len()) + }) + + t.Run("string_listview", func(t *testing.T) { + vals := [][]string{{"a", "b"}, {"c"}} + arr := mustBuildArray(t, vals, tagOpts{View: true}, mem) + assert.Equal(t, arrow.LIST_VIEW, arr.DataType().ID()) + lv := arr.DataType().(*arrow.ListViewType) + assert.Equal(t, arrow.STRING_VIEW, lv.Elem().ID()) + }) + + t.Run("null_in_listview", func(t *testing.T) { + vals := [][]int32{{1, 2, 3}, nil, {4, 5}} + arr := mustBuildArray(t, vals, tagOpts{View: true}, mem) + allVals := arr.(*array.ListView).ListValues().(*array.Int32) + assert.Equal(t, 5, allVals.Len()) + }) + + t.Run("nil_pointer_view_element", func(t *testing.T) { + a := []int32{1, 2} + vals := []*[]int32{&a, nil} + arr := mustBuildArray(t, vals, tagOpts{View: true}, mem) + assert.True(t, arr.IsNull(1)) + }) +} + +func TestBuildTemporalTaggedArray(t *testing.T) { + mem := checkedMem(t) + + ref := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + + t.Run("date32", func(t *testing.T) { + vals := []time.Time{ref, ref.AddDate(0, 0, 1)} + opts := tagOpts{Temporal: "date32"} + sv := reflect.ValueOf(vals) + arr, err := buildTemporalArray(sv, opts, mem) + require.NoError(t, err, "unexpected error") + defer arr.Release() + assert.Equal(t, arrow.DATE32, arr.DataType().ID()) + assert.Equal(t, 2, arr.Len()) + d32arr := arr.(*array.Date32) + got0 := d32arr.Value(0).ToTime() + assert.Equal(t, ref.Year(), got0.Year()) + assert.Equal(t, ref.Month(), got0.Month()) + assert.Equal(t, ref.Day(), got0.Day()) + }) + + t.Run("date64", func(t *testing.T) { + vals := []time.Time{ref} + opts := tagOpts{Temporal: "date64"} + sv := reflect.ValueOf(vals) + arr, err := buildTemporalArray(sv, opts, mem) + require.NoError(t, err, "unexpected error") + defer arr.Release() + assert.Equal(t, arrow.DATE64, arr.DataType().ID()) + d64arr := arr.(*array.Date64) + got0 := d64arr.Value(0).ToTime() + assert.Equal(t, ref.Year(), got0.Year()) + assert.Equal(t, ref.Month(), got0.Month()) + assert.Equal(t, ref.Day(), got0.Day()) + }) + + t.Run("time32", func(t *testing.T) { + vals := []time.Time{ref} + opts := tagOpts{Temporal: "time32"} + sv := reflect.ValueOf(vals) + arr, err := buildTemporalArray(sv, opts, mem) + require.NoError(t, err, "unexpected error") + defer arr.Release() + assert.Equal(t, arrow.TIME32, arr.DataType().ID()) + assert.Equal(t, 1, arr.Len()) + t32arr := arr.(*array.Time32) + unit := arr.DataType().(*arrow.Time32Type).Unit + got0 := t32arr.Value(0).ToTime(unit) + assert.Equal(t, ref.Hour(), got0.Hour()) + assert.Equal(t, ref.Minute(), got0.Minute()) + assert.Equal(t, ref.Second(), got0.Second()) + refWithMs := time.Date(ref.Year(), ref.Month(), ref.Day(), ref.Hour(), ref.Minute(), ref.Second(), 500_000_000, ref.Location()) + svMs := reflect.ValueOf([]time.Time{refWithMs}) + arrMs, err := buildTemporalArray(svMs, tagOpts{Temporal: "time32"}, mem) + require.NoError(t, err, "time32 with ms") + defer arrMs.Release() + t32ms := arrMs.(*array.Time32) + unitMs := arrMs.DataType().(*arrow.Time32Type).Unit + gotMs := t32ms.Value(0).ToTime(unitMs) + assert.Equal(t, 500, gotMs.Nanosecond()/1_000_000, "time32 millisecond: got %d ms, want 500 ms", gotMs.Nanosecond()/1_000_000) + }) + + t.Run("time64", func(t *testing.T) { + vals := []time.Time{ref} + opts := tagOpts{Temporal: "time64"} + sv := reflect.ValueOf(vals) + arr, err := buildTemporalArray(sv, opts, mem) + require.NoError(t, err, "unexpected error") + defer arr.Release() + assert.Equal(t, arrow.TIME64, arr.DataType().ID()) + t64arr := arr.(*array.Time64) + unit := arr.DataType().(*arrow.Time64Type).Unit + got0 := t64arr.Value(0).ToTime(unit) + assert.Equal(t, ref.Hour(), got0.Hour()) + assert.Equal(t, ref.Minute(), got0.Minute()) + assert.Equal(t, ref.Second(), got0.Second()) + refWithNanos := time.Date(ref.Year(), ref.Month(), ref.Day(), ref.Hour(), ref.Minute(), ref.Second(), 123456789, ref.Location()) + sv64 := reflect.ValueOf([]time.Time{refWithNanos}) + arr64, err := buildTemporalArray(sv64, tagOpts{Temporal: "time64"}, mem) + require.NoError(t, err, "time64 with nanos") + defer arr64.Release() + t64arr64 := arr64.(*array.Time64) + unit64 := arr64.DataType().(*arrow.Time64Type).Unit + got64 := t64arr64.Value(0).ToTime(unit64) + assert.Equal(t, refWithNanos.Nanosecond(), got64.Nanosecond(), + "time64 nanosecond: got %d, want %d", got64.Nanosecond(), refWithNanos.Nanosecond()) + }) +} + +func TestNilByteSliceIsNull(t *testing.T) { + mem := memory.NewGoAllocator() + arr, err := FromSlice([][]byte{[]byte("hello"), nil}, mem) + require.NoError(t, err) + defer arr.Release() + assert.False(t, arr.IsNull(0), "non-nil byte slice should not be null") + assert.True(t, arr.IsNull(1), "nil byte slice should be null") +} + +func TestAppendToDictBuilderAllTypes(t *testing.T) { + mem := checkedMem(t) + + cases := []struct { + name string + run func(t *testing.T) + }{ + {"int8", func(t *testing.T) { + arr := mustBuildArray(t, []int8{1, 2, 1, 3}, tagOpts{Dict: true}, mem) + assert.Equal(t, arrow.DICTIONARY, arr.DataType().ID()) + assert.Equal(t, 3, arr.(*array.Dictionary).Dictionary().Len()) + }}, + {"int16", func(t *testing.T) { + arr := mustBuildArray(t, []int16{1, 2, 1, 3}, tagOpts{Dict: true}, mem) + assert.Equal(t, 3, arr.(*array.Dictionary).Dictionary().Len()) + }}, + {"int64", func(t *testing.T) { + arr := mustBuildArray(t, []int64{1, 2, 1, 3}, tagOpts{Dict: true}, mem) + assert.Equal(t, 3, arr.(*array.Dictionary).Dictionary().Len()) + }}, + {"uint8", func(t *testing.T) { + arr := mustBuildArray(t, []uint8{1, 2, 1, 3}, tagOpts{Dict: true}, mem) + assert.Equal(t, 3, arr.(*array.Dictionary).Dictionary().Len()) + }}, + {"uint16", func(t *testing.T) { + arr := mustBuildArray(t, []uint16{1, 2, 1, 3}, tagOpts{Dict: true}, mem) + assert.Equal(t, 3, arr.(*array.Dictionary).Dictionary().Len()) + }}, + {"uint32", func(t *testing.T) { + arr := mustBuildArray(t, []uint32{1, 2, 1, 3}, tagOpts{Dict: true}, mem) + assert.Equal(t, 3, arr.(*array.Dictionary).Dictionary().Len()) + }}, + {"uint64", func(t *testing.T) { + arr := mustBuildArray(t, []uint64{1, 2, 1, 3}, tagOpts{Dict: true}, mem) + assert.Equal(t, 3, arr.(*array.Dictionary).Dictionary().Len()) + }}, + {"float32", func(t *testing.T) { + arr := mustBuildArray(t, []float32{1.1, 2.2, 1.1, 3.3}, tagOpts{Dict: true}, mem) + assert.Equal(t, 3, arr.(*array.Dictionary).Dictionary().Len()) + }}, + {"float64", func(t *testing.T) { + arr := mustBuildArray(t, []float64{1.1, 2.2, 1.1, 3.3}, tagOpts{Dict: true}, mem) + assert.Equal(t, 3, arr.(*array.Dictionary).Dictionary().Len()) + }}, + {"binary bytes", func(t *testing.T) { + arr := mustBuildArray(t, [][]byte{[]byte("a"), []byte("b"), []byte("a")}, tagOpts{Dict: true}, mem) + assert.Equal(t, 2, arr.(*array.Dictionary).Dictionary().Len()) + }}, + {"binary nil is null", func(t *testing.T) { + arr := mustBuildArray(t, [][]byte{[]byte("a"), nil, []byte("a")}, tagOpts{Dict: true}, mem) + assert.True(t, arr.IsNull(1)) + assert.Equal(t, 1, arr.(*array.Dictionary).Dictionary().Len()) + }}, + } + for _, tc := range cases { + t.Run(tc.name, tc.run) + } + + t.Run("binary with unsupported kind returns error", func(t *testing.T) { + db := array.NewDictionaryBuilder(mem, &arrow.DictionaryType{ + IndexType: arrow.PrimitiveTypes.Int32, ValueType: arrow.BinaryTypes.Binary, + }).(*array.BinaryDictionaryBuilder) + defer db.Release() + err := appendToDictBuilder(db, reflect.ValueOf(int32(7))) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("unsupported dict builder type returns error", func(t *testing.T) { + db := array.NewDictionaryBuilder(mem, &arrow.DictionaryType{ + IndexType: arrow.PrimitiveTypes.Int32, + ValueType: &arrow.Decimal128Type{Precision: 10, Scale: 2}, + }) + defer db.Release() + err := appendToDictBuilder(db, reflect.ValueOf(decimal128.New(0, 1))) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) +} + +func TestAppendListElementDirect(t *testing.T) { + mem := checkedMem(t) + + t.Run("nil slice appends null", func(t *testing.T) { + lb := array.NewListBuilder(mem, arrow.PrimitiveTypes.Int32) + defer lb.Release() + var empty []int32 + require.NoError(t, appendListElement(lb, reflect.ValueOf(empty))) + arr := lb.NewArray() + defer arr.Release() + assert.True(t, arr.IsNull(0)) + }) + + t.Run("large list builder", func(t *testing.T) { + lb := array.NewLargeListBuilder(mem, arrow.PrimitiveTypes.Int32) + defer lb.Release() + require.NoError(t, appendListElement(lb, reflect.ValueOf([]int32{1, 2, 3}))) + arr := lb.NewArray().(*array.LargeList) + defer arr.Release() + assert.Equal(t, 1, arr.Len()) + vb := arr.ListValues().(*array.Int32) + assert.Equal(t, 3, vb.Len()) + }) + + t.Run("list view builder", func(t *testing.T) { + lb := array.NewListViewBuilder(mem, arrow.PrimitiveTypes.Int32) + defer lb.Release() + require.NoError(t, appendListElement(lb, reflect.ValueOf([]int32{4, 5}))) + arr := lb.NewArray().(*array.ListView) + defer arr.Release() + assert.Equal(t, 1, arr.Len()) + }) + + t.Run("large list view builder", func(t *testing.T) { + lb := array.NewLargeListViewBuilder(mem, arrow.PrimitiveTypes.Int32) + defer lb.Release() + require.NoError(t, appendListElement(lb, reflect.ValueOf([]int32{6}))) + arr := lb.NewArray().(*array.LargeListView) + defer arr.Release() + assert.Equal(t, 1, arr.Len()) + }) + + t.Run("unexpected builder type returns error", func(t *testing.T) { + b := array.NewInt32Builder(mem) + defer b.Release() + err := appendListElement(b, reflect.ValueOf([]int32{1})) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) +} + +func TestBuildRunEndEncodedArrayExtras(t *testing.T) { + mem := checkedMem(t) + + t.Run("empty_slice_direct", func(t *testing.T) { + empty := reflect.MakeSlice(reflect.TypeOf([]int32{}), 0, 0) + arr, err := buildRunEndEncodedArray(empty, tagOpts{REE: true}, mem) + require.NoError(t, err) + defer arr.Release() + assert.Equal(t, 0, arr.Len()) + assert.Equal(t, arrow.RUN_END_ENCODED, arr.DataType().ID()) + }) + + t.Run("nil_pointer_runs_collapse", func(t *testing.T) { + s := "x" + vals := []*string{nil, nil, &s, nil} + arr := mustBuildArray(t, vals, tagOpts{REE: true}, mem) + ree := arr.(*array.RunEndEncoded) + assert.Equal(t, 4, ree.Len()) + assert.Equal(t, 3, ree.RunEndsArr().Len(), + "expected 3 runs (nil,nil + x + nil), got %d", ree.RunEndsArr().Len()) + }) + + t.Run("nil_and_non_nil_pointer_are_not_equal", func(t *testing.T) { + s := "x" + vals := []*string{nil, &s} + arr := mustBuildArray(t, vals, tagOpts{REE: true}, mem) + ree := arr.(*array.RunEndEncoded) + assert.Equal(t, 2, ree.RunEndsArr().Len(), + "expected 2 runs (nil != &x), got %d", ree.RunEndsArr().Len()) + }) + + t.Run("non_comparable_elem_uses_deep_equal", func(t *testing.T) { + vals := [][]int32{{1, 2}, {1, 2}, {3}} + arr := mustBuildArray(t, vals, tagOpts{REE: true}, mem) + ree := arr.(*array.RunEndEncoded) + assert.Equal(t, 3, ree.Len()) + assert.Equal(t, 2, ree.RunEndsArr().Len(), + "expected 2 runs via DeepEqual, got %d", ree.RunEndsArr().Len()) + }) +} + +func TestBuildMapArrayExtras(t *testing.T) { + mem := checkedMem(t) + + t.Run("pointer_key_type", func(t *testing.T) { + k1, k2 := "a", "b" + vals := []map[*string]int32{{&k1: 1, &k2: 2}} + arr := mustBuildDefault(t, vals, mem) + require.Equal(t, arrow.MAP, arr.DataType().ID()) + assert.Equal(t, 1, arr.Len()) + }) + + t.Run("pointer_value_type", func(t *testing.T) { + v1, v2 := int32(1), int32(2) + vals := []map[string]*int32{{"a": &v1, "b": &v2}} + arr := mustBuildDefault(t, vals, mem) + require.Equal(t, arrow.MAP, arr.DataType().ID()) + assert.Equal(t, 1, arr.Len()) + }) + + t.Run("unsupported_key_type_errors", func(t *testing.T) { + vals := []map[complex64]int32{{1 + 2i: 1}} + _, err := buildArray(reflect.ValueOf(vals), tagOpts{}, mem) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("unsupported_value_type_errors", func(t *testing.T) { + vals := []map[string]complex64{{"a": 1 + 2i}} + _, err := buildArray(reflect.ValueOf(vals), tagOpts{}, mem) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) +} + +func TestAppendTemporalValueErrors(t *testing.T) { + mem := checkedMem(t) + notATime := reflect.ValueOf(int32(42)) + + builderCases := []struct { + name string + builder array.Builder + }{ + {"timestamp", array.NewTimestampBuilder(mem, &arrow.TimestampType{Unit: arrow.Nanosecond})}, + {"date32", array.NewDate32Builder(mem)}, + {"date64", array.NewDate64Builder(mem)}, + {"time32", array.NewTime32Builder(mem, &arrow.Time32Type{Unit: arrow.Millisecond})}, + {"time64", array.NewTime64Builder(mem, &arrow.Time64Type{Unit: arrow.Nanosecond})}, + } + for _, tc := range builderCases { + t.Run(tc.name+"_requires_time_Time", func(t *testing.T) { + defer tc.builder.Release() + err := appendTemporalValue(tc.builder, notATime) + require.Error(t, err) + assert.ErrorIs(t, err, ErrTypeMismatch) + }) + } + + t.Run("duration_requires_time_Duration", func(t *testing.T) { + b := array.NewDurationBuilder(mem, &arrow.DurationType{Unit: arrow.Nanosecond}) + defer b.Release() + err := appendTemporalValue(b, notATime) + require.Error(t, err) + assert.ErrorIs(t, err, ErrTypeMismatch) + }) + + t.Run("unexpected_builder_type", func(t *testing.T) { + b := array.NewInt32Builder(mem) + defer b.Release() + err := appendTemporalValue(b, reflect.ValueOf(time.Now())) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) +} + +func TestAppendDecimalValueErrors(t *testing.T) { + mem := checkedMem(t) + notDecimal := reflect.ValueOf("not a decimal") + + t.Run("decimal128_wrong_type", func(t *testing.T) { + b := array.NewDecimal128Builder(mem, &arrow.Decimal128Type{Precision: 10, Scale: 2}) + defer b.Release() + err := appendDecimalValue(b, notDecimal) + require.Error(t, err) + assert.ErrorIs(t, err, ErrTypeMismatch) + }) + + t.Run("decimal256_wrong_type", func(t *testing.T) { + b := array.NewDecimal256Builder(mem, &arrow.Decimal256Type{Precision: 40, Scale: 2}) + defer b.Release() + err := appendDecimalValue(b, notDecimal) + require.Error(t, err) + assert.ErrorIs(t, err, ErrTypeMismatch) + }) + + t.Run("unexpected_builder_type", func(t *testing.T) { + b := array.NewInt32Builder(mem) + defer b.Release() + err := appendDecimalValue(b, reflect.ValueOf(decimal128.New(0, 1))) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) +} + +func TestBuildLargeTypes(t *testing.T) { + mem := checkedMem(t) + largeOpts := tagOpts{Large: true} + + t.Run("string→LARGE_STRING", func(t *testing.T) { + arr := mustBuildArray(t, []string{"a", "b", "c"}, largeOpts, mem) + assert.Equal(t, arrow.LARGE_STRING, arr.DataType().ID()) + ls := arr.(*array.LargeString) + assert.Equal(t, "a", ls.Value(0)) + assert.Equal(t, "b", ls.Value(1)) + assert.Equal(t, "c", ls.Value(2)) + }) + + t.Run("[]byte→LARGE_BINARY", func(t *testing.T) { + arr := mustBuildArray(t, [][]byte{{1, 2}, {3}}, largeOpts, mem) + assert.Equal(t, arrow.LARGE_BINARY, arr.DataType().ID()) + }) + + t.Run("[]string→LARGE_LIST", func(t *testing.T) { + arr := mustBuildArray(t, [][]string{{"a", "b"}, {"c"}}, largeOpts, mem) + assert.Equal(t, arrow.LARGE_LIST, arr.DataType().ID()) + ll := arr.DataType().(*arrow.LargeListType) + assert.Equal(t, arrow.LARGE_STRING, ll.Elem().ID()) + }) + + t.Run("[][]byte→LARGE_LIST", func(t *testing.T) { + arr := mustBuildArray(t, [][][]byte{{{1}, {2}}, {{3}}}, largeOpts, mem) + assert.Equal(t, arrow.LARGE_LIST, arr.DataType().ID()) + ll := arr.DataType().(*arrow.LargeListType) + assert.Equal(t, arrow.LARGE_BINARY, ll.Elem().ID()) + }) + + t.Run("view+large→LARGE_LIST_VIEW", func(t *testing.T) { + // large→LARGE_LIST, then view→LARGE_LIST_VIEW; view wins on string elem (no LARGE_STRING_VIEW) + opts := tagOpts{Large: true, View: true} + arr := mustBuildArray(t, [][]string{{"x"}, {"y", "z"}}, opts, mem) + assert.Equal(t, arrow.LARGE_LIST_VIEW, arr.DataType().ID()) + llv := arr.DataType().(*arrow.LargeListViewType) + assert.Equal(t, arrow.STRING_VIEW, llv.Elem().ID()) + }) + + t.Run("map with large", func(t *testing.T) { + arr := mustBuildArray(t, []map[string]string{{"k": "v"}}, largeOpts, mem) + assert.Equal(t, arrow.MAP, arr.DataType().ID()) + mt := arr.DataType().(*arrow.MapType) + assert.Equal(t, arrow.LARGE_STRING, mt.KeyType().ID()) + assert.Equal(t, arrow.LARGE_STRING, mt.ItemField().Type.ID()) + }) + + t.Run("dict+large on string→Dictionary (large ignored for dict)", func(t *testing.T) { + opts := tagOpts{Large: true, Dict: true} + arr := mustBuildArray(t, []string{"a", "b", "a"}, opts, mem) + assert.Equal(t, arrow.DICTIONARY, arr.DataType().ID()) + dt := arr.DataType().(*arrow.DictionaryType) + assert.Equal(t, arrow.STRING, dt.ValueType.ID()) // large not applied, library limitation + }) +} + +func TestAppendTemporalValueUnitHandling(t *testing.T) { + mem := checkedMem(t) + ref := time.Date(2024, 1, 15, 12, 34, 56, 789_000_000, time.UTC) + + timestampCases := []struct { + name string + unit arrow.TimeUnit + }{ + {"timestamp_second", arrow.Second}, + {"timestamp_millisecond", arrow.Millisecond}, + {"timestamp_microsecond", arrow.Microsecond}, + {"timestamp_nanosecond", arrow.Nanosecond}, + } + for _, tc := range timestampCases { + t.Run(tc.name, func(t *testing.T) { + dt := &arrow.TimestampType{Unit: tc.unit} + b := array.NewTimestampBuilder(mem, dt) + defer b.Release() + require.NoError(t, appendTemporalValue(b, reflect.ValueOf(ref))) + arr := b.NewArray().(*array.Timestamp) + defer arr.Release() + got := int64(arr.Value(0)) + want := ref.UnixNano() / int64(tc.unit.Multiplier()) + assert.Equal(t, want, got, "%s: stored value should be scaled by unit", tc.name) + }) + } + + durationCases := []struct { + name string + unit arrow.TimeUnit + d time.Duration + }{ + {"duration_second", arrow.Second, 90 * time.Second}, + {"duration_millisecond", arrow.Millisecond, 1500 * time.Millisecond}, + {"duration_microsecond", arrow.Microsecond, 2500 * time.Microsecond}, + {"duration_nanosecond", arrow.Nanosecond, 12345 * time.Nanosecond}, + } + for _, tc := range durationCases { + t.Run(tc.name, func(t *testing.T) { + dt := &arrow.DurationType{Unit: tc.unit} + b := array.NewDurationBuilder(mem, dt) + defer b.Release() + require.NoError(t, appendTemporalValue(b, reflect.ValueOf(tc.d))) + arr := b.NewArray().(*array.Duration) + defer arr.Release() + got := int64(arr.Value(0)) + want := tc.d.Nanoseconds() / int64(tc.unit.Multiplier()) + assert.Equal(t, want, got, "%s: stored value should be scaled by unit", tc.name) + }) + } +} + +func TestWithLargeErrors(t *testing.T) { + mem := checkedMem(t) + + t.Run("large on int64 slice errors", func(t *testing.T) { + _, err := FromSlice([]int64{1, 2, 3}, mem, WithLarge()) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + assert.Contains(t, err.Error(), "large option has no effect") + }) + + t.Run("large on float32 slice errors", func(t *testing.T) { + _, err := FromSlice([]float32{1.0}, mem, WithLarge()) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("large on struct with no string fields errors", func(t *testing.T) { + type NoStrings struct { + X int32 + Y float64 + } + _, err := FromSlice([]NoStrings{{1, 2.0}}, mem, WithLarge()) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + assert.Contains(t, err.Error(), "large option has no effect") + }) + + t.Run("large on string slice succeeds", func(t *testing.T) { + arr, err := FromSlice([]string{"a"}, mem, WithLarge()) + require.NoError(t, err) + defer arr.Release() + assert.Equal(t, arrow.LARGE_STRING, arr.DataType().ID()) + }) +} + +func TestAppendValueViewBuilders(t *testing.T) { + mem := checkedMem(t) + + t.Run("StringViewBuilder appends string value", func(t *testing.T) { + b := array.NewStringViewBuilder(mem) + defer b.Release() + err := appendValue(b, reflect.ValueOf("hello")) + require.NoError(t, err) + arr := b.NewArray() + defer arr.Release() + assert.Equal(t, 1, arr.Len()) + assert.Equal(t, "hello", arr.(*array.StringView).Value(0)) + }) + + t.Run("BinaryViewBuilder appends binary value", func(t *testing.T) { + b := array.NewBinaryViewBuilder(mem) + defer b.Release() + err := appendValue(b, reflect.ValueOf([]byte{1, 2, 3})) + require.NoError(t, err) + arr := b.NewArray() + defer arr.Release() + assert.Equal(t, 1, arr.Len()) + assert.Equal(t, []byte{1, 2, 3}, arr.(*array.BinaryView).Value(0)) + }) + + t.Run("BinaryViewBuilder appends null for nil slice", func(t *testing.T) { + b := array.NewBinaryViewBuilder(mem) + defer b.Release() + var nilSlice []byte + err := appendValue(b, reflect.ValueOf(nilSlice)) + require.NoError(t, err) + arr := b.NewArray() + defer arr.Release() + assert.True(t, arr.IsNull(0)) + }) +} diff --git a/arrow/array/arreflect/reflect_helpers_test.go b/arrow/array/arreflect/reflect_helpers_test.go new file mode 100644 index 00000000..6f1f241f --- /dev/null +++ b/arrow/array/arreflect/reflect_helpers_test.go @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect + +import ( + "reflect" + "testing" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func checkedMem(t *testing.T) *memory.CheckedAllocator { + t.Helper() + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + t.Cleanup(func() { mem.AssertSize(t, 0) }) + return mem +} + +func setValueInto[T any](t *testing.T, dst *T, arr arrow.Array, i int) { + t.Helper() + require.NoError(t, setValue(reflect.ValueOf(dst).Elem(), arr, i)) +} + +func assertMultiLevelPtrNullPattern(t *testing.T, arr arrow.Array) { + t.Helper() + assert.Equal(t, 3, arr.Len()) + assert.False(t, arr.IsNull(0), "index 0 should not be null") + assert.True(t, arr.IsNull(1), "index 1 should be null") + assert.False(t, arr.IsNull(2), "index 2 should not be null") +} + +func makeStringArray(t *testing.T, mem memory.Allocator, vals ...string) *array.String { + t.Helper() + b := array.NewStringBuilder(mem) + defer b.Release() + b.AppendValues(vals, nil) + a := b.NewStringArray() + t.Cleanup(a.Release) + return a +} + +func makeInt32Array(t *testing.T, mem memory.Allocator, vals ...int32) *array.Int32 { + t.Helper() + b := array.NewInt32Builder(mem) + defer b.Release() + b.AppendValues(vals, nil) + a := b.NewInt32Array() + t.Cleanup(a.Release) + return a +} + +func makeStructArray(t *testing.T, arrays []arrow.Array, names []string) *array.Struct { + t.Helper() + sa, err := array.NewStructArray(arrays, names) + require.NoError(t, err) + t.Cleanup(sa.Release) + return sa +} + +func mustBuildArray(t *testing.T, vals any, opts tagOpts, mem memory.Allocator) arrow.Array { + t.Helper() + arr, err := buildArray(reflect.ValueOf(vals), opts, mem) + require.NoError(t, err) + t.Cleanup(arr.Release) + return arr +} + +func mustBuildDefault(t *testing.T, vals any, mem memory.Allocator) arrow.Array { + t.Helper() + return mustBuildArray(t, vals, tagOpts{}, mem) +} diff --git a/arrow/array/arreflect/reflect_infer.go b/arrow/array/arreflect/reflect_infer.go new file mode 100644 index 00000000..811edf11 --- /dev/null +++ b/arrow/array/arreflect/reflect_infer.go @@ -0,0 +1,551 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect + +import ( + "fmt" + "reflect" + "time" + "unicode" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/arrow-go/v18/arrow/decimal256" +) + +var ( + typeOfTime = reflect.TypeOf(time.Time{}) + typeOfDuration = reflect.TypeOf(time.Duration(0)) + typeOfDec32 = reflect.TypeOf(decimal.Decimal32(0)) + typeOfDec64 = reflect.TypeOf(decimal.Decimal64(0)) + typeOfDec128 = reflect.TypeOf(decimal128.Num{}) + typeOfDec256 = reflect.TypeOf(decimal256.Num{}) + typeOfByteSlice = reflect.TypeOf([]byte{}) + typeOfInt = reflect.TypeOf(int(0)) + typeOfUint = reflect.TypeOf(uint(0)) + typeOfInt8 = reflect.TypeOf(int8(0)) + typeOfInt16 = reflect.TypeOf(int16(0)) + typeOfInt32 = reflect.TypeOf(int32(0)) + typeOfInt64 = reflect.TypeOf(int64(0)) + typeOfUint8 = reflect.TypeOf(uint8(0)) + typeOfUint16 = reflect.TypeOf(uint16(0)) + typeOfUint32 = reflect.TypeOf(uint32(0)) + typeOfUint64 = reflect.TypeOf(uint64(0)) + typeOfFloat32 = reflect.TypeOf(float32(0)) + typeOfFloat64 = reflect.TypeOf(float64(0)) + typeOfBool = reflect.TypeOf(false) + typeOfString = reflect.TypeOf("") +) + +const ( + dec32DefaultPrecision int32 = 9 + dec64DefaultPrecision int32 = 18 + dec128DefaultPrecision int32 = 38 + dec256DefaultPrecision int32 = 76 +) + +type listElemTyper interface{ Elem() arrow.DataType } + +func inferPrimitiveArrowType(t reflect.Type) (arrow.DataType, error) { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + switch t { + case typeOfInt8: + return arrow.PrimitiveTypes.Int8, nil + case typeOfInt16: + return arrow.PrimitiveTypes.Int16, nil + case typeOfInt32: + return arrow.PrimitiveTypes.Int32, nil + case typeOfInt64: + return arrow.PrimitiveTypes.Int64, nil + case typeOfInt: + return arrow.PrimitiveTypes.Int64, nil + case typeOfUint8: + return arrow.PrimitiveTypes.Uint8, nil + case typeOfUint16: + return arrow.PrimitiveTypes.Uint16, nil + case typeOfUint32: + return arrow.PrimitiveTypes.Uint32, nil + case typeOfUint64: + return arrow.PrimitiveTypes.Uint64, nil + case typeOfUint: + return arrow.PrimitiveTypes.Uint64, nil + case typeOfFloat32: + return arrow.PrimitiveTypes.Float32, nil + case typeOfFloat64: + return arrow.PrimitiveTypes.Float64, nil + case typeOfBool: + return arrow.FixedWidthTypes.Boolean, nil + case typeOfString: + return arrow.BinaryTypes.String, nil + case typeOfByteSlice: + return arrow.BinaryTypes.Binary, nil + case typeOfTime: + return &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: "UTC"}, nil + case typeOfDuration: + return &arrow.DurationType{Unit: arrow.Nanosecond}, nil + case typeOfDec128: + return &arrow.Decimal128Type{Precision: dec128DefaultPrecision, Scale: 0}, nil + case typeOfDec32: + return &arrow.Decimal32Type{Precision: dec32DefaultPrecision, Scale: 0}, nil + case typeOfDec64: + return &arrow.Decimal64Type{Precision: dec64DefaultPrecision, Scale: 0}, nil + case typeOfDec256: + return &arrow.Decimal256Type{Precision: dec256DefaultPrecision, Scale: 0}, nil + default: + return nil, fmt.Errorf("unsupported Go type for Arrow inference %v: %w", t, ErrUnsupportedType) + } +} + +func inferArrowType(t reflect.Type) (arrow.DataType, error) { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + if t == typeOfByteSlice { + return arrow.BinaryTypes.Binary, nil + } + + switch t.Kind() { + case reflect.Slice: + elemDT, err := inferArrowType(t.Elem()) + if err != nil { + return nil, err + } + return arrow.ListOf(elemDT), nil + + case reflect.Array: + elemDT, err := inferArrowType(t.Elem()) + if err != nil { + return nil, err + } + return arrow.FixedSizeListOf(int32(t.Len()), elemDT), nil + + case reflect.Map: + keyDT, err := inferArrowType(t.Key()) + if err != nil { + return nil, err + } + valDT, err := inferArrowType(t.Elem()) + if err != nil { + return nil, err + } + return arrow.MapOf(keyDT, valDT), nil + + case reflect.Struct: + return inferStructType(t) + + default: + return inferPrimitiveArrowType(t) + } +} + +func applyDecimalOpts(dt arrow.DataType, origType reflect.Type, opts tagOpts) arrow.DataType { + if !opts.HasDecimalOpts { + return dt + } + prec, scale := opts.DecimalPrecision, opts.DecimalScale + switch origType { + case typeOfDec128: + return &arrow.Decimal128Type{Precision: prec, Scale: scale} + case typeOfDec256: + return &arrow.Decimal256Type{Precision: prec, Scale: scale} + case typeOfDec32: + return &arrow.Decimal32Type{Precision: prec, Scale: scale} + case typeOfDec64: + return &arrow.Decimal64Type{Precision: prec, Scale: scale} + } + return dt +} + +func applyTemporalOpts(dt arrow.DataType, origType reflect.Type, opts tagOpts) arrow.DataType { + if origType != typeOfTime || opts.Temporal == "" || opts.Temporal == "timestamp" { + return dt + } + switch opts.Temporal { + case "date32": + return arrow.FixedWidthTypes.Date32 + case "date64": + return arrow.FixedWidthTypes.Date64 + case "time32": + return &arrow.Time32Type{Unit: arrow.Millisecond} + case "time64": + return &arrow.Time64Type{Unit: arrow.Nanosecond} + } + return dt +} + +func applyLargeOpts(dt arrow.DataType) arrow.DataType { + switch dt.ID() { + case arrow.STRING: + return arrow.BinaryTypes.LargeString + case arrow.BINARY: + return arrow.BinaryTypes.LargeBinary + case arrow.LIST: + return arrow.LargeListOf(applyLargeOpts(dt.(*arrow.ListType).Elem())) + case arrow.LIST_VIEW: + return arrow.LargeListViewOf(applyLargeOpts(dt.(*arrow.ListViewType).Elem())) + case arrow.LARGE_LIST: + return arrow.LargeListOf(applyLargeOpts(dt.(*arrow.LargeListType).Elem())) + case arrow.LARGE_LIST_VIEW: + return arrow.LargeListViewOf(applyLargeOpts(dt.(*arrow.LargeListViewType).Elem())) + case arrow.FIXED_SIZE_LIST: + fsl := dt.(*arrow.FixedSizeListType) + return arrow.FixedSizeListOf(fsl.Len(), applyLargeOpts(fsl.Elem())) + case arrow.MAP: + mt := dt.(*arrow.MapType) + return arrow.MapOf(applyLargeOpts(mt.KeyType()), applyLargeOpts(mt.ItemField().Type)) + case arrow.STRUCT: + st := dt.(*arrow.StructType) + fields := make([]arrow.Field, st.NumFields()) + for i := 0; i < st.NumFields(); i++ { + f := st.Field(i) + f.Type = applyLargeOpts(f.Type) + fields[i] = f + } + return arrow.StructOf(fields...) + default: + return dt + } +} + +func hasLargeableType(dt arrow.DataType) bool { + switch dt.ID() { + case arrow.STRING, arrow.BINARY, arrow.LIST, arrow.LIST_VIEW: + return true + case arrow.STRUCT: + st := dt.(*arrow.StructType) + for i := 0; i < st.NumFields(); i++ { + if hasLargeableType(st.Field(i).Type) { + return true + } + } + return false + case arrow.FIXED_SIZE_LIST: + return hasLargeableType(dt.(*arrow.FixedSizeListType).Elem()) + case arrow.MAP: + mt := dt.(*arrow.MapType) + return hasLargeableType(mt.KeyType()) || hasLargeableType(mt.ItemField().Type) + default: + return false + } +} + +func applyViewOpts(dt arrow.DataType) arrow.DataType { + switch dt.ID() { + case arrow.STRING, arrow.LARGE_STRING: + return arrow.BinaryTypes.StringView + case arrow.BINARY, arrow.LARGE_BINARY: + return arrow.BinaryTypes.BinaryView + case arrow.LIST: + return arrow.ListViewOf(applyViewOpts(dt.(*arrow.ListType).Elem())) + case arrow.LIST_VIEW: + return arrow.ListViewOf(applyViewOpts(dt.(*arrow.ListViewType).Elem())) + case arrow.LARGE_LIST: + return arrow.LargeListViewOf(applyViewOpts(dt.(*arrow.LargeListType).Elem())) + case arrow.LARGE_LIST_VIEW: + return arrow.LargeListViewOf(applyViewOpts(dt.(*arrow.LargeListViewType).Elem())) + case arrow.FIXED_SIZE_LIST: + fsl := dt.(*arrow.FixedSizeListType) + return arrow.FixedSizeListOf(fsl.Len(), applyViewOpts(fsl.Elem())) + case arrow.MAP: + mt := dt.(*arrow.MapType) + return arrow.MapOf(applyViewOpts(mt.KeyType()), applyViewOpts(mt.ItemField().Type)) + case arrow.STRUCT: + st := dt.(*arrow.StructType) + fields := make([]arrow.Field, st.NumFields()) + for i := 0; i < st.NumFields(); i++ { + f := st.Field(i) + f.Type = applyViewOpts(f.Type) + fields[i] = f + } + return arrow.StructOf(fields...) + default: + return dt + } +} + +func hasViewableType(dt arrow.DataType) bool { + switch dt.ID() { + case arrow.STRING, arrow.BINARY, arrow.LARGE_STRING, arrow.LARGE_BINARY, + arrow.STRING_VIEW, arrow.BINARY_VIEW, + arrow.LIST, arrow.LIST_VIEW, arrow.LARGE_LIST, arrow.LARGE_LIST_VIEW: + return true + case arrow.STRUCT: + st := dt.(*arrow.StructType) + for i := 0; i < st.NumFields(); i++ { + if hasViewableType(st.Field(i).Type) { + return true + } + } + return false + case arrow.FIXED_SIZE_LIST: + return hasViewableType(dt.(*arrow.FixedSizeListType).Elem()) + case arrow.MAP: + mt := dt.(*arrow.MapType) + return hasViewableType(mt.KeyType()) || hasViewableType(mt.ItemField().Type) + default: + return false + } +} + +func applyEncodingOpts(dt arrow.DataType, fm fieldMeta) (arrow.DataType, error) { + switch { + case fm.Opts.Dict: + if err := validateDictValueType(dt); err != nil { + return nil, fmt.Errorf("arreflect: dict tag on field %q: %w", fm.Name, err) + } + return &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, ValueType: dt}, nil + case fm.Opts.REE: + return nil, fmt.Errorf("arreflect: ree tag on struct field %q is not supported; use ree at top-level via FromSlice", fm.Name) + } + return dt, nil +} + +func inferStructType(t reflect.Type) (*arrow.StructType, error) { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("arreflect: expected struct, got %v", t) + } + + fields := cachedStructFields(t) + arrowFields := make([]arrow.Field, 0, len(fields)) + + for _, fm := range fields { + if err := validateOptions(fm.Opts); err != nil { + return nil, fmt.Errorf("struct field %q: %w", fm.Name, err) + } + origType := fm.Type + for origType.Kind() == reflect.Ptr { + origType = origType.Elem() + } + + dt, err := inferArrowType(fm.Type) + if err != nil { + return nil, fmt.Errorf("struct field %q: %w", fm.Name, err) + } + + dt = applyDecimalOpts(dt, origType, fm.Opts) + dt = applyTemporalOpts(dt, origType, fm.Opts) + if fm.Opts.Large { + dt = applyLargeOpts(dt) + } + if fm.Opts.View { + dt = applyViewOpts(dt) + } + dt, err = applyEncodingOpts(dt, fm) + if err != nil { + return nil, err + } + + arrowFields = append(arrowFields, arrow.Field{ + Name: fm.Name, + Type: dt, + Nullable: fm.Nullable, + }) + } + + return arrow.StructOf(arrowFields...), nil +} + +// InferSchema infers an *arrow.Schema from a Go struct type T. +// T must be a struct type; returns an error otherwise. +// For column-level Arrow type inspection, use [InferType]. +// Field names come from arrow struct tags or Go field names. +// Pointer fields are marked Nullable=true. +func InferSchema[T any]() (*arrow.Schema, error) { + t := reflect.TypeFor[T]() + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("arreflect: InferSchema requires a struct type T, got %v", t) + } + st, err := inferStructType(t) + if err != nil { + return nil, err + } + fields := make([]arrow.Field, st.NumFields()) + for i := 0; i < st.NumFields(); i++ { + fields[i] = st.Field(i) + } + return arrow.NewSchema(fields, nil), nil +} + +// InferType infers the Arrow DataType for a Go type T. +// For struct types, [InferSchema] is preferred when the result will be used with +// arrow.Record or array.NewRecord; InferType returns an arrow.DataType that would +// require an additional cast to *arrow.StructType. +func InferType[T any]() (arrow.DataType, error) { + t := reflect.TypeFor[T]() + return inferArrowType(t) +} + +// InferGoType returns the Go reflect.Type corresponding to the given Arrow DataType. +// For STRUCT types it constructs an anonymous struct type at runtime using +// [reflect.StructOf]; field names are exported (capitalised) with the original +// Arrow field name preserved in an arrow struct tag. Nullable Arrow fields +// (field.Nullable == true) become pointer types (*T). +// For DICTIONARY and RUN_END_ENCODED types it returns the Go type of the +// value/encoded type respectively (dictionaries are resolved transparently). +func InferGoType(dt arrow.DataType) (reflect.Type, error) { + switch dt.ID() { + case arrow.INT8: + return typeOfInt8, nil + case arrow.INT16: + return typeOfInt16, nil + case arrow.INT32: + return typeOfInt32, nil + case arrow.INT64: + return typeOfInt64, nil + case arrow.UINT8: + return typeOfUint8, nil + case arrow.UINT16: + return typeOfUint16, nil + case arrow.UINT32: + return typeOfUint32, nil + case arrow.UINT64: + return typeOfUint64, nil + case arrow.FLOAT32: + return typeOfFloat32, nil + case arrow.FLOAT64: + return typeOfFloat64, nil + case arrow.BOOL: + return typeOfBool, nil + case arrow.STRING, arrow.LARGE_STRING, arrow.STRING_VIEW: + return typeOfString, nil + case arrow.BINARY, arrow.LARGE_BINARY, arrow.BINARY_VIEW: + return typeOfByteSlice, nil + case arrow.TIMESTAMP, arrow.DATE32, arrow.DATE64, arrow.TIME32, arrow.TIME64: + return typeOfTime, nil + case arrow.DURATION: + return typeOfDuration, nil + case arrow.DECIMAL128: + return typeOfDec128, nil + case arrow.DECIMAL256: + return typeOfDec256, nil + case arrow.DECIMAL32: + return typeOfDec32, nil + case arrow.DECIMAL64: + return typeOfDec64, nil + + case arrow.LIST, arrow.LARGE_LIST, arrow.LIST_VIEW, arrow.LARGE_LIST_VIEW: + ll, ok := dt.(listElemTyper) + if !ok { + return nil, fmt.Errorf("unsupported Arrow type for Go inference: %v: %w", dt, ErrUnsupportedType) + } + elemDT := ll.Elem() + elemType, err := InferGoType(elemDT) + if err != nil { + return nil, err + } + return reflect.SliceOf(elemType), nil + + case arrow.FIXED_SIZE_LIST: + fsl := dt.(*arrow.FixedSizeListType) + elemType, err := InferGoType(fsl.Elem()) + if err != nil { + return nil, err + } + return reflect.ArrayOf(int(fsl.Len()), elemType), nil + + case arrow.MAP: + mt := dt.(*arrow.MapType) + keyType, err := InferGoType(mt.KeyType()) + if err != nil { + return nil, err + } + if !keyType.Comparable() { + return nil, fmt.Errorf("arreflect: InferGoType: MAP key type %v is not comparable in Go: %w", mt.KeyType(), ErrUnsupportedType) + } + valType, err := InferGoType(mt.ItemField().Type) + if err != nil { + return nil, err + } + return reflect.MapOf(keyType, valType), nil + + case arrow.STRUCT: + return inferGoStructType(dt.(*arrow.StructType)) + + case arrow.DICTIONARY: + return InferGoType(dt.(*arrow.DictionaryType).ValueType) + + case arrow.RUN_END_ENCODED: + return InferGoType(dt.(*arrow.RunEndEncodedType).Encoded()) + + default: + return nil, fmt.Errorf("unsupported Arrow type for Go inference: %v: %w", dt, ErrUnsupportedType) + } +} + +func exportedFieldName(name string, index int) (string, error) { + if len(name) == 0 { + return fmt.Sprintf("Field%d", index), nil + } + runes := []rune(name) + // If the first rune is not a letter (e.g. '_', digit), prefix with "X" + // to produce a valid exported Go identifier while preserving the original + // name in the struct tag. + if !unicode.IsLetter(runes[0]) { + runes = append([]rune{'X'}, runes...) + } else { + runes[0] = unicode.ToUpper(runes[0]) + } + for j, r := range runes { + if j == 0 { + continue + } + if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' { + return "", fmt.Errorf("arreflect: InferGoType: field name %q produces invalid Go identifier: %w", name, ErrUnsupportedType) + } + } + return string(runes), nil +} + +func inferGoStructType(st *arrow.StructType) (reflect.Type, error) { + fields := make([]reflect.StructField, st.NumFields()) + seen := make(map[string]string, st.NumFields()) + for i := 0; i < st.NumFields(); i++ { + f := st.Field(i) + ft, err := InferGoType(f.Type) + if err != nil { + return nil, err + } + if f.Nullable { + ft = reflect.PointerTo(ft) + } + exportedName, err := exportedFieldName(f.Name, i) + if err != nil { + return nil, err + } + if origName, dup := seen[exportedName]; dup { + return nil, fmt.Errorf("arreflect: InferGoType: field names %q and %q both export as %q: %w", origName, f.Name, exportedName, ErrUnsupportedType) + } + seen[exportedName] = f.Name + fields[i] = reflect.StructField{ + Name: exportedName, + Type: ft, + Tag: reflect.StructTag(fmt.Sprintf(`arrow:%q`, f.Name)), + } + } + return reflect.StructOf(fields), nil +} diff --git a/arrow/array/arreflect/reflect_infer_test.go b/arrow/array/arreflect/reflect_infer_test.go new file mode 100644 index 00000000..744d296f --- /dev/null +++ b/arrow/array/arreflect/reflect_infer_test.go @@ -0,0 +1,869 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect + +import ( + "reflect" + "strings" + "testing" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/arrow-go/v18/arrow/decimal256" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInferPrimitiveArrowType(t *testing.T) { + cases := []struct { + name string + goType reflect.Type + wantID arrow.Type + wantErr bool + }{ + {"int8", reflect.TypeOf(int8(0)), arrow.INT8, false}, + {"int16", reflect.TypeOf(int16(0)), arrow.INT16, false}, + {"int32", reflect.TypeOf(int32(0)), arrow.INT32, false}, + {"int64", reflect.TypeOf(int64(0)), arrow.INT64, false}, + {"int", reflect.TypeOf(int(0)), arrow.INT64, false}, + {"uint8", reflect.TypeOf(uint8(0)), arrow.UINT8, false}, + {"uint16", reflect.TypeOf(uint16(0)), arrow.UINT16, false}, + {"uint32", reflect.TypeOf(uint32(0)), arrow.UINT32, false}, + {"uint64", reflect.TypeOf(uint64(0)), arrow.UINT64, false}, + {"uint", reflect.TypeOf(uint(0)), arrow.UINT64, false}, + {"float32", reflect.TypeOf(float32(0)), arrow.FLOAT32, false}, + {"float64", reflect.TypeOf(float64(0)), arrow.FLOAT64, false}, + {"bool", reflect.TypeOf(false), arrow.BOOL, false}, + {"string", reflect.TypeOf(""), arrow.STRING, false}, + {"[]byte", reflect.TypeOf([]byte{}), arrow.BINARY, false}, + {"time.Time", reflect.TypeOf(time.Time{}), arrow.TIMESTAMP, false}, + {"time.Duration", reflect.TypeOf(time.Duration(0)), arrow.DURATION, false}, + {"decimal128.Num", reflect.TypeOf(decimal128.Num{}), arrow.DECIMAL128, false}, + {"decimal256.Num", reflect.TypeOf(decimal256.Num{}), arrow.DECIMAL256, false}, + {"decimal.Decimal32", reflect.TypeOf(decimal.Decimal32(0)), arrow.DECIMAL32, false}, + {"decimal.Decimal64", reflect.TypeOf(decimal.Decimal64(0)), arrow.DECIMAL64, false}, + {"*int32 pointer transparent", reflect.TypeOf((*int32)(nil)), arrow.INT32, false}, + {"chan int unsupported", reflect.TypeOf(make(chan int)), 0, true}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, err := inferPrimitiveArrowType(tc.goType) + if tc.wantErr { + require.Error(t, err, "expected error, got nil (type: %v)", got) + return + } + require.NoError(t, err) + assert.Equal(t, tc.wantID, got.ID()) + }) + } +} + +func TestInferArrowType(t *testing.T) { + t.Run("[]int32 is LIST", func(t *testing.T) { + dt, err := inferArrowType(reflect.TypeOf([]int32{})) + require.NoError(t, err) + assert.Equal(t, arrow.LIST, dt.ID()) + }) + + t.Run("[3]float64 is FIXED_SIZE_LIST size 3", func(t *testing.T) { + dt, err := inferArrowType(reflect.TypeOf([3]float64{})) + require.NoError(t, err) + assert.Equal(t, arrow.FIXED_SIZE_LIST, dt.ID()) + fsl := dt.(*arrow.FixedSizeListType) + assert.Equal(t, int32(3), fsl.Len()) + }) + + t.Run("map[string]int64 is MAP", func(t *testing.T) { + dt, err := inferArrowType(reflect.TypeOf(map[string]int64{})) + require.NoError(t, err) + assert.Equal(t, arrow.MAP, dt.ID()) + }) + + t.Run("struct with 2 fields is STRUCT", func(t *testing.T) { + type S struct { + Name string + Age int32 + } + dt, err := inferArrowType(reflect.TypeOf(S{})) + require.NoError(t, err) + assert.Equal(t, arrow.STRUCT, dt.ID()) + st := dt.(*arrow.StructType) + assert.Equal(t, 2, st.NumFields()) + }) + + t.Run("[]map[string]struct{Score float64} nested", func(t *testing.T) { + type Inner struct { + Score float64 + } + dt, err := inferArrowType(reflect.TypeOf([]map[string]Inner{})) + require.NoError(t, err) + assert.Equal(t, arrow.LIST, dt.ID()) + lt := dt.(*arrow.ListType) + assert.Equal(t, arrow.MAP, lt.Elem().ID()) + mt := lt.Elem().(*arrow.MapType) + assert.Equal(t, arrow.STRUCT, mt.ItemField().Type.ID()) + }) + + t.Run("*[]string pointer to slice is LIST", func(t *testing.T) { + dt, err := inferArrowType(reflect.TypeOf((*[]string)(nil))) + require.NoError(t, err) + assert.Equal(t, arrow.LIST, dt.ID()) + }) +} + +func TestInferStructType(t *testing.T) { + t.Run("simple struct field names and types", func(t *testing.T) { + type S struct { + Name string + Score float32 + } + st, err := inferStructType(reflect.TypeOf(S{})) + require.NoError(t, err) + require.Equal(t, 2, st.NumFields()) + assert.Equal(t, "Name", st.Field(0).Name) + assert.Equal(t, arrow.STRING, st.Field(0).Type.ID()) + assert.Equal(t, "Score", st.Field(1).Name) + assert.Equal(t, arrow.FLOAT32, st.Field(1).Type.ID()) + }) + + t.Run("pointer fields are nullable", func(t *testing.T) { + type S struct { + ID int32 + Label *string + } + st, err := inferStructType(reflect.TypeOf(S{})) + require.NoError(t, err) + assert.False(t, st.Field(0).Nullable, "ID should not be nullable") + assert.True(t, st.Field(1).Nullable, "Label should be nullable") + }) + + t.Run("arrow:\"-\" tagged field is excluded", func(t *testing.T) { + type S struct { + Keep string + Hidden int32 `arrow:"-"` + } + st, err := inferStructType(reflect.TypeOf(S{})) + require.NoError(t, err) + assert.Equal(t, 1, st.NumFields()) + assert.Equal(t, "Keep", st.Field(0).Name) + }) + + t.Run("arrow custom name tag", func(t *testing.T) { + type S struct { + GoName int64 `arrow:"custom_name"` + } + st, err := inferStructType(reflect.TypeOf(S{})) + require.NoError(t, err) + assert.Equal(t, "custom_name", st.Field(0).Name) + }) + + t.Run("decimal128 with precision/scale tag", func(t *testing.T) { + type S struct { + Amount decimal128.Num `arrow:",decimal(18,2)"` + } + st, err := inferStructType(reflect.TypeOf(S{})) + require.NoError(t, err) + dt := st.Field(0).Type + require.Equal(t, arrow.DECIMAL128, dt.ID()) + d128 := dt.(*arrow.Decimal128Type) + assert.Equal(t, int32(18), d128.Precision) + assert.Equal(t, int32(2), d128.Scale) + }) + + t.Run("decimal256 with precision/scale tag", func(t *testing.T) { + type S struct { + Amount decimal256.Num `arrow:",decimal(40,5)"` + } + st, err := inferStructType(reflect.TypeOf(S{})) + require.NoError(t, err) + dt := st.Field(0).Type + require.Equal(t, arrow.DECIMAL256, dt.ID()) + d256 := dt.(*arrow.Decimal256Type) + assert.Equal(t, int32(40), d256.Precision) + assert.Equal(t, int32(5), d256.Scale) + }) + + t.Run("decimal32 with precision/scale tag", func(t *testing.T) { + type S struct { + Amount decimal.Decimal32 `arrow:",decimal(9,2)"` + } + st, err := inferStructType(reflect.TypeOf(S{})) + require.NoError(t, err) + dt := st.Field(0).Type + require.Equal(t, arrow.DECIMAL32, dt.ID()) + d32 := dt.(*arrow.Decimal32Type) + assert.Equal(t, int32(9), d32.Precision) + assert.Equal(t, int32(2), d32.Scale) + }) + + t.Run("non-struct returns error", func(t *testing.T) { + _, err := inferStructType(reflect.TypeOf(42)) + assert.Error(t, err, "expected error for non-struct, got nil") + }) + + t.Run("time.Time with date32 tag maps to DATE32", func(t *testing.T) { + type S struct { + Ts time.Time `arrow:",date32"` + } + st, err := inferStructType(reflect.TypeOf(S{})) + require.NoError(t, err) + dt := st.Field(0).Type + assert.Equal(t, arrow.DATE32, dt.ID()) + }) +} + +func TestInferArrowSchema(t *testing.T) { + t.Run("simple struct mixed fields", func(t *testing.T) { + type S struct { + Name string + Age int32 + Score float64 + } + schema, err := InferSchema[S]() + require.NoError(t, err) + require.Equal(t, 3, schema.NumFields()) + assert.Equal(t, "Name", schema.Field(0).Name) + assert.Equal(t, arrow.STRING, schema.Field(0).Type.ID()) + assert.Equal(t, "Age", schema.Field(1).Name) + assert.Equal(t, arrow.INT32, schema.Field(1).Type.ID()) + assert.Equal(t, "Score", schema.Field(2).Name) + assert.Equal(t, arrow.FLOAT64, schema.Field(2).Type.ID()) + }) + + t.Run("pointer fields are nullable", func(t *testing.T) { + type S struct { + ID int32 + Label *string + } + schema, err := InferSchema[S]() + require.NoError(t, err) + assert.False(t, schema.Field(0).Nullable, "ID should not be nullable") + assert.True(t, schema.Field(1).Nullable, "Label should be nullable") + }) + + t.Run("arrow:\"-\" tag excludes field", func(t *testing.T) { + type S struct { + Keep string + Hidden int32 `arrow:"-"` + } + schema, err := InferSchema[S]() + require.NoError(t, err) + assert.Equal(t, 1, schema.NumFields()) + assert.Equal(t, "Keep", schema.Field(0).Name) + }) + + t.Run("arrow custom name tag", func(t *testing.T) { + type S struct { + GoName int64 `arrow:"custom_name"` + } + schema, err := InferSchema[S]() + require.NoError(t, err) + assert.Equal(t, "custom_name", schema.Field(0).Name) + }) + + t.Run("non-struct type returns error", func(t *testing.T) { + _, err := InferSchema[int]() + assert.Error(t, err, "expected error for non-struct, got nil") + }) +} + +func TestInferArrowTypePublic(t *testing.T) { + t.Run("int32 is INT32", func(t *testing.T) { + dt, err := InferType[int32]() + require.NoError(t, err) + assert.Equal(t, arrow.INT32, dt.ID()) + }) + + t.Run("[]string is LIST", func(t *testing.T) { + dt, err := InferType[[]string]() + require.NoError(t, err) + assert.Equal(t, arrow.LIST, dt.ID()) + }) + + t.Run("map[string]float64 is MAP", func(t *testing.T) { + dt, err := InferType[map[string]float64]() + require.NoError(t, err) + assert.Equal(t, arrow.MAP, dt.ID()) + }) + + t.Run("struct{X int32} is STRUCT", func(t *testing.T) { + type S struct{ X int32 } + dt, err := InferType[S]() + require.NoError(t, err) + assert.Equal(t, arrow.STRUCT, dt.ID()) + }) +} + +func TestInferArrowSchemaStructFieldEncoding(t *testing.T) { + t.Run("dict-tagged string field becomes DICTIONARY", func(t *testing.T) { + type S struct { + Name string `arrow:"name,dict"` + } + schema, err := InferSchema[S]() + require.NoError(t, err) + f, ok := schema.FieldsByName("name") + require.True(t, ok && len(f) > 0, "field 'name' not found in schema") + assert.Equal(t, arrow.DICTIONARY, f[0].Type.ID()) + }) + + t.Run("view-tagged []string field becomes LIST_VIEW", func(t *testing.T) { + type S struct { + Tags []string `arrow:"tags,view"` + } + schema, err := InferSchema[S]() + require.NoError(t, err) + f, ok := schema.FieldsByName("tags") + require.True(t, ok && len(f) > 0, "field 'tags' not found in schema") + assert.Equal(t, arrow.LIST_VIEW, f[0].Type.ID()) + }) + + t.Run("ree-tagged field on struct is unsupported", func(t *testing.T) { + type REERow struct { + Val string `arrow:"val,ree"` + } + _, err := InferSchema[REERow]() + require.Error(t, err, "expected error for ree tag on struct field, got nil") + assert.True(t, strings.Contains(err.Error(), "ree tag on struct field"), "unexpected error message: %v", err) + }) +} + +func TestInferGoType(t *testing.T) { + primitives := []struct { + dt arrow.DataType + want reflect.Type + }{ + {arrow.PrimitiveTypes.Int32, reflect.TypeOf(int32(0))}, + {arrow.PrimitiveTypes.Float64, reflect.TypeOf(float64(0))}, + {arrow.FixedWidthTypes.Boolean, reflect.TypeOf(bool(false))}, + {arrow.BinaryTypes.String, reflect.TypeOf("")}, + {arrow.BinaryTypes.Binary, reflect.TypeOf([]byte{})}, + {&arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: "UTC"}, reflect.TypeOf(time.Time{})}, + {&arrow.DurationType{Unit: arrow.Nanosecond}, reflect.TypeOf(time.Duration(0))}, + } + for _, tt := range primitives { + got, err := InferGoType(tt.dt) + if assert.NoError(t, err, "InferGoType(%v)", tt.dt) { + assert.Equal(t, tt.want, got, "InferGoType(%v)", tt.dt) + } + } + + st := arrow.StructOf( + arrow.Field{Name: "id", Type: arrow.PrimitiveTypes.Int64}, + arrow.Field{Name: "name", Type: arrow.BinaryTypes.String, Nullable: true}, + ) + structType, err := InferGoType(st) + require.NoError(t, err, "struct") + require.Equal(t, reflect.Struct, structType.Kind()) + require.Equal(t, 2, structType.NumField()) + assert.Equal(t, reflect.Ptr, structType.Field(1).Type.Kind(), "nullable field should be pointer") + assert.Equal(t, reflect.String, structType.Field(1).Type.Elem().Kind(), "nullable field should be *string") + + listType, err := InferGoType(arrow.ListOf(arrow.PrimitiveTypes.Int32)) + require.NoError(t, err, "list") + require.Equal(t, reflect.Slice, listType.Kind()) + assert.Equal(t, reflect.TypeOf(int32(0)), listType.Elem(), "list elem wrong") + + fslType, err := InferGoType(arrow.FixedSizeListOf(3, arrow.PrimitiveTypes.Float32)) + require.NoError(t, err, "fsl") + require.Equal(t, reflect.Array, fslType.Kind()) + assert.Equal(t, 3, fslType.Len(), "array len want 3") + + _, err = InferGoType(arrow.Null) + require.Error(t, err, "expected error for unsupported type") + assert.ErrorIs(t, err, ErrUnsupportedType) +} + +func TestInferGoTypeMapNonComparableKey(t *testing.T) { + t.Run("MAP with non-comparable key returns error", func(t *testing.T) { + dt := arrow.MapOf(arrow.ListOf(arrow.PrimitiveTypes.Int32), arrow.BinaryTypes.String) + _, err := InferGoType(dt) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) +} + +func TestInferGoTypeStructDuplicateExportedNames(t *testing.T) { + t.Run("STRUCT with colliding exported names returns error", func(t *testing.T) { + st := arrow.StructOf( + arrow.Field{Name: "foo", Type: arrow.PrimitiveTypes.Int32}, + arrow.Field{Name: "Foo", Type: arrow.PrimitiveTypes.Int64}, + ) + _, err := InferGoType(st) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) +} + +func TestInferGoTypeStructInvalidIdentifier(t *testing.T) { + cases := []struct { + name string + fieldName string + }{ + {"hyphenated", "my-field"}, + {"space", "a b"}, + {"dot", "first.name"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + st := arrow.StructOf(arrow.Field{Name: tc.fieldName, Type: arrow.PrimitiveTypes.Int32}) + _, err := InferGoType(st) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + } + + t.Run("non-letter prefix mapped", func(t *testing.T) { + for _, tc := range []struct { + name string + expected string + }{ + {"_id", "X_id"}, + {"1st", "X1st"}, + } { + st := arrow.StructOf(arrow.Field{Name: tc.name, Type: arrow.PrimitiveTypes.Int32}) + goType, err := InferGoType(st) + assert.NoError(t, err) + assert.Equal(t, tc.expected, goType.Field(0).Name) + } + }) +} + +func TestInferGoTypeAllPrimitives(t *testing.T) { + cases := []struct { + name string + dt arrow.DataType + want reflect.Type + }{ + {"int8", arrow.PrimitiveTypes.Int8, reflect.TypeOf(int8(0))}, + {"int16", arrow.PrimitiveTypes.Int16, reflect.TypeOf(int16(0))}, + {"int64", arrow.PrimitiveTypes.Int64, reflect.TypeOf(int64(0))}, + {"uint8", arrow.PrimitiveTypes.Uint8, reflect.TypeOf(uint8(0))}, + {"uint16", arrow.PrimitiveTypes.Uint16, reflect.TypeOf(uint16(0))}, + {"uint32", arrow.PrimitiveTypes.Uint32, reflect.TypeOf(uint32(0))}, + {"uint64", arrow.PrimitiveTypes.Uint64, reflect.TypeOf(uint64(0))}, + {"float32", arrow.PrimitiveTypes.Float32, reflect.TypeOf(float32(0))}, + {"large_string", arrow.BinaryTypes.LargeString, reflect.TypeOf("")}, + {"large_binary", arrow.BinaryTypes.LargeBinary, reflect.TypeOf([]byte{})}, + {"date32", arrow.FixedWidthTypes.Date32, reflect.TypeOf(time.Time{})}, + {"date64", arrow.FixedWidthTypes.Date64, reflect.TypeOf(time.Time{})}, + {"time32_ms", &arrow.Time32Type{Unit: arrow.Millisecond}, reflect.TypeOf(time.Time{})}, + {"time64_ns", &arrow.Time64Type{Unit: arrow.Nanosecond}, reflect.TypeOf(time.Time{})}, + {"decimal32", &arrow.Decimal32Type{Precision: 9, Scale: 2}, reflect.TypeOf(decimal.Decimal32(0))}, + {"decimal64", &arrow.Decimal64Type{Precision: 18, Scale: 3}, reflect.TypeOf(decimal.Decimal64(0))}, + {"decimal128", &arrow.Decimal128Type{Precision: 10, Scale: 2}, reflect.TypeOf(decimal128.Num{})}, + {"decimal256", &arrow.Decimal256Type{Precision: 20, Scale: 4}, reflect.TypeOf(decimal256.Num{})}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, err := InferGoType(tc.dt) + require.NoError(t, err) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestInferGoTypeCompositeTypes(t *testing.T) { + t.Run("large_list", func(t *testing.T) { + got, err := InferGoType(arrow.LargeListOf(arrow.PrimitiveTypes.Int64)) + require.NoError(t, err) + assert.Equal(t, reflect.Slice, got.Kind()) + assert.Equal(t, reflect.Int64, got.Elem().Kind()) + }) + + t.Run("list_view", func(t *testing.T) { + got, err := InferGoType(arrow.ListViewOf(arrow.PrimitiveTypes.Int32)) + require.NoError(t, err) + assert.Equal(t, reflect.Slice, got.Kind()) + assert.Equal(t, reflect.Int32, got.Elem().Kind()) + }) + + t.Run("large_list_view", func(t *testing.T) { + got, err := InferGoType(arrow.LargeListViewOf(arrow.PrimitiveTypes.Int32)) + require.NoError(t, err) + assert.Equal(t, reflect.Slice, got.Kind()) + }) + + t.Run("list with unsupported element returns error", func(t *testing.T) { + _, err := InferGoType(arrow.ListOf(arrow.Null)) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("fixed size list with unsupported element returns error", func(t *testing.T) { + _, err := InferGoType(arrow.FixedSizeListOf(3, arrow.Null)) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("map with unsupported key returns error", func(t *testing.T) { + _, err := InferGoType(arrow.MapOf(arrow.Null, arrow.BinaryTypes.String)) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("map with unsupported value returns error", func(t *testing.T) { + _, err := InferGoType(arrow.MapOf(arrow.BinaryTypes.String, arrow.Null)) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("map with comparable key builds map type", func(t *testing.T) { + got, err := InferGoType(arrow.MapOf(arrow.BinaryTypes.String, arrow.PrimitiveTypes.Int32)) + require.NoError(t, err) + assert.Equal(t, reflect.Map, got.Kind()) + assert.Equal(t, reflect.String, got.Key().Kind()) + assert.Equal(t, reflect.Int32, got.Elem().Kind()) + }) + + t.Run("dictionary unwraps to value type", func(t *testing.T) { + dt := &arrow.DictionaryType{ + IndexType: arrow.PrimitiveTypes.Int32, + ValueType: arrow.BinaryTypes.String, + } + got, err := InferGoType(dt) + require.NoError(t, err) + assert.Equal(t, reflect.String, got.Kind()) + }) + + t.Run("run end encoded unwraps to encoded type", func(t *testing.T) { + dt := arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int64) + got, err := InferGoType(dt) + require.NoError(t, err) + assert.Equal(t, reflect.Int64, got.Kind()) + }) +} + +func TestApplyTemporalOptsAllBranches(t *testing.T) { + timeType := reflect.TypeOf(time.Time{}) + base := arrow.FixedWidthTypes.Timestamp_ns + + t.Run("non-time type returns dt unchanged", func(t *testing.T) { + got := applyTemporalOpts(base, reflect.TypeOf(int32(0)), tagOpts{Temporal: "date32"}) + assert.Equal(t, base, got) + }) + + t.Run("empty temporal returns dt unchanged", func(t *testing.T) { + got := applyTemporalOpts(base, timeType, tagOpts{Temporal: ""}) + assert.Equal(t, base, got) + }) + + t.Run("timestamp returns dt unchanged", func(t *testing.T) { + got := applyTemporalOpts(base, timeType, tagOpts{Temporal: "timestamp"}) + assert.Equal(t, base, got) + }) + + t.Run("date32", func(t *testing.T) { + got := applyTemporalOpts(base, timeType, tagOpts{Temporal: "date32"}) + assert.Equal(t, arrow.DATE32, got.ID()) + }) + + t.Run("date64", func(t *testing.T) { + got := applyTemporalOpts(base, timeType, tagOpts{Temporal: "date64"}) + assert.Equal(t, arrow.DATE64, got.ID()) + }) + + t.Run("time32", func(t *testing.T) { + got := applyTemporalOpts(base, timeType, tagOpts{Temporal: "time32"}) + assert.Equal(t, arrow.TIME32, got.ID()) + }) + + t.Run("time64", func(t *testing.T) { + got := applyTemporalOpts(base, timeType, tagOpts{Temporal: "time64"}) + assert.Equal(t, arrow.TIME64, got.ID()) + }) + + t.Run("unknown temporal falls through", func(t *testing.T) { + got := applyTemporalOpts(base, timeType, tagOpts{Temporal: "bogus"}) + assert.Equal(t, base, got) + }) +} + +func TestApplyDecimalOptsAllBranches(t *testing.T) { + base := arrow.BinaryTypes.String + opts := tagOpts{HasDecimalOpts: true, DecimalPrecision: 18, DecimalScale: 4} + + t.Run("no_decimal_opts_returns_dt_unchanged", func(t *testing.T) { + got := applyDecimalOpts(base, reflect.TypeOf(decimal128.Num{}), tagOpts{}) + assert.Equal(t, base, got) + }) + + t.Run("decimal128", func(t *testing.T) { + got := applyDecimalOpts(base, reflect.TypeOf(decimal128.Num{}), opts) + dt, ok := got.(*arrow.Decimal128Type) + require.True(t, ok, "expected *arrow.Decimal128Type, got %T", got) + assert.Equal(t, int32(18), dt.Precision) + assert.Equal(t, int32(4), dt.Scale) + }) + + t.Run("decimal256", func(t *testing.T) { + got := applyDecimalOpts(base, reflect.TypeOf(decimal256.Num{}), opts) + dt, ok := got.(*arrow.Decimal256Type) + require.True(t, ok, "expected *arrow.Decimal256Type, got %T", got) + assert.Equal(t, int32(18), dt.Precision) + assert.Equal(t, int32(4), dt.Scale) + }) + + t.Run("decimal32", func(t *testing.T) { + got := applyDecimalOpts(base, reflect.TypeOf(decimal.Decimal32(0)), opts) + dt, ok := got.(*arrow.Decimal32Type) + require.True(t, ok, "expected *arrow.Decimal32Type, got %T", got) + assert.Equal(t, int32(18), dt.Precision) + assert.Equal(t, int32(4), dt.Scale) + }) + + t.Run("decimal64", func(t *testing.T) { + got := applyDecimalOpts(base, reflect.TypeOf(decimal.Decimal64(0)), opts) + dt, ok := got.(*arrow.Decimal64Type) + require.True(t, ok, "expected *arrow.Decimal64Type, got %T", got) + assert.Equal(t, int32(18), dt.Precision) + assert.Equal(t, int32(4), dt.Scale) + }) + + t.Run("non_decimal_type_returns_dt_unchanged", func(t *testing.T) { + got := applyDecimalOpts(base, reflect.TypeOf(int32(0)), opts) + assert.Equal(t, base, got) + }) +} + +func TestApplyLargeOpts(t *testing.T) { + cases := []struct { + name string + input arrow.DataType + want arrow.Type + }{ + {"string→large_string", arrow.BinaryTypes.String, arrow.LARGE_STRING}, + {"binary→large_binary", arrow.BinaryTypes.Binary, arrow.LARGE_BINARY}, + {"list→large_list", arrow.ListOf(arrow.BinaryTypes.String), arrow.LARGE_LIST}, + {"list_view→large_list_view", arrow.ListViewOf(arrow.BinaryTypes.Binary), arrow.LARGE_LIST_VIEW}, + {"int64 unchanged", arrow.PrimitiveTypes.Int64, arrow.INT64}, + {"float32 unchanged", arrow.PrimitiveTypes.Float32, arrow.FLOAT32}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := applyLargeOpts(tc.input) + assert.Equal(t, tc.want, got.ID()) + }) + } + + t.Run("list elem is large_binary", func(t *testing.T) { + got := applyLargeOpts(arrow.ListOf(arrow.BinaryTypes.Binary)) + ll, ok := got.(*arrow.LargeListType) + require.True(t, ok) + assert.Equal(t, arrow.LARGE_BINARY, ll.Elem().ID()) + }) + + t.Run("fixed_size_list recurses", func(t *testing.T) { + got := applyLargeOpts(arrow.FixedSizeListOf(3, arrow.BinaryTypes.String)) + fsl, ok := got.(*arrow.FixedSizeListType) + require.True(t, ok) + assert.Equal(t, arrow.LARGE_STRING, fsl.Elem().ID()) + }) + + t.Run("map recurses", func(t *testing.T) { + got := applyLargeOpts(arrow.MapOf(arrow.BinaryTypes.String, arrow.BinaryTypes.Binary)) + mt, ok := got.(*arrow.MapType) + require.True(t, ok) + assert.Equal(t, arrow.LARGE_STRING, mt.KeyType().ID()) + assert.Equal(t, arrow.LARGE_BINARY, mt.ItemField().Type.ID()) + }) + + t.Run("struct recurses into fields", func(t *testing.T) { + st := arrow.StructOf( + arrow.Field{Name: "name", Type: arrow.BinaryTypes.String}, + arrow.Field{Name: "count", Type: arrow.PrimitiveTypes.Int64}, + ) + got := applyLargeOpts(st) + gst, ok := got.(*arrow.StructType) + require.True(t, ok) + assert.Equal(t, arrow.LARGE_STRING, gst.Field(0).Type.ID()) + assert.Equal(t, arrow.INT64, gst.Field(1).Type.ID()) + }) +} + +func TestInferStructTypeWithLarge(t *testing.T) { + type Row struct { + Name string `arrow:",large"` + Count int64 + } + st, err := inferStructType(reflect.TypeOf(Row{})) + require.NoError(t, err) + assert.Equal(t, arrow.LARGE_STRING, st.Field(0).Type.ID(), "Name should be LARGE_STRING") + assert.Equal(t, arrow.INT64, st.Field(1).Type.ID(), "Count should be INT64") +} + +func TestApplyViewOptsViewCombinations(t *testing.T) { + t.Run("view+large: LARGE_LIST→LARGE_LIST_VIEW", func(t *testing.T) { + dt := applyLargeOpts(arrow.ListOf(arrow.BinaryTypes.String)) + // dt is now LARGE_LIST + got := applyViewOpts(dt) + assert.Equal(t, arrow.LARGE_LIST_VIEW, got.ID()) + }) + + t.Run("view only: LIST→LIST_VIEW", func(t *testing.T) { + dt := arrow.ListOf(arrow.BinaryTypes.String) + got := applyViewOpts(dt) + assert.Equal(t, arrow.LIST_VIEW, got.ID()) + lv := got.(*arrow.ListViewType) + assert.Equal(t, arrow.STRING_VIEW, lv.Elem().ID()) + }) +} + +func TestInferStructTypeWithView(t *testing.T) { + type Row struct { + Name string `arrow:",view"` + Tags []string `arrow:"tags,view"` + } + st, err := inferStructType(reflect.TypeOf(Row{})) + require.NoError(t, err) + assert.Equal(t, arrow.STRING_VIEW, st.Field(0).Type.ID(), "Name should be STRING_VIEW") + assert.Equal(t, arrow.LIST_VIEW, st.Field(1).Type.ID(), "Tags should be LIST_VIEW") + lv := st.Field(1).Type.(*arrow.ListViewType) + assert.Equal(t, arrow.STRING_VIEW, lv.Elem().ID()) +} + +func TestHasLargeableType(t *testing.T) { + assert.True(t, hasLargeableType(arrow.BinaryTypes.String)) + assert.True(t, hasLargeableType(arrow.BinaryTypes.Binary)) + assert.True(t, hasLargeableType(arrow.ListOf(arrow.PrimitiveTypes.Int64))) + assert.True(t, hasLargeableType(arrow.ListViewOf(arrow.PrimitiveTypes.Int64))) + assert.False(t, hasLargeableType(arrow.PrimitiveTypes.Int64)) + assert.False(t, hasLargeableType(arrow.PrimitiveTypes.Float32)) + + t.Run("struct with string field is true", func(t *testing.T) { + st := arrow.StructOf(arrow.Field{Name: "x", Type: arrow.BinaryTypes.String}) + assert.True(t, hasLargeableType(st)) + }) + t.Run("struct with only ints is false", func(t *testing.T) { + st := arrow.StructOf(arrow.Field{Name: "x", Type: arrow.PrimitiveTypes.Int32}) + assert.False(t, hasLargeableType(st)) + }) + t.Run("fixed_size_list is true", func(t *testing.T) { + assert.True(t, hasLargeableType(arrow.FixedSizeListOf(4, arrow.BinaryTypes.String))) + }) + t.Run("fixed_size_list is false", func(t *testing.T) { + assert.False(t, hasLargeableType(arrow.FixedSizeListOf(4, arrow.PrimitiveTypes.Int32))) + }) + t.Run("map with string key is true", func(t *testing.T) { + assert.True(t, hasLargeableType(arrow.MapOf(arrow.BinaryTypes.String, arrow.PrimitiveTypes.Int64))) + }) + t.Run("map with no strings is false", func(t *testing.T) { + assert.False(t, hasLargeableType(arrow.MapOf(arrow.PrimitiveTypes.Int64, arrow.PrimitiveTypes.Int64))) + }) +} + +func TestApplyViewOpts(t *testing.T) { + cases := []struct { + name string + input arrow.DataType + want arrow.Type + }{ + {"string→string_view", arrow.BinaryTypes.String, arrow.STRING_VIEW}, + {"binary→binary_view", arrow.BinaryTypes.Binary, arrow.BINARY_VIEW}, + {"large_string→string_view", arrow.BinaryTypes.LargeString, arrow.STRING_VIEW}, + {"large_binary→binary_view", arrow.BinaryTypes.LargeBinary, arrow.BINARY_VIEW}, + {"list→list_view", arrow.ListOf(arrow.BinaryTypes.String), arrow.LIST_VIEW}, + {"int64 unchanged", arrow.PrimitiveTypes.Int64, arrow.INT64}, + {"float32 unchanged", arrow.PrimitiveTypes.Float32, arrow.FLOAT32}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := applyViewOpts(tc.input) + assert.Equal(t, tc.want, got.ID()) + }) + } + + t.Run("list elem is string_view", func(t *testing.T) { + got := applyViewOpts(arrow.ListOf(arrow.BinaryTypes.String)) + lv, ok := got.(*arrow.ListViewType) + require.True(t, ok) + assert.Equal(t, arrow.STRING_VIEW, lv.Elem().ID()) + }) + + t.Run("fixed_size_list recurses", func(t *testing.T) { + got := applyViewOpts(arrow.FixedSizeListOf(3, arrow.BinaryTypes.String)) + fsl, ok := got.(*arrow.FixedSizeListType) + require.True(t, ok) + assert.Equal(t, arrow.STRING_VIEW, fsl.Elem().ID()) + }) + + t.Run("map recurses", func(t *testing.T) { + got := applyViewOpts(arrow.MapOf(arrow.BinaryTypes.String, arrow.BinaryTypes.Binary)) + mt, ok := got.(*arrow.MapType) + require.True(t, ok) + assert.Equal(t, arrow.STRING_VIEW, mt.KeyType().ID()) + assert.Equal(t, arrow.BINARY_VIEW, mt.ItemField().Type.ID()) + }) + + t.Run("struct recurses into fields", func(t *testing.T) { + st := arrow.StructOf( + arrow.Field{Name: "name", Type: arrow.BinaryTypes.String}, + arrow.Field{Name: "count", Type: arrow.PrimitiveTypes.Int64}, + ) + got := applyViewOpts(st) + gst, ok := got.(*arrow.StructType) + require.True(t, ok) + assert.Equal(t, arrow.STRING_VIEW, gst.Field(0).Type.ID()) + assert.Equal(t, arrow.INT64, gst.Field(1).Type.ID()) + }) + + t.Run("list_view is idempotent", func(t *testing.T) { + got := applyViewOpts(arrow.ListViewOf(arrow.BinaryTypes.String)) + lv, ok := got.(*arrow.ListViewType) + require.True(t, ok) + assert.Equal(t, arrow.STRING_VIEW, lv.Elem().ID()) + }) +} + +func TestHasViewableType(t *testing.T) { + assert.True(t, hasViewableType(arrow.BinaryTypes.String)) + assert.True(t, hasViewableType(arrow.BinaryTypes.Binary)) + assert.True(t, hasViewableType(arrow.ListOf(arrow.PrimitiveTypes.Int64))) + assert.False(t, hasViewableType(arrow.PrimitiveTypes.Int64)) + assert.False(t, hasViewableType(arrow.PrimitiveTypes.Float32)) + + t.Run("struct with string field is true", func(t *testing.T) { + st := arrow.StructOf(arrow.Field{Name: "x", Type: arrow.BinaryTypes.String}) + assert.True(t, hasViewableType(st)) + }) + t.Run("struct with only ints is false", func(t *testing.T) { + st := arrow.StructOf(arrow.Field{Name: "x", Type: arrow.PrimitiveTypes.Int32}) + assert.False(t, hasViewableType(st)) + }) + t.Run("fixed_size_list is true", func(t *testing.T) { + assert.True(t, hasViewableType(arrow.FixedSizeListOf(4, arrow.BinaryTypes.String))) + }) + t.Run("map with string key is true", func(t *testing.T) { + assert.True(t, hasViewableType(arrow.MapOf(arrow.BinaryTypes.String, arrow.PrimitiveTypes.Int64))) + }) + t.Run("large_string is true", func(t *testing.T) { + assert.True(t, hasViewableType(arrow.BinaryTypes.LargeString)) + }) + t.Run("large_binary is true", func(t *testing.T) { + assert.True(t, hasViewableType(arrow.BinaryTypes.LargeBinary)) + }) + t.Run("string_view is true", func(t *testing.T) { + assert.True(t, hasViewableType(arrow.BinaryTypes.StringView)) + }) + t.Run("binary_view is true", func(t *testing.T) { + assert.True(t, hasViewableType(arrow.BinaryTypes.BinaryView)) + }) + t.Run("list_view is true", func(t *testing.T) { + assert.True(t, hasViewableType(arrow.ListViewOf(arrow.PrimitiveTypes.Int64))) + }) + t.Run("large_list is true", func(t *testing.T) { + assert.True(t, hasViewableType(arrow.LargeListOf(arrow.PrimitiveTypes.Int64))) + }) + t.Run("large_list_view is true", func(t *testing.T) { + assert.True(t, hasViewableType(arrow.LargeListViewOf(arrow.PrimitiveTypes.Int64))) + }) +} diff --git a/arrow/array/arreflect/reflect_integration_test.go b/arrow/array/arreflect/reflect_integration_test.go new file mode 100644 index 00000000..e958dc25 --- /dev/null +++ b/arrow/array/arreflect/reflect_integration_test.go @@ -0,0 +1,438 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect + +import ( + "testing" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type integOrderItem struct { + Product string + Tags map[string]string + Ratings [5]float32 +} + +type integOrder struct { + ID int64 + Items []integOrderItem +} + +type integLargeRow struct { + X int32 + Y float64 +} + +type integNullable struct { + A *string + B *int32 + C *float64 +} + +type integMixed struct { + Required string + Optional *string + Count int32 + MaybeCount *int32 +} + +type integBase struct { + ID int64 +} + +type integExtended struct { + integBase + Name string `arrow:"name"` + Skip string `arrow:"-"` +} + +func TestReflectIntegration(t *testing.T) { + mem := testMem() + + t.Run("complex nested round-trip", func(t *testing.T) { + orders := []integOrder{ + { + ID: 1001, + Items: []integOrderItem{ + {Product: "widget", Tags: map[string]string{"color": "red"}, Ratings: [5]float32{4.5, 3.0, 5.0, 4.0, 3.5}}, + {Product: "gadget", Tags: map[string]string{"size": "large"}, Ratings: [5]float32{1.0, 2.0, 3.0, 4.0, 5.0}}, + }, + }, + { + ID: 1002, + Items: []integOrderItem{ + {Product: "thingamajig", Tags: map[string]string{"material": "steel", "finish": "matte"}, Ratings: [5]float32{5.0, 5.0, 5.0, 5.0, 5.0}}, + }, + }, + { + ID: 1003, + Items: nil, + }, + { + ID: 1004, + Items: []integOrderItem{ + {Product: "doohickey", Tags: map[string]string{"brand": "acme"}, Ratings: [5]float32{2.5, 3.5, 4.5, 1.5, 0.5}}, + {Product: "whatchamacallit", Tags: map[string]string{"type": "premium"}, Ratings: [5]float32{3.0, 3.0, 3.0, 3.0, 3.0}}, + {Product: "thingy", Tags: map[string]string{"category": "misc"}, Ratings: [5]float32{1.0, 1.0, 1.0, 1.0, 1.0}}, + }, + }, + { + ID: 1005, + Items: []integOrderItem{ + {Product: "sprocket", Tags: map[string]string{"grade": "A"}, Ratings: [5]float32{4.0, 4.0, 4.0, 4.0, 4.0}}, + }, + }, + } + + arr, err := FromSlice(orders, mem) + require.NoError(t, err, "FromSlice") + defer arr.Release() + + output, err := ToSlice[integOrder](arr) + require.NoError(t, err, "ToSlice") + + require.Len(t, output, len(orders)) + + for i, want := range orders { + got := output[i] + assert.Equal(t, want.ID, got.ID, "[%d] ID", i) + if assert.Len(t, got.Items, len(want.Items), "[%d] Items length", i) { + for j, wantItem := range want.Items { + gotItem := got.Items[j] + assert.Equal(t, wantItem.Product, gotItem.Product, "[%d][%d] Product", i, j) + assert.Equal(t, wantItem.Ratings, gotItem.Ratings, "[%d][%d] Ratings", i, j) + assert.Equal(t, wantItem.Tags, gotItem.Tags, "[%d][%d] Tags", i, j) + } + } + } + }) + + t.Run("large array round-trip", func(t *testing.T) { + const n = 10000 + rows := make([]integLargeRow, n) + for i := range rows { + rows[i] = integLargeRow{X: int32(i), Y: float64(i) * 1.5} + } + + arr, err := FromSlice(rows, mem) + require.NoError(t, err, "FromSlice") + defer arr.Release() + + require.Equal(t, n, arr.Len()) + + output, err := ToSlice[integLargeRow](arr) + require.NoError(t, err, "ToSlice") + + require.Len(t, output, n) + for i, want := range rows { + assert.Equal(t, want.X, output[i].X, "[%d] X", i) + assert.Equal(t, want.Y, output[i].Y, "[%d] Y", i) + } + }) + + t.Run("all-null fields", func(t *testing.T) { + rows := []integNullable{ + {A: nil, B: nil, C: nil}, + {A: nil, B: nil, C: nil}, + {A: nil, B: nil, C: nil}, + } + + arr, err := FromSlice(rows, mem) + require.NoError(t, err, "FromSlice") + defer arr.Release() + + output, err := ToSlice[integNullable](arr) + require.NoError(t, err, "ToSlice") + + require.Len(t, output, 3) + for i, got := range output { + assert.Nil(t, got.A, "[%d] A: expected nil", i) + assert.Nil(t, got.B, "[%d] B: expected nil", i) + assert.Nil(t, got.C, "[%d] C: expected nil", i) + } + }) + + t.Run("empty int32 slice", func(t *testing.T) { + arr, err := FromSlice[int32]([]int32{}, mem) + require.NoError(t, err, "FromSlice") + defer arr.Release() + + assert.Equal(t, 0, arr.Len()) + + output, err := ToSlice[int32](arr) + require.NoError(t, err, "ToSlice") + assert.NotNil(t, output, "ToSlice returned nil, want non-nil empty slice") + assert.Len(t, output, 0) + }) + + t.Run("empty struct slice", func(t *testing.T) { + type simpleXY struct{ X int32 } + arr, err := FromSlice[simpleXY]([]simpleXY{}, mem) + require.NoError(t, err, "FromSlice empty struct") + defer arr.Release() + + assert.Equal(t, 0, arr.Len()) + assert.Equal(t, arrow.STRUCT, arr.DataType().ID()) + }) + + t.Run("mixed nullability round-trip", func(t *testing.T) { + s1 := "hello" + s2 := "world" + c1 := int32(42) + c3 := int32(99) + + rows := []integMixed{ + {Required: "first", Optional: &s1, Count: 10, MaybeCount: &c1}, + {Required: "second", Optional: nil, Count: 20, MaybeCount: nil}, + {Required: "third", Optional: &s2, Count: 30, MaybeCount: &c3}, + {Required: "fourth", Optional: nil, Count: 40, MaybeCount: nil}, + } + + arr, err := FromSlice(rows, mem) + require.NoError(t, err, "FromSlice") + defer arr.Release() + + output, err := ToSlice[integMixed](arr) + require.NoError(t, err, "ToSlice") + + require.Len(t, output, len(rows)) + + for i, want := range rows { + got := output[i] + assert.Equal(t, want.Required, got.Required, "[%d] Required", i) + assert.Equal(t, want.Count, got.Count, "[%d] Count", i) + if assert.Equal(t, want.Optional == nil, got.Optional == nil, "[%d] Optional nil mismatch", i) { + if got.Optional != nil { + assert.Equal(t, *want.Optional, *got.Optional, "[%d] Optional value", i) + } + } + if assert.Equal(t, want.MaybeCount == nil, got.MaybeCount == nil, "[%d] MaybeCount nil mismatch", i) { + if got.MaybeCount != nil { + assert.Equal(t, *want.MaybeCount, *got.MaybeCount, "[%d] MaybeCount value", i) + } + } + } + }) + + t.Run("embedded struct with tags", func(t *testing.T) { + rows := []integExtended{ + {integBase: integBase{ID: 1}, Name: "alice"}, + {integBase: integBase{ID: 2}, Name: "bob"}, + {integBase: integBase{ID: 3}, Name: "carol"}, + } + + arr, err := FromSlice(rows, mem) + require.NoError(t, err, "FromSlice") + defer arr.Release() + + st, ok := arr.DataType().(*arrow.StructType) + require.True(t, ok, "expected StructType, got %T", arr.DataType()) + + var hasID, hasName, hasSkip bool + for i := 0; i < st.NumFields(); i++ { + switch st.Field(i).Name { + case "ID": + hasID = true + case "name": + hasName = true + case "Skip": + hasSkip = true + } + } + assert.True(t, hasID, "expected field 'ID' in schema") + assert.True(t, hasName, "expected field 'name' in schema") + assert.False(t, hasSkip, "unexpected field 'Skip' in schema (should be skipped by arrow:\"-\" tag)") + + output, err := ToSlice[integExtended](arr) + require.NoError(t, err, "ToSlice") + + require.Len(t, output, len(rows)) + for i, want := range rows { + got := output[i] + assert.Equal(t, want.ID, got.ID, "[%d] ID", i) + assert.Equal(t, want.Name, got.Name, "[%d] Name", i) + assert.Equal(t, "", got.Skip, "[%d] Skip: expected empty string", i) + } + }) + + t.Run("schema consistency", func(t *testing.T) { + orders := []integOrder{ + {ID: 1, Items: []integOrderItem{{Product: "a", Tags: map[string]string{"k": "v"}, Ratings: [5]float32{1, 2, 3, 4, 5}}}}, + } + + schema, err := InferSchema[integOrder]() + require.NoError(t, err, "SchemaOf") + + arr, err := FromSlice(orders, mem) + require.NoError(t, err, "FromSlice") + defer arr.Release() + + st, ok := arr.DataType().(*arrow.StructType) + require.True(t, ok, "expected StructType, got %T", arr.DataType()) + + require.Equal(t, schema.NumFields(), st.NumFields()) + + for i := 0; i < schema.NumFields(); i++ { + schemaField := schema.Field(i) + structField := st.Field(i) + assert.Equal(t, schemaField.Name, structField.Name, "field[%d] name", i) + } + }) + + t.Run("cache reuse without corruption", func(t *testing.T) { + batch1 := make([]integLargeRow, 3) + for i := range batch1 { + batch1[i] = integLargeRow{X: int32(i + 1), Y: float64(i+1) * 2.0} + } + + arr1, err := FromSlice(batch1, mem) + require.NoError(t, err, "FromSlice batch1") + defer arr1.Release() + + batch2 := make([]integLargeRow, 5) + for i := range batch2 { + batch2[i] = integLargeRow{X: int32(i * 10), Y: float64(i) * 3.14} + } + + arr2, err := FromSlice(batch2, mem) + require.NoError(t, err, "FromSlice batch2") + defer arr2.Release() + + out1, err := ToSlice[integLargeRow](arr1) + require.NoError(t, err, "ToSlice batch1") + out2, err := ToSlice[integLargeRow](arr2) + require.NoError(t, err, "ToSlice batch2") + + require.Len(t, out1, len(batch1)) + require.Len(t, out2, len(batch2)) + + for i, want := range batch1 { + assert.Equal(t, want, out1[i], "batch1[%d]", i) + } + for i, want := range batch2 { + assert.Equal(t, want, out2[i], "batch2[%d]", i) + } + }) + + t.Run("record batch round-trip", func(t *testing.T) { + rows := []integLargeRow{ + {X: 10, Y: 1.1}, + {X: 20, Y: 2.2}, + {X: 30, Y: 3.3}, + {X: 40, Y: 4.4}, + {X: 50, Y: 5.5}, + } + + rec, err := RecordFromSlice(rows, mem) + require.NoError(t, err, "RecordFromSlice") + defer rec.Release() + + require.Equal(t, int64(len(rows)), rec.NumRows()) + + output, err := RecordToSlice[integLargeRow](rec) + require.NoError(t, err, "RecordToSlice") + + require.Len(t, output, len(rows)) + assert.Equal(t, rows, output) + }) + + t.Run("listview_struct_field_roundtrip", func(t *testing.T) { + type Row struct { + Name string `arrow:"name"` + Tags []string `arrow:"tags,view"` + } + rows := []Row{ + {"alice", []string{"admin", "user"}}, + {"bob", []string{"guest"}}, + } + arr, err := FromSlice(rows, nil) + require.NoError(t, err) + defer arr.Release() + + sa := arr.(*array.Struct) + require.Equal(t, arrow.LIST_VIEW, sa.Field(1).DataType().ID()) + + output, err := ToSlice[Row](arr) + require.NoError(t, err) + assert.Equal(t, rows, output) + }) + + t.Run("duration_struct_field_roundtrip", func(t *testing.T) { + type Row struct { + Name string `arrow:"name"` + Elapsed time.Duration `arrow:"elapsed"` + } + rows := []Row{ + {"fast", 100 * time.Millisecond}, + {"slow", 5 * time.Second}, + } + arr, err := FromSlice(rows, nil) + require.NoError(t, err) + defer arr.Release() + + sa := arr.(*array.Struct) + assert.Equal(t, arrow.DURATION, sa.Field(1).DataType().ID()) + + output, err := ToSlice[Row](arr) + require.NoError(t, err) + assert.Equal(t, rows, output) + }) +} + +func BenchmarkReflectFromGoSlice(b *testing.B) { + mem := testMem() + rows := make([]integLargeRow, 1000) + for i := range rows { + rows[i] = integLargeRow{X: int32(i), Y: float64(i) * 1.5} + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + arr, err := FromSlice(rows, mem) + if err != nil { + b.Fatal(err) + } + arr.Release() + } +} + +func BenchmarkReflectToGoSlice(b *testing.B) { + mem := testMem() + rows := make([]integLargeRow, 1000) + for i := range rows { + rows[i] = integLargeRow{X: int32(i), Y: float64(i) * 1.5} + } + + arr, err := FromSlice(rows, mem) + if err != nil { + b.Fatal(err) + } + defer arr.Release() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + out, err := ToSlice[integLargeRow](arr) + if err != nil { + b.Fatal(err) + } + _ = out + } +} diff --git a/arrow/array/arreflect/reflect_public_test.go b/arrow/array/arreflect/reflect_public_test.go new file mode 100644 index 00000000..edb3adf6 --- /dev/null +++ b/arrow/array/arreflect/reflect_public_test.go @@ -0,0 +1,820 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect + +import ( + "reflect" + "testing" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testMem() memory.Allocator { return memory.NewGoAllocator() } + +func fieldValueByTag(v reflect.Value, tag string) reflect.Value { + for i := 0; i < v.NumField(); i++ { + if v.Type().Field(i).Tag.Get("arrow") == tag { + return v.Field(i) + } + } + return reflect.Value{} +} + +func TestToGo(t *testing.T) { + mem := testMem() + + t.Run("int32 element 0", func(t *testing.T) { + b := array.NewInt32Builder(mem) + defer b.Release() + b.AppendValues([]int32{10, 20, 30}, nil) + arr := b.NewInt32Array() + defer arr.Release() + + got, err := At[int32](arr, 0) + require.NoError(t, err) + assert.Equal(t, int32(10), got) + }) + + t.Run("string element 1", func(t *testing.T) { + b := array.NewStringBuilder(mem) + defer b.Release() + b.AppendValues([]string{"hello", "world"}, nil) + arr := b.NewStringArray() + defer arr.Release() + + got, err := At[string](arr, 1) + require.NoError(t, err) + assert.Equal(t, "world", got) + }) + + t.Run("struct element 0", func(t *testing.T) { + type Person struct { + Name string + Age int32 + } + vals := []Person{{"Alice", 30}, {"Bob", 25}} + arr, err := FromSlice(vals, mem) + require.NoError(t, err) + defer arr.Release() + + got, err := At[Person](arr, 0) + require.NoError(t, err) + assert.Equal(t, "Alice", got.Name) + assert.Equal(t, int32(30), got.Age) + }) + + t.Run("null element to *int32 is nil", func(t *testing.T) { + b := array.NewInt32Builder(mem) + defer b.Release() + b.AppendNull() + arr := b.NewInt32Array() + defer arr.Release() + + got, err := At[*int32](arr, 0) + require.NoError(t, err) + assert.Nil(t, got) + }) + + t.Run("null element to int32 is zero", func(t *testing.T) { + b := array.NewInt32Builder(mem) + defer b.Release() + b.AppendNull() + arr := b.NewInt32Array() + defer arr.Release() + + got, err := At[int32](arr, 0) + require.NoError(t, err) + assert.Equal(t, int32(0), got) + }) +} + +func TestToGoSlice(t *testing.T) { + mem := testMem() + + t.Run("[]int32", func(t *testing.T) { + b := array.NewInt32Builder(mem) + defer b.Release() + b.AppendValues([]int32{1, 2, 3}, nil) + arr := b.NewInt32Array() + defer arr.Release() + + got, err := ToSlice[int32](arr) + require.NoError(t, err) + want := []int32{1, 2, 3} + require.Len(t, got, len(want)) + for i, v := range want { + assert.Equal(t, v, got[i], "index %d", i) + } + }) + + t.Run("[]string", func(t *testing.T) { + b := array.NewStringBuilder(mem) + defer b.Release() + b.AppendValues([]string{"foo", "bar", "baz"}, nil) + arr := b.NewStringArray() + defer arr.Release() + + got, err := ToSlice[string](arr) + require.NoError(t, err) + want := []string{"foo", "bar", "baz"} + require.Len(t, got, len(want)) + for i, v := range want { + assert.Equal(t, v, got[i], "index %d", i) + } + }) + + t.Run("[]struct{Name string}", func(t *testing.T) { + type Row struct { + Name string + } + vals := []Row{{"Alice"}, {"Bob"}, {"Charlie"}} + arr, err := FromSlice(vals, mem) + require.NoError(t, err) + defer arr.Release() + + got, err := ToSlice[Row](arr) + require.NoError(t, err) + require.Len(t, got, len(vals)) + for i, want := range vals { + assert.Equal(t, want.Name, got[i].Name, "index %d", i) + } + }) + + t.Run("empty array gives empty slice", func(t *testing.T) { + b := array.NewInt32Builder(mem) + defer b.Release() + arr := b.NewInt32Array() + defer arr.Release() + + got, err := ToSlice[int32](arr) + require.NoError(t, err) + assert.NotNil(t, got, "expected non-nil empty slice, got nil") + assert.Len(t, got, 0) + }) +} + +func TestFromGoSlice(t *testing.T) { + mem := testMem() + + t.Run("[]int32", func(t *testing.T) { + arr, err := FromSlice([]int32{1, 2, 3}, mem) + require.NoError(t, err) + defer arr.Release() + + require.Equal(t, 3, arr.Len()) + typed := arr.(*array.Int32) + for i, want := range []int32{1, 2, 3} { + assert.Equal(t, want, typed.Value(i), "index %d", i) + } + }) + + t.Run("[]string", func(t *testing.T) { + arr, err := FromSlice([]string{"a", "b"}, mem) + require.NoError(t, err) + defer arr.Release() + + require.Equal(t, 2, arr.Len()) + typed := arr.(*array.String) + assert.Equal(t, "a", typed.Value(0)) + assert.Equal(t, "b", typed.Value(1)) + }) + + t.Run("[]struct{Name string; Score float64}", func(t *testing.T) { + type Row struct { + Name string + Score float64 + } + vals := []Row{{"Alice", 9.5}, {"Bob", 8.0}} + arr, err := FromSlice(vals, mem) + require.NoError(t, err) + defer arr.Release() + + require.Equal(t, 2, arr.Len()) + got, err := ToSlice[Row](arr) + require.NoError(t, err) + for i, want := range vals { + assert.Equal(t, want.Name, got[i].Name, "index %d Name", i) + assert.Equal(t, want.Score, got[i].Score, "index %d Score", i) + } + }) + + t.Run("[]*int32 with nil produces null", func(t *testing.T) { + v := int32(42) + arr, err := FromSlice([]*int32{&v, nil}, mem) + require.NoError(t, err) + defer arr.Release() + + require.Equal(t, 2, arr.Len()) + assert.True(t, arr.IsNull(1), "expected index 1 to be null") + typed := arr.(*array.Int32) + assert.Equal(t, int32(42), typed.Value(0)) + }) + + t.Run("empty []int32 gives length-0 array", func(t *testing.T) { + arr, err := FromSlice([]int32{}, mem) + require.NoError(t, err) + defer arr.Release() + + assert.Equal(t, 0, arr.Len()) + }) + + t.Run("empty slice with WithView", func(t *testing.T) { + arr, err := FromSlice([][]int32{}, mem, WithView()) + require.NoError(t, err) + defer arr.Release() + + assert.Equal(t, arrow.LIST_VIEW, arr.DataType().ID()) + assert.Equal(t, arrow.INT32, arr.DataType().(*arrow.ListViewType).Elem().ID()) + }) + + t.Run("empty slice with WithREE", func(t *testing.T) { + arr, err := FromSlice([]int32{}, mem, WithREE()) + require.NoError(t, err) + defer arr.Release() + + assert.Equal(t, arrow.RUN_END_ENCODED, arr.DataType().ID()) + }) + + t.Run("WithTemporal invalid value returns error", func(t *testing.T) { + _, err := FromSlice([]time.Time{}, mem, WithTemporal("invalid")) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("WithTemporal on non-time type returns error", func(t *testing.T) { + _, err := FromSlice([]string{}, mem, WithTemporal("date32")) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("WithTemporal timestamp on non-time type returns error", func(t *testing.T) { + _, err := FromSlice([]string{}, mem, WithTemporal("timestamp")) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("struct field with malformed decimal tag returns error", func(t *testing.T) { + type BadDecimal struct { + Amount decimal128.Num `arrow:",decimal(18,two)"` + } + _, err := FromSlice([]BadDecimal{}, mem) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("conflicting options return error", func(t *testing.T) { + cases := []struct { + name string + opts []Option + }{ + {"WithDict+WithREE", []Option{WithDict(), WithREE()}}, + {"WithDict+WithView", []Option{WithDict(), WithView()}}, + {"WithREE+WithView", []Option{WithREE(), WithView()}}, + {"all three", []Option{WithDict(), WithREE(), WithView()}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, err := FromSlice([]int32{1, 2, 3}, mem, tc.opts...) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + } + }) +} + +func TestRecordToSlice(t *testing.T) { + mem := testMem() + + type Row struct { + Name string + Score float64 + } + + buildRecord := func(rows []Row) arrow.RecordBatch { + schema := arrow.NewSchema([]arrow.Field{ + {Name: "Name", Type: arrow.BinaryTypes.String}, + {Name: "Score", Type: arrow.PrimitiveTypes.Float64}, + }, nil) + nameB := array.NewStringBuilder(mem) + defer nameB.Release() + scoreB := array.NewFloat64Builder(mem) + defer scoreB.Release() + for _, r := range rows { + nameB.Append(r.Name) + scoreB.Append(r.Score) + } + nameArr := nameB.NewStringArray() + defer nameArr.Release() + scoreArr := scoreB.NewFloat64Array() + defer scoreArr.Release() + return array.NewRecordBatch(schema, []arrow.Array{nameArr, scoreArr}, int64(len(rows))) + } + + t.Run("basic 3-row record", func(t *testing.T) { + want := []Row{{"Alice", 9.5}, {"Bob", 8.0}, {"Carol", 7.5}} + rec := buildRecord(want) + defer rec.Release() + + got, err := RecordToSlice[Row](rec) + require.NoError(t, err) + require.Len(t, got, len(want)) + for i, w := range want { + assert.Equal(t, w.Name, got[i].Name, "index %d Name", i) + assert.Equal(t, w.Score, got[i].Score, "index %d Score", i) + } + }) + + t.Run("empty record gives empty slice", func(t *testing.T) { + rec := buildRecord(nil) + defer rec.Release() + + got, err := RecordToSlice[Row](rec) + require.NoError(t, err) + assert.Len(t, got, 0) + }) +} + +func TestRecordFromSlice(t *testing.T) { + mem := testMem() + + type Row struct { + Name string + Score float64 + } + + t.Run("struct slice produces correct schema and values", func(t *testing.T) { + vals := []Row{{"Alice", 9.5}, {"Bob", 8.0}} + rec, err := RecordFromSlice(vals, mem) + require.NoError(t, err) + defer rec.Release() + + require.Equal(t, int64(2), rec.NumCols()) + require.Equal(t, int64(2), rec.NumRows()) + assert.Equal(t, "Name", rec.Schema().Field(0).Name) + assert.Equal(t, "Score", rec.Schema().Field(1).Name) + nameCol := rec.Column(0).(*array.String) + assert.Equal(t, "Alice", nameCol.Value(0)) + assert.Equal(t, "Bob", nameCol.Value(1)) + scoreCol := rec.Column(1).(*array.Float64) + assert.Equal(t, 9.5, scoreCol.Value(0)) + assert.Equal(t, 8.0, scoreCol.Value(1)) + }) + + t.Run("non-struct T returns error", func(t *testing.T) { + _, err := RecordFromSlice([]int32{1, 2, 3}, mem) + require.Error(t, err) + }) + + t.Run("round-trip RecordFromSlice then RecordToSlice", func(t *testing.T) { + want := []Row{{"Alice", 9.5}, {"Bob", 8.0}, {"Carol", 7.5}} + rec, err := RecordFromSlice(want, mem) + require.NoError(t, err) + defer rec.Release() + + got, err := RecordToSlice[Row](rec) + require.NoError(t, err) + require.Len(t, got, len(want)) + for i, w := range want { + assert.Equal(t, w.Name, got[i].Name, "index %d Name", i) + assert.Equal(t, w.Score, got[i].Score, "index %d Score", i) + } + }) +} + +func TestAtAny(t *testing.T) { + mem := testMem() + b := array.NewInt32Builder(mem) + defer b.Release() + b.Append(42) + b.AppendNull() + arr := b.NewArray() + defer arr.Release() + + got, err := AtAny(arr, 0) + require.NoError(t, err, "AtAny(0)") + v, ok := got.(int32) + assert.True(t, ok, "AtAny(0): expected int32 type, got %T", got) + assert.Equal(t, int32(42), v, "AtAny(0) value") + + got, err = AtAny(arr, 1) + require.NoError(t, err, "AtAny(1)") + v, ok = got.(int32) + assert.True(t, ok, "AtAny(1): expected int32 type, got %T", got) + assert.Equal(t, int32(0), v, "AtAny(1) value") +} + +func TestAtAnyErrors(t *testing.T) { + arr := array.NewNull(1) + defer arr.Release() + + _, err := AtAny(arr, 0) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) +} + +func TestToAnySlice(t *testing.T) { + mem := testMem() + b := array.NewStringBuilder(mem) + defer b.Release() + b.Append("hello") + b.Append("world") + arr := b.NewArray() + defer arr.Release() + + got, err := ToAnySlice(arr) + require.NoError(t, err, "ToAnySlice") + require.Len(t, got, 2) + assert.Equal(t, "hello", got[0].(string)) + assert.Equal(t, "world", got[1].(string)) +} + +func TestErrSentinels(t *testing.T) { + mem := testMem() + + t.Run("ErrTypeMismatch via setValue wrong kind", func(t *testing.T) { + b := array.NewInt32Builder(mem) + defer b.Release() + b.Append(42) + arr := b.NewArray() + defer arr.Release() + + var got string + v := reflect.ValueOf(&got).Elem() + err := setValue(v, arr, 0) + require.Error(t, err) + assert.ErrorIs(t, err, ErrTypeMismatch) + }) + + t.Run("ErrUnsupportedType via InferGoType", func(t *testing.T) { + _, err := InferGoType(arrow.Null) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("ErrTypeMismatch propagates through struct field context wrapper", func(t *testing.T) { + st := arrow.StructOf(arrow.Field{Name: "name", Type: arrow.BinaryTypes.String}) + sb := array.NewStructBuilder(mem, st) + defer sb.Release() + sb.Append(true) + sb.FieldBuilder(0).(*array.StringBuilder).Append("hello") + arr := sb.NewArray() + defer arr.Release() + + type wrongType struct { + Name int32 `arrow:"name"` + } + _, err := At[wrongType](arr, 0) + require.Error(t, err) + assert.ErrorIs(t, err, ErrTypeMismatch) + }) +} + +func TestRecordAt(t *testing.T) { + mem := testMem() + type Row struct { + Name string `arrow:"name"` + Score float64 `arrow:"score"` + } + rows := []Row{{"alice", 9.5}, {"bob", 7.0}} + rec, err := RecordFromSlice(rows, mem) + require.NoError(t, err, "RecordFromSlice") + defer rec.Release() + + got, err := RecordAt[Row](rec, 0) + require.NoError(t, err, "RecordAt(0)") + assert.Equal(t, rows[0], got) + + got, err = RecordAt[Row](rec, 1) + require.NoError(t, err, "RecordAt(1)") + assert.Equal(t, rows[1], got) +} + +func TestRecordAtAny(t *testing.T) { + mem := testMem() + type Row struct { + Name string `arrow:"name"` + Score float64 `arrow:"score"` + } + rows := []Row{{"alice", 9.5}, {"bob", 7.0}} + rec, err := RecordFromSlice(rows, mem) + require.NoError(t, err, "RecordFromSlice") + defer rec.Release() + + got, err := RecordAtAny(rec, 0) + require.NoError(t, err, "RecordAtAny(0)") + v := reflect.ValueOf(got) + require.Equal(t, reflect.Struct, v.Kind()) + nameField := fieldValueByTag(v, "name") + scoreField := fieldValueByTag(v, "score") + require.True(t, nameField.IsValid(), "name field not found") + require.True(t, scoreField.IsValid(), "score field not found") + assert.Equal(t, "alice", nameField.String()) + assert.Equal(t, 9.5, scoreField.Float()) +} + +func TestRecordToAnySlice(t *testing.T) { + mem := testMem() + type Row struct { + Name string `arrow:"name"` + Score float64 `arrow:"score"` + } + rows := []Row{{"alice", 9.5}, {"bob", 7.0}} + rec, err := RecordFromSlice(rows, mem) + require.NoError(t, err, "RecordFromSlice") + defer rec.Release() + + got, err := RecordToAnySlice(rec) + require.NoError(t, err, "RecordToAnySlice") + require.Len(t, got, 2) + for i, row := range got { + v := reflect.ValueOf(row) + require.Equal(t, reflect.Struct, v.Kind(), "row %d", i) + nameField := fieldValueByTag(v, "name") + assert.Equal(t, rows[i].Name, nameField.String(), "row %d name", i) + } +} + +func TestAtAnyComposite(t *testing.T) { + mem := testMem() + + t.Run("struct", func(t *testing.T) { + st := arrow.StructOf( + arrow.Field{Name: "id", Type: arrow.PrimitiveTypes.Int32}, + arrow.Field{Name: "name", Type: arrow.BinaryTypes.String}, + ) + sb := array.NewStructBuilder(mem, st) + defer sb.Release() + sb.Append(true) + sb.FieldBuilder(0).(*array.Int32Builder).Append(99) + sb.FieldBuilder(1).(*array.StringBuilder).Append("alice") + arr := sb.NewArray() + defer arr.Release() + + got, err := AtAny(arr, 0) + require.NoError(t, err, "AtAny") + + v := reflect.ValueOf(got) + require.Equal(t, reflect.Struct, v.Kind()) + + idField := fieldValueByTag(v, "id") + nameField := fieldValueByTag(v, "name") + require.True(t, idField.IsValid(), "id field not found") + require.True(t, nameField.IsValid(), "name field not found") + assert.Equal(t, int64(99), idField.Int()) + assert.Equal(t, "alice", nameField.String()) + }) + + t.Run("list", func(t *testing.T) { + lb := array.NewListBuilder(mem, arrow.PrimitiveTypes.Int32) + defer lb.Release() + lb.Append(true) + lb.ValueBuilder().(*array.Int32Builder).Append(1) + lb.ValueBuilder().(*array.Int32Builder).Append(2) + lb.ValueBuilder().(*array.Int32Builder).Append(3) + arr := lb.NewArray() + defer arr.Release() + + got, err := AtAny(arr, 0) + require.NoError(t, err, "AtAny") + + v := reflect.ValueOf(got) + require.Equal(t, reflect.Slice, v.Kind()) + require.Equal(t, 3, v.Len()) + assert.Equal(t, int64(1), v.Index(0).Int()) + assert.Equal(t, int64(3), v.Index(2).Int()) + }) + + t.Run("map", func(t *testing.T) { + mb := array.NewMapBuilder(mem, arrow.BinaryTypes.String, arrow.PrimitiveTypes.Int32, false) + defer mb.Release() + mb.Append(true) + mb.KeyBuilder().(*array.StringBuilder).Append("x") + mb.ItemBuilder().(*array.Int32Builder).Append(7) + arr := mb.NewArray() + defer arr.Release() + + got, err := AtAny(arr, 0) + require.NoError(t, err, "AtAny") + + v := reflect.ValueOf(got) + require.Equal(t, reflect.Map, v.Kind()) + key := reflect.ValueOf("x") + val := v.MapIndex(key) + require.True(t, val.IsValid(), "key 'x' not found in map") + assert.Equal(t, int64(7), val.Int()) + }) +} + +func TestToAnySliceStructArray(t *testing.T) { + mem := testMem() + st := arrow.StructOf( + arrow.Field{Name: "id", Type: arrow.PrimitiveTypes.Int64, Nullable: false}, + arrow.Field{Name: "label", Type: arrow.BinaryTypes.String, Nullable: false}, + arrow.Field{Name: "score", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + ) + sb := array.NewStructBuilder(mem, st) + defer sb.Release() + + sb.Append(true) + sb.FieldBuilder(0).(*array.Int64Builder).Append(1) + sb.FieldBuilder(1).(*array.StringBuilder).Append("alpha") + sb.FieldBuilder(2).(*array.Float64Builder).Append(9.5) + + sb.Append(true) + sb.FieldBuilder(0).(*array.Int64Builder).Append(2) + sb.FieldBuilder(1).(*array.StringBuilder).Append("beta") + sb.FieldBuilder(2).(*array.Float64Builder).Append(3.14) + + sb.Append(true) + sb.FieldBuilder(0).(*array.Int64Builder).Append(3) + sb.FieldBuilder(1).(*array.StringBuilder).Append("gamma") + sb.FieldBuilder(2).(*array.Float64Builder).AppendNull() + + arr := sb.NewArray() + defer arr.Release() + + got, err := ToAnySlice(arr) + require.NoError(t, err, "ToAnySlice") + require.Len(t, got, 3) + + type expected struct { + id int64 + label string + score float64 + } + want := []expected{ + {1, "alpha", 9.5}, + {2, "beta", 3.14}, + {3, "gamma", 0}, + } + + for i, row := range got { + v := reflect.ValueOf(row) + require.Equal(t, reflect.Struct, v.Kind(), "row %d", i) + require.Equal(t, 3, v.NumField(), "row %d", i) + + id := fieldValueByTag(v, "id") + label := fieldValueByTag(v, "label") + score := fieldValueByTag(v, "score") + require.True(t, id.IsValid(), "row %d: id field not found", i) + require.True(t, label.IsValid(), "row %d: label field not found", i) + require.True(t, score.IsValid(), "row %d: score field not found", i) + assert.Equal(t, want[i].id, id.Int(), "row %d id", i) + assert.Equal(t, want[i].label, label.String(), "row %d label", i) + if score.Kind() == reflect.Ptr { + if i == 2 { + assert.True(t, score.IsNil(), "row 2 score: want nil") + } else { + if assert.False(t, score.IsNil(), "row %d score: unexpected nil", i) { + assert.Equal(t, want[i].score, score.Elem().Float(), "row %d score", i) + } + } + } else { + assert.Equal(t, want[i].score, score.Float(), "row %d score", i) + } + } +} + +func TestWithLargeRoundTrip(t *testing.T) { + mem := testMem() + + t.Run("[]string WithLarge round-trips via ToSlice", func(t *testing.T) { + input := []string{"alpha", "beta", "gamma"} + arr, err := FromSlice(input, mem, WithLarge()) + require.NoError(t, err) + defer arr.Release() + + assert.Equal(t, arrow.LARGE_STRING, arr.DataType().ID()) + + got, err := ToSlice[string](arr) + require.NoError(t, err) + assert.Equal(t, input, got) + }) + + t.Run("struct with large tag round-trips", func(t *testing.T) { + type Row struct { + Label string `arrow:"label,large"` + Count int32 `arrow:"count"` + } + input := []Row{{"a", 1}, {"b", 2}} + arr, err := FromSlice(input, mem) + require.NoError(t, err) + defer arr.Release() + + sa := arr.(*array.Struct) + assert.Equal(t, arrow.LARGE_STRING, sa.Field(0).DataType().ID()) + assert.Equal(t, arrow.INT32, sa.Field(1).DataType().ID()) + + got, err := ToSlice[Row](arr) + require.NoError(t, err) + assert.Equal(t, input, got) + }) +} + +func TestUnknownTagOptionError(t *testing.T) { + type Bad struct { + Name string `arrow:"name,unknown_option"` + } + mem := testMem() + + t.Run("FromSlice surfaces ErrUnsupportedType for unknown tag", func(t *testing.T) { + _, err := FromSlice([]Bad{{"x"}}, mem) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("InferSchema surfaces ErrUnsupportedType for unknown tag", func(t *testing.T) { + _, err := InferSchema[Bad]() + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) +} + +func TestWithViewRoundTrip(t *testing.T) { + mem := testMem() + + t.Run("[]string WithView round-trips via ToSlice", func(t *testing.T) { + input := []string{"alpha", "beta", "gamma"} + arr, err := FromSlice(input, mem, WithView()) + require.NoError(t, err) + defer arr.Release() + + assert.Equal(t, arrow.STRING_VIEW, arr.DataType().ID()) + + got, err := ToSlice[string](arr) + require.NoError(t, err) + assert.Equal(t, input, got) + }) + + t.Run("[][]string WithView produces LIST_VIEW", func(t *testing.T) { + input := [][]string{{"a", "b"}, {"c"}} + arr, err := FromSlice(input, mem, WithView()) + require.NoError(t, err) + defer arr.Release() + + assert.Equal(t, arrow.LIST_VIEW, arr.DataType().ID()) + lv := arr.DataType().(*arrow.ListViewType) + assert.Equal(t, arrow.STRING_VIEW, lv.Elem().ID()) + }) + + t.Run("struct with view tag round-trips", func(t *testing.T) { + type Row struct { + Label string `arrow:"label,view"` + Count int32 `arrow:"count"` + } + input := []Row{{"a", 1}, {"b", 2}} + arr, err := FromSlice(input, mem) + require.NoError(t, err) + defer arr.Release() + + sa := arr.(*array.Struct) + assert.Equal(t, arrow.STRING_VIEW, sa.Field(0).DataType().ID()) + assert.Equal(t, arrow.INT32, sa.Field(1).DataType().ID()) + + got, err := ToSlice[Row](arr) + require.NoError(t, err) + assert.Equal(t, input, got) + }) + + t.Run("WithView on int64 errors", func(t *testing.T) { + _, err := FromSlice([]int64{1}, mem, WithView()) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("struct with view-tagged fields via WithView is idempotent", func(t *testing.T) { + // Fields already tagged ,view infer to STRING_VIEW; WithView() should still + // accept the struct and the top-level applyViewOpts walk is a no-op on views. + type Row struct { + Name string `arrow:"name,view"` + Code int32 `arrow:"code"` + } + input := []Row{{"click", 1}, {"view", 2}} + arr, err := FromSlice(input, mem, WithView()) + require.NoError(t, err) + defer arr.Release() + sa := arr.(*array.Struct) + assert.Equal(t, arrow.STRING_VIEW, sa.Field(0).DataType().ID()) + + got, err := ToSlice[Row](arr) + require.NoError(t, err) + assert.Equal(t, input, got) + }) +} diff --git a/arrow/array/arreflect/reflect_test.go b/arrow/array/arreflect/reflect_test.go new file mode 100644 index 00000000..b0ad487c --- /dev/null +++ b/arrow/array/arreflect/reflect_test.go @@ -0,0 +1,387 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect + +import ( + "reflect" + "testing" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseTag(t *testing.T) { + tests := []struct { + input string + want tagOpts + }{ + { + input: "custom_name", + want: tagOpts{Name: "custom_name"}, + }, + { + input: "-", + want: tagOpts{Skip: true}, + }, + { + input: "-,", + want: tagOpts{Name: "-"}, + }, + { + input: "", + want: tagOpts{}, + }, + { + input: "name,dict", + want: tagOpts{Name: "name", Dict: true}, + }, + { + input: "name,view", + want: tagOpts{Name: "name", View: true}, + }, + { + input: "name,ree", + want: tagOpts{Name: "name", REE: true}, + }, + { + input: "name,decimal(38,10)", + want: tagOpts{Name: "name", HasDecimalOpts: true, DecimalPrecision: 38, DecimalScale: 10}, + }, + { + input: ",decimal(18,2)", + want: tagOpts{Name: "", HasDecimalOpts: true, DecimalPrecision: 18, DecimalScale: 2}, + }, + { + input: "name,dict,ree", + want: tagOpts{Name: "name", Dict: true, REE: true}, + }, + { + input: "name,unknown_option", + want: tagOpts{Name: "name", ParseErr: "unknown option \"unknown_option\""}, + }, + { + input: `field,Date32`, + want: tagOpts{Name: "field", ParseErr: "unknown option \"Date32\""}, + }, + { + input: "name,large", + want: tagOpts{Name: "name", Large: true}, + }, + { + input: "name,large,view", + want: tagOpts{Name: "name", Large: true, View: true}, + }, + { + input: "name,large,dict", + want: tagOpts{Name: "name", Large: true, Dict: true}, + }, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := parseTag(tt.input) + assert.Equal(t, tt.want, got, "parseTag(%q)", tt.input) + }) + } +} + +func TestGetStructFields(t *testing.T) { + t.Run("simple struct", func(t *testing.T) { + type Simple struct { + Name string + Age int32 + } + fields := getStructFields(reflect.TypeOf(Simple{})) + require.Len(t, fields, 2) + assert.Equal(t, "Name", fields[0].Name) + assert.Equal(t, "Age", fields[1].Name) + }) + + t.Run("struct with arrow tags", func(t *testing.T) { + type Tagged struct { + UserName string `arrow:"user_name"` + Score float64 `arrow:"score"` + Internal string `arrow:"-"` + } + fields := getStructFields(reflect.TypeOf(Tagged{})) + require.Len(t, fields, 2) + assert.Equal(t, "user_name", fields[0].Name) + assert.Equal(t, "score", fields[1].Name) + }) + + t.Run("unexported fields skipped", func(t *testing.T) { + type Mixed struct { + Exported string + unexported string //nolint:unused + } + fields := getStructFields(reflect.TypeOf(Mixed{})) + require.Len(t, fields, 1) + assert.Equal(t, "Exported", fields[0].Name) + }) + + t.Run("pointer fields are nullable", func(t *testing.T) { + type WithPointers struct { + Required string + Optional *string + } + fields := getStructFields(reflect.TypeOf(WithPointers{})) + require.Len(t, fields, 2) + assert.False(t, fields[0].Nullable, "Required.Nullable = true, want false") + assert.True(t, fields[1].Nullable, "Optional.Nullable = false, want true") + }) + + t.Run("embedded struct promotion", func(t *testing.T) { + type Inner struct { + City string + Zip int32 + } + type Outer struct { + Name string + Inner + } + fields := getStructFields(reflect.TypeOf(Outer{})) + require.Len(t, fields, 3) + names := make([]string, len(fields)) + for i, f := range fields { + names[i] = f.Name + } + wantNames := []string{"Name", "City", "Zip"} + for i, want := range wantNames { + assert.Equal(t, want, names[i], "fields[%d].Name", i) + } + }) + + t.Run("embedded struct conflict excluded", func(t *testing.T) { + type A struct{ ID string } + type B struct{ ID string } + type Conflicted struct { + A + B + } + fields := getStructFields(reflect.TypeOf(Conflicted{})) + assert.Len(t, fields, 0, "expected 0 fields due to conflict") + }) + + t.Run("embedded with tag overrides promotion", func(t *testing.T) { + type Inner struct { + City string + Zip int32 + } + type HasTag struct { + Inner `arrow:"inner_struct"` + } + fields := getStructFields(reflect.TypeOf(HasTag{})) + require.Len(t, fields, 1) + assert.Equal(t, "inner_struct", fields[0].Name) + }) + + t.Run("pointer to struct is dereferenced", func(t *testing.T) { + type Simple struct { + X int32 + Y string + } + fields := getStructFields(reflect.TypeOf(&Simple{})) + require.Len(t, fields, 2) + assert.Equal(t, "X", fields[0].Name) + assert.Equal(t, "Y", fields[1].Name) + }) + + t.Run("multi-level pointer to struct is dereferenced", func(t *testing.T) { + type Simple struct { + X int32 + } + var pp **Simple + fields := getStructFields(reflect.TypeOf(pp)) + require.Len(t, fields, 1) + assert.Equal(t, "X", fields[0].Name) + }) + + t.Run("non-struct type returns nil", func(t *testing.T) { + assert.Nil(t, getStructFields(reflect.TypeOf(int32(0)))) + assert.Nil(t, getStructFields(reflect.TypeOf(""))) + assert.Nil(t, getStructFields(reflect.TypeOf([]int32{}))) + }) +} + +func TestCachedStructFields(t *testing.T) { + type S struct { + X int32 + Y string + } + + fields1 := cachedStructFields(reflect.TypeOf(S{})) + fields2 := cachedStructFields(reflect.TypeOf(S{})) + + require.Len(t, fields2, len(fields1), "cached call returned different lengths") + + for i := range fields1 { + assert.Equal(t, fields1[i].Name, fields2[i].Name, "fields[%d].Name mismatch", i) + } + + require.Len(t, fields1, 2) + assert.Equal(t, "X", fields1[0].Name) + assert.Equal(t, "Y", fields1[1].Name) +} + +func TestBuildEmptyTyped(t *testing.T) { + mem := checkedMem(t) + + t.Run("unsupported_type_returns_error", func(t *testing.T) { + _, err := buildEmptyTyped(reflect.TypeOf((chan int)(nil)), tagOpts{}, mem) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("pointer_element_type_is_dereferenced", func(t *testing.T) { + arr, err := buildEmptyTyped(reflect.TypeOf((*int32)(nil)), tagOpts{}, mem) + require.NoError(t, err) + defer arr.Release() + assert.Equal(t, 0, arr.Len()) + assert.Equal(t, arrow.INT32, arr.DataType().ID()) + }) + + t.Run("multi_level_pointer_element_type", func(t *testing.T) { + arr, err := buildEmptyTyped(reflect.TypeOf((**int32)(nil)), tagOpts{}, mem) + require.NoError(t, err) + defer arr.Release() + assert.Equal(t, 0, arr.Len()) + assert.Equal(t, arrow.INT32, arr.DataType().ID()) + }) + + t.Run("view_on_non_slice_type_errors", func(t *testing.T) { + _, err := buildEmptyTyped(reflect.TypeOf(int32(0)), tagOpts{View: true}, mem) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("view_happy_path_binary", func(t *testing.T) { + arr, err := buildEmptyTyped(reflect.TypeOf([]byte(nil)), tagOpts{View: true}, mem) + require.NoError(t, err) + defer arr.Release() + assert.Equal(t, arrow.BINARY_VIEW, arr.DataType().ID()) + }) + + t.Run("view_with_slice_of_pointers_derefs_inner", func(t *testing.T) { + arr, err := buildEmptyTyped(reflect.TypeOf([]*int32(nil)), tagOpts{View: true}, mem) + require.NoError(t, err) + defer arr.Release() + assert.Equal(t, 0, arr.Len()) + assert.Equal(t, arrow.LIST_VIEW, arr.DataType().ID()) + }) + + t.Run("view_happy_path_list", func(t *testing.T) { + arr, err := buildEmptyTyped(reflect.TypeOf([]int32(nil)), tagOpts{View: true}, mem) + require.NoError(t, err) + defer arr.Release() + assert.Equal(t, arrow.LIST_VIEW, arr.DataType().ID()) + }) + + t.Run("view_happy_path_string", func(t *testing.T) { + arr, err := buildEmptyTyped(reflect.TypeOf(""), tagOpts{View: true}, mem) + require.NoError(t, err) + defer arr.Release() + assert.Equal(t, arrow.STRING_VIEW, arr.DataType().ID()) + }) + + t.Run("large_view_empty", func(t *testing.T) { + arr, err := buildEmptyTyped(reflect.TypeOf([]string(nil)), tagOpts{Large: true, View: true}, mem) + require.NoError(t, err) + defer arr.Release() + assert.Equal(t, arrow.LARGE_LIST_VIEW, arr.DataType().ID()) + llv := arr.DataType().(*arrow.LargeListViewType) + assert.Equal(t, arrow.STRING_VIEW, llv.Elem().ID()) + }) + + t.Run("large_view_string_empty", func(t *testing.T) { + // large applied first: STRING→LARGE_STRING; then view: LARGE_STRING→STRING_VIEW + arr, err := buildEmptyTyped(reflect.TypeOf(""), tagOpts{Large: true, View: true}, mem) + require.NoError(t, err) + defer arr.Release() + assert.Equal(t, arrow.STRING_VIEW, arr.DataType().ID()) + }) + + t.Run("dict_with_unsupported_value_type_errors", func(t *testing.T) { + _, err := buildEmptyTyped(reflect.TypeOf(time.Time{}), tagOpts{Dict: true}, mem) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) + + t.Run("dict_happy_path", func(t *testing.T) { + arr, err := buildEmptyTyped(reflect.TypeOf(""), tagOpts{Dict: true}, mem) + require.NoError(t, err) + defer arr.Release() + assert.Equal(t, 0, arr.Len()) + assert.Equal(t, arrow.DICTIONARY, arr.DataType().ID()) + }) + + t.Run("ree_happy_path", func(t *testing.T) { + arr, err := buildEmptyTyped(reflect.TypeOf(int32(0)), tagOpts{REE: true}, mem) + require.NoError(t, err) + defer arr.Release() + assert.Equal(t, 0, arr.Len()) + assert.Equal(t, arrow.RUN_END_ENCODED, arr.DataType().ID()) + }) + + t.Run("large_string_empty", func(t *testing.T) { + arr, err := buildEmptyTyped(reflect.TypeOf(""), tagOpts{Large: true}, mem) + require.NoError(t, err) + defer arr.Release() + assert.Equal(t, arrow.LARGE_STRING, arr.DataType().ID()) + }) + + t.Run("large_on_int_errors", func(t *testing.T) { + _, err := buildEmptyTyped(reflect.TypeOf(int32(0)), tagOpts{Large: true}, mem) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) +} + +func TestParseDecimalOpt(t *testing.T) { + t.Run("valid_tag_sets_precision_and_scale", func(t *testing.T) { + got := parseTag(",decimal(18,2)") + assert.True(t, got.HasDecimalOpts) + assert.Equal(t, int32(18), got.DecimalPrecision) + assert.Equal(t, int32(2), got.DecimalScale) + assert.Empty(t, got.ParseErr) + }) + + t.Run("non_integer_precision_records_error", func(t *testing.T) { + got := parseTag(",decimal(abc,2)") + assert.False(t, got.HasDecimalOpts) + assert.NotEmpty(t, got.ParseErr) + }) + + t.Run("non_integer_scale_records_error", func(t *testing.T) { + got := parseTag(",decimal(18,two)") + assert.False(t, got.HasDecimalOpts) + assert.NotEmpty(t, got.ParseErr) + }) + + t.Run("missing_scale_records_error", func(t *testing.T) { + got := parseTag(",decimal(18)") + assert.False(t, got.HasDecimalOpts) + assert.NotEmpty(t, got.ParseErr) + }) + + t.Run("validateOptions_surfaces_parse_error", func(t *testing.T) { + err := validateOptions(tagOpts{ParseErr: "bad decimal tag"}) + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedType) + }) +}