diff --git a/rows.go b/rows.go index 941544b..2b51d46 100644 --- a/rows.go +++ b/rows.go @@ -4,8 +4,10 @@ import ( "bytes" "database/sql/driver" "encoding/csv" + "errors" "fmt" "io" + "reflect" "strings" ) @@ -127,6 +129,103 @@ type Rows struct { closeErr error } +// new Rows and add a set of driver.Value by using the struct reflect tag +func newRowsFromStruct(m interface{}, tagName ...string) (*Rows, error) { + val := reflect.ValueOf(m).Elem() + num := val.NumField() + if num == 0 { + return nil, errors.New("no properties available") + } + columns := make([]string, 0, num) + var values []driver.Value + tag := "json" + if len(tagName) > 0 { + tag = tagName[0] + } + for i := 0; i < num; i++ { + f := val.Type().Field(i) + column := f.Tag.Get(tag) + if len(column) > 0 { + columns = append(columns, column) + values = append(values, val.Field(i).Interface()) + } + } + if len(columns) == 0 { + return nil, errors.New("tag not match") + } + rows := &Rows{ + cols: columns, + nextErr: make(map[int]error), + converter: driver.DefaultParameterConverter, + } + return rows.AddRow(values...), nil +} + +// NewRowsFromInterface new Rows from struct or slice or array reflect with tagName +// NOTE: arr/slice must be of the same type +// tagName default "json" +func NewRowsFromInterface(m interface{}, tagName string) (*Rows, error) { + kind := reflect.TypeOf(m).Elem().Kind() + if kind == reflect.Ptr { + kind = reflect.TypeOf(m).Kind() + } + switch kind { + case reflect.Slice, reflect.Array: + return newRowsFromSliceOrArray(m, tagName) + case reflect.Struct: + return newRowsFromStruct(m, tagName) + default: + return nil, errors.New("the type m must in struct or slice or array") + } +} + +// new Rows and add multiple sets of driver.Value by using the tags of the element in reflect type slice/array +func newRowsFromSliceOrArray(m interface{}, tagName string) (*Rows, error) { + vals := reflect.ValueOf(m) + if vals.Len() == 0 { + return nil, errors.New("the len of m is zero") + } + typ := reflect.TypeOf(vals.Index(0).Interface()).Elem() + if typ.Kind() != reflect.Struct { + return nil, errors.New("param type must be struct") + } + if typ.NumField() == 0 { + return nil, errors.New("no properties available") + } + var idx []int + tag := "json" + if len(tagName) > 0 { + tag = tagName + } + columns := make([]string, 0, typ.NumField()) + for i := 0; i < typ.NumField(); i++ { + f := typ.Field(i) + column := f.Tag.Get(tag) + if len(column) > 0 { + columns = append(columns, column) + idx = append(idx, i) + } + } + if len(columns) == 0 { + return nil, errors.New("tag not match") + } + rows := &Rows{ + cols: columns, + nextErr: make(map[int]error), + converter: driver.DefaultParameterConverter, + } + for i := 0; i < vals.Len(); i++ { + val := vals.Index(i).Elem() + values := make([]driver.Value, 0, len(idx)) + for _, i := range idx { + // NOTE: field by name ptr nil + values = append(values, val.Field(i).Interface()) + } + rows.AddRow(values...) + } + return rows, nil +} + // NewRows allows Rows to be created from a // sql driver.Value slice or from the CSV string and // to be used as sql driver.Rows. diff --git a/rows_test.go b/rows_test.go index ef17521..c0a84f6 100644 --- a/rows_test.go +++ b/rows_test.go @@ -5,7 +5,9 @@ import ( "database/sql" "database/sql/driver" "fmt" + "reflect" "testing" + "time" ) const invalid = `☠☠☠ MEMORY OVERWRITTEN ☠☠☠ ` @@ -753,3 +755,57 @@ func ExampleRows_AddRows() { // Output: scanned id: 1 and title: one // scanned id: 2 and title: two } + +type MockStruct struct { + Type int `mock:"type"` + Name string `mock:"name"` + CreateTime time.Time `mock:"createTime"` +} + +func TestNewRowsFromInterface(t *testing.T) { + m := &MockStruct{ + Type: 1, + Name: "sqlMock", + CreateTime: time.Now(), + } + want := NewRows([]string{"type", "name", "createTime"}).AddRow(m.Type, m.Name, m.CreateTime) + actual, err := NewRowsFromInterface(m, "mock") + if err != nil { + t.Fatal(err) + } + same := reflect.DeepEqual(want.cols, actual.cols) + if !same { + t.Fatal("custom tag reflect failed") + } + same = reflect.DeepEqual(want.rows, actual.rows) + if !same { + t.Fatal("reflect value from tag failed") + } + m1 := &MockStruct{ + Type: 1, + Name: "sqlMock1", + CreateTime: time.Now(), + } + m2 := &MockStruct{ + Type: 2, + Name: "sqlMock2", + CreateTime: time.Now(), + } + arr := [3]*MockStruct{m, m1, m2} + want2 := NewRows([]string{"type", "name", "createTime"}) + for _, v := range arr { + want2.AddRow(v.Type, v.Name, v.CreateTime) + } + actual2, err := NewRowsFromInterface(arr, "mock") + if err != nil { + t.Fatal(err) + } + same = reflect.DeepEqual(want2.cols, actual2.cols) + if !same { + t.Fatal("custom tag reflect failed") + } + same = reflect.DeepEqual(want2.rows, actual2.rows) + if !same { + t.Fatal("reflect value from tag failed") + } +}