Skip to content

Commit

Permalink
[+] feat: supports specifying auto-increment ID
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew-M-C committed Mar 18, 2021
1 parent 4b0c240 commit 2684619
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 20 deletions.
77 changes: 70 additions & 7 deletions basic_curd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,6 @@ func TestMultiConds(t *testing.T) {
return
}

// type DateRecord struct {
// ID int64 `db:"f_id" mysqlx:"increment:true"`
// Year int32 `db:"f_year"`
// Month int8 `db:"f_month"`
// Day int8 `db:"f_day"`
// }

type DateRecord struct {
ID int64 `db:"id" mysqlx:"increment:true" comment:"自增 ID"`
BusinessID int32 `db:"business_id" comment:"集成商 ID"`
Expand All @@ -376,3 +369,73 @@ func (DateRecord) Options() Options {
TableName: "t_date_record",
}
}

type recForAutoIncTest struct {
ID int64 `db:"id" mysqlx:"increment:true"`
String string `db:"string" mysqlx:"type:varchar(32)"`
}

const recForAutoIncTestTableName = "t_mysqlx_rec_for_auto_inc_test"

func (recForAutoIncTest) Options() Options {
return Options{
TableName: recForAutoIncTestTableName,
Uniques: []Unique{{
Name: "u_string",
Fields: []string{"string"},
}},
}
}

