Skip to content
This repository was archived by the owner on Dec 23, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 50 additions & 7 deletions middleware_accesstoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ import (
)

type accessTokens struct {
headerName string
tokens []string
paramName string
tokens []string
getFunc func(string, *http.Request) string
missingMessage string
}

/*
NewMiddlewareAccessToken creates a new handler to verify access tokens in a rye chain.
NewMiddlewareAccessToken creates a new handler to verify access tokens passed as a header.

Example usage:

Expand All @@ -23,19 +25,60 @@ Example usage:
})).Methods("POST")
*/
func NewMiddlewareAccessToken(headerName string, tokens []string) func(rw http.ResponseWriter, req *http.Request) *Response {
return newAccessTokenHandler(headerName, tokens, "header")
}

/*
NewMiddlewareAccessQueryToken creates a new handler to verify access tokens passed as a query parameter.

Example usage:

routes.Handle("/some/route", a.Dependencies.MWHandler.Handle(
[]rye.Handler{
rye.NewMiddlewareAccessQueryToken(queryParamName, []string{token1, token2}),
yourHandler,
})).Methods("POST")
*/
func NewMiddlewareAccessQueryToken(queryParamName string, tokens []string) func(rw http.ResponseWriter, req *http.Request) *Response {
return newAccessTokenHandler(queryParamName, tokens, "query")
}

func newAccessTokenHandler(name string, tokens []string, tokenType string) func(rw http.ResponseWriter, req *http.Request) *Response {
a := &accessTokens{
headerName: headerName,
tokens: tokens,
paramName: name,
tokens: tokens,
}

switch tokenType {

case "query":
a.getFunc = func(s string, r *http.Request) string {
q, ok := r.URL.Query()[s]
if !ok {
return ""
}

return q[0]
}
a.missingMessage = fmt.Sprintf("No access token found; ensure you pass the '%s' parameter", name)

default:
// default to using the header
a.getFunc = func(s string, r *http.Request) string {
return r.Header.Get(s)
}
a.missingMessage = fmt.Sprintf("No access token found; ensure you pass '%s' in header", name)
}

return a.handle
}

func (a *accessTokens) handle(rw http.ResponseWriter, r *http.Request) *Response {
token := r.Header.Get(a.headerName)
token := a.getFunc(a.paramName, r)

if token == "" {
return &Response{
Err: fmt.Errorf("No access token found; ensure you pass '%s' in header", a.headerName),
Err: errors.New(a.missingMessage),
StatusCode: http.StatusUnauthorized,
}
}
Expand Down
138 changes: 127 additions & 11 deletions middleware_accesstoken_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package rye

import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
Expand All @@ -14,39 +16,48 @@ var _ = Describe("AccessToken Middleware", func() {
request *http.Request
response *httptest.ResponseRecorder

tokenHeaderName = "at-hname"
token1, token2 string
testHandler func(http.ResponseWriter, *http.Request) *Response

token1, token2 string
)

BeforeEach(func() {
response = httptest.NewRecorder()
request = &http.Request{
Header: map[string][]string{},
}

token1 = "test1"
token2 = "test2"
})

Describe("handle", func() {
Context("header token", func() {
var (
tokenHeaderName = "at-hname"
)

BeforeEach(func() {
testHandler = NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})
request = &http.Request{
Header: map[string][]string{},
}
})

Context("when a valid token is used", func() {
It("should return nil", func() {
request.Header.Add(tokenHeaderName, token1)
resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request)
resp := testHandler(response, request)
Expect(resp).To(BeNil())
})

It("should return nil", func() {
request.Header.Add(tokenHeaderName, token2)
resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request)
resp := testHandler(response, request)
Expect(resp).To(BeNil())
})
})

Context("when an invalid token is used", func() {
It("should return an error", func() {
request.Header.Add(tokenHeaderName, "blah")
resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request)
resp := testHandler(response, request)
Expect(resp).ToNot(BeNil())
Expect(resp.Err).To(HaveOccurred())
Expect(resp.Error()).To(ContainSubstring("invalid access token"))
Expand All @@ -56,7 +67,7 @@ var _ = Describe("AccessToken Middleware", func() {

Context("when no token header exists", func() {
It("should return an error", func() {
resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request)
resp := testHandler(response, request)
Expect(resp).ToNot(BeNil())
Expect(resp.Err).To(HaveOccurred())
Expect(resp.Error()).To(ContainSubstring("No access token found"))
Expand All @@ -67,12 +78,117 @@ var _ = Describe("AccessToken Middleware", func() {
Context("when token header is blank", func() {
It("should return an error", func() {
request.Header.Add(tokenHeaderName, "")
resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request)
resp := testHandler(response, request)
Expect(resp).ToNot(BeNil())
Expect(resp.Err).To(HaveOccurred())
Expect(resp.Error()).To(ContainSubstring("No access token found"))
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
})
})
})

Context("query param token", func() {
var (
qParamName string
qParams string
)

BeforeEach(func() {
qParamName = "token"
testHandler = NewMiddlewareAccessQueryToken(qParamName, []string{token1, token2})
})

JustBeforeEach(func() {
u, err := url.Parse(fmt.Sprintf("http://doesntmatter.io/blah?%s", qParams))
Expect(err).ToNot(HaveOccurred())

request = &http.Request{
URL: u,
}
})

Context("when a valid token is used", func() {
BeforeEach(func() {
qParams = fmt.Sprintf("%s=%s", qParamName, token1)
})

It("should return nil", func() {
resp := testHandler(response, request)
Expect(resp).To(BeNil())
})
})

Context("when the other valid token is used", func() {
BeforeEach(func() {
qParams = fmt.Sprintf("%s=%s", qParamName, token2)
})

It("should return nil", func() {
resp := testHandler(response, request)
Expect(resp).To(BeNil())
})
})

Context("when an invalid token is used", func() {
BeforeEach(func() {
qParams = fmt.Sprintf("%s=blah", qParamName)
})

It("should return an error", func() {
resp := testHandler(response, request)
Expect(resp).ToNot(BeNil())
Expect(resp.Err).To(HaveOccurred())
Expect(resp.Error()).To(ContainSubstring("invalid access token"))
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
})
})

Context("when no token param exists", func() {
BeforeEach(func() {
qParams = "something=else"
})

It("should return an error", func() {
resp := testHandler(response, request)
Expect(resp).ToNot(BeNil())
Expect(resp.Err).To(HaveOccurred())
Expect(resp.Error()).To(ContainSubstring("No access token found"))
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
})
})

Context("when token param is blank", func() {
BeforeEach(func() {
qParams = fmt.Sprintf("%s=''", qParamName)
})

It("should return an error", func() {
resp := testHandler(response, request)
Expect(resp).ToNot(BeNil())
Expect(resp.Err).To(HaveOccurred())
Expect(resp.Error()).To(ContainSubstring("invalid access token"))
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
})
})

Context("when no query params", func() {
JustBeforeEach(func() {
u, err := url.Parse("http://doesntmatter.io/blah")
Expect(err).ToNot(HaveOccurred())

request = &http.Request{
URL: u,
}
})

It("should return an error", func() {
resp := testHandler(response, request)
Expect(resp).ToNot(BeNil())
Expect(resp.Err).To(HaveOccurred())
Expect(resp.Error()).To(ContainSubstring("No access token found"))
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
})
})

})
})