Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sqlmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,11 @@ func (c *sqlmock) Close() error {

func (c *sqlmock) ExpectationsWereMet() error {
for _, e := range c.expected {
if !e.fulfilled() {
e.Lock()
fulfilled := e.fulfilled()
e.Unlock()

if !fulfilled {
return fmt.Errorf("there is a remaining expectation which was not matched: %s", e)
}

Expand Down
50 changes: 50 additions & 0 deletions sqlmock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1167,3 +1167,53 @@ func TestNewRows(t *testing.T) {
t.Errorf("expecting to create a row with columns %v, actual colmns are %v", r.cols, columns)
}
}

// This is actually a test of ExpectationsWereMet. Without a lock around e.fulfilled() inside
// ExpectationWereMet, the race detector complains if e.triggered is being read while it is also
// being written by the query running in another goroutine.
func TestQueryWithTimeout(t *testing.T) {
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()

rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world")

mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
WillDelayFor(15 * time.Millisecond). // Query will take longer than timeout
WithArgs(5).
WillReturnRows(rs)

_, err = queryWithTimeout(10*time.Millisecond, db, "SELECT (.+) FROM articles WHERE id = ?", 5)
if err == nil {
t.Errorf("expecting query to time out")
}

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

func queryWithTimeout(t time.Duration, db *sql.DB, query string, args ...interface{}) (*sql.Rows, error) {
rowsChan := make(chan *sql.Rows, 1)
errChan := make(chan error, 1)

go func() {
rows, err := db.Query(query, args...)
if err != nil {
errChan <- err
return
}
rowsChan <- rows
}()

select {
case rows := <-rowsChan:
return rows, nil
case err := <-errChan:
return nil, err
case <-time.After(t):
return nil, fmt.Errorf("query timed out after %v", t)
}
}