Skip to content

Commit

Permalink
fix: fix bugs in authentication and CORS(#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
vishnu-deepsource committed Aug 1, 2023
1 parent 84829ca commit 7e12c36
Show file tree
Hide file tree
Showing 19 changed files with 171 additions and 161 deletions.
35 changes: 13 additions & 22 deletions artifact/facade.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"errors"
"net/http"

"github.com/deepsourcecorp/runner/middleware"

"github.com/labstack/echo/v4"
)

Expand All @@ -25,12 +27,12 @@ var (

type Facade struct {
ArtifactHandler *Handler
CORSMiddleware echo.MiddlewareFunc
allowedOrigin string
}

type Opts struct {
AllowedOrigin string // For CORS
Bucket string
AllowedOrigin string
Storage StorageClient
}

Expand All @@ -39,30 +41,19 @@ func New(ctx context.Context, opts *Opts) (*Facade, error) {
return nil, ErrMissingOpts
}

cors := corsMiddleware(opts.AllowedOrigin)

return &Facade{
allowedOrigin: opts.AllowedOrigin,
ArtifactHandler: NewHandler(opts.Storage, opts.Bucket),
CORSMiddleware: cors,
}, nil
}

func (f *Facade) AddRoutes(router Router, middleware []echo.MiddlewareFunc) Router {
middleware = append([]echo.MiddlewareFunc{f.CORSMiddleware}, middleware...)
router.AddRoute(http.MethodOptions, "apps/:app_id/artifacts/*", f.ArtifactHandler.HandleOptions, middleware...)
router.AddRoute(http.MethodPost, "apps/:app_id/artifacts/analysis", f.ArtifactHandler.HandleAnalysis, middleware...)
router.AddRoute(http.MethodPost, "apps/:app_id/artifacts/autofix", f.ArtifactHandler.HandleAutofix, middleware...)
return router
}
func (f *Facade) AddRoutes(router Router, m []echo.MiddlewareFunc) Router {
cors := middleware.CorsMiddleware(f.allowedOrigin)
router.AddRoute(http.MethodOptions, "apps/:app_id/artifacts", func(c echo.Context) error { return c.NoContent(http.StatusOK) }, cors)

func corsMiddleware(origin string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Response().Header().Set(HeaderAccessControlAllowOrigin, origin)
c.Response().Header().Set(HeaderAccessControlAllowMethods, "GET, POST, OPTIONS")
c.Response().Header().Set(HeaderAccessControlAllowHeaders, "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, Cache-Control, Pragma")
c.Response().Header().Set(HeaderAccessControlAllowCredentials, "true")
return next(c)
}
}
m = append([]echo.MiddlewareFunc{cors}, m...)
router.AddRoute(http.MethodPost, "apps/:app_id/artifacts/analysis", f.ArtifactHandler.HandleAnalysis, m...)
router.AddRoute(http.MethodPost, "apps/:app_id/artifacts/autofix", f.ArtifactHandler.HandleAutofix, m...)

return router
}
5 changes: 0 additions & 5 deletions artifact/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package artifact
import (
"fmt"
"io"
"net/http"

"github.com/labstack/echo/v4"
"golang.org/x/exp/slog"
Expand Down Expand Up @@ -120,7 +119,3 @@ func (h *Handler) HandleAutofix(c echo.Context) error {
}
return c.JSON(200, autofixArtifactsResponse)
}

func (*Handler) HandleOptions(c echo.Context) error {
return c.NoContent(http.StatusOK)
}
15 changes: 0 additions & 15 deletions artifact/middleware.go

This file was deleted.

17 changes: 11 additions & 6 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/deepsourcecorp/runner/auth/store"
"github.com/deepsourcecorp/runner/auth/token"
"github.com/deepsourcecorp/runner/httperror"
"github.com/deepsourcecorp/runner/middleware"
"github.com/labstack/echo/v4"
"golang.org/x/exp/slog"
)
Expand All @@ -27,14 +28,16 @@ type Facade struct {
SAMLHandlers *saml.Handler
TokenMiddleware echo.MiddlewareFunc
SessionMiddleware echo.MiddlewareFunc
allowedOrigin string
}

type Opts struct {
Runner *model.Runner
DeepSource *model.DeepSource
Apps map[string]*oauth.App
SAML *saml.Opts
Store store.Store
Runner *model.Runner
DeepSource *model.DeepSource
Apps map[string]*oauth.App
SAML *saml.Opts
Store store.Store
AllowedOrigin string // For CORS
}

func New(ctx context.Context, opts *Opts, client *http.Client) (*Facade, error) {
Expand Down Expand Up @@ -80,11 +83,13 @@ func New(ctx context.Context, opts *Opts, client *http.Client) (*Facade, error)
TokenMiddleware: tokenMiddleware,
SessionMiddleware: sessionMiddleware,
SAMLHandlers: samlHandlers,
allowedOrigin: opts.AllowedOrigin,
}, nil
}

