diff --git a/driver_test.go b/driver_test.go index b3e3afa..3ac1c13 100644 --- a/driver_test.go +++ b/driver_test.go @@ -178,3 +178,66 @@ CREATE TABLE IF NOT EXISTS Singers ( } }) } + +func TestPreparedStatements(t *testing.T) { + t.Run("prepared select", func(t *testing.T) { + db, err := sql.Open("zetasqlite", ":memory:") + if err != nil { + t.Fatal(err) + } + if _, err := db.Exec(` +CREATE TABLE IF NOT EXISTS Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + SingerInfo BYTES(MAX) +)`); err != nil { + t.Fatal(err) + } + stmt, err := db.Prepare("SELECT * FROM Singers WHERE SingerId = ?") + if err != nil { + t.Fatal(err) + } + rows, err := stmt.Query("123") + if err != nil { + t.Fatal(err) + } + if rows.Next() { + t.Fatal("found unexpected row; expected no rows") + } + }) + t.Run("prepared insert", func(t *testing.T) { + db, err := sql.Open("zetasqlite", ":memory:") + if err != nil { + t.Fatal(err) + } + if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS Items (ItemId INT64 NOT NULL)`); err != nil { + t.Fatal(err) + } + if _, err := db.Exec("INSERT `Items` (`ItemId`) VALUES (123)"); err != nil { + t.Fatal(err) + } + + // Test that executing without args fails + _, err = db.Exec("INSERT `Items` (`ItemId`) VALUES (?)") + if err == nil { + t.Fatal("expected error when inserting without args; got no error") + } + + stmt, err := db.Prepare("INSERT `Items` (`ItemId`) VALUES (?)") + if err != nil { + t.Fatal(err) + } + if _, err := stmt.Exec(456); err != nil { + t.Fatal(err) + } + + rows, err := db.Query("SELECT * FROM Items WHERE ItemId = 456") + if err != nil { + t.Fatal(err) + } + if !rows.Next() { + t.Fatal("expected no rows; expected one row") + } + }) +} diff --git a/internal/analyzer.go b/internal/analyzer.go index f089b93..c778cf0 100644 --- a/internal/analyzer.go +++ b/internal/analyzer.go @@ -696,6 +696,9 @@ func getParamsFromNode(node ast.Node) []*ast.ParameterNode { } func getArgsFromParams(values []driver.NamedValue, params []*ast.ParameterNode) ([]interface{}, error) { + if values == nil { + return nil, nil + } argNum := len(params) if len(values) < argNum { return nil, fmt.Errorf("not enough query arguments")