Skip to content

Commit

Permalink
address PR review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lyoshenka committed May 12, 2020
1 parent 6ebb456 commit b0348fc
Show file tree
Hide file tree
Showing 12 changed files with 107 additions and 94 deletions.
10 changes: 6 additions & 4 deletions app/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ import (

var logger = monitor.NewModuleLogger("auth")

const ContextKey = "user"
type ctxKey int

const contextKey ctxKey = iota

var ErrNoAuthInfo = errors.Base("unauthorized")

Expand All @@ -27,9 +29,9 @@ type result struct {
}

func FromRequest(r *http.Request) (*models.User, error) {
v := r.Context().Value(ContextKey)
v := r.Context().Value(contextKey)
if v == nil {
panic("auth.Middleware is required")
return nil, errors.Err("auth.Middleware is required")
}
res := v.(result)
return res.user, res.err
Expand Down Expand Up @@ -61,7 +63,7 @@ func Middleware(provider Provider) mux.MiddlewareFunc {
} else {
err = errors.Err(ErrNoAuthInfo)
}
next.ServeHTTP(w, r.Clone(context.WithValue(r.Context(), ContextKey, result{user, err})))
next.ServeHTTP(w, r.Clone(context.WithValue(r.Context(), contextKey, result{user, err})))
})
}
}
Expand Down
7 changes: 5 additions & 2 deletions app/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func TestMiddleware_Error(t *testing.T) {

func TestFromRequestSuccess(t *testing.T) {
expected := result{nil, errors.Base("a test")}
ctx := context.WithValue(context.Background(), ContextKey, expected)
ctx := context.WithValue(context.Background(), contextKey, expected)

r, err := http.NewRequestWithContext(ctx, http.MethodPost, "", &bytes.Buffer{})
require.NoError(t, err)
Expand All @@ -119,7 +119,10 @@ func TestFromRequestSuccess(t *testing.T) {
func TestFromRequestFail(t *testing.T) {
r, err := http.NewRequest(http.MethodPost, "", &bytes.Buffer{})
require.NoError(t, err)
assert.Panics(t, func() { FromRequest(r) })
user, err := FromRequest(r)
assert.Nil(t, user)
assert.Error(t, err)
assert.Equal(t, "auth.Middleware is required", err.Error())
}

func authChecker(w http.ResponseWriter, r *http.Request) {
Expand Down
27 changes: 23 additions & 4 deletions app/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import (
"github.com/lbryio/lbrytv/internal/errors"
"github.com/lbryio/lbrytv/internal/monitor"
"github.com/lbryio/lbrytv/internal/responses"
"github.com/lbryio/lbrytv/models"

"github.com/volatiletech/sqlboiler/boil"
"github.com/ybbus/jsonrpc"
)

Expand Down Expand Up @@ -56,8 +56,12 @@ func Handle(w http.ResponseWriter, r *http.Request) {
logger.Log().Tracef("call to method %s", req.Method)

user, err := auth.FromRequest(r)
if query.MethodRequiresWallet(req.Method) && !rpcerrors.EnsureAuthenticated(w, user, err) {
return
if query.MethodRequiresWallet(req.Method) {
authErr := EnsureAuthenticated(user, err)
if authErr != nil {
w.Write(rpcerrors.ErrorToJSON(authErr))
return
}
}

var userID int
Expand All @@ -77,7 +81,6 @@ func Handle(w http.ResponseWriter, r *http.Request) {
}
c := query.NewCaller(sdkAddress, userID)
c.Cache = qCache
c.DB = boil.GetDB()
w.Write(c.Call(&req))
}

Expand All @@ -89,3 +92,19 @@ func HandleCORS(w http.ResponseWriter, r *http.Request) {
hs.Set("Access-Control-Allow-Headers", wallet.TokenHeader+", Origin, X-Requested-With, Content-Type, Accept")
w.WriteHeader(http.StatusOK)
}

func EnsureAuthenticated(user *models.User, err error) error {
if err == nil && user != nil {
return nil
}

if errors.Is(err, auth.ErrNoAuthInfo) {
return rpcerrors.NewAuthRequiredError(errors.Err(responses.AuthRequiredErrorMessage))
} else if err != nil {
return rpcerrors.NewForbiddenError(err)
} else if user == nil {
return rpcerrors.NewForbiddenError(errors.Err("must authenticate"))
}

return errors.Err("unknown auth error")
}
7 changes: 4 additions & 3 deletions app/publish/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ func TestHandler_NoAuthMiddleware(t *testing.T) {
handler := &Handler{UploadPath: os.TempDir()}

rr := httptest.NewRecorder()
assert.Panics(t, func() {
handler.Handle(rr, r)
})
handler.Handle(rr, r)
respBody, err := ioutil.ReadAll(rr.Result().Body)
require.NoError(t, err)
assert.Equal(t, "auth.Middleware is required", test.StrToRes(t, string(respBody)).Error.Message)
}

func TestHandler_NoSDKAddress(t *testing.T) {
Expand Down
8 changes: 3 additions & 5 deletions app/publish/publish.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"path"

"github.com/lbryio/lbrytv/app/auth"
"github.com/lbryio/lbrytv/app/proxy"
"github.com/lbryio/lbrytv/app/query"
"github.com/lbryio/lbrytv/app/query/cache"
"github.com/lbryio/lbrytv/app/rpcerrors"
Expand All @@ -17,7 +18,6 @@ import (

"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
"github.com/volatiletech/sqlboiler/boil"
)

var logger = monitor.NewModuleLogger("publish")
Expand All @@ -40,10 +40,9 @@ type Handler struct {
// It should be wrapped with users.Authenticator.Wrap before it can be used
// in a mux.Router.
func (h Handler) Handle(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)

user, err := auth.FromRequest(r)
if !rpcerrors.EnsureAuthenticated(w, user, err) {
if authErr := proxy.EnsureAuthenticated(user, err); authErr != nil {
w.Write(rpcerrors.ErrorToJSON(authErr))
return
}
if auth.SDKAddress(user) == "" {
Expand Down Expand Up @@ -84,7 +83,6 @@ func (h Handler) Handle(w http.ResponseWriter, r *http.Request) {
func publish(sdkAddress, filename string, userID int, qCache cache.QueryCache, rawQuery []byte) []byte {
c := query.NewCaller(sdkAddress, userID)
c.Cache = qCache
c.DB = boil.GetDB()
c.Preprocessor = func(q *query.Query) {
params := q.ParamsAsMap()
params[fileNameParam] = filename
Expand Down
27 changes: 3 additions & 24 deletions app/query/caller.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/lbryio/lbrytv/app/rpcerrors"
"github.com/lbryio/lbrytv/app/sdkrouter"
"github.com/lbryio/lbrytv/app/wallet"
"github.com/lbryio/lbrytv/app/wallet/accesstracker"
"github.com/lbryio/lbrytv/config"
"github.com/lbryio/lbrytv/internal/errors"
"github.com/lbryio/lbrytv/internal/lbrynet"
Expand All @@ -20,7 +19,6 @@ import (

"github.com/davecgh/go-spew/spew"
"github.com/sirupsen/logrus"
"github.com/volatiletech/sqlboiler/boil"
"github.com/ybbus/jsonrpc"
)

Expand All @@ -36,8 +34,6 @@ type Caller struct {
Preprocessor func(q *Query)
// Cache stores cachable queries to improve performance
Cache cache.QueryCache
// DB is used to track when a wallet is used so that unused wallets can be unloaded
DB boil.Executor

client jsonrpc.RPCClient
userID int
Expand Down Expand Up @@ -69,7 +65,7 @@ func (c *Caller) CallRaw(rawQuery []byte) []byte {
var req jsonrpc.RPCRequest
err := json.Unmarshal(rawQuery, &req)
if err != nil {
return errorToJSON(rpcerrors.NewJSONParseError(err))
return rpcerrors.ErrorToJSON(rpcerrors.NewJSONParseError(err))
}
return c.Call(&req)
}
Expand All @@ -81,23 +77,14 @@ func (c *Caller) Call(req *jsonrpc.RPCRequest) []byte {
if err != nil {
monitor.ErrorToSentry(err, map[string]string{"request": spew.Sdump(req), "response": fmt.Sprintf("%v", r)})
logger.Log().Errorf("error calling lbrynet: %v, request: %s", err, spew.Sdump(req))
return errorToJSON(err)
return rpcerrors.ErrorToJSON(err)
}

serialized, err := json.MarshalIndent(r, "", " ")
if err != nil {
monitor.ErrorToSentry(err)
logger.Log().Errorf("error marshaling response: %v", err)
return errorToJSON(rpcerrors.NewInternalError(err))
}

if c.DB != nil {
// TODO: run this in a goroutine so it doesn't block the response?
err = accesstracker.Touch(c.DB, c.userID)
if err != nil {
monitor.ErrorToSentry(err)
logger.Log().Errorf("error touching wallet access time: %v", err)
}
return rpcerrors.ErrorToJSON(rpcerrors.NewInternalError(err))
}

return serialized
Expand Down Expand Up @@ -264,14 +251,6 @@ func (c *Caller) callQueryWithRetry(q *Query) (*jsonrpc.RPCResponse, error) {
return r, err
}

func errorToJSON(err error) []byte {
var rpcErr rpcerrors.RPCError
if errors.As(err, &rpcErr) {
return rpcErr.JSON()
}
return rpcerrors.NewInternalError(err).JSON()
}

func isErrWalletNotLoaded(r *jsonrpc.RPCResponse) bool {
return r.Error != nil && errors.Is(lbrynet.NewWalletError(0, errors.Err(r.Error.Message)), lbrynet.ErrWalletNotLoaded)
}
Expand Down
22 changes: 5 additions & 17 deletions app/rpcerrors/rpcerrors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,9 @@ package rpcerrors

import (
"encoding/json"
"net/http"

"github.com/lbryio/lbrytv/app/auth"
"github.com/lbryio/lbrytv/internal/errors"
"github.com/lbryio/lbrytv/internal/monitor"
"github.com/lbryio/lbrytv/internal/responses"
"github.com/lbryio/lbrytv/models"

"github.com/ybbus/jsonrpc"
)

Expand Down Expand Up @@ -68,17 +63,10 @@ func isJSONParseError(err error) bool {
return err != nil && errors.As(err, &e) && e.code == rpcErrorCodeJSONParse
}

func EnsureAuthenticated(w http.ResponseWriter, user *models.User, err error) bool {
if err == nil && user != nil {
return true
}

if errors.Is(err, auth.ErrNoAuthInfo) {
w.Write(NewAuthRequiredError(errors.Err(responses.AuthRequiredErrorMessage)).JSON())
} else if err != nil {
w.Write(NewForbiddenError(err).JSON())
} else if user == nil {
w.Write(NewForbiddenError(errors.Err("must authenticate")).JSON())
func ErrorToJSON(err error) []byte {
var rpcErr RPCError
if errors.As(err, &rpcErr) {
return rpcErr.JSON()
}
return false
return NewInternalError(err).JSON()
}
8 changes: 5 additions & 3 deletions app/sdkrouter/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import (
"github.com/gorilla/mux"
)

const ContextKey = "sdkrouter"
type ctxKey int

const contextKey ctxKey = iota

func FromRequest(r *http.Request) *Router {
v := r.Context().Value(ContextKey)
v := r.Context().Value(contextKey)
if v == nil {
panic("sdkrouter.Middleware is required")
}
Expand All @@ -19,7 +21,7 @@ func FromRequest(r *http.Request) *Router {

func AddToRequest(rt *Router, fn http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
fn(w, r.Clone(context.WithValue(r.Context(), ContextKey, rt)))
fn(w, r.Clone(context.WithValue(r.Context(), contextKey, rt)))
}
}

Expand Down
33 changes: 33 additions & 0 deletions app/wallet/tracker/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package tracker

import (
"net/http"

"github.com/gorilla/mux"
"github.com/lbryio/lbrytv/app/auth"
"github.com/lbryio/lbrytv/internal/monitor"
"github.com/volatiletech/sqlboiler/boil"
)

func Middleware(db boil.Executor) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)

user, err := auth.FromRequest(r)
if err != nil {
logger.Log().Error(err)
return
}
if user == nil {
return
}

err = Touch(db, user.ID)
if err != nil {
monitor.ErrorToSentry(err)
logger.Log().Errorf("error touching wallet access time: %v", err)
}
})
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package accesstracker
package tracker

import (
"fmt"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package accesstracker
package tracker

import (
"math/rand"
Expand Down Expand Up @@ -46,7 +46,8 @@ func TestTouch(t *testing.T) {
assert.False(t, u.WalletAccessedAt.Valid)

// set access time back in the past
u.WalletAccessedAt = null.TimeFrom(TimeNow().Add(-1 * time.Hour))
oneHourAgo := TimeNow().Add(-1 * time.Hour)
u.WalletAccessedAt = null.TimeFrom(oneHourAgo)
_, err = u.UpdateG(boil.Infer())
require.NoError(t, err)

Expand All @@ -62,7 +63,7 @@ func TestTouch(t *testing.T) {
err = u.ReloadG()
require.NoError(t, err)
assert.True(t, u.WalletAccessedAt.Valid)
assert.True(t, TimeNow().Add(-1*time.Minute).Before(u.WalletAccessedAt.Time))
assert.True(t, oneHourAgo.Before(u.WalletAccessedAt.Time))
}

func TestUnload(t *testing.T) {
Expand Down

0 comments on commit b0348fc

Please sign in to comment.