Skip to content

Commit

Permalink
add some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aradwann committed Jan 1, 2024
1 parent ac9be44 commit d49df3c
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 7 deletions.
52 changes: 52 additions & 0 deletions db/store/error_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package db

import (
"errors"
"testing"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/assert"
)

func TestErrRecordNotFound(t *testing.T) {
err := ErrRecordNotFound
assert.True(t, pgx.ErrNoRows == err)
}

func TestErrUniqueViolation(t *testing.T) {
err := ErrUniqueViolation
assert.True(t, errors.Is(err, ErrUniqueViolation))
}

func TestErrorCode(t *testing.T) {
t.Run("NoError", func(t *testing.T) {
err := errors.New("some generic error")
code := ErrorCode(err)
assert.Equal(t, "", code)
})

t.Run("PgError", func(t *testing.T) {
pgErr := &pgconn.PgError{
Code: "23505",
}
err := pgErr
code := ErrorCode(err)
assert.Equal(t, "23505", code)
})

// t.Run("WrappedPgError", func(t *testing.T) {
// wrappedPgErr := &pgconn.PgError{
// Code: "23503",
// }
// err := errors.Wrap(wrappedPgErr, "wrapped error")
// code := ErrorCode(err)
// assert.Equal(t, "23503", code)
// })

t.Run("NonPgError", func(t *testing.T) {
err := errors.New("some other error")
code := ErrorCode(err)
assert.Equal(t, "", code)
})
}
13 changes: 6 additions & 7 deletions db/store/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"os"
"path/filepath"
"sort"
"strings"

"github.com/golang-migrate/migrate/v4"
Expand Down Expand Up @@ -85,7 +84,7 @@ func runUnversionedMigrations(db *sql.DB, migrationDir string) error {
// Note: You may need a custom sorting logic if file names include version numbers
// For simplicity, we assume alphabetical order here.
// Sorting ensures that the files are executed in the correct order.
sortFiles(sqlFiles)
// sortFiles(sqlFiles)

// Execute each SQL file
for _, file := range sqlFiles {
Expand All @@ -108,8 +107,8 @@ func runUnversionedMigrations(db *sql.DB, migrationDir string) error {
}

// Simple alphabetical sorting function
func sortFiles(files []string) {
sort.Slice(files, func(i, j int) bool {
return strings.Compare(files[i], files[j]) < 0
})
}
// func sortFiles(files []string) {
// sort.Slice(files, func(i, j int) bool {
// return strings.Compare(files[i], files[j]) < 0
// })
// }
7 changes: 7 additions & 0 deletions db/store/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package db

import (
"context"
"database/sql"
"testing"
"time"

Expand Down Expand Up @@ -60,3 +61,9 @@ func TestGetSession(t *testing.T) {
require.WithinDuration(t, session1.CreatedAt, session2.CreatedAt, time.Second)
require.WithinDuration(t, session1.ExpiresAt, session2.ExpiresAt, time.Second)
}
func TestGetSessionNotFound(t *testing.T) {
session, err := testQueries.GetSession(context.Background(), uuid.New())
require.Equal(t, session, Session{})

require.ErrorIs(t, err, sql.ErrNoRows)
}

0 comments on commit d49df3c

Please sign in to comment.