Skip to content
This repository has been archived by the owner on Mar 18, 2022. It is now read-only.

Commit

Permalink
Merge 5eadb21 into f5b3077
Browse files Browse the repository at this point in the history
  • Loading branch information
trusch committed Sep 6, 2018
2 parents f5b3077 + 5eadb21 commit 22e958f
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 20 deletions.
16 changes: 9 additions & 7 deletions jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,11 @@ func ParsePrivateKey(data []byte) (interface{}, error) {

// GetTokenFromRequest takes the first Authorization header or `token` GET pararm , then
// extract the token prefix and json web token
func GetTokenFromRequest(r *http.Request) (prefix string, token string, err error) {

tokenList, ok := r.Header[AuthorizationHeader]
func GetTokenFromRequest(r *http.Request, header string) (prefix string, token string, err error) {
if header == "" {
header = AuthorizationHeader
}
tokenList, ok := r.Header[http.CanonicalHeaderKey(header)]
// pull from GET if not in the headers
if !ok || len(tokenList) < 1 {
tokenList, ok = r.URL.Query()["token"]
Expand All @@ -182,8 +184,8 @@ func GetTokenFromRequest(r *http.Request) (prefix string, token string, err erro
}

// GetClaimsFromRequestWithValidation extracts and validates the token from a request, returning the claims
func GetClaimsFromRequestWithValidation(r *http.Request, key interface{}) (prefix string, claims Claims, err error) {
prefix, token, err := GetTokenFromRequest(r)
func GetClaimsFromRequestWithValidation(r *http.Request, header string, key interface{}) (prefix string, claims Claims, err error) {
prefix, token, err := GetTokenFromRequest(r, header)
if err != nil {
return prefix, nil, err
}
Expand All @@ -196,8 +198,8 @@ func GetClaimsFromRequestWithValidation(r *http.Request, key interface{}) (prefi
// claims without validating the token. This should only be used in situations
// where you can already trust or if you are simply logging the claim
// information.
func GetClaimsFromRequest(r *http.Request) (prefix string, claims Claims, err error) {
prefix, token, err := GetTokenFromRequest(r)
func GetClaimsFromRequest(r *http.Request, header string) (prefix string, claims Claims, err error) {
prefix, token, err := GetTokenFromRequest(r, header)
if err != nil {
return prefix, nil, err
}
Expand Down
53 changes: 42 additions & 11 deletions jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,15 @@ var _ = Describe("JWT", func() {
Expect(err).NotTo(HaveOccurred())
r, _ := http.NewRequest("GET", "http://foobar.com", nil)
r.Header.Add("Authorization", "Bearer "+token)
prefix, reClaims, err := GetClaimsFromRequest(r)
prefix, reClaims, err := GetClaimsFromRequest(r, "")
Expect(err).NotTo(HaveOccurred())
Expect(reClaims).To(Equal(claims))
Expect(prefix).To(Equal("Bearer"))
})

It("should NOT be possible to get a claims from a request without token", func() {
r, _ := http.NewRequest("GET", "http://foobar.com", nil)
prefix, reClaims, err := GetClaimsFromRequest(r)
prefix, reClaims, err := GetClaimsFromRequest(r, "")
Expect(err).To(HaveOccurred())
Expect(reClaims).To(BeEmpty())
Expect(prefix).To(BeEmpty())
Expand All @@ -153,7 +153,7 @@ var _ = Describe("JWT", func() {
Expect(err).NotTo(HaveOccurred())
r, _ := http.NewRequest("GET", "http://foobar.com", nil)
r.Header.Add("Authorization", "Bearer "+token)
prefix, reClaims, err := GetClaimsFromRequestWithValidation(r, pubKey)
prefix, reClaims, err := GetClaimsFromRequestWithValidation(r, "", pubKey)
Expect(err).NotTo(HaveOccurred())
Expect(reClaims).To(Equal(claims))
Expect(prefix).To(Equal("Bearer"))
Expand All @@ -169,7 +169,7 @@ var _ = Describe("JWT", func() {
Expect(err).NotTo(HaveOccurred())
r, _ := http.NewRequest("GET", "http://foobar.com", nil)
r.Header.Add("Authorization", "Token "+token)
prefix, reClaims, err := GetClaimsFromRequestWithValidation(r, pubKey)
prefix, reClaims, err := GetClaimsFromRequestWithValidation(r, "", pubKey)
Expect(err).NotTo(HaveOccurred())
Expect(reClaims).To(Equal(claims))
Expect(prefix).To(Equal("Token"))
Expand All @@ -188,7 +188,7 @@ var _ = Describe("JWT", func() {
q.Add("token", token)
r.URL.RawQuery = q.Encode()

prefix, reClaims, err := GetClaimsFromRequestWithValidation(r, pubKey)
prefix, reClaims, err := GetClaimsFromRequestWithValidation(r, "", pubKey)
Expect(err).NotTo(HaveOccurred())
Expect(reClaims).To(Equal(claims))
Expect(prefix).To(Equal("GET"))
Expand All @@ -204,7 +204,7 @@ var _ = Describe("JWT", func() {
Expect(err).NotTo(HaveOccurred())
r, _ := http.NewRequest("GET", "http://foobar.com", nil)
r.Header.Add("Authorization", "bearder "+token+" garbage")
prefix, reClaims, err := GetClaimsFromRequestWithValidation(r, pubKey)
prefix, reClaims, err := GetClaimsFromRequestWithValidation(r, "", pubKey)
Expect(err).To(HaveOccurred())
Expect(reClaims).To(BeEmpty())
Expect(prefix).To(BeEmpty())
Expand All @@ -214,7 +214,7 @@ var _ = Describe("JWT", func() {
pubKey, err := ParsePublicKey(rsaPubKey)
Expect(err).NotTo(HaveOccurred())
r, _ := http.NewRequest("GET", "http://foobar.com", nil)
prefix, reClaims, err := GetClaimsFromRequestWithValidation(r, pubKey)
prefix, reClaims, err := GetClaimsFromRequestWithValidation(r, "", pubKey)
Expect(err).To(HaveOccurred())
Expect(reClaims).To(BeEmpty())
Expect(prefix).To(BeEmpty())
Expand Down Expand Up @@ -285,15 +285,15 @@ var _ = Describe("JWT", func() {
Expect(err).NotTo(HaveOccurred())
Expect(token).NotTo(BeEmpty())

handlerA := ClaimsToContextMiddleware(RequireClaim(handler, "foo", "bar"), pubKey)
handlerA := ClaimsToContextMiddleware(RequireClaim(handler, "foo", "bar"), "", pubKey)
r, _ := http.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Authorization", "bearer "+token)
w := httptest.NewRecorder()
handlerA.ServeHTTP(w, r)
fmt.Println(w.Body)
Expect(w.Code).To(Equal(http.StatusOK))

handlerB := ClaimsToContextMiddleware(RequireClaim(handler, "foo", "barbara"), pubKey)
handlerB := ClaimsToContextMiddleware(RequireClaim(handler, "foo", "barbara"), "", pubKey)
r, _ = http.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Authorization", "bearer "+token)
w = httptest.NewRecorder()
Expand All @@ -313,7 +313,7 @@ var _ = Describe("JWT", func() {
r.Header.Add("Authorization", "Bearer "+token)

handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
handler = ClaimsToContextMiddleware(handler, pubKey)
handler = ClaimsToContextMiddleware(handler, "", pubKey)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
Expect(w.Code).To(Equal(http.StatusUnauthorized))
Expand Down Expand Up @@ -342,12 +342,43 @@ var _ = Describe("JWT", func() {

handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
handler = RequireClaim(handler, "foo", "bar")
handler = ClaimsToContextMiddleware(handler, pubKey)
handler = ClaimsToContextMiddleware(handler, "", pubKey)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
Expect(w.Code).To(Equal(http.StatusUnauthorized))
})

It("should be possible to get claims from a user specified header without prefix", func() {
claims := Claims{"foo": "bar"}
pubKey, err := ParsePublicKey(rsaPubKey)
Expect(err).NotTo(HaveOccurred())
privKey, err := ParsePrivateKey(rsaPrivKey)
Expect(err).NotTo(HaveOccurred())
token, err := CreateToken(claims, privKey)
Expect(err).NotTo(HaveOccurred())
r, _ := http.NewRequest("GET", "http://foobar.com", nil)
r.Header.Add("X-Custom-Token", token)
prefix, reClaims, err := GetClaimsFromRequestWithValidation(r, "X-Custom-Token", pubKey)
Expect(err).NotTo(HaveOccurred())
Expect(reClaims).To(Equal(claims))
Expect(prefix).To(Equal(""))
})
It("should be possible to get claims from a user specified header with prefix", func() {
claims := Claims{"foo": "bar"}
pubKey, err := ParsePublicKey(rsaPubKey)
Expect(err).NotTo(HaveOccurred())
privKey, err := ParsePrivateKey(rsaPrivKey)
Expect(err).NotTo(HaveOccurred())
token, err := CreateToken(claims, privKey)
Expect(err).NotTo(HaveOccurred())
r, _ := http.NewRequest("GET", "http://foobar.com", nil)
r.Header.Add("X-Custom-Token", "custom-prefix "+token)
prefix, reClaims, err := GetClaimsFromRequestWithValidation(r, "X-Custom-Token", pubKey)
Expect(err).NotTo(HaveOccurred())
Expect(reClaims).To(Equal(claims))
Expect(prefix).To(Equal("custom-prefix"))
})

})

var (
Expand Down
4 changes: 2 additions & 2 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ var (
)

// ClaimsToContextMiddleware is a http middleware which parses and validates a jwt from the authorization header and stores the claims in the requests context before calling the next handler.
func ClaimsToContextMiddleware(handler http.Handler, idpKey interface{}) http.Handler {
func ClaimsToContextMiddleware(handler http.Handler, header string, idpKey interface{}) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, claims, err := GetClaimsFromRequestWithValidation(r, idpKey)
_, claims, err := GetClaimsFromRequestWithValidation(r, header, idpKey)
if err != nil {
http.Error(w, "not authorized: failed to validate token: "+err.Error(), http.StatusUnauthorized)
return
Expand Down

0 comments on commit 22e958f

Please sign in to comment.