Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
aradwann committed Dec 22, 2023
1 parent 323e931 commit 23864c8
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 132 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
uses: actions/checkout@v4

- name: Run migrations
run: go run db/scripts/migrate.go
run: make migrateup

- name: Run Unit tests
run: make testci
Expand Down
10 changes: 2 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dropdb:
docker exec -it postgres15 dropdb eenergy

migrateup:
migrate -path $(MIGRATIONS_PATH) -database $(DB_SOURCE) -verbose up
go run db/scripts/migrate.go

migrateprocsup:
migrate -path $(PROCS_PATH) -database $(DB_SOURCE) -verbose up
Expand All @@ -27,12 +27,6 @@ migratedown1:
createmigration:
migrate create -ext sql -dir $(MIGRATIONS_PATH) -seq "$(filter-out $@,$(MAKECMDGOALS))"

createprocmigration:
migrate create -ext sql -dir $(PROCS_PATH) -seq <migration_file_name>

sqlc:
sqlc generate

test:
go test -v -cover ./...

Expand All @@ -57,5 +51,5 @@ protoc:
evans:
evans --host localhost --port 9090 -r repl

.PHONEY: postgres pgadmin4 createdb dropdb migrateup migrateup1 migratedown migratedown1 sqlc test server mock protoc evans migrateprocsup
.PHONEY: createdb dropdb migrateup migrateup1 migratedown migratedown1 test server protoc evans migrateprocsup

4 changes: 2 additions & 2 deletions api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (server *Server) createUser(ctx *gin.Context) {
ctx.JSON(http.StatusInternalServerError, errResponse(err))
return
}
rsp := newUserResponse(user)
rsp := newUserResponse(*user)
ctx.JSON(http.StatusOK, rsp)

}
Expand Down Expand Up @@ -145,7 +145,7 @@ func (server *Server) loginUser(ctx *gin.Context) {
AccessTokenExpiresAt: accessPayload.ExpiredAt,
RefreshToken: refreshToken,
RefreshTokenExpiresAt: refreshPayload.ExpiredAt,
User: newUserResponse(user),
User: newUserResponse(*user),
}
ctx.JSON(http.StatusOK, rsp)
}
36 changes: 14 additions & 22 deletions db/migrations/procs/create_user.sql
Original file line number Diff line number Diff line change
@@ -1,33 +1,25 @@
CREATE OR REPLACE PROCEDURE create_user(
p_username VARCHAR,
p_hashed_password VARCHAR,
p_fullname VARCHAR,
p_email VARCHAR,
OUT username_out VARCHAR,
OUT hashed_password_out VARCHAR,
OUT fullname_out VARCHAR,
OUT email_out VARCHAR,
INOUT p_username VARCHAR,
INOUT p_hashed_password VARCHAR,
INOUT p_fullname VARCHAR,
INOUT p_email VARCHAR,
OUT password_changed_at_out TIMESTAMP WITH TIME ZONE,
OUT created_at_out TIMESTAMP WITH TIME ZONE
)
LANGUAGE plpgsql
AS $$
BEGIN
-- Insert new user and get the password_changed_at and created_at values
INSERT INTO users (username, hashed_password, fullname, email)
VALUES (p_username, p_hashed_password, p_fullname, p_email)
RETURNING
username,
hashed_password,
fullname,
email,
password_changed_at,
created_at
INTO
username_out,
hashed_password_out,
fullname_out,
email_out,
password_changed_at_out,
created_at_out;
RETURNING password_changed_at, created_at
INTO password_changed_at_out, created_at_out;

-- Optionally update the input parameters with the inserted values
SELECT username, hashed_password, fullname, email
INTO p_username, p_hashed_password, p_fullname, p_email
FROM users
WHERE username = p_username;

END;
$$;
4 changes: 4 additions & 0 deletions db/migrations/procs/update_user.sql
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,9 @@ BEGIN
email_out,
password_changed_at_out,
created_at_out;

