Skip to content

Commit

Permalink
fix: make no argument passed validation opt-in
Browse files Browse the repository at this point in the history
  • Loading branch information
IvoGoman committed Dec 11, 2023
1 parent b2d135c commit a6a27b7
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 15 deletions.
33 changes: 33 additions & 0 deletions expectations.go
Expand Up @@ -134,11 +134,27 @@ type ExpectedQuery struct {
// WithArgs will match given expected args to actual database query arguments.
// if at least one argument does not match, it will return an error. For specific
// arguments an sqlmock.Argument interface can be used to match an argument.
// Must not be used together with WithoutArgs()
func (e *ExpectedQuery) WithArgs(args ...driver.Value) *ExpectedQuery {
if e.noArgs {
panic("WithArgs() and WithoutArgs() must not be used together")
}
e.args = args
return e
}

// WithoutArgs will ensure that no arguments are passed for this query.
// if at least one argument is passed, it will return an error. This allows
// for stricter validation of the query arguments.
// Must no be used together with WithArgs()
func (e *ExpectedQuery) WithoutArgs() *ExpectedQuery {
if len(e.args) > 0 {
panic("WithoutArgs() and WithArgs() must not be used together")
}
e.noArgs = true
return e
}

// RowsWillBeClosed expects this query rows to be closed.
func (e *ExpectedQuery) RowsWillBeClosed() *ExpectedQuery {
e.rowsMustBeClosed = true
Expand Down Expand Up @@ -195,11 +211,27 @@ type ExpectedExec struct {
// WithArgs will match given expected args to actual database exec operation arguments.
// if at least one argument does not match, it will return an error. For specific
// arguments an sqlmock.Argument interface can be used to match an argument.
// Must not be used together with WithoutArgs()
func (e *ExpectedExec) WithArgs(args ...driver.Value) *ExpectedExec {
if len(e.args) > 0 {
panic("WithArgs() and WithoutArgs() must not be used together")
}
e.args = args
return e
}

// WithoutArgs will ensure that no args are passed for this expected database exec action.
// if at least one argument is passed, it will return an error. This allows for stricter
// validation of the query arguments.
// Must not be used together with WithArgs()
func (e *ExpectedExec) WithoutArgs() *ExpectedExec {
if len(e.args) > 0 {
panic("WithoutArgs() and WithArgs() must not be used together")
}
e.noArgs = true
return e
}

// WillReturnError allows to set an error for expected database exec action
func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec {
e.err = err
Expand Down Expand Up @@ -338,6 +370,7 @@ type queryBasedExpectation struct {
expectSQL string
converter driver.ValueConverter
args []driver.Value
noArgs bool // ensure no args are passed
}

// ExpectedPing is used to manage *sql.DB.Ping expectations.
Expand Down
3 changes: 2 additions & 1 deletion expectations_before_go18.go
@@ -1,3 +1,4 @@
//go:build !go1.8
// +build !go1.8

package sqlmock
Expand All @@ -17,7 +18,7 @@ func (e *ExpectedQuery) WillReturnRows(rows *Rows) *ExpectedQuery {

func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
if nil == e.args {
if len(args) > 0 {
if e.noArgs && len(args) > 0 {
return fmt.Errorf("expected 0, but got %d arguments", len(args))
}
return nil
Expand Down
10 changes: 8 additions & 2 deletions expectations_before_go18_test.go
@@ -1,3 +1,4 @@
//go:build !go1.8
// +build !go1.8

package sqlmock
Expand All @@ -9,10 +10,15 @@ import (
)

func TestQueryExpectationArgComparison(t *testing.T) {
e := &queryBasedExpectation{converter: driver.DefaultParameterConverter}
e := &queryBasedExpectation{converter: driver.DefaultParameterConverter, noArgs: true}
against := []namedValue{{Value: int64(5), Ordinal: 1}}
if err := e.argsMatches(against); err == nil {
t.Error("arguments should not match, since no expectation was set, but argument was passed")
t.Error("arguments should not match, since argument was passed, but noArgs was set")
}

e.noArgs = false
if err := e.argsMatches(against); err != nil {
t.Error("arguments should match, since argument was passed, but no expected args or noArgs was set")
}

e.args = []driver.Value{5, "str"}
Expand Down
3 changes: 2 additions & 1 deletion expectations_go18.go
@@ -1,3 +1,4 @@
//go:build go1.8
// +build go1.8

package sqlmock
Expand Down Expand Up @@ -30,7 +31,7 @@ func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery {

func (e *queryBasedExpectation) argsMatches(args []driver.NamedValue) error {
if nil == e.args {
if len(args) > 0 {
if e.noArgs && len(args) > 0 {
return fmt.Errorf("expected 0, but got %d arguments", len(args))
}
return nil
Expand Down
19 changes: 15 additions & 4 deletions expectations_go18_test.go
@@ -1,3 +1,4 @@
//go:build go1.8
// +build go1.8

package sqlmock
Expand All @@ -10,10 +11,15 @@ import (
)

func TestQueryExpectationArgComparison(t *testing.T) {
e := &queryBasedExpectation{converter: driver.DefaultParameterConverter}
e := &queryBasedExpectation{converter: driver.DefaultParameterConverter, noArgs: true}
against := []driver.NamedValue{{Value: int64(5), Ordinal: 1}}
if err := e.argsMatches(against); err == nil {
t.Error("arguments should not match, since no expectation was set, but argument was passed")
t.Error("arguments should not match, since argument was passed, but noArgs was set")
}

e.noArgs = false
if err := e.argsMatches(against); err != nil {
t.Error("arguments should match, since argument was passed, but no expected args or noArgs was set")
}

e.args = []driver.Value{5, "str"}
Expand Down Expand Up @@ -102,10 +108,15 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) {
}

func TestQueryExpectationNamedArgComparison(t *testing.T) {
e := &queryBasedExpectation{converter: driver.DefaultParameterConverter}
e := &queryBasedExpectation{converter: driver.DefaultParameterConverter, noArgs: true}
against := []driver.NamedValue{{Value: int64(5), Name: "id"}}
if err := e.argsMatches(against); err == nil {
t.Errorf("arguments should not match, since no expectation was set, but argument was passed")
t.Error("arguments should not match, since argument was passed, but noArgs was set")
}

e.noArgs = false
if err := e.argsMatches(against); err != nil {
t.Error("arguments should match, since argument was passed, but no expected args or noArgs was set")
}

e.args = []driver.Value{
Expand Down
22 changes: 22 additions & 0 deletions expectations_test.go
Expand Up @@ -101,3 +101,25 @@ func TestCustomValueConverterQueryScan(t *testing.T) {
t.Error(err)
}
}

func TestQueryWithNoArgsAndWithArgsPanic(t *testing.T) {
defer func() {
if r := recover(); r != nil {
return
}
t.Error("Expected panic for using WithArgs and ExpectNoArgs together")
}()
mock := &sqlmock{}
mock.ExpectQuery("SELECT (.+) FROM user").WithArgs("John").WithoutArgs()
}

func TestExecWithNoArgsAndWithArgsPanic(t *testing.T) {
defer func() {
if r := recover(); r != nil {
return
}
t.Error("Expected panic for using WithArgs and ExpectNoArgs together")
}()
mock := &sqlmock{}
mock.ExpectExec("^INSERT INTO user").WithArgs("John").WithoutArgs()
}
2 changes: 1 addition & 1 deletion sqlmock_go18_test.go
@@ -1,3 +1,4 @@
//go:build go1.8
// +build go1.8

package sqlmock
Expand Down Expand Up @@ -437,7 +438,6 @@ func TestContextExecErrorDelay(t *testing.T) {
// test that return of error is delayed
var delay time.Duration = 100 * time.Millisecond
mock.ExpectExec("^INSERT INTO articles").
WithArgs("hello").
WillReturnError(errors.New("slow fail")).
WillDelayFor(delay)

Expand Down
37 changes: 31 additions & 6 deletions sqlmock_test.go
Expand Up @@ -749,6 +749,16 @@ func TestRunExecsWithExpectedErrorMeetsExpectations(t *testing.T) {
}
}

func TestRunExecsWithNoArgsExpectedMeetsExpectations(t *testing.T) {
db, dbmock, _ := New()
dbmock.ExpectExec("THE FIRST EXEC").WithoutArgs().WillReturnResult(NewResult(0, 0))

_, err := db.Exec("THE FIRST EXEC", "foobar")
if err == nil {
t.Fatalf("expected error, but there wasn't any")
}
}

func TestRunQueryWithExpectedErrorMeetsExpectations(t *testing.T) {
db, dbmock, _ := New()
dbmock.ExpectQuery("THE FIRST QUERY").WillReturnError(fmt.Errorf("big bad bug"))
Expand Down Expand Up @@ -959,7 +969,7 @@ func TestPrepareExec(t *testing.T) {
mock.ExpectBegin()
ep := mock.ExpectPrepare("INSERT INTO ORDERS\\(ID, STATUS\\) VALUES \\(\\?, \\?\\)")
for i := 0; i < 3; i++ {
ep.ExpectExec().WithArgs(i, "Hello"+strconv.Itoa(i)).WillReturnResult(NewResult(1, 1))
ep.ExpectExec().WillReturnResult(NewResult(1, 1))
}
mock.ExpectCommit()
tx, _ := db.Begin()
Expand Down Expand Up @@ -1073,7 +1083,7 @@ func TestPreparedStatementCloseExpectation(t *testing.T) {
defer db.Close()

ep := mock.ExpectPrepare("INSERT INTO ORDERS").WillBeClosed()
ep.ExpectExec().WithArgs(1, "Hello").WillReturnResult(NewResult(1, 1))
ep.ExpectExec().WillReturnResult(NewResult(1, 1))

stmt, err := db.Prepare("INSERT INTO ORDERS(ID, STATUS) VALUES (?, ?)")
if err != nil {
Expand Down Expand Up @@ -1104,7 +1114,6 @@ func TestExecExpectationErrorDelay(t *testing.T) {
// test that return of error is delayed
var delay time.Duration = 100 * time.Millisecond
mock.ExpectExec("^INSERT INTO articles").
WithArgs("hello").
WillReturnError(errors.New("slow fail")).
WillDelayFor(delay)

Expand Down Expand Up @@ -1230,10 +1239,10 @@ func Test_sqlmock_Prepare_and_Exec(t *testing.T) {

mock.ExpectPrepare("SELECT (.+) FROM users WHERE (.+)")
expected := NewResult(1, 1)
mock.ExpectExec("SELECT (.+) FROM users WHERE (.+)").WithArgs("test").
mock.ExpectExec("SELECT (.+) FROM users WHERE (.+)").
WillReturnResult(expected)
expectedRows := mock.NewRows([]string{"id", "name", "email"}).AddRow(1, "test", "test@example.com")
mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WithArgs("test").WillReturnRows(expectedRows)
mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WillReturnRows(expectedRows)

got, err := mock.(*sqlmock).Prepare(query)
if err != nil {
Expand Down Expand Up @@ -1326,7 +1335,7 @@ func Test_sqlmock_Query(t *testing.T) {
}
defer db.Close()
expectedRows := mock.NewRows([]string{"id", "name", "email"}).AddRow(1, "test", "test@example.com")
mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WithArgs("test").WillReturnRows(expectedRows)
mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WillReturnRows(expectedRows)
query := "SELECT name, email FROM users WHERE name = ?"
rows, err := mock.(*sqlmock).Query(query, []driver.Value{"test"})
if err != nil {
Expand All @@ -1340,3 +1349,19 @@ func Test_sqlmock_Query(t *testing.T) {
return
}
}

func Test_sqlmock_QueryExpectWithoutArgs(t *testing.T) {
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
expectedRows := mock.NewRows([]string{"id", "name", "email"}).AddRow(1, "test", "test@example.com")
mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WillReturnRows(expectedRows).WithoutArgs()
query := "SELECT name, email FROM users WHERE name = ?"
_, err = mock.(*sqlmock).Query(query, []driver.Value{"test"})
if err == nil {
t.Errorf("error expected")
return
}
}

0 comments on commit a6a27b7

Please sign in to comment.