Skip to content

Commit

Permalink
GH-34724: [Go] Add Schema AddField (#34748)
Browse files Browse the repository at this point in the history
Similar to cpp [`AddField`](https://arrow.apache.org/docs/cpp/api/datatype.html#_CPPv4NK5arrow10StructType8AddFieldEiRKNSt10shared_ptrI5FieldEE) API and [`AddColumn`](https://arrow.apache.org/docs/cpp/api/table.html#_CPPv4NK5arrow11RecordBatch9AddColumnEiRKNSt10shared_ptrI5FieldEERKNSt10shared_ptrI5ArrayEE)
* Closes: #34724

Lead-authored-by: Yevgeny Pats <16490766+yevgenypats@users.noreply.github.com>
Co-authored-by: Kemal Hadimli <disq@users.noreply.github.com>
Signed-off-by: Matt Topol <zotthewizard@gmail.com>
  • Loading branch information
yevgenypats and disq committed Apr 14, 2023
1 parent 5c73973 commit 2ba6628
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 5 deletions.
28 changes: 24 additions & 4 deletions go/arrow/array/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ func NewTable(schema *arrow.Schema, cols []arrow.Column, rows int64) *simpleTabl
// of slices of arrow.Array.
//
// Like other NewTable functions this can panic if:
// - len(schema.Fields) != len(data)
// - the total length of each column's array slice (ie: number of rows
// in the column) aren't the same for all columns.
// - len(schema.Fields) != len(data)
// - the total length of each column's array slice (ie: number of rows
// in the column) aren't the same for all columns.
func NewTableFromSlice(schema *arrow.Schema, data [][]arrow.Array) *simpleTable {
if len(data) != len(schema.Fields()) {
panic("array/table: mismatch in number of columns and data for creating a table")
Expand Down Expand Up @@ -197,7 +197,27 @@ func NewTableFromRecords(schema *arrow.Schema, recs []arrow.Record) *simpleTable
return NewTable(schema, cols, -1)
}

func (tbl *simpleTable) Schema() *arrow.Schema { return tbl.schema }
func (tbl *simpleTable) Schema() *arrow.Schema { return tbl.schema }

func (tbl *simpleTable) AddColumn(i int, field arrow.Field, column arrow.Column) (arrow.Table, error) {
if int64(column.Len()) != tbl.rows {
return nil, fmt.Errorf("arrow/array: column length mismatch: %d != %d", column.Len(), tbl.rows)
}
if field.Type != column.DataType() {
return nil, fmt.Errorf("arrow/array: column type mismatch: %v != %v", field.Type, column.DataType())
}
newSchema, err := tbl.schema.AddField(i, field)
if err != nil {
return nil, err
}
cols := make([]arrow.Column, len(tbl.cols)+1)
copy(cols[:i], tbl.cols[:i])
cols[i] = column
copy(cols[i+1:], tbl.cols[i:])
newTable := NewTable(newSchema, cols, tbl.rows)
return newTable, nil
}

func (tbl *simpleTable) NumRows() int64 { return tbl.rows }
func (tbl *simpleTable) NumCols() int64 { return int64(len(tbl.cols)) }
func (tbl *simpleTable) Column(i int) *arrow.Column { return &tbl.cols[i] }
Expand Down
17 changes: 16 additions & 1 deletion go/arrow/array/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,12 @@ func TestTable(t *testing.T) {
mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
defer mem.AssertSize(t, 0)

preSchema := arrow.NewSchema(
[]arrow.Field{
{Name: "f1-i32", Type: arrow.PrimitiveTypes.Int32},
},
nil,
)
schema := arrow.NewSchema(
[]arrow.Field{
{Name: "f1-i32", Type: arrow.PrimitiveTypes.Int32},
Expand Down Expand Up @@ -469,8 +475,17 @@ func TestTable(t *testing.T) {

slices := [][]arrow.Array{col1.Data().Chunks(), col2.Data().Chunks()}

tbl := array.NewTable(schema, cols, -1)
preTbl := array.NewTable(preSchema, []arrow.Column{*col1}, -1)
defer preTbl.Release()
tbl, err := preTbl.AddColumn(
1,
arrow.Field{Name: "f2-f64", Type: arrow.PrimitiveTypes.Float64},
*col2,
)
defer tbl.Release()
if err != nil {
t.Fatalf("could not add column: %+v", err)
}

tbl2 := array.NewTableFromSlice(schema, slices)
defer tbl2.Release()
Expand Down
13 changes: 13 additions & 0 deletions go/arrow/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,19 @@ func (sc *Schema) Equal(o *Schema) bool {
return true
}

// AddField adds a field at the given index and return a new schema.
func (s *Schema) AddField(i int, field Field) (*Schema, error) {
if i < 0 || i > len(s.fields) {
return nil, fmt.Errorf("arrow: invalid field index %d", i)
}

fields := make([]Field, len(s.fields)+1)
copy(fields[:i], s.fields[:i])
fields[i] = field
copy(fields[i+1:], s.fields[i:])
return NewSchema(fields, &s.meta), nil
}

func (s *Schema) String() string {
o := new(strings.Builder)
fmt.Fprintf(o, "schema:\n fields: %d\n", len(s.Fields()))
Expand Down
24 changes: 24 additions & 0 deletions go/arrow/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,30 @@ func TestSchema(t *testing.T) {
}
}

func TestSchemaAddField(t *testing.T) {
s := NewSchema([]Field{
{Name: "f1", Type: PrimitiveTypes.Int32},
{Name: "f2", Type: PrimitiveTypes.Int64},
}, nil)

_, err := s.AddField(3, Field{Name: "f3", Type: PrimitiveTypes.Int32})
if err == nil {
t.Fatalf("expected an error")
}

s, err = s.AddField(2, Field{Name: "f3", Type: PrimitiveTypes.Int32})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got, want := len(s.Fields()), 3; got != want {
t.Fatalf("invalid number of fields. got=%d, want=%d", got, want)
}
got, want := s.Field(2), Field{Name: "f3", Type: PrimitiveTypes.Int32};
if !got.Equal(want) {
t.Fatalf("invalid field: got=%#v, want=%#v", got, want)
}
}

func TestSchemaEqual(t *testing.T) {
fields := []Field{
{Name: "f1", Type: PrimitiveTypes.Int32},
Expand Down
4 changes: 4 additions & 0 deletions go/arrow/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ type Table interface {
NumCols() int64
Column(i int) *Column

// AddColumn adds a new column to the table and a corresponding field (of the same type)
// to its schema, at the specified position. Returns the new table with updated columns and schema.
AddColumn(pos int, f Field, c Column) (Table, error)

Retain()
Release()
}
Expand Down

0 comments on commit 2ba6628

Please sign in to comment.