Skip to content

Commit

Permalink
refactor sql.ErrNoRows errors
Browse files Browse the repository at this point in the history
  • Loading branch information
aradwann committed Jan 19, 2024
1 parent efe529e commit e5f9715
Show file tree
Hide file tree
Showing 10 changed files with 16 additions and 14 deletions.
4 changes: 2 additions & 2 deletions db/store/error.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package db

import (
"database/sql"
"errors"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)

Expand All @@ -12,7 +12,7 @@ const (
UniqueViolation = "23505"
)

var ErrRecordNotFound = pgx.ErrNoRows
var ErrRecordNotFound = sql.ErrNoRows

var ErrUniqueViolation = &pgconn.PgError{
Code: UniqueViolation,
Expand Down
4 changes: 2 additions & 2 deletions db/store/error_test.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package db

import (
"database/sql"
"errors"
"testing"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/assert"
)

func TestErrRecordNotFound(t *testing.T) {
err := ErrRecordNotFound
assert.True(t, pgx.ErrNoRows == err)
assert.True(t, sql.ErrNoRows == err)
}

func TestErrUniqueViolation(t *testing.T) {
Expand Down
3 changes: 2 additions & 1 deletion db/store/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package db
import (
"context"
"database/sql"
"errors"
"log"
"time"

Expand Down Expand Up @@ -70,7 +71,7 @@ func scanSessionFromRow(row *sql.Row, session *Session) error {
}
if err != nil {
// Check for a specific error related to the scan
if err == sql.ErrNoRows {
if errors.Is(err, ErrRecordNotFound) {
// fmt.Println("No rows were returned.")
return err
} else {
Expand Down
3 changes: 1 addition & 2 deletions db/store/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package db

import (
"context"
"database/sql"
"testing"
"time"

Expand Down Expand Up @@ -65,5 +64,5 @@ func TestGetSessionNotFound(t *testing.T) {
session, err := testQueries.GetSession(context.Background(), uuid.New())
require.Equal(t, session, Session{})

require.ErrorIs(t, err, sql.ErrNoRows)
require.ErrorIs(t, err, ErrRecordNotFound)
}
3 changes: 2 additions & 1 deletion db/store/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package db
import (
"context"
"database/sql"
"errors"
"log"
)

Expand Down Expand Up @@ -82,7 +83,7 @@ func scanUserFromRow(row *sql.Row, user *User) error {
// Check for errors after scanning
if err != nil {
// Handle scan-related errors
if err == sql.ErrNoRows {
if errors.Is(err, ErrRecordNotFound) {
// fmt.Println("No rows were returned.")
return err
} else {
Expand Down
3 changes: 2 additions & 1 deletion db/store/verify_email.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package db
import (
"context"
"database/sql"
"errors"
"log"
)

Expand Down Expand Up @@ -73,7 +74,7 @@ func scanVerifyEmailFromRow(row *sql.Row, verifyEmail *VerifyEmail) error {
// Check for errors after scanning
if err != nil {
// Handle scan-related errors
if err == sql.ErrNoRows {
if errors.Is(err, ErrRecordNotFound) {
// fmt.Println("No rows were returned.")
return err
} else {
Expand Down
3 changes: 1 addition & 2 deletions gapi/rpc_login_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package gapi

import (
"context"
"database/sql"
"errors"

db "github.com/aradwann/eenergy/db/store"
Expand All @@ -24,7 +23,7 @@ func (server *Server) LoginUser(ctx context.Context, req *pb.LoginUserRequest) (

user, err := server.store.GetUser(ctx, req.GetUsername())
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
if errors.Is(err, db.ErrRecordNotFound) {
return nil, status.Errorf(codes.NotFound, "user not found")

}
Expand Down
3 changes: 2 additions & 1 deletion gapi/rpc_update_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gapi
import (
"context"
"database/sql"
"errors"
"time"

db "github.com/aradwann/eenergy/db/store"
Expand Down Expand Up @@ -49,7 +50,7 @@ func (server *Server) UpdateUser(ctx context.Context, req *pb.UpdateUserRequest)

user, err := server.store.UpdateUser(ctx, arg)
if err != nil {
if err == sql.ErrNoRows {
if errors.Is(err, db.ErrRecordNotFound) {
return nil, status.Errorf(codes.NotFound, "user not found")
}
return nil, status.Errorf(codes.Internal, "failed to Update user: %s", err)
Expand Down
2 changes: 1 addition & 1 deletion gapi/rpc_update_user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func TestUpdateUserAPI(t *testing.T) {
store.EXPECT().
UpdateUser(gomock.Any(), gomock.Any()).
Times(1).
Return(db.User{}, sql.ErrNoRows)
Return(db.User{}, db.ErrRecordNotFound)

},
buildContext: func(t *testing.T, tokenMaker token.Maker) context.Context {
Expand Down
2 changes: 1 addition & 1 deletion worker/task_send_verify_email.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (processor *RedisTaskProcessor) ProcessTaskSendVerifyEmail(ctx context.Cont
user, err := processor.store.GetUser(ctx, payload.Username)
if err != nil {
// if the user is not found try again later, as the creation might be not commited yet
// if errors.Is(err, sql.ErrNoRows) {
// if errors.Is(err, ErrRecordNotFound) {
// return fmt.Errorf("user doesn't exist: %w", asynq.SkipRetry)

// }
Expand Down

0 comments on commit e5f9715

Please sign in to comment.