Skip to content

Commit

Permalink
VIP: make it so that only test code depends on "github.com/dgrijalva/…
Browse files Browse the repository at this point in the history
…jwt-go"
  • Loading branch information
aldas committed Dec 17, 2020
1 parent 4422e3b commit 2f2e986
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 49 deletions.
82 changes: 43 additions & 39 deletions middleware/jwt.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package middleware

import (
"errors"
"fmt"
"github.com/labstack/echo/v4"
"net/http"
"reflect"
"strings"

"github.com/dgrijalva/jwt-go"
"github.com/labstack/echo/v4"
)

type (
Expand Down Expand Up @@ -41,14 +39,17 @@ type (
// Optional. Default value HS256.
SigningMethod string

// KeyFunc is custom method to return signing key during token validation by TokenParser
// Optional. If not set middleware will use default implementation that checks if token algorithm matches singing
// method and returns matching singing key for kid from SigningKeys
// Both `alg` (https://tools.ietf.org/html/rfc7515#section-4.1.1)
// and `kid` (https://tools.ietf.org/html/rfc7515#section-4.1.4) are JWT header parameter values
KeyFunc func(alg string, kid interface{}) (interface{}, error)

// Context key to store user information from the token into context.
// Optional. Default value "user".
ContextKey string

// Claims are extendable claims data defining token content.
// Optional. Default value jwt.MapClaims
Claims jwt.Claims

// TokenLookup is a string in the form of "<source>:<name>" that is used
// to extract token from the request.
// Optional. Default value "header:Authorization".
Expand All @@ -64,7 +65,9 @@ type (
// Optional. Default value "Bearer".
AuthScheme string

keyFunc jwt.Keyfunc
// TokenParser is wrapper interface for different JWT token parsing implementations.
// Required.
TokenParser JwtTokenParser
}

// JWTSuccessHandler defines a function which is executed for a valid token.
Expand Down Expand Up @@ -98,7 +101,6 @@ var (
ContextKey: "user",
TokenLookup: "header:" + echo.HeaderAuthorization,
AuthScheme: "Bearer",
Claims: jwt.MapClaims{},
}
)

Expand All @@ -119,43 +121,46 @@ func JWT(key interface{}) echo.MiddlewareFunc {
// JWTWithConfig returns a JWT auth middleware with config.
// See: `JWT()`.
func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if config.SigningKey == nil && len(config.SigningKeys) == 0 {
panic("echo: jwt middleware requires signing key")
}
if config.TokenParser == nil {
panic("echo: jwt middleware requires token parser instance")
}
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultJWTConfig.Skipper
}
if config.SigningKey == nil && len(config.SigningKeys) == 0 {
panic("echo: jwt middleware requires signing key")
}
if config.SigningMethod == "" {
config.SigningMethod = DefaultJWTConfig.SigningMethod
}
if config.ContextKey == "" {
config.ContextKey = DefaultJWTConfig.ContextKey
}
if config.Claims == nil {
config.Claims = DefaultJWTConfig.Claims
}
if config.TokenLookup == "" {
config.TokenLookup = DefaultJWTConfig.TokenLookup
}
if config.AuthScheme == "" {
config.AuthScheme = DefaultJWTConfig.AuthScheme
}
config.keyFunc = func(t *jwt.Token) (interface{}, error) {
// Check the signing method
if t.Method.Alg() != config.SigningMethod {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
if len(config.SigningKeys) > 0 {
if kid, ok := t.Header["kid"].(string); ok {
if key, ok := config.SigningKeys[kid]; ok {
return key, nil
}
if config.KeyFunc == nil {
config.KeyFunc = func(alg string, kid interface{}) (interface{}, error) {
// signature algorithm used for token must match our signing method
if alg != config.SigningMethod {
return nil, fmt.Errorf("unexpected jwt signing method=%s", alg)
}
if len(config.SigningKeys) == 0 {
return config.SigningKey, nil
}
kidStr, ok := kid.(string)
if !ok {
return nil, errors.New("failed to cast jwt key id as string")
}
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
if key, ok := config.SigningKeys[kidStr]; ok {
return key, nil
}
return nil, fmt.Errorf("unexpected jwt key id=%v", kid)
}

return config.SigningKey, nil
}

// Initialize
Expand Down Expand Up @@ -193,16 +198,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
}
return err
}
token := new(jwt.Token)
// Issue #647, #656
if _, ok := config.Claims.(jwt.MapClaims); ok {
token, err = jwt.Parse(auth, config.keyFunc)
} else {
t := reflect.ValueOf(config.Claims).Type().Elem()
claims := reflect.New(t).Interface().(jwt.Claims)
token, err = jwt.ParseWithClaims(auth, claims, config.keyFunc)
}
if err == nil && token.Valid {
token, err := config.TokenParser.Parse(auth, config)
if err == nil {
// Store user information from token into context.
c.Set(config.ContextKey, token)
if config.SuccessHandler != nil {
Expand All @@ -225,6 +222,13 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
}
}

// JwtTokenParser is wrapper interface for different JWT token parsing implementations.
type JwtTokenParser interface {
// Parse parses token string to token instance that is set to echo.Context under JWTConfig.ContextKey
// Must return error when parsing failed, token is not valid or otherwise incorrect
Parse(tokenString string, config JWTConfig) (interface{}, error)
}

// jwtFromHeader returns a `jwtExtractor` that extracts token from the request header.
func jwtFromHeader(header string, authScheme string) jwtExtractor {
return func(c echo.Context) (string, error) {
Expand Down
68 changes: 58 additions & 10 deletions middleware/jwt_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"errors"
"net/http"
"net/http/httptest"
"net/url"
Expand All @@ -12,6 +13,36 @@ import (
"github.com/stretchr/testify/assert"
)

type DgrijalvaJwtGoParser struct {
// DefaultClaimsFunc returns new claims instance for parsed token. This instance is used as destination when
// marshalling json to claims. Use our own custom claims implementation here.
// Defaults to jwt.MapClaims when not set
// NB: ALWAYS return new instance!!! or requests (goroutines) we would see panics runtime
DefaultClaimsFunc func() jwt.Claims
}

func (p *DgrijalvaJwtGoParser) Parse(tokenString string, config JWTConfig) (interface{}, error) {
var claims jwt.Claims = jwt.MapClaims{}
if p.DefaultClaimsFunc != nil {
claims = p.DefaultClaimsFunc()
}
token, err := new(jwt.Parser).ParseWithClaims(tokenString, claims, p.keyFunc(config))
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("token is not valid")
}
return token, nil
}

func (p *DgrijalvaJwtGoParser) keyFunc(config JWTConfig) jwt.Keyfunc {
return func(t *jwt.Token) (interface{}, error) {
kid, _ := t.Header["kid"]
return config.KeyFunc(t.Method.Alg(), kid)
}
}

// jwtCustomInfo defines some custom types we're going to use within our tokens.
type jwtCustomInfo struct {
Name string `json:"name"`
Expand All @@ -34,7 +65,11 @@ func TestJWTRace(t *testing.T) {
validKey := []byte("secret")

h := JWTWithConfig(JWTConfig{
Claims: &jwtCustomClaims{},
TokenParser: &DgrijalvaJwtGoParser{
DefaultClaimsFunc: func() jwt.Claims {
return &jwtCustomClaims{}
},
},
SigningKey: validKey,
})(handler)

Expand Down Expand Up @@ -70,15 +105,18 @@ func TestJWT(t *testing.T) {
invalidKey := []byte("invalid-key")
validAuth := DefaultJWTConfig.AuthScheme + " " + token

dgrijalvaJwtGoParser := DgrijalvaJwtGoParser{}

for _, tc := range []struct {
expPanic bool
expErrCode int // 0 for Success
config JWTConfig
reqURL string // "/" if empty
hdrAuth string
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
formValues map[string]string
info string
expPanic bool
expErrCode int // 0 for Success
config JWTConfig
tokenParser JwtTokenParser
reqURL string // "/" if empty
hdrAuth string
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
formValues map[string]string
info string
}{
{
expPanic: true,
Expand Down Expand Up @@ -111,7 +149,11 @@ func TestJWT(t *testing.T) {
{
hdrAuth: validAuth,
config: JWTConfig{
Claims: &jwtCustomClaims{},
TokenParser: &DgrijalvaJwtGoParser{
DefaultClaimsFunc: func() jwt.Claims {
return &jwtCustomClaims{}
},
},
SigningKey: []byte("secret"),
},
info: "Valid JWT with custom claims",
Expand Down Expand Up @@ -247,6 +289,10 @@ func TestJWT(t *testing.T) {
c.SetParamValues(token)
}

if tc.config.TokenParser == nil {
tc.config.TokenParser = &dgrijalvaJwtGoParser
}

if tc.expPanic {
assert.Panics(t, func() {
JWTWithConfig(tc.config)
Expand Down Expand Up @@ -344,6 +390,8 @@ func TestJWTwithKID(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
c := e.NewContext(req, res)

tc.config.TokenParser = &DgrijalvaJwtGoParser{}

if tc.expErrCode != 0 {
h := JWTWithConfig(tc.config)(handler)
he := h(c).(*echo.HTTPError)
Expand Down

0 comments on commit 2f2e986

Please sign in to comment.