Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 41 additions & 21 deletions api/auth_middleware.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"context"
"fmt"
"strings"

Expand Down Expand Up @@ -62,14 +63,14 @@ func (app *ApiServer) recoverAuthorityFromSignatureHeaders(c *fiber.Ctx) (int32,
}

// Checks if authedWallet is authorized to act on behalf of userId
func (app *ApiServer) isAuthorizedRequest(c *fiber.Ctx, userId int32, authedWallet string) bool {
func (app *ApiServer) isAuthorizedRequest(ctx context.Context, userId int32, authedWallet string) bool {
cacheKey := fmt.Sprintf("%d:%s", userId, authedWallet)
if hit, ok := app.resolveGrantCache.Get(cacheKey); ok {
return hit
}

var isAuthorized bool
err := app.pool.QueryRow(c.Context(), `
err := app.pool.QueryRow(ctx, `
SELECT EXISTS (
SELECT 1
FROM grants
Expand Down Expand Up @@ -99,37 +100,56 @@ func (app *ApiServer) getAuthedWallet(c *fiber.Ctx) string {
}

// Middleware to set authedUserId and authedWallet in context
// Returns a 403 if either
// - the user is not authorized to act on behalf of "myId"
// - the user is not authorized to act on behalf of "requestedWallet"
func (app *ApiServer) authMiddleware(c *fiber.Ctx) error {
userId, wallet := app.recoverAuthorityFromSignatureHeaders(c)
c.Locals("authedUserId", userId)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this fn cares at all about authedUserId by the looks of it - it really only cares that the wallet recovered has a grant or is the wallet of the user with userId = myId

maybe we can avoid fetching the user ID on every request then, and add a clause to the isAuthorizedRequest query to check if there's a row for userId <=> wallet.

then we can remove authedUserId from the context entirely.

wdyt?

(requireAuthMiddleware can do the query still, since it explicitly wants a user ID, but that's only used in one route)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like what you're thinking here.

I think the query can be done in a single union all. let me merge this first and open another PR

c.Locals("authedWallet", wallet)

myId := app.getMyId(c)
requestedWallet := c.Params("wallet")

// Not authorized to act on behalf of myId
if myId != 0 {
if userId != myId && !app.isAuthorizedRequest(c.Context(), myId, wallet) {
return fiber.NewError(
fiber.StatusForbidden,
fmt.Sprintf(
"You are not authorized to make this request authedUserId=%d authedWallet=%s myId=%d",
userId,
wallet,
myId,
),
)
}
}

// Not authorized to act on behalf of requestedWallet
if requestedWallet != "" && wallet != "" {
if !strings.EqualFold(requestedWallet, wallet) {
return fiber.NewError(
fiber.StatusForbidden,
fmt.Sprintf(
"You are not authorized to make this request authedUserId=%d authedWallet=%s requestedWallet=%s",
userId,
wallet,
requestedWallet,
),
)
}
}

return c.Next()
}

// Middleware that asserts the authedUserId is valid and is the same as the userId in
// the request path or a managed user of the authedUserId
// Middleware that asserts that there is an authedUserId
func (app *ApiServer) requireAuthMiddleware(c *fiber.Ctx) error {
authedUserId := app.getAuthedUserId(c)
authedWallet := app.getAuthedWallet(c)
myId := app.getMyId(c)
wallet := c.Params("wallet")
if authedUserId == 0 {
return fiber.NewError(fiber.StatusUnauthorized, "You must be logged in to make this request")
}

if myId != 0 && myId == authedUserId {
return c.Next()
}

if wallet != "" && strings.EqualFold(wallet, authedWallet) {
return c.Next()
}

if app.isAuthorizedRequest(c, myId, authedWallet) {
return c.Next()
}

msg := fmt.Sprintf("You are not authorized to make this request authedUserId=%d authedWallet=%s myId=%d wallet=%s", authedUserId, authedWallet, myId, wallet)
return fiber.NewError(fiber.StatusForbidden, msg)
return c.Next()
}
79 changes: 54 additions & 25 deletions api/auth_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,52 +29,81 @@ func TestRecoverAuthorityFromSignatureHeaders(t *testing.T) {
assert.Equal(t, "0x7d273271690538cf855e5b3002a0dd8c154bb060", wallet)
}

func TestRequireAuthMiddleware(t *testing.T) {
// Create a dummy endpoint to test the requireAuthMiddleware
func TestAuthorized(t *testing.T) {
// Create a dummy endpoint to test the authMiddleware
testApp := fiber.New()
testApp.Get("/", app.resolveMyIdMiddleware, app.authMiddleware, app.requireAuthMiddleware, func(c *fiber.Ctx) error {
testApp.Get("/", app.resolveMyIdMiddleware, app.authMiddleware, func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
testApp.Get("/account/:wallet", app.resolveMyIdMiddleware, app.authMiddleware, func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})

// Unauthorized when no auth headers
req1 := httptest.NewRequest("GET", "/", nil)
res, err := testApp.Test(req1, -1)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusUnauthorized, res.StatusCode)

// Forbidden when not authorized
req2 := httptest.NewRequest("GET", "/?user_id=1", nil)
req := httptest.NewRequest("GET", "/?user_id=7eP5n", nil)
// wallet: 0x681c616ae836ceca1effe00bd07f2fdbf9a082bc
req2.Header.Set("Encoded-Data-Message", "signature:1745543704165")
req2.Header.Set("Encoded-Data-Signature", "0x4af765948dccd72026f1059a59c7a6a1172628255d7d387d1590c0fe43961c5908fc6011443805ca0dbd39156300c04dc21bbfa9adce50acea9ad29a7e2fde2a1b")
res, err = testApp.Test(req2, -1)
req.Header.Set("Encoded-Data-Message", "signature:1745543704165")
req.Header.Set("Encoded-Data-Signature", "0x4af765948dccd72026f1059a59c7a6a1172628255d7d387d1590c0fe43961c5908fc6011443805ca0dbd39156300c04dc21bbfa9adce50acea9ad29a7e2fde2a1b")
res, err := testApp.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusForbidden, res.StatusCode)

// Forbidden when grant is revoked
req3 := httptest.NewRequest("GET", "/?user_id=1", nil)
req = httptest.NewRequest("GET", "/?user_id=7eP5n", nil)
// wallet: 0xc451c1f8943b575158310552b41230c61844a1c1
req3.Header.Set("Encoded-Data-Message", "signature:1745542789211")
req3.Header.Set("Encoded-Data-Signature", "0xffd5f92c0d253c7222cd407cf3398fac664530ef968bd4435ea698ba1daee1d73353330848b65d212eeeaae9f41e177e49078c4efa1131e5e517090626f6dd961c")
res, err = testApp.Test(req3, -1)
req.Header.Set("Encoded-Data-Message", "signature:1745542789211")
req.Header.Set("Encoded-Data-Signature", "0xffd5f92c0d253c7222cd407cf3398fac664530ef968bd4435ea698ba1daee1d73353330848b65d212eeeaae9f41e177e49078c4efa1131e5e517090626f6dd961c")
res, err = testApp.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusForbidden, res.StatusCode)

// Authorized when grant is approved
req4 := httptest.NewRequest("GET", "/?user_id=1", nil)
req = httptest.NewRequest("GET", "/?user_id=7eP5n", nil)
// wallet: 0x5f1a372b28956c8363f8bc3a231a6e9e1186ead8
req4.Header.Set("Encoded-Data-Message", "signature:1745544459796")
req4.Header.Set("Encoded-Data-Signature", "0x1c9cb405d8437d28ff5596918551f7a45f981e81618d65ee10892313292a8c7a325af002231d115b28ca2d244b082abe1bde4a7d9610f8140d3738a9be5c4fd91b")
res, err = testApp.Test(req4, -1)
req.Header.Set("Encoded-Data-Message", "signature:1745544459796")
req.Header.Set("Encoded-Data-Signature", "0x1c9cb405d8437d28ff5596918551f7a45f981e81618d65ee10892313292a8c7a325af002231d115b28ca2d244b082abe1bde4a7d9610f8140d3738a9be5c4fd91b")
res, err = testApp.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusOK, res.StatusCode)

// Authorized when own user
req5 := httptest.NewRequest("GET", "/?user_id=1", nil)
req = httptest.NewRequest("GET", "/?user_id=7eP5n", nil)
// wallet: 0x7d273271690538cf855e5b3002a0dd8c154bb060
req5.Header.Set("Encoded-Data-Message", "signature:1744763856446")
req5.Header.Set("Encoded-Data-Signature", "0xbb202be3a7f3a0aa22c1458ef6a3f2f8360fb86791c7b137e8562df0707825c11fa1db01096efd2abc5e6613c4d1e8d4ae1e2b993abdd555fe270c1b17bff0d21c")
res, err = testApp.Test(req5, -1)
req.Header.Set("Encoded-Data-Message", "signature:1744763856446")
req.Header.Set("Encoded-Data-Signature", "0xbb202be3a7f3a0aa22c1458ef6a3f2f8360fb86791c7b137e8562df0707825c11fa1db01096efd2abc5e6613c4d1e8d4ae1e2b993abdd555fe270c1b17bff0d21c")
res, err = testApp.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusOK, res.StatusCode)

// Forbidden when not authorized to act on behalf of requested wallet
req = httptest.NewRequest("GET", "/account/0x111c616ae836ceca1effe00bd07f2fdbf9a082bc", nil)
// wallet: 0x681c616ae836ceca1effe00bd07f2fdbf9a082bc
req.Header.Set("Encoded-Data-Message", "signature:1745543704165")
req.Header.Set("Encoded-Data-Signature", "0x4af765948dccd72026f1059a59c7a6a1172628255d7d387d1590c0fe43961c5908fc6011443805ca0dbd39156300c04dc21bbfa9adce50acea9ad29a7e2fde2a1b")
res, err = testApp.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusForbidden, res.StatusCode)

// Authorized when requesting wallet matches authed wallet
req = httptest.NewRequest("GET", "/account/0x681c616ae836ceca1effe00bd07f2fdbf9a082bc", nil)
// wallet: 0x681c616ae836ceca1effe00bd07f2fdbf9a082bc
req.Header.Set("Encoded-Data-Message", "signature:1745543704165")
req.Header.Set("Encoded-Data-Signature", "0x4af765948dccd72026f1059a59c7a6a1172628255d7d387d1590c0fe43961c5908fc6011443805ca0dbd39156300c04dc21bbfa9adce50acea9ad29a7e2fde2a1b")
res, err = testApp.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusOK, res.StatusCode)
}

func TestRequireAuthMiddleware(t *testing.T) {
// Create a dummy endpoint to test the requireAuthMiddleware
testApp := fiber.New()
testApp.Get("/", app.resolveMyIdMiddleware, app.authMiddleware, app.requireAuthMiddleware, func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})

// Unauthorized when no auth headers
req1 := httptest.NewRequest("GET", "/", nil)
res, err := testApp.Test(req1, -1)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusUnauthorized, res.StatusCode)
}
40 changes: 37 additions & 3 deletions api/dbv1/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,33 @@ type Access struct {
Download bool `json:"download"`
}

