diff --git a/context.go b/context.go index 8231034..a8f3418 100644 --- a/context.go +++ b/context.go @@ -7,6 +7,16 @@ import ( "github.com/goccy/go-zetasqlite/internal" ) +// DisableQueryFormattingKey use to disable query formatting for queries that require raw SQLite access +type DisableQueryFormattingKey = internal.DisableQueryFormattingKey + +// WithQueryFormattingDisabled use for queries that require raw SQLite SQL. +// This is useful for queries that do not require additional functionality from go-zetasqlite. +// Utilizing this option often allows the SQLite query planner to generate more efficient plans. +func WithQueryFormattingDisabled(ctx context.Context) context.Context { + return context.WithValue(ctx, internal.DisableQueryFormattingKey{}, true) +} + // WithCurrentTime use to replace the current time with the specified time. // To replace the time, you need to pass the returned context as an argument to QueryContext. // `CURRENT_DATE`, `CURRENT_DATETIME`, `CURRENT_TIME`, `CURRENT_TIMESTAMP` functions are targeted. diff --git a/driver_test.go b/driver_test.go index 1a866c5..2f56048 100644 --- a/driver_test.go +++ b/driver_test.go @@ -206,7 +206,7 @@ CREATE TABLE IF NOT EXISTS Singers ( t.Fatal("found unexpected row; expected no rows") } }) - t.Run("prepared insert", func(t *testing.T) { + t.Run("prepared insert with named values", func(t *testing.T) { db, err := sql.Open("zetasqlite", ":memory:") if err != nil { t.Fatal(err) @@ -224,11 +224,11 @@ CREATE TABLE IF NOT EXISTS Singers ( t.Fatal("expected error when inserting without args; got no error") } - stmt, err := db.Prepare("INSERT `Items` (`ItemId`) VALUES (?)") + stmt, err := db.Prepare("INSERT `Items` (`ItemId`) VALUES (@itemID)") if err != nil { t.Fatal(err) } - if _, err := stmt.Exec(456); err != nil { + if _, err := stmt.Exec(sql.Named("itemID", 456)); err != nil { t.Fatal(err) } @@ -248,4 +248,29 @@ CREATE TABLE IF NOT EXISTS Singers ( t.Fatal("expected no rows; expected one row") } }) + + t.Run("prepared select with named values, formatting disabled, uppercased parameter", func(t *testing.T) { + db, err := sql.Open("zetasqlite", ":memory:") + ctx := zetasqlite.WithQueryFormattingDisabled(context.Background()) + if err != nil { + t.Fatal(err) + } + if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS Items (ItemId INT64 NOT NULL)`); err != nil { + t.Fatal(err) + } + if _, err := db.Exec("INSERT `Items` (`ItemId`) VALUES (123)"); err != nil { + t.Fatal(err) + } + + stmt, err := db.PrepareContext(ctx, "SELECT `ItemID` FROM `Items` WHERE `ItemID` = @itemID AND @bool = TRUE") + if err != nil { + t.Fatal("unexpected error when preparing stmt; got %w", err) + } + + var itemID string + err = stmt.QueryRowContext(ctx, sql.Named("itemID", 123), sql.Named("bool", true)).Scan(&itemID) + if err != nil { + t.Fatal("expected one row; got error %w", err) + } + }) } diff --git a/internal/analyzer.go b/internal/analyzer.go index acb81b7..ae420f8 100644 --- a/internal/analyzer.go +++ b/internal/analyzer.go @@ -20,6 +20,8 @@ type Analyzer struct { opt *zetasql.AnalyzerOptions } +type DisableQueryFormattingKey struct{} + func NewAnalyzer(catalog *Catalog) (*Analyzer, error) { opt, err := newAnalyzerOptions() if err != nil { @@ -511,14 +513,30 @@ func (a *Analyzer) newQueryStmtAction(ctx context.Context, query string, args [] Type: newType(col.Column().Type()), }) } - formattedQuery, err := newNode(node).FormatSQL(ctx) - if err != nil { - return nil, fmt.Errorf("failed to format query %s: %w", query, err) + var formattedQuery string + params := getParamsFromNode(node) + if disabledFormatting, ok := ctx.Value(DisableQueryFormattingKey{}).(bool); ok && disabledFormatting { + formattedQuery = query + // ZetaSQL will always lowercase parameter names, so we must match it in the query + queryBytes := []byte(query) + for _, param := range params { + location := param.ParseLocationRange() + start := location.Start().ByteOffset() + end := location.End().ByteOffset() + // Finds the parameter including its prefix i.e. @itemID + parameter := string(queryBytes[start:end]) + formattedQuery = strings.ReplaceAll(formattedQuery, parameter, strings.ToLower(parameter)) + } + } else { + var err error + formattedQuery, err = newNode(node).FormatSQL(ctx) + if err != nil { + return nil, fmt.Errorf("failed to format query %s: %w", query, err) + } } if formattedQuery == "" { return nil, fmt.Errorf("failed to format query %s", query) } - params := getParamsFromNode(node) queryArgs, err := getArgsFromParams(args, params) if err != nil { return nil, err diff --git a/internal/stmt.go b/internal/stmt.go index 54f1f01..a92e55b 100644 --- a/internal/stmt.go +++ b/internal/stmt.go @@ -152,15 +152,15 @@ func (s *DMLStmt) NumInput() int { } func (s *DMLStmt) Exec(args []driver.Value) (driver.Result, error) { - values := make([]interface{}, 0, len(args)) - for _, arg := range args { - values = append(values, arg) - } - newArgs, err := EncodeGoValues(values, s.args) + return s.ExecContext(context.Background(), valuesToNamedValues(args)) +} + +func (s *DMLStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + newArgs, err := getArgsFromParams(args, s.args) if err != nil { return nil, err } - result, err := s.stmt.Exec(newArgs...) + result, err := s.stmt.ExecContext(ctx, newArgs...) if err != nil { return nil, fmt.Errorf( "failed to execute query %s: args %v: %w", @@ -172,10 +172,6 @@ func (s *DMLStmt) Exec(args []driver.Value) (driver.Result, error) { return result, nil } -func (s *DMLStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - return nil, fmt.Errorf("unimplemented ExecContext for DMLStmt") -} - func (s *DMLStmt) Query(args []driver.Value) (driver.Rows, error) { return nil, fmt.Errorf("unsupported query for DMLStmt") } @@ -224,16 +220,28 @@ func (s *QueryStmt) ExecContext(ctx context.Context, query string, args []driver return nil, fmt.Errorf("unsupported exec for QueryStmt") } -func (s *QueryStmt) Query(args []driver.Value) (driver.Rows, error) { - values := make([]interface{}, 0, len(args)) +func valuesToNamedValues(args []driver.Value) []driver.NamedValue { + values := make([]driver.NamedValue, 0, len(args)) for _, arg := range args { - values = append(values, arg) + if namedValue, ok := arg.(driver.NamedValue); ok { + values = append(values, namedValue) + } + values = append(values, driver.NamedValue{Value: arg}) } - newArgs, err := EncodeGoValues(values, s.args) + + return values +} + +func (s *QueryStmt) Query(args []driver.Value) (driver.Rows, error) { + return s.QueryContext(context.Background(), valuesToNamedValues(args)) +} + +func (s *QueryStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + newArgs, err := getArgsFromParams(args, s.args) if err != nil { return nil, err } - rows, err := s.stmt.Query(newArgs...) + rows, err := s.stmt.QueryContext(ctx, newArgs...) if err != nil { return nil, fmt.Errorf( "failed to query %s: args: %v: %w", @@ -244,7 +252,3 @@ func (s *QueryStmt) Query(args []driver.Value) (driver.Rows, error) { } return &Rows{rows: rows, columns: s.outputColumns}, nil } - -func (s *QueryStmt) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - return nil, fmt.Errorf("unimplemented QueryContext for QueryStmt") -}