diff --git a/sqlite3.go b/sqlite3.go index 91aea8fc..4e68e966 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -121,6 +121,7 @@ type SQLiteTx struct { type SQLiteStmt struct { c *SQLiteConn s *C.sqlite3_stmt + nv int t string closed bool cls bool @@ -368,7 +369,19 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { if tail != nil && C.strlen(tail) > 0 { t = strings.TrimSpace(C.GoString(tail)) } - ss := &SQLiteStmt{c: c, s: s, t: t} + nv := int(C.sqlite3_bind_parameter_count(s)) + if nv > 0 { + pn := C.GoString(C.sqlite3_bind_parameter_name(s, 1)) + /* TODO: map argument for named parameters + if len(pn) > 0 && pn[0] == '$' && pn[1] != '1' { + nv = -1 + } + */ + if len(pn) > 0 && pn[0] != '?' { + nv = -1 + } + } + ss := &SQLiteStmt{c: c, s: s, nv: nv, t: t} runtime.SetFinalizer(ss, (*SQLiteStmt).Close) return ss, nil } @@ -392,7 +405,12 @@ func (s *SQLiteStmt) Close() error { // Return a number of parameters. func (s *SQLiteStmt) NumInput() int { - return int(C.sqlite3_bind_parameter_count(s.s)) + return s.nv +} + +type bindArg struct { + n int + v driver.Value } func (s *SQLiteStmt) bind(args []driver.Value) error { @@ -401,8 +419,43 @@ func (s *SQLiteStmt) bind(args []driver.Value) error { return s.c.lastError() } - for i, v := range args { - n := C.int(i + 1) + var vargs []bindArg + narg := len(args) + if s.nv == -1 { + /* TODO: map argument for named parameters + if narg == 1 { + if m, ok := args[0].(map[string]driver.Value); ok { + for k, v := range m { + pn := C.CString(k) + if pi := int(C.sqlite3_bind_parameter_index(s.s, pn)); pi > 0 { + println(pi) + vargs = append(vargs, bindArg{pi, v}) + } + C.free(unsafe.Pointer(pn)) + } + } + narg = 0 + } + */ + if narg > 0 { + for i := 0; i < narg; i++ { + pn := C.CString(fmt.Sprint(i + 1)) + if pi := int(C.sqlite3_bind_parameter_index(s.s, pn)); pi > 0 { + vargs = append(vargs, bindArg{pi, args[i]}) + } + C.free(unsafe.Pointer(pn)) + } + } + } else { + vargs = make([]bindArg, narg) + for i, v := range args { + vargs[i] = bindArg{i + 1, v} + } + } + + for _, varg := range vargs { + n := C.int(varg.n) + v := varg.v switch v := v.(type) { case nil: rv = C.sqlite3_bind_null(s.s, n) diff --git a/sqlite3_test.go b/sqlite3_test.go index 6570b527..245f363d 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -909,3 +909,39 @@ func TestVersion(t *testing.T) { t.Errorf("Version failed %q, %d, %q\n", s, n, id) } } + +func TestNumberNamedParams(t *testing.T) { + tempFilename := TempFilename() + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer os.Remove(tempFilename) + defer db.Close() + + _, err = db.Exec(` + create table foo (id integer, name text, extra text); + `) + if err != nil { + t.Error("Failed to call db.Query:", err) + } + + _, err = db.Exec(`insert into foo(id, name, extra)) values($1, $2, $2)`, 1, "foo") + if err != nil { + t.Error("Failed to call db.Exec:", err) + } + + row := db.QueryRow(`select id, name, extra where id = $1 and extra = $2`, 1, "foo") + if row == nil { + t.Error("Failed to call db.QueryRow") + } + var id int + var extra string + err = row.Scan(&id, &extra) + if err != nil { + t.Error("Failed to db.Scan:", err) + } + if id != 1 || extra != "foo" { + t.Error("Failed to db.QueryRow: not matched results") + } +}