func TestSpecifyingAutoIncrementID(t *testing.T) {
var err error

d, err := getDB()
if err != nil {
t.Errorf("open failed: %v", err)
return
}

d.Sqlx().Exec("DROP TABLE ?", recForAutoIncTestTableName)
id := int64(2)
s := "Hello, world!"
r := recForAutoIncTest{
ID: id,
String: s,
}

err = d.CreateTable(r)
if err != nil {
t.Errorf("CreateTable error: %v", err)
return
}

ins, err := d.Insert(&r)
if err != nil {
t.Errorf("Insert error: %w", err)
return
}
if inserted, _ := ins.LastInsertId(); inserted != id {
t.Errorf("expected insert ID %d, but got %d", id, inserted)
return
}

// ensure again
var res []*recForAutoIncTest
err = d.Select(&res, Condition("string", "=", s))
if err != nil {
t.Errorf("Select error: %v", err)
return
}
if 0 == len(res) {
t.Errorf("read no records from %s", recForAutoIncTestTableName)
return
}

if res[0].ID != id {
t.Errorf("expected insert ID %d, but got %d", id, res[0].ID)
return
}

return
}
25 changes: 20 additions & 5 deletions insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ var (

// ========

// InsertFields return keys and values for inserting. Auto-increment fields will be ignored
// InsertFields return keys and values for inserting, auto-increment fields will be ignored if its value is zero
func (d *xdb) InsertFields(s interface{}, backQuoted bool) (keys []string, values []string, err error) {
return d.insertFields(s, backQuoted, false)
}

func (d *xdb) insertFields(s interface{}, backQuoted bool, ignoreNonZeroIncrement bool) (keys []string, values []string, err error) {
t := reflect.TypeOf(s)
v := reflect.ValueOf(s)

Expand All @@ -45,6 +49,7 @@ func (d *xdb) InsertFields(s interface{}, backQuoted bool) (keys []string, value
continue
}

incrementField := false
fieldName := getFieldName(&tf)
if fieldName == "-" {
continue
Expand All @@ -56,10 +61,11 @@ func (d *xdb) InsertFields(s interface{}, backQuoted bool) (keys []string, value
}
} else {
f, exist := fieldMap[fieldName]
if false == exist || f.AutoIncrement {
if false == exist {
// log.Println(fieldName, "not exists")
continue
}
incrementField = f.AutoIncrement
}

var val string
Expand Down Expand Up @@ -105,7 +111,7 @@ func (d *xdb) InsertFields(s interface{}, backQuoted bool) (keys []string, value
default:
if reflect.Struct == tf.Type.Kind() {
// log.Println("Embedded struct: ", tf.Type)
embedKey, embedValue, err := d.InsertFields(vf.Interface(), false)
embedKey, embedValue, err := d.insertFields(vf.Interface(), false, ignoreNonZeroIncrement)
if err != nil {
return nil, nil, err
}
Expand All @@ -117,6 +123,15 @@ func (d *xdb) InsertFields(s interface{}, backQuoted bool) (keys []string, value
continue
}

if incrementField {
if ignoreNonZeroIncrement {
continue
}
if val == "0" {
continue
}
}

keys = append(keys, fieldName)
values = append(values, val)
// continue
Expand All @@ -131,7 +146,7 @@ func (d *xdb) InsertFields(s interface{}, backQuoted bool) (keys []string, value
return
}

// Insert insert a given structure. auto-increment fields will be ignored
// Insert insert a given structure. auto-increment fields will be ignored if its value is zero
func (d *xdb) Insert(v interface{}, opts ...Options) (result sql.Result, err error) {
return d.insert(d.db, v, opts...)
}
Expand All @@ -152,7 +167,7 @@ func (d *xdb) insert(obj sqlObj, v interface{}, opts ...Options) (result sql.Res
return nil, fmt.Errorf("parameter type invalid (%v)", prototypeType)
}

keys, values, err := d.InsertFields(v, true)
keys, values, err := d.insertFields(v, true, false)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions insert_many.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

// InsertMany insert multiple records into table. If additional option with table name is not given,
// mysqlx will use the FIRST table name in records for all.
// mysqlx will use the FIRST table name in records for all. All auto-increment fields will be ignored.
func (d *xdb) InsertMany(records interface{}, opts ...Options) (result sql.Result, err error) {
return d.insertMany(d.db, records, opts...)
}
Expand Down Expand Up @@ -58,7 +58,7 @@ func (d *xdb) insertMany(obj sqlObj, records interface{}, opts ...Options) (resu
return nil, fmt.Errorf("empty table name for type %v", reflect.TypeOf(v))
}

keys, values, err := d.InsertFields(v, true)
keys, values, err := d.insertFields(v, true, true)
if err != nil {
return
}
Expand All @@ -80,7 +80,7 @@ func (d *xdb) insertMany(obj sqlObj, records interface{}, opts ...Options) (resu
v = va.Index(i).Interface()
}

_, values, err = d.InsertFields(v, true)
_, values, err = d.insertFields(v, true, true)
if err != nil {
return
}
Expand Down
6 changes: 3 additions & 3 deletions insert_on_duplicate_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (d *xdb) insertOnDuplicateKeyUpdate(
return nil, fmt.Errorf("parameter type invalid (%v)", prototypeType)
}

keys, values, err := d.InsertFields(v, true)
keys, values, err := d.insertFields(v, true, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -133,7 +133,7 @@ func (d *xdb) insertManyOnDuplicateKeyUpdate(
return nil, fmt.Errorf("empty table name for type %v", reflect.TypeOf(v))
}

keys, values, err := d.InsertFields(v, true)
keys, values, err := d.insertFields(v, true, false)
if err != nil {
return
}
Expand Down Expand Up @@ -164,7 +164,7 @@ func (d *xdb) insertManyOnDuplicateKeyUpdate(
v = va.Index(i).Interface()
}

_, values, err = d.InsertFields(v, true)
_, values, err = d.insertFields(v, true, true)
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion select_or_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (d *xdb) selectOrInsert(
}

// handle insert fields and values
keys, values, err := d.InsertFields(insert, false)
keys, values, err := d.insertFields(insert, false, false)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion structs/structs.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Package structs is a simple test package for mysqlx. Please do not use this.
// Package structs is a utility package for mysqlx. Please do not use this.
package structs

import (
Expand Down

0 comments on commit 2684619

Please sign in to comment.