func (q *Queries) GetTrackAccess(ctx context.Context, myId int32, conditions *AccessGate, track *GetTracksRow, user *FullUser) bool {
func (q *Queries) GetTrackAccess(
ctx context.Context,
myId int32,
conditions *AccessGate,
track *GetTracksRow,
user *FullUser,
) bool {
// No track? no access
if track == nil || user == nil {
return false
}

// no conditions means open access
// No conditions means open access
if conditions == nil {
return true
}

// No myId? no access. we need to know who you are if there are conditions.
if myId == 0 {
return false
}

// You always have access to your own content
if myId == user.UserID {
return true
}

switch {
case conditions.FollowUserID != nil:
return user.DoesCurrentUserFollow
Expand Down Expand Up @@ -114,11 +131,28 @@ func (q *Queries) GetTrackAccess(ctx context.Context, myId int32, conditions *Ac
return false
}

func (q *Queries) GetPlaylistAccess(ctx context.Context, myId int32, conditions *AccessGate, playlist *GetPlaylistsRow, user *FullUser) bool {
func (q *Queries) GetPlaylistAccess(
ctx context.Context,
myId int32,
conditions *AccessGate,
playlist *GetPlaylistsRow,
user *FullUser,
) bool {
// No playlist? no access.
if playlist == nil || user == nil {
return false
}

// no conditions means open access
if conditions == nil {
return true
}

// I always have access to my own content
if myId != 0 && myId == user.UserID {
return true
}

switch {
case conditions.FollowUserID != nil:
return user.DoesCurrentUserFollow
Expand Down
19 changes: 14 additions & 5 deletions api/dbv1/full_playlists.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)

type FullPlaylistsParams struct {
GetPlaylistsParams
}

type FullPlaylist struct {
GetPlaylistsRow

Expand All @@ -31,8 +35,8 @@ type FullPlaylistContentsItem struct {
MetadataTime int64 `json:"metadata_timestamp"`
}

func (q *Queries) FullPlaylistsKeyed(ctx context.Context, arg GetPlaylistsParams) (map[int32]FullPlaylist, error) {
rawPlaylists, err := q.GetPlaylists(ctx, arg)
func (q *Queries) FullPlaylistsKeyed(ctx context.Context, arg FullPlaylistsParams) (map[int32]FullPlaylist, error) {
rawPlaylists, err := q.GetPlaylists(ctx, arg.GetPlaylistsParams)
if err != nil {
return nil, err
}
Expand All @@ -51,7 +55,7 @@ func (q *Queries) FullPlaylistsKeyed(ctx context.Context, arg GetPlaylistsParams
loaded, err := q.Parallel(ctx, ParallelParams{
UserIds: userIds,
TrackIds: trackIds,
MyID: arg.MyID,
MyID: arg.MyID.(int32),
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -88,7 +92,12 @@ func (q *Queries) FullPlaylistsKeyed(ctx context.Context, arg GetPlaylistsParams
}

// For playlists, download access is the same as stream access
streamAccess := q.GetPlaylistAccess(ctx, arg.MyID.(int32), playlist.StreamConditions, &playlist, &user)
streamAccess := q.GetPlaylistAccess(
ctx,
arg.MyID.(int32),
playlist.StreamConditions,
&playlist,
&user)
downloadAccess := streamAccess

var playlistType string
Expand Down Expand Up @@ -120,7 +129,7 @@ func (q *Queries) FullPlaylistsKeyed(ctx context.Context, arg GetPlaylistsParams
return playlistMap, nil
}

func (q *Queries) FullPlaylists(ctx context.Context, arg GetPlaylistsParams) ([]FullPlaylist, error) {
func (q *Queries) FullPlaylists(ctx context.Context, arg FullPlaylistsParams) ([]FullPlaylist, error) {
playlistMap, err := q.FullPlaylistsKeyed(ctx, arg)
if err != nil {
return nil, err
Expand Down
Loading
Loading