func (f *Facade) AddRoutes(r Router) Router {
r.AddRoute(http.MethodPost, "/refresh", f.TokenHandlers.HandleRefresh)
cors := middleware.CorsMiddleware(f.allowedOrigin)
r.AddRoute(http.MethodPost, "/refresh", f.TokenHandlers.HandleRefresh, cors)
r.AddRoute(http.MethodPost, "/logout", f.TokenHandlers.HandleLogout)

r.AddRoute(http.MethodGet, "/apps/:app_id/auth/authorize", f.OAuthHandlers.HandleAuthorize)
Expand Down
14 changes: 5 additions & 9 deletions auth/oauth/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ import (
"golang.org/x/oauth2"
)

const (
ExpiryAccessToken = 15 * time.Minute
)

type Handler struct {
runner *model.Runner
deepsource *model.DeepSource
Expand Down Expand Up @@ -97,7 +93,7 @@ func (h *Handler) HandleCallback(c echo.Context) error {
return c.JSON(500, err.Error())
}

accessToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeUser, token.ScopeCodeRead}, user)
accessToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeUser, token.ScopeCodeRead}, user, token.ExpiryAccessToken)
if err != nil {
return c.JSON(500, err.Error())
}
Expand All @@ -111,7 +107,7 @@ func (h *Handler) HandleCallback(c echo.Context) error {
HttpOnly: true,
})

refreshToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeRefresh}, user)
refreshToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeRefresh}, user, token.ExpiryRefreshToken)
if err != nil {
return c.JSON(500, err.Error())
}
Expand Down Expand Up @@ -185,11 +181,11 @@ func (h *Handler) HandleToken(c echo.Context) error {
return c.JSON(http.StatusForbidden, err.Error())
}

accessToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeUser}, user)
accessToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeUser}, user, token.ExpiryAccessToken)
if err != nil {
return c.JSON(500, err.Error())
}
refreshtToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeUser}, user)
refreshtToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeUser}, user, token.ExpiryRefreshToken)
if err != nil {
return c.JSON(500, err.Error())
}
Expand Down Expand Up @@ -272,7 +268,7 @@ func (h *Handler) HandleRefresh(c echo.Context) error {
return c.JSON(http.StatusUnauthorized, err.Error())
}

accessToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeUser}, user)
accessToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeUser}, user, token.ExpiryAccessToken)
if err != nil {
return c.JSON(500, err.Error())
}
Expand Down
10 changes: 5 additions & 5 deletions auth/saml/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (h *Handler) AuthorizationHandler() echo.HandlerFunc {
Email: attr.Get("email"),
Name: attr.Get("first_name") + " " + attr.Get("last_name"),
}
accessToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeUser, token.ScopeCodeRead}, user)
accessToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeUser, token.ScopeCodeRead}, user, token.ExpiryAccessToken)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
if _, err := w.Write([]byte(err.Error())); err != nil {
Expand All @@ -108,7 +108,7 @@ func (h *Handler) AuthorizationHandler() echo.HandlerFunc {
HttpOnly: true,
})

refreshToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeRefresh}, user)
refreshToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeRefresh}, user, token.ExpiryRefreshToken)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
if _, err := w.Write([]byte(err.Error())); err != nil {
Expand Down Expand Up @@ -187,12 +187,12 @@ func (h *Handler) HandleToken(c echo.Context) error {
return c.JSON(http.StatusForbidden, err.Error())
}

accessToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeUser}, user)
accessToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeUser}, user, token.ExpiryAccessToken)
if err != nil {
return c.JSON(http.StatusInternalServerError, err.Error())
}

refreshtoken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeRefresh}, user)
refreshtoken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeRefresh}, user, token.ExpiryRefreshToken)
if err != nil {
return c.JSON(http.StatusInternalServerError, err.Error())
}
Expand Down Expand Up @@ -227,7 +227,7 @@ func (h *Handler) HandleRefresh(c echo.Context) error {
return c.JSON(http.StatusUnauthorized, err.Error())
}

accessToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeUser}, user)
accessToken, err := h.tokenService.GenerateToken(h.runner.ID, []string{token.ScopeUser}, user, token.ExpiryAccessToken)
if err != nil {
return c.JSON(500, err.Error())
}
Expand Down
8 changes: 6 additions & 2 deletions auth/token/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ type Handler struct {
func NewHandler(runner *model.Runner, service *Service) *Handler {
return &Handler{
service: service,
runner: runner,
}
}

func (h *Handler) HandleRefresh(c echo.Context) error {
referrer := c.Request().Referer()

referrer := c.QueryParam("redirect")

cookie, err := c.Cookie("refresh")
if err != nil {
return c.JSON(http.StatusUnauthorized, err.Error())
Expand All @@ -33,10 +36,11 @@ func (h *Handler) HandleRefresh(c echo.Context) error {
return c.JSON(http.StatusUnauthorized, err.Error())
}

accessToken, err := h.service.GenerateToken(h.runner.ID, []string{ScopeUser, ScopeCodeRead}, user)
accessToken, err := h.service.GenerateToken(h.runner.ID, []string{ScopeUser, ScopeCodeRead}, user, ExpiryAccessToken)
if err != nil {
return c.JSON(500, err.Error())
}

c.SetCookie(&http.Cookie{
Name: "session",
Value: accessToken,
Expand Down
9 changes: 6 additions & 3 deletions auth/token/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@ import (
func SessionAuthMiddleware(runnerID string, service *Service) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
referer := c.Request().URL.String()
c.Response().Header().Set("referer", referer)

cookie, err := c.Cookie("session")
if err != nil {
return c.Redirect(http.StatusTemporaryRedirect, "/refresh")
return c.Redirect(http.StatusTemporaryRedirect, "/refresh?redirect="+referer)
}
if cookie.Value == "" {
return c.Redirect(http.StatusTemporaryRedirect, "/refresh")
return c.Redirect(http.StatusTemporaryRedirect, "/refresh?redirect="+referer)
}
_, err = service.ReadToken(runnerID, ScopeCodeRead, cookie.Value)
if err != nil {
return c.Redirect(http.StatusTemporaryRedirect, "/refresh")
return c.Redirect(http.StatusTemporaryRedirect, "/refresh?redirect="+referer)
}
return next(c)
}
Expand Down
4 changes: 2 additions & 2 deletions auth/token/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func TestSessionAuthMiddleware(t *testing.T) {

t.Run("token expired", func(t *testing.T) {
ExpiryAccessToken = -1 * time.Minute
token, err := service.GenerateToken("runner-id", []string{ScopeUser}, user)
token, err := service.GenerateToken("runner-id", []string{ScopeUser}, user, ExpiryAccessToken)
require.NoError(t, err)
req := httptest.NewRequest("GET", "/", nil)
req.AddCookie(&http.Cookie{Name: "session", Value: token})
Expand All @@ -93,7 +93,7 @@ func TestSessionAuthMiddleware(t *testing.T) {

t.Run("valid token", func(t *testing.T) {
ExpiryAccessToken = 10 * time.Minute
token, err := service.GenerateToken("runner-id", []string{ScopeCodeRead}, user)
token, err := service.GenerateToken("runner-id", []string{ScopeCodeRead}, user, ExpiryAccessToken)
require.NoError(t, err)
req := httptest.NewRequest("GET", "/", nil)
req.AddCookie(&http.Cookie{Name: "session", Value: token})
Expand Down
60 changes: 4 additions & 56 deletions auth/token/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ const (
)

var (
ExpiryAccessToken = 15 * time.Minute
ExpiryAccessToken = 15 * time.Minute
ExpiryRefreshToken = 15 * 24 * time.Hour
)

type Service struct {
Expand All @@ -32,8 +33,8 @@ func NewService(signer *jwtutil.Signer, verifier *jwtutil.Verifier) *Service {
}
}

func (s *Service) GenerateToken(issuer string, scopes []string, user *model.User) (string, error) {
return s.signer.GenerateToken(issuer, scopes, user.Claims(), ExpiryAccessToken)
func (s *Service) GenerateToken(issuer string, scopes []string, user *model.User, expiry time.Duration) (string, error) {
return s.signer.GenerateToken(issuer, scopes, user.Claims(), expiry)
}

func (s *Service) ReadToken(issuer string, scope string, token string) (*model.User, error) {
Expand Down Expand Up @@ -69,56 +70,3 @@ func (s *Service) ReadToken(issuer string, scope string, token string) (*model.U
Provider: claims["provider"].(string),
}, nil
}

// func (s *Service) ReadAccessToken(issuer string, token string) (*model.User, error) {
// claims, err := s.verifier.Verify(token)
// if err != nil {
// return nil, err
// }
// for _, v := range []string{"id", "name", "email", "login", "provider"} {
// if _, ok := claims[v]; !ok {
// return nil, errors.New("invalid claims")
// }
// }

// if claims["iss"] != issuer {
// return nil, errors.New("invalid issuer")
// }

// if claims["scp"] != ScopeCodeRead {
// return nil, errors.New("invalid scope")
// }

// return &model.User{
// ID: claims["id"].(string),
// Name: claims["name"].(string),
// Email: claims["email"].(string),
// Login: claims["login"].(string),
// Provider: claims["provider"].(string),
// }, nil
// }

// func (s *Service) ReadRefreshToken(issuer string, token string) (*model.User, error) {
// claims, err := s.verifier.Verify(token)
// if err != nil {
// return nil, err
// }
// for _, v := range []string{"id", "name", "email", "login", "provider"} {
// if _, ok := claims[v]; !ok {
// return nil, errors.New("invalid claims")
// }
// }
// if claims["iss"] != issuer {
// return nil, errors.New("invalid issuer")
// }
// if claims["scp"] != ScopeRefresh {
// return nil, errors.New("invalid scope")
// }
// return &model.User{
// ID: claims["id"].(string),
// Name: claims["name"].(string),
// Email: claims["email"].(string),
// Login: claims["login"].(string),
// Provider: claims["provider"].(string),
// }, nil
// }
Loading

0 comments on commit 7e12c36

Please sign in to comment.