From 8fe76e87a0755ef3da23aa67b5b28cfc49f9b658 Mon Sep 17 00:00:00 2001 From: Tatsuro Alpert Date: Tue, 21 Mar 2017 15:51:20 -0700 Subject: [PATCH 1/3] query param access token --- middleware_accesstoken.go | 55 +++++++++++-- middleware_accesstoken_test.go | 138 ++++++++++++++++++++++++++++++--- 2 files changed, 175 insertions(+), 18 deletions(-) diff --git a/middleware_accesstoken.go b/middleware_accesstoken.go index 8e0e4c5..2aad91e 100644 --- a/middleware_accesstoken.go +++ b/middleware_accesstoken.go @@ -7,12 +7,14 @@ import ( ) type accessTokens struct { - headerName string - tokens []string + paramName string + tokens []string + getFunc func(string, *http.Request) string + missingMessage string } /* -NewMiddlewareAccessToken creates a new handler to verify access tokens in a rye chain. +NewMiddlewareAccessToken creates a new handler to verify access tokens passed as a header. Example usage: @@ -23,19 +25,58 @@ Example usage: })).Methods("POST") */ func NewMiddlewareAccessToken(headerName string, tokens []string) func(rw http.ResponseWriter, req *http.Request) *Response { + return newAccessTokenHandler(headerName, tokens, true) +} + +/* +NewMiddlewareAccessQueryToken creates a new handler to verify access tokens passed as a query parameter. + +Example usage: + + routes.Handle("/some/route", a.Dependencies.MWHandler.Handle( + []rye.Handler{ + rye.NewMiddlewareAccessQueryToken(queryParamName, []string{token1, token2}), + yourHandler, + })).Methods("POST") +*/ +func NewMiddlewareAccessQueryToken(queryParamName string, tokens []string) func(rw http.ResponseWriter, req *http.Request) *Response { + return newAccessTokenHandler(queryParamName, tokens, false) +} + +func newAccessTokenHandler(name string, tokens []string, headerToken bool) func(rw http.ResponseWriter, req *http.Request) *Response { a := &accessTokens{ - headerName: headerName, - tokens: tokens, + paramName: name, + tokens: tokens, + } + + switch headerToken { + case true: + a.getFunc = func(s string, r *http.Request) string { + return r.Header.Get(s) + } + a.missingMessage = fmt.Sprintf("No access token found; ensure you pass '%s' in header", name) + + case false: + a.getFunc = func(s string, r *http.Request) string { + q, ok := r.URL.Query()[s] + if !ok { + return "" + } + + return q[0] + } + a.missingMessage = fmt.Sprintf("No access token found; ensure you pass the '%s' parameter", name) } + return a.handle } func (a *accessTokens) handle(rw http.ResponseWriter, r *http.Request) *Response { - token := r.Header.Get(a.headerName) + token := a.getFunc(a.paramName, r) if token == "" { return &Response{ - Err: fmt.Errorf("No access token found; ensure you pass '%s' in header", a.headerName), + Err: errors.New(a.missingMessage), StatusCode: http.StatusUnauthorized, } } diff --git a/middleware_accesstoken_test.go b/middleware_accesstoken_test.go index ea752b0..09a8022 100644 --- a/middleware_accesstoken_test.go +++ b/middleware_accesstoken_test.go @@ -1,8 +1,10 @@ package rye import ( + "fmt" "net/http" "net/http/httptest" + "net/url" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -14,31 +16,40 @@ var _ = Describe("AccessToken Middleware", func() { request *http.Request response *httptest.ResponseRecorder - tokenHeaderName = "at-hname" - token1, token2 string + testHandler func(http.ResponseWriter, *http.Request) *Response + + token1, token2 string ) BeforeEach(func() { response = httptest.NewRecorder() - request = &http.Request{ - Header: map[string][]string{}, - } token1 = "test1" token2 = "test2" }) - Describe("handle", func() { + Context("header token", func() { + var ( + tokenHeaderName = "at-hname" + ) + + BeforeEach(func() { + testHandler = NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2}) + request = &http.Request{ + Header: map[string][]string{}, + } + }) + Context("when a valid token is used", func() { It("should return nil", func() { request.Header.Add(tokenHeaderName, token1) - resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request) + resp := testHandler(response, request) Expect(resp).To(BeNil()) }) It("should return nil", func() { request.Header.Add(tokenHeaderName, token2) - resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request) + resp := testHandler(response, request) Expect(resp).To(BeNil()) }) }) @@ -46,7 +57,7 @@ var _ = Describe("AccessToken Middleware", func() { Context("when an invalid token is used", func() { It("should return an error", func() { request.Header.Add(tokenHeaderName, "blah") - resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request) + resp := testHandler(response, request) Expect(resp).ToNot(BeNil()) Expect(resp.Err).To(HaveOccurred()) Expect(resp.Error()).To(ContainSubstring("invalid access token")) @@ -56,7 +67,7 @@ var _ = Describe("AccessToken Middleware", func() { Context("when no token header exists", func() { It("should return an error", func() { - resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request) + resp := testHandler(response, request) Expect(resp).ToNot(BeNil()) Expect(resp.Err).To(HaveOccurred()) Expect(resp.Error()).To(ContainSubstring("No access token found")) @@ -67,7 +78,7 @@ var _ = Describe("AccessToken Middleware", func() { Context("when token header is blank", func() { It("should return an error", func() { request.Header.Add(tokenHeaderName, "") - resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request) + resp := testHandler(response, request) Expect(resp).ToNot(BeNil()) Expect(resp.Err).To(HaveOccurred()) Expect(resp.Error()).To(ContainSubstring("No access token found")) @@ -75,4 +86,109 @@ var _ = Describe("AccessToken Middleware", func() { }) }) }) + + Context("query param token", func() { + var ( + qParamName string + qParams string + ) + + BeforeEach(func() { + qParamName = "token" + testHandler = NewMiddlewareAccessQueryToken(qParamName, []string{token1, token2}) + }) + + JustBeforeEach(func() { + u, err := url.Parse(fmt.Sprintf("http://doesntmatter.io/blah?%s", qParams)) + Expect(err).ToNot(HaveOccurred()) + + request = &http.Request{ + URL: u, + } + }) + + Context("when a valid token is used", func() { + BeforeEach(func() { + qParams = fmt.Sprintf("%s=%s", qParamName, token1) + }) + + It("should return nil", func() { + resp := testHandler(response, request) + Expect(resp).To(BeNil()) + }) + }) + + Context("when the other valid token is used", func() { + BeforeEach(func() { + qParams = fmt.Sprintf("%s=%s", qParamName, token2) + }) + + It("should return nil", func() { + resp := testHandler(response, request) + Expect(resp).To(BeNil()) + }) + }) + + Context("when an invalid token is used", func() { + BeforeEach(func() { + qParams = fmt.Sprintf("%s=blah", qParamName) + }) + + It("should return an error", func() { + resp := testHandler(response, request) + Expect(resp).ToNot(BeNil()) + Expect(resp.Err).To(HaveOccurred()) + Expect(resp.Error()).To(ContainSubstring("invalid access token")) + Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) + }) + }) + + Context("when no token param exists", func() { + BeforeEach(func() { + qParams = "something=else" + }) + + It("should return an error", func() { + resp := testHandler(response, request) + Expect(resp).ToNot(BeNil()) + Expect(resp.Err).To(HaveOccurred()) + Expect(resp.Error()).To(ContainSubstring("No access token found")) + Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) + }) + }) + + Context("when token param is blank", func() { + BeforeEach(func() { + qParams = fmt.Sprintf("%s=''", qParamName) + }) + + It("should return an error", func() { + resp := testHandler(response, request) + Expect(resp).ToNot(BeNil()) + Expect(resp.Err).To(HaveOccurred()) + Expect(resp.Error()).To(ContainSubstring("invalid access token")) + Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) + }) + }) + + Context("when no query params", func() { + JustBeforeEach(func() { + u, err := url.Parse("http://doesntmatter.io/blah") + Expect(err).ToNot(HaveOccurred()) + + request = &http.Request{ + URL: u, + } + }) + + It("should return an error", func() { + resp := testHandler(response, request) + Expect(resp).ToNot(BeNil()) + Expect(resp.Err).To(HaveOccurred()) + Expect(resp.Error()).To(ContainSubstring("No access token found")) + Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) + }) + }) + + }) }) From 167c6f733771372a6bb1df62d99ad7d3cb9d7d97 Mon Sep 17 00:00:00 2001 From: Tatsuro Alpert Date: Wed, 22 Mar 2017 07:23:39 -0700 Subject: [PATCH 2/3] make header default --- middleware_accesstoken.go | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/middleware_accesstoken.go b/middleware_accesstoken.go index 2aad91e..19f9ee7 100644 --- a/middleware_accesstoken.go +++ b/middleware_accesstoken.go @@ -25,7 +25,7 @@ Example usage: })).Methods("POST") */ func NewMiddlewareAccessToken(headerName string, tokens []string) func(rw http.ResponseWriter, req *http.Request) *Response { - return newAccessTokenHandler(headerName, tokens, true) + return newAccessTokenHandler(headerName, tokens, "header") } /* @@ -40,23 +40,16 @@ Example usage: })).Methods("POST") */ func NewMiddlewareAccessQueryToken(queryParamName string, tokens []string) func(rw http.ResponseWriter, req *http.Request) *Response { - return newAccessTokenHandler(queryParamName, tokens, false) + return newAccessTokenHandler(queryParamName, tokens, "query") } -func newAccessTokenHandler(name string, tokens []string, headerToken bool) func(rw http.ResponseWriter, req *http.Request) *Response { +func newAccessTokenHandler(name string, tokens []string, tokenType string) func(rw http.ResponseWriter, req *http.Request) *Response { a := &accessTokens{ paramName: name, tokens: tokens, } - switch headerToken { - case true: - a.getFunc = func(s string, r *http.Request) string { - return r.Header.Get(s) - } - a.missingMessage = fmt.Sprintf("No access token found; ensure you pass '%s' in header", name) - - case false: + if tokenType == "query" { a.getFunc = func(s string, r *http.Request) string { q, ok := r.URL.Query()[s] if !ok { @@ -66,7 +59,15 @@ func newAccessTokenHandler(name string, tokens []string, headerToken bool) func( return q[0] } a.missingMessage = fmt.Sprintf("No access token found; ensure you pass the '%s' parameter", name) + + return a.handle + } + + // default to using the header + a.getFunc = func(s string, r *http.Request) string { + return r.Header.Get(s) } + a.missingMessage = fmt.Sprintf("No access token found; ensure you pass '%s' in header", name) return a.handle } From 139d65e3a974ec5658462153f30b0948203c01b9 Mon Sep 17 00:00:00 2001 From: Tatsuro Alpert Date: Wed, 22 Mar 2017 10:33:58 -0700 Subject: [PATCH 3/3] use switch --- middleware_accesstoken.go | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/middleware_accesstoken.go b/middleware_accesstoken.go index 19f9ee7..00bb9a2 100644 --- a/middleware_accesstoken.go +++ b/middleware_accesstoken.go @@ -49,7 +49,9 @@ func newAccessTokenHandler(name string, tokens []string, tokenType string) func( tokens: tokens, } - if tokenType == "query" { + switch tokenType { + + case "query": a.getFunc = func(s string, r *http.Request) string { q, ok := r.URL.Query()[s] if !ok { @@ -60,14 +62,13 @@ func newAccessTokenHandler(name string, tokens []string, tokenType string) func( } a.missingMessage = fmt.Sprintf("No access token found; ensure you pass the '%s' parameter", name) - return a.handle - } - - // default to using the header - a.getFunc = func(s string, r *http.Request) string { - return r.Header.Get(s) + default: + // default to using the header + a.getFunc = func(s string, r *http.Request) string { + return r.Header.Get(s) + } + a.missingMessage = fmt.Sprintf("No access token found; ensure you pass '%s' in header", name) } - a.missingMessage = fmt.Sprintf("No access token found; ensure you pass '%s' in header", name) return a.handle }