IF NOT FOUND THEN
RAISE EXCEPTION 'User not found with username: %', p_username;
END IF;
END;
$$;
30 changes: 19 additions & 11 deletions db/store/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,25 @@ type Queries struct {
db DBTX
}

// callStoredProcedure executes a stored procedure with parameters.
func (q *Queries) callStoredProcedure(ctx context.Context, procedureName string, params ...interface{}) (*sql.Row, error) {
// Construct placeholders for the parameters
placeholders := make([]string, len(params))
for i := range placeholders {
placeholders[i] = fmt.Sprintf("$%d", i+1)
}
type StoredProcedureParams struct {
InParams []interface{}
OutParams []interface{}
}

// Construct the SQL statement for calling the stored procedure
sqlStatement := fmt.Sprintf(`CALL %s(%s)`, procedureName, strings.Join(placeholders, ", "))
func (q *Queries) callStoredProcedure(ctx context.Context, procedureName string, params StoredProcedureParams) *sql.Row {
sqlStatement := fmt.Sprintf(`CALL %s(%s)`, procedureName, generateParamPlaceholders(len(params.InParams)))

// Execute the stored procedure and return the result
return q.db.QueryRowContext(ctx, sqlStatement, params...), nil
return q.db.QueryRowContext(
ctx,
sqlStatement,
params.InParams...,
)
}

func generateParamPlaceholders(count int) string {
placeholders := make([]string, count)
for i := 1; i <= count; i++ {
placeholders[i-1] = fmt.Sprintf("$%d", i)
}
return strings.Join(placeholders, ", ")
}
36 changes: 36 additions & 0 deletions db/store/db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package db

import (
"fmt"
"testing"
)

func TestGenerateParamPlaceholders(t *testing.T) {
testCases := []struct {
count int
expectedResult string
expectedErrorMsg string
}{
{
count: 0,
expectedResult: "",
expectedErrorMsg: "",
},
{
count: 3,
expectedResult: "$1, $2, $3",
expectedErrorMsg: "",
},
}

for _, testCase := range testCases {
t.Run(fmt.Sprintf("Count%d", testCase.count), func(t *testing.T) {
result := generateParamPlaceholders(testCase.count)

// Check if the result matches the expected value
if result != testCase.expectedResult {
t.Errorf("Expected: %s, Got: %s", testCase.expectedResult, result)
}
})
}
}
6 changes: 3 additions & 3 deletions db/store/querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import (
)

type Querier interface {
CreateUser(ctx context.Context, arg CreateUserParams) (User, error)
CreateUser(ctx context.Context, arg CreateUserParams) (*User, error)
// GetSession(ctx context.Context, id uuid.UUID) (Session, error)
GetUser(ctx context.Context, username string) (User, error)
UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error)
GetUser(ctx context.Context, username string) (*User, error)
UpdateUser(ctx context.Context, arg UpdateUserParams) (*User, error)
}

var _ Querier = (*Queries)(nil)
139 changes: 55 additions & 84 deletions db/store/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package db
import (
"context"
"database/sql"
"fmt"
"log"
)

type CreateUserParams struct {
Expand All @@ -12,68 +14,6 @@ type CreateUserParams struct {
Email string `json:"email"`
}

// CreateUser calls the create_user stored procedure
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) {
var user User
interf := []interface{}{
arg.Username,
arg.HashedPassword,
arg.Fullname,
arg.Email,
&user.Username,
&user.HashedPassword,
&user.Fullname,
&user.Email,
&user.PasswordChangedAt,
&user.CreatedAt,
}
row, err := q.callStoredProcedure(ctx, "create_user", interf...)
if err != nil {
return User{}, err
}
// Execute the stored procedure and scan the results into the variables
err = row.Scan(
&user.Username,
&user.HashedPassword,
&user.Fullname,
&user.Email,
&user.PasswordChangedAt,
&user.CreatedAt,
)

if err != nil {
return User{}, err
}

return user, nil
}

