Skip to content

Commit

Permalink
add roles to user and token
Browse files Browse the repository at this point in the history
  • Loading branch information
aradwann committed Feb 3, 2024
1 parent 0011075 commit bd65200
Show file tree
Hide file tree
Showing 18 changed files with 50 additions and 20 deletions.
1 change: 1 addition & 0 deletions db/migrations/000005_add_role.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "users" DROP COLUMN "role";
1 change: 1 addition & 0 deletions db/migrations/000005_add_role.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "users" ADD COLUMN "role" varchar NOT NULL DEFAULT 'generator';
3 changes: 2 additions & 1 deletion db/migrations/procs/user/create_user.sql
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ RETURNS TABLE (
email VARCHAR,
password_changed_at TIMESTAMP WITH TIME ZONE,
created_at TIMESTAMP WITH TIME ZONE,
is_email_verified BOOLEAN
is_email_verified BOOLEAN,
role VARCHAR
) AS $$
BEGIN
RETURN QUERY
Expand Down
6 changes: 4 additions & 2 deletions db/migrations/procs/user/create_verify_email.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ RETURNS TABLE (
secret_code VARCHAR,
is_used BOOLEAN,
created_at TIMESTAMP WITH TIME ZONE,
expires_at_at TIMESTAMP WITH TIME ZONE
expires_at_at TIMESTAMP WITH TIME ZONE,
role VARCHAR
) AS $$
BEGIN
RETURN QUERY
Expand All @@ -24,7 +25,8 @@ BEGIN
verify_emails.secret_code,
verify_emails.is_used,
verify_emails.created_at,
verify_emails.expired_at;
verify_emails.expired_at,
verify_emails.role VARCHAR;

END;
$$ LANGUAGE plpgsql;
6 changes: 4 additions & 2 deletions db/migrations/procs/user/get_user.sql
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ RETURNS TABLE (
email VARCHAR,
password_changed_at TIMESTAMP WITH TIME ZONE,
created_at TIMESTAMP WITH TIME ZONE,
is_email_verified BOOLEAN
is_email_verified BOOLEAN,
role VARCHAR
) AS $$
BEGIN
RETURN QUERY
Expand All @@ -20,7 +21,8 @@ BEGIN
users.email,
users.password_changed_at,
users.created_at,
users.is_email_verified
users.is_email_verified,
users.role
FROM users
WHERE users.username = p_username
LIMIT 1;
Expand Down
3 changes: 2 additions & 1 deletion db/migrations/procs/user/update_user.sql
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ RETURNS TABLE (
email VARCHAR,
password_changed_at TIMESTAMP WITH TIME ZONE,
created_at TIMESTAMP WITH TIME ZONE,
is_email_verfied BOOLEAN
is_email_verfied BOOLEAN,
role VARCHAR
) AS $$
BEGIN
RETURN QUERY
Expand Down
6 changes: 4 additions & 2 deletions db/migrations/procs/user/update_verify_email.sql
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ RETURNS TABLE (
secret_code VARCHAR,
is_used BOOLEAN,
created_at TIMESTAMP WITH TIME ZONE,
expires_at_at TIMESTAMP WITH TIME ZONE
expires_at_at TIMESTAMP WITH TIME ZONE,
role VARCHAR
) AS $$
BEGIN
RETURN QUERY
Expand All @@ -28,7 +29,8 @@ BEGIN
verify_emails.secret_code,
verify_emails.is_used,
verify_emails.created_at,
verify_emails.expired_at;
verify_emails.expired_at,
verify_emails.role;

