Skip to content

Commit

Permalink
Support $NNN-style named parameter. Close #187
Browse files Browse the repository at this point in the history
  • Loading branch information
mattn committed Mar 21, 2015
1 parent 5253daf commit a6c2085
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 4 deletions.
61 changes: 57 additions & 4 deletions sqlite3.go
Expand Up @@ -121,6 +121,7 @@ type SQLiteTx struct {
type SQLiteStmt struct {
c *SQLiteConn
s *C.sqlite3_stmt
nv int
t string
closed bool
cls bool
Expand Down Expand Up @@ -368,7 +369,19 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
if tail != nil && C.strlen(tail) > 0 {
t = strings.TrimSpace(C.GoString(tail))
}
ss := &SQLiteStmt{c: c, s: s, t: t}
nv := int(C.sqlite3_bind_parameter_count(s))
if nv > 0 {
pn := C.GoString(C.sqlite3_bind_parameter_name(s, 1))
/* TODO: map argument for named parameters
if len(pn) > 0 && pn[0] == '$' && pn[1] != '1' {
nv = -1
}
*/
if len(pn) > 0 && pn[0] != '?' {
nv = -1
}
}
ss := &SQLiteStmt{c: c, s: s, nv: nv, t: t}
runtime.SetFinalizer(ss, (*SQLiteStmt).Close)
return ss, nil
}
Expand All @@ -392,7 +405,12 @@ func (s *SQLiteStmt) Close() error {

// Return a number of parameters.
func (s *SQLiteStmt) NumInput() int {
return int(C.sqlite3_bind_parameter_count(s.s))
return s.nv
}

type bindArg struct {
n int
v driver.Value
}

func (s *SQLiteStmt) bind(args []driver.Value) error {
Expand All @@ -401,8 +419,43 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
return s.c.lastError()
}

for i, v := range args {
n := C.int(i + 1)
var vargs []bindArg
narg := len(args)
if s.nv == -1 {
/* TODO: map argument for named parameters
if narg == 1 {
if m, ok := args[0].(map[string]driver.Value); ok {
for k, v := range m {
pn := C.CString(k)
if pi := int(C.sqlite3_bind_parameter_index(s.s, pn)); pi > 0 {
println(pi)
vargs = append(vargs, bindArg{pi, v})
}
C.free(unsafe.Pointer(pn))
}
}
narg = 0
}
*/
if narg > 0 {
for i := 0; i < narg; i++ {
pn := C.CString(fmt.Sprint(i + 1))
if pi := int(C.sqlite3_bind_parameter_index(s.s, pn)); pi > 0 {
vargs = append(vargs, bindArg{pi, args[i]})
}
C.free(unsafe.Pointer(pn))
}
}
} else {
vargs = make([]bindArg, narg)
for i, v := range args {
vargs[i] = bindArg{i + 1, v}
}
}

for _, varg := range vargs {
n := C.int(varg.n)
v := varg.v
switch v := v.(type) {
case nil:
rv = C.sqlite3_bind_null(s.s, n)
Expand Down
36 changes: 36 additions & 0 deletions sqlite3_test.go
Expand Up @@ -909,3 +909,39 @@ func TestVersion(t *testing.T) {
t.Errorf("Version failed %q, %d, %q\n", s, n, id)
}
}

func TestNumberNamedParams(t *testing.T) {
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer os.Remove(tempFilename)
defer db.Close()

_, err = db.Exec(`
create table foo (id integer, name text, extra text);
`)
if err != nil {
t.Error("Failed to call db.Query:", err)
}

_, err = db.Exec(`insert into foo(id, name, extra)) values($1, $2, $2)`, 1, "foo")
if err != nil {
t.Error("Failed to call db.Exec:", err)
}

row := db.QueryRow(`select id, name, extra where id = $1 and extra = $2`, 1, "foo")
if row == nil {
t.Error("Failed to call db.QueryRow")
}
var id int
var extra string
err = row.Scan(&id, &extra)
if err != nil {
t.Error("Failed to db.Scan:", err)
}
if id != 1 || extra != "foo" {
t.Error("Failed to db.QueryRow: not matched results")
}
}

0 comments on commit a6c2085

Please sign in to comment.