func (q *Queries) GetUser(ctx context.Context, username string) (User, error) {
var user User
interf := []interface{}{
username,
&user.Username,
&user.HashedPassword,
&user.Fullname,
&user.Email,
&user.PasswordChangedAt,
&user.CreatedAt,
}
row, err := q.callStoredProcedure(ctx, "get_user", interf...)
if err != nil {
return User{}, err
}
err = row.Scan(
&user.Username,
&user.HashedPassword,
&user.Fullname,
&user.Email,
&user.PasswordChangedAt,
&user.CreatedAt,
)
return user, err
}

type UpdateUserParams struct {
HashedPassword sql.NullString `json:"hashed_password"`
PasswordChangedAt sql.NullTime `json:"password_changed_at"`
Expand All @@ -82,38 +22,69 @@ type UpdateUserParams struct {
Username string `json:"username"`
}

func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error) {
var user User
interf := []interface{}{
arg.Username,
arg.HashedPassword,
arg.PasswordChangedAt,
arg.Fullname,
arg.Email,
&user.Username,
&user.HashedPassword,
&user.Fullname,
&user.Email,
&user.PasswordChangedAt,
&user.CreatedAt,
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (*User, error) {
user := User{}
params := StoredProcedureParams{
InParams: []interface{}{arg.Username, arg.HashedPassword, arg.Fullname, arg.Email, &user.PasswordChangedAt, &user.CreatedAt},
OutParams: []interface{}{&user.Username, &user.HashedPassword, &user.Fullname, &user.Email, &user.PasswordChangedAt, &user.CreatedAt},
}
row, err := q.callStoredProcedure(ctx, "update_user", interf...)
if err != nil {
return User{}, err

row := q.callStoredProcedure(ctx, "create_user", params)
err := scanUserFromRow(row, &user)
return &user, err
}

func (q *Queries) GetUser(ctx context.Context, username string) (*User, error) {
user := User{}
params := StoredProcedureParams{
InParams: []interface{}{username, &user.Username, &user.HashedPassword, &user.Fullname, &user.Email, &user.PasswordChangedAt, &user.CreatedAt},
OutParams: []interface{}{&user.Username, &user.HashedPassword, &user.Fullname, &user.Email, &user.PasswordChangedAt, &user.CreatedAt},
}
// Execute the stored procedure and scan the results into the variables
err = row.Scan(

row := q.callStoredProcedure(ctx, "get_user", params)
err := scanUserFromRow(row, &user)
return &user, err
}

func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (*User, error) {
user := User{}
params := StoredProcedureParams{
InParams: []interface{}{arg.Username, arg.HashedPassword, arg.PasswordChangedAt, arg.Fullname, arg.Email, &user.Username, &user.HashedPassword, &user.Fullname, &user.Email, &user.PasswordChangedAt, &user.CreatedAt},
OutParams: []interface{}{&user.Username, &user.HashedPassword, &user.Fullname, &user.Email, &user.PasswordChangedAt, &user.CreatedAt},
}

row := q.callStoredProcedure(ctx, "update_user", params)
err := scanUserFromRow(row, &user)
return &user, err
}

func scanUserFromRow(row *sql.Row, user *User) error {

err := row.Scan(
&user.Username,
&user.HashedPassword,
&user.Fullname,
&user.Email,
&user.PasswordChangedAt,
&user.CreatedAt,
)

// Check for errors after scanning
if err := row.Err(); err != nil {
// Handle row-related errors
log.Fatal(err)
return err
}
if err != nil {
return User{}, err
// Check for a specific error related to the scan
if err == sql.ErrNoRows {
fmt.Println("No rows were returned.")
return err
} else {
// Handle other scan-related errors
log.Fatal(err)
return err
}
}

return user, nil
return nil
}
2 changes: 1 addition & 1 deletion db/store/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/require"
)

func createRandomUser(t *testing.T) User {
func createRandomUser(t *testing.T) *User {
hashedPassword, err := util.HashPassword(util.RandomString(6))
require.NoError(t, err)

Expand Down

0 comments on commit 23864c8

Please sign in to comment.