END;
$$ LANGUAGE plpgsql;
4 changes: 4 additions & 0 deletions db/store/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package db
import (
"database/sql"
"errors"
"fmt"

"github.com/jackc/pgx/v5/pgconn"
)
Expand All @@ -20,6 +21,9 @@ var ErrUniqueViolation = &pgconn.PgError{

func ErrorCode(err error) string {
var pgErr *pgconn.PgError
// TODO: handle err conversion
fmt.Printf("errrrrr %#v", err)

if errors.As(err, &pgErr) {
return pgErr.Code
}
Expand Down
1 change: 1 addition & 0 deletions db/store/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type User struct {
PasswordChangedAt time.Time `json:"password_changed_at"`
CreatedAt time.Time `json:"created_at"`
IsEmailVerified bool `json:"is_email_verified"`
Role string `json:"role"`
}

type VerifyEmail struct {
Expand Down
1 change: 1 addition & 0 deletions db/store/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ func scanUserFromRow(row *sql.Row, user *User) error {
&user.PasswordChangedAt,
&user.CreatedAt,
&user.IsEmailVerified,
&user.Role,
)

// Check for errors after scanning
Expand Down
4 changes: 2 additions & 2 deletions gapi/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ func newTestServer(t *testing.T, store db.Store, taskDistributor worker.TaskDist
return server
}

func newNewContextWithBearerToken(t *testing.T, tokenMaker token.Maker, username string, duration time.Duration) context.Context {
func newNewContextWithBearerToken(t *testing.T, tokenMaker token.Maker, username, role string, duration time.Duration) context.Context {
ctx := context.Background()
accessToken, payload, err := tokenMaker.CreateToken(username, duration)
accessToken, payload, err := tokenMaker.CreateToken(username, role, duration)
require.NoError(t, err)
require.NotNil(t, payload)

Expand Down
2 changes: 2 additions & 0 deletions gapi/rpc_login_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func (server *Server) LoginUser(ctx context.Context, req *pb.LoginUserRequest) (

accessToken, accessPayload, err := server.tokenMaker.CreateToken(
user.Username,
user.Role,
server.config.AccessTokenDuration,
)
if err != nil {
Expand All @@ -46,6 +47,7 @@ func (server *Server) LoginUser(ctx context.Context, req *pb.LoginUserRequest) (

refreshToken, refreshPayload, err := server.tokenMaker.CreateToken(
user.Username,
user.Role,
server.config.RefreshTokenDuration,
)
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions gapi/rpc_update_user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func TestUpdateUserAPI(t *testing.T) {

},
buildContext: func(t *testing.T, tokenMaker token.Maker) context.Context {
return newNewContextWithBearerToken(t, tokenMaker, user.Username, time.Minute)
return newNewContextWithBearerToken(t, tokenMaker, user.Username, user.Role, time.Minute)
},
checkResponse: func(t *testing.T, res *pb.UpdateUserResponse, err error) {
require.NoError(t, err)
Expand Down Expand Up @@ -94,7 +94,7 @@ func TestUpdateUserAPI(t *testing.T) {

},
buildContext: func(t *testing.T, tokenMaker token.Maker) context.Context {
return newNewContextWithBearerToken(t, tokenMaker, user.Username, time.Minute)
return newNewContextWithBearerToken(t, tokenMaker, user.Username, user.Role, time.Minute)
},
checkResponse: func(t *testing.T, res *pb.UpdateUserResponse, err error) {
require.Error(t, err)
Expand All @@ -119,7 +119,7 @@ func TestUpdateUserAPI(t *testing.T) {

},
buildContext: func(t *testing.T, tokenMaker token.Maker) context.Context {
return newNewContextWithBearerToken(t, tokenMaker, user.Username, -time.Minute)
return newNewContextWithBearerToken(t, tokenMaker, user.Username, user.Role, -time.Minute)
},
checkResponse: func(t *testing.T, res *pb.UpdateUserResponse, err error) {
require.Error(t, err)
Expand Down Expand Up @@ -168,7 +168,7 @@ func TestUpdateUserAPI(t *testing.T) {

},
buildContext: func(t *testing.T, tokenMaker token.Maker) context.Context {
return newNewContextWithBearerToken(t, tokenMaker, user.Username, time.Minute)
return newNewContextWithBearerToken(t, tokenMaker, user.Username, user.Role, time.Minute)
},
checkResponse: func(t *testing.T, res *pb.UpdateUserResponse, err error) {
require.Error(t, err)
Expand Down
2 changes: 1 addition & 1 deletion token/maker.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import "time"
// Maker is an interface for managing token
type Maker interface {
// CreateToken creates a new token for a specific username and duration
CreateToken(username string, duration time.Duration) (string, *Payload, error)
CreateToken(username, role string, duration time.Duration) (string, *Payload, error)

// VerifyToken checks if the token is valid or not
VerifyToken(token string) (*Payload, error)
Expand Down
4 changes: 2 additions & 2 deletions token/paseto_maker.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ func NewPASETOMaker(symmetricKey string) (Maker, error) {
}

// CreateToken creates a new token for a specific username and duration
func (maker *PASETOMaker) CreateToken(username string, duration time.Duration) (string, *Payload, error) {
payload, err := NewPayload(username, duration)
func (maker *PASETOMaker) CreateToken(username, role string, duration time.Duration) (string, *Payload, error) {
payload, err := NewPayload(username, role, duration)
if err != nil {
return "", payload, err
}
Expand Down
8 changes: 6 additions & 2 deletions token/paseto_maker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ func TestPASETOMaker(t *testing.T) {

issuedAt := time.Now()
expiredAt := issuedAt.Add(duration)
role := util.GeneratorRole

token, payload, err := maker.CreateToken(username, duration)
token, payload, err := maker.CreateToken(username, role, duration)
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotEmpty(t, payload)
Expand All @@ -30,6 +31,7 @@ func TestPASETOMaker(t *testing.T) {

require.NotZero(t, payload.ID)
require.Equal(t, username, payload.Username)
require.Equal(t, role, payload.Role)
require.WithinDuration(t, issuedAt, payload.IssuedAt, time.Second)
require.WithinDuration(t, expiredAt, payload.ExpiredAt, time.Second)

Expand All @@ -39,7 +41,9 @@ func TestExpiredPASETOToken(t *testing.T) {
maker, err := NewPASETOMaker(util.RandomString(32))
require.NoError(t, err)

token, payload, err := maker.CreateToken(util.RandomOwner(), -time.Minute)
role := util.GeneratorRole

token, payload, err := maker.CreateToken(util.RandomOwner(), role, -time.Minute)
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotEmpty(t, payload)
Expand Down
4 changes: 3 additions & 1 deletion token/payload.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ var ErrExpiredToken = errors.New("token has expired")
type Payload struct {
ID uuid.UUID `json:"id"`
Username string `json:"username"`
Role string `json:"role"`
IssuedAt time.Time `json:"issued_at"`
ExpiredAt time.Time `json:"expired_at"`
}

// NewPayload create a new token payload with a specfic username and duration
func NewPayload(username string, duration time.Duration) (*Payload, error) {
func NewPayload(username, role string, duration time.Duration) (*Payload, error) {
tokenID, err := uuid.NewRandom()
if err != nil {
return nil, err
Expand All @@ -27,6 +28,7 @@ func NewPayload(username string, duration time.Duration) (*Payload, error) {
payload := &Payload{
ID: tokenID,
Username: username,
Role: role,
IssuedAt: time.Now(),
ExpiredAt: time.Now().Add(duration),
}
Expand Down
6 changes: 6 additions & 0 deletions util/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package util

const (
GeneratorRole = "generator" //normal user
AdminRole = "admin" // admin
)

0 comments on commit bd65200

Please sign in to comment.