Skip to content

Commit

Permalink
revert to stored functions
Browse files Browse the repository at this point in the history
  • Loading branch information
aradwann committed Dec 31, 2023
1 parent cbc50b0 commit 1d0d3fb
Show file tree
Hide file tree
Showing 13 changed files with 248 additions and 153 deletions.
26 changes: 0 additions & 26 deletions db/migrations/procs/create_user.sql

This file was deleted.

32 changes: 0 additions & 32 deletions db/migrations/procs/get_user.sql

This file was deleted.

44 changes: 0 additions & 44 deletions db/migrations/procs/update_user.sql

This file was deleted.

22 changes: 22 additions & 0 deletions db/migrations/procs/user/create_user.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
DROP FUNCTION IF EXISTS create_user;
CREATE OR REPLACE FUNCTION create_user(
p_username VARCHAR,
p_hashed_password VARCHAR,
p_fullname VARCHAR,
p_email VARCHAR
)
RETURNS TABLE (
username VARCHAR,
hashed_password VARCHAR,
fullname VARCHAR,
email VARCHAR,
password_changed_at TIMESTAMP WITH TIME ZONE,
created_at TIMESTAMP WITH TIME ZONE
) AS $$
BEGIN
RETURN QUERY
INSERT INTO users (username, hashed_password, fullname, email)
VALUES (p_username, p_hashed_password, p_fullname, p_email)
RETURNING *;
END;
$$ LANGUAGE plpgsql;
26 changes: 26 additions & 0 deletions db/migrations/procs/user/get_user.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
DROP FUNCTION IF EXISTS get_user;
CREATE OR REPLACE FUNCTION get_user(
p_username VARCHAR
)
RETURNS TABLE (
username VARCHAR,
hashed_password VARCHAR,
fullname VARCHAR,
email VARCHAR,
password_changed_at TIMESTAMP WITH TIME ZONE,
created_at TIMESTAMP WITH TIME ZONE
) AS $$
BEGIN
RETURN QUERY
SELECT
users.username,
users.hashed_password,
users.fullname,
users.email,
users.password_changed_at,
users.created_at
FROM users
WHERE users.username = p_username
LIMIT 1;
END;
$$ LANGUAGE plpgsql;
28 changes: 28 additions & 0 deletions db/migrations/procs/user/update_user.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
DROP FUNCTION IF EXISTS update_user;
CREATE OR REPLACE FUNCTION update_user(
p_username VARCHAR,
p_hashed_password VARCHAR,
p_password_changed_at TIMESTAMP WITH TIME ZONE,
p_fullname VARCHAR,
p_email VARCHAR
)
RETURNS TABLE (
username VARCHAR,
hashed_password VARCHAR,
fullname VARCHAR,
email VARCHAR,
password_changed_at TIMESTAMP WITH TIME ZONE,
created_at TIMESTAMP WITH TIME ZONE
) AS $$
BEGIN
RETURN QUERY
UPDATE users
SET
hashed_password = COALESCE(p_hashed_password, users.hashed_password),
password_changed_at = COALESCE(p_password_changed_at, users.password_changed_at),
fullname = COALESCE(p_fullname, users.fullname),
email = COALESCE(p_email, users.email)
WHERE users.username = p_username
RETURNING *;
END;
$$ LANGUAGE plpgsql;
39 changes: 34 additions & 5 deletions db/scripts/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,21 +67,50 @@ func runDBMigrations(db *sql.DB, migrationsURL string) {

}

// Get a list of SQL files in the migration directory
func getSQLFiles(migrationDir string) ([]string, error) {
var sqlFiles []string

err := filepath.WalkDir(migrationDir, func(path string, d os.DirEntry, err error) error {
if err != nil {
return err
}

// Skip directories
if d.IsDir() {
return nil
}

// Check if the file has a .sql extension
if strings.HasSuffix(path, ".sql") {
if err != nil {
return err
}
sqlFiles = append(sqlFiles, path)
}

return nil
})

return sqlFiles, err
}

func runUnversionedMigrations(db *sql.DB, migrationDir string) error {
// Get a list of SQL files in the migration directory
files, err := filepath.Glob(filepath.Join(migrationDir, "*.sql"))

sqlFiles, err := getSQLFiles(migrationDir)

if err != nil {
return err
}

// Sort files to ensure execution order
// Note: You may need a custom sorting logic if file names include version numbers
// For simplicity, we assume alphabetical order here.
// Sorting ensures that the files are executed in the correct order.
sortFiles(files)
sortFiles(sqlFiles)

// Execute each SQL file
for _, file := range files {
for _, file := range sqlFiles {

contents, err := os.ReadFile(file)
if err != nil {
return err
Expand Down
35 changes: 22 additions & 13 deletions db/store/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,28 @@ type Queries struct {
db DBTX
}

type StoredProcedureParams struct {
InParams []interface{}
OutParams []interface{}
}

func (q *Queries) callStoredProcedure(ctx context.Context, procedureName string, params StoredProcedureParams) *sql.Row {
sqlStatement := fmt.Sprintf(`CALL %s(%s)`, procedureName, generateParamPlaceholders(len(params.InParams)))

return q.db.QueryRowContext(
ctx,
sqlStatement,
params.InParams...,
)
// type storedProcedureParams struct {
// InParams []interface{}
// OutParams []interface{}
// }

// func (q *Queries) callStoredProcedure(ctx context.Context, procedureName string, params storedProcedureParams) *sql.Row {
// sqlStatement := fmt.Sprintf(`CALL %s(%s)`, procedureName, generateParamPlaceholders(len(params.InParams)))

// return q.db.QueryRowContext(
// ctx,
// sqlStatement,
// params.InParams...,
// )
// }
func (q *Queries) callStoredFunction(ctx context.Context, functionName string, params ...interface{}) *sql.Row {
// Assuming generateParamPlaceholders generates the placeholders for parameters
placeholders := generateParamPlaceholders(len(params))

// Use SELECT statement to call the stored function
sqlStatement := fmt.Sprintf(`SELECT * FROM %s(%s)`, functionName, placeholders)

return q.db.QueryRowContext(ctx, sqlStatement, params...)
}

func generateParamPlaceholders(count int) string {
Expand Down
36 changes: 31 additions & 5 deletions db/store/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,47 @@ func RunDBMigrations(db *sql.DB, migrationsURL string) {

}

// Get a list of SQL files in the migration directory
func getSQLFiles(migrationDir string) ([]string, error) {
var sqlFiles []string

err := filepath.WalkDir(migrationDir, func(path string, d os.DirEntry, err error) error {
if err != nil {
return err
}

// Skip directories
if d.IsDir() {
return nil
}

// Check if the file has a .sql extension
if strings.HasSuffix(path, ".sql") {
sqlFiles = append(sqlFiles, path)
}

return nil
})

return sqlFiles, err
}

func runUnversionedMigrations(db *sql.DB, migrationDir string) error {
// Get a list of SQL files in the migration directory
files, err := filepath.Glob(filepath.Join(migrationDir, "*.sql"))

sqlFiles, err := getSQLFiles(migrationDir)

if err != nil {
return err
}

// Sort files to ensure execution order
// Note: You may need a custom sorting logic if file names include version numbers
// For simplicity, we assume alphabetical order here.
// Sorting ensures that the files are executed in the correct order.
sortFiles(files)
sortFiles(sqlFiles)

// Execute each SQL file
for _, file := range files {
for _, file := range sqlFiles {

contents, err := os.ReadFile(file)
if err != nil {
return err
Expand Down
48 changes: 48 additions & 0 deletions db/store/migrate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package db

import (
"reflect"
"strings"
"testing"
)

func TestGetSQLFiles(t *testing.T) {
testCases := []struct {
name string
migrationDir string
expectedFiles []string
expectedErrMsgSubstring string
}{
{
name: "Valid Directory",
migrationDir: "test_data",
expectedFiles: []string{"test_data/file_1.sql", "test_data/subdir/file_2.sql"},
},
{
name: "Non-Existent Directory",
migrationDir: "nonexistent_directory",
expectedErrMsgSubstring: "no such file or directory",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
files, err := getSQLFiles(tc.migrationDir)

if tc.expectedErrMsgSubstring != "" {
if err == nil || !strings.Contains(err.Error(), tc.expectedErrMsgSubstring) {
t.Fatalf("Expected error containing '%s', but got '%v'", tc.expectedErrMsgSubstring, err)
}
return
}

if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

if !reflect.DeepEqual(files, tc.expectedFiles) {
t.Fatalf("Expected files %v, but got %v", tc.expectedFiles, files)
}
})
}
}
Empty file added db/store/test_data/file_1.sql
Empty file.
Empty file.
Loading

0 comments on commit 1d0d3fb

Please sign in to comment.