Skip to content

Commit

Permalink
Fix after PR review
Browse files Browse the repository at this point in the history
  • Loading branch information
letzya committed May 18, 2018
2 parents 7be5ddb + 84ba095 commit 401a697
Show file tree
Hide file tree
Showing 11 changed files with 308 additions and 1,145 deletions.
2 changes: 1 addition & 1 deletion apidef/api_definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ type APIDefinition struct {
EnableJWT bool `bson:"enable_jwt" json:"enable_jwt"`
UseStandardAuth bool `bson:"use_standard_auth" json:"use_standard_auth"`
EnableCoProcessAuth bool `bson:"enable_coprocess_auth" json:"enable_coprocess_auth"`
JWTSkipCheckKidAsId bool `bson:"jwt_skip_check_kid_as_id" json:"jwt_skip_check_kid_as_id"`
JWTSkipCheckKidAsId bool `bson:"jwt_skip_check_kid_as_id" json:"jwt_skip_check_kid_as_id"`
JWTSigningMethod string `bson:"jwt_signing_method" json:"jwt_signing_method"`
JWTSource string `bson:"jwt_source" json:"jwt_source"`
JWTIdentityBaseField string `bson:"jwt_identit_base_field" json:"jwt_identity_base_field"`
Expand Down
2 changes: 1 addition & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ type Config struct {
VersionHeader string `json:"version_header"`
EnableHashedKeysListing bool `json:"enable_hashed_keys_listing"`
MinTokenLength int `json:"min_token_length"`
JWTSkipCheckKidAsId bool `json:"jwt_use_id_from_kid"`
JWTSkipCheckKidAsId bool `json:"jwt_skip_check_kid_as_id"`
}

type CertData struct {
Expand Down
3 changes: 3 additions & 0 deletions lint/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,9 @@ const confSchema = `{
},
"min_token_length": {
"type": "integer"
},
"jwt_skip_check_kid_as_id": {
"type": "boolean"
}
}
}`
134 changes: 61 additions & 73 deletions mw_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ type JWTMiddleware struct {
}

const (
KID string = "kid"
SUB string = "sub"
HMACSign = "hmac"
RSASign = "rsa"
ECDSASing = "ecdsa"
KID = "kid"
SUB = "sub"
HMACSign = "hmac"
RSASign = "rsa"
ECDSASign = "ecdsa"
)

func (k *JWTMiddleware) Name() string {
Expand Down Expand Up @@ -108,10 +108,8 @@ func (k *JWTMiddleware) getSecretFromURL(url, kid, keyType string) ([]byte, erro
return nil, errors.New("No matching KID could be found")
}

// Try using a kid or configured claim or sub claim
func (k *JWTMiddleware) getIdentityFomToken(token *jwt.Token) (string, bool, error) {
idFound := false
var tykId string
func (k *JWTMiddleware) getIdentityFromToken(token *jwt.Token) (string, error) {
// Try using a kid or configured claim or sub claim

// Global config overrides and sets the api def
if k.Spec.GlobalConfig.JWTSkipCheckKidAsId {
Expand All @@ -121,19 +119,14 @@ func (k *JWTMiddleware) getIdentityFomToken(token *jwt.Token) (string, bool, err
// Check which claim is used for the id - kid or sub header
// If is not supposed to ignore KID - will use this as ID if not empty
if !k.Spec.APIDefinition.JWTSkipCheckKidAsId {
tykId, idFound = token.Header[KID].(string)
if idFound {
if tykId, idFound := token.Header[KID].(string); idFound {
log.Debug("Found: ", tykId)
return tykId, idFound, nil
return tykId, nil
}
}
// In case KID was empty or was set to ignore KID ==> Will try to get the Id from JWTIdentityBaseField or fallback to SUM
tykId, err := k.getUserIdFromClaim(token)
if err != nil {
return "", false, err
}

return tykId, true, nil
// In case KID was empty or was set to ignore KID ==> Will try to get the Id from JWTIdentityBaseField or fallback to 'sub'
tykId, err := k.getUserIdFromClaim(token.Claims.(jwt.MapClaims))
return tykId, err
}

func (k *JWTMiddleware) getSecretToVerifySignature(token *jwt.Token) ([]byte, error) {
Expand Down Expand Up @@ -172,8 +165,8 @@ func (k *JWTMiddleware) getSecretToVerifySignature(token *jwt.Token) ([]byte, er
// If we are here, there's no central JWT source

// Get the ID from the token (in KID header or configured claim or SUB claim)
tykId, found, err := k.getIdentityFomToken(token)
if !found {
tykId, err := k.getIdentityFromToken(token)
if err != nil {
return nil, err
}

Expand All @@ -187,8 +180,8 @@ func (k *JWTMiddleware) getSecretToVerifySignature(token *jwt.Token) ([]byte, er
return []byte(session.JWTData.Secret), nil
}

func (k *JWTMiddleware) getPolicyIDFromToken(token *jwt.Token) (string, bool) {
policyID, foundPolicy := token.Claims.(jwt.MapClaims)[k.Spec.JWTPolicyFieldName].(string)
func (k *JWTMiddleware) getPolicyIDFromToken(claims jwt.MapClaims) (string, bool) {
policyID, foundPolicy := claims[k.Spec.JWTPolicyFieldName].(string)
if !foundPolicy {
log.Error("Could not identify a policy to apply to this token from field!")
return "", false
Expand All @@ -197,11 +190,11 @@ func (k *JWTMiddleware) getPolicyIDFromToken(token *jwt.Token) (string, bool) {
return policyID, true
}

func (k *JWTMiddleware) getBasePolicyID(token *jwt.Token) (string, bool) {
func (k *JWTMiddleware) getBasePolicyID(claims jwt.MapClaims) (string, bool) {
if k.Spec.JWTPolicyFieldName != "" {
return k.getPolicyIDFromToken(token)
return k.getPolicyIDFromToken(claims)
} else if k.Spec.JWTClientIDBaseField != "" {
clientID, clientIDFound := token.Claims.(jwt.MapClaims)[k.Spec.JWTClientIDBaseField].(string)
clientID, clientIDFound := claims[k.Spec.JWTClientIDBaseField].(string)
if !clientIDFound {
log.Error("Could not identify a policy to apply to this token from field!")
return "", false
Expand All @@ -225,56 +218,53 @@ func (k *JWTMiddleware) getBasePolicyID(token *jwt.Token) (string, bool) {
return "", false
}

func (k *JWTMiddleware) getUserIdFromClaim(token *jwt.Token) (string, error) {
func (k *JWTMiddleware) getUserIdFromClaim(claims jwt.MapClaims) (string, error) {
var userId string
var found bool
var found = false

if k.Spec.JWTIdentityBaseField != "" {
userId, found = token.Claims.(jwt.MapClaims)[k.Spec.JWTIdentityBaseField].(string)
if userId, found = claims[k.Spec.JWTIdentityBaseField].(string); found {
if len(userId) > 0 {
log.WithField("userId", userId).Debug("Found User Id in Base Field")
return userId, nil
}
message := "found an empty user ID in predefined base field claim " + k.Spec.JWTIdentityBaseField
log.Error(message)
return "", errors.New(message)
}

if !found {
log.Warning("Base Field not found, falling back to SUB")
log.WithField("Base Field", k.Spec.JWTIdentityBaseField).Warning("Base Field claim not found, trying to find user ID in 'sub' claim.")
}
}
if !found {
userId, found = token.Claims.(jwt.MapClaims)[SUB].(string)
if !found {
message := fmt.Sprintf("user id was not found in claims: %s", userId)
log.Error(message)
return "", errors.New(message)

if userId, found = claims[SUB].(string); found {
if len(userId) > 0 {
log.WithField("userId", userId).Debug("Found User Id in 'sub' claim")
return userId, nil
}
message := "found an empty user ID in sub claim"
log.Error(message)
return "", errors.New(message)
}
log.Debugf("Found User Id: ", userId)
return userId, nil

message := "no suitable claims for user ID were found"
log.Error(message)
return "", errors.New(message)
}

// processCentralisedJWT Will check a JWT token centrally against the secret stored in the API Definition.
func (k *JWTMiddleware) processCentralisedJWT(r *http.Request, token *jwt.Token) (error, int) {
log.Debug("JWT authority is centralised")

// Generate a virtual token
baseFieldData, baseFound := token.Claims.(jwt.MapClaims)[k.Spec.JWTIdentityBaseField].(string)
if !baseFound {
log.Warning("Base Field not found, using SUB")
var found bool
baseFieldData, found = token.Claims.(jwt.MapClaims)["sub"].(string)
if !found {
log.Error("ID Could not be generated. Failing Request.")
k.reportLoginFailure("[NOT FOUND]", r)
return errors.New("Key not authorized"), http.StatusForbidden
}


var baseFieldData string

baseFieldData, err := k.getUserIdFromClaim(token)
claims := token.Claims.(jwt.MapClaims)
baseFieldData, err := k.getUserIdFromClaim(claims)
if err != nil {
k.reportLoginFailure("[NOT FOUND]", r)
return err, 403
return err, http.StatusForbidden
}

// Generate a virtual token
log.Debug("Base Field ID set to: ", baseFieldData)
data := []byte(baseFieldData)
tokenID := fmt.Sprintf("%x", md5.Sum(data))
sessionID := k.Spec.OrgID + tokenID
Expand All @@ -288,7 +278,7 @@ func (k *JWTMiddleware) processCentralisedJWT(r *http.Request, token *jwt.Token)
session = user.SessionState{}

// We need a base policy as a template, either get it from the token itself OR a proxy client ID within Tyk
basePolicyID, foundPolicy := k.getBasePolicyID(token)
basePolicyID, foundPolicy := k.getBasePolicyID(claims)
if !foundPolicy {
k.reportLoginFailure(baseFieldData, r)
return errors.New("Key not authorized: no matching policy found"), http.StatusForbidden
Expand Down Expand Up @@ -320,7 +310,7 @@ func (k *JWTMiddleware) processCentralisedJWT(r *http.Request, token *jwt.Token)
return nil, http.StatusOK
} else if k.Spec.JWTPolicyFieldName != "" {
// extract policy ID from JWT token
policyID, foundPolicy := k.getPolicyIDFromToken(token)
policyID, foundPolicy := k.getPolicyIDFromToken(claims)
if !foundPolicy {
k.reportLoginFailure(baseFieldData, r)
return errors.New("Key not authorized: no matching policy found"), http.StatusForbidden
Expand Down Expand Up @@ -385,14 +375,14 @@ func (k *JWTMiddleware) reportLoginFailure(tykId string, r *http.Request) {

func (k *JWTMiddleware) processOneToOneTokenMap(r *http.Request, token *jwt.Token) (error, int) {
// Get the ID from the token
tykId, found, err := k.getIdentityFomToken(token)

if !found {
tykId, err := k.getIdentityFromToken(token)
if err != nil {
k.reportLoginFailure(tykId, r)
return err, http.StatusNotFound
}

log.Debug("Using raw key ID: ", tykId)

session, exists := k.CheckSessionAndIdentityForValidKey(tykId)
if !exists {
k.reportLoginFailure(tykId, r)
Expand Down Expand Up @@ -451,18 +441,18 @@ func (k *JWTMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Request, _
switch k.Spec.JWTSigningMethod {
case HMACSign:
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v and not HMACSign", token.Header["alg"])
return nil, fmt.Errorf("Unexpected signing method: %v and not HMAC signature", token.Header["alg"])
}
case RSASign:
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v and not RSASign", token.Header["alg"])
return nil, fmt.Errorf("Unexpected signing method: %v and not RSA signature", token.Header["alg"])
}
case ECDSASing:
case ECDSASign:
if _, ok := token.Method.(*jwt.SigningMethodECDSA); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v and not ECDSASing", token.Header["alg"])
return nil, fmt.Errorf("Unexpected signing method: %v and not ECDSA signature", token.Header["alg"])
}
default:
log.Warning("No signing method found in API Definition, defaulting to HMAC")
log.Warning("No signing method found in API Definition, defaulting to HMAC signature")
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
Expand All @@ -474,7 +464,7 @@ func (k *JWTMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Request, _
return nil, err
}

if k.Spec.JWTSigningMethod == "rsa" {
if k.Spec.JWTSigningMethod == RSASign {
asRSA, err := jwt.ParseRSAPublicKeyFromPEM(val)
if err != nil {
log.Error("Failed to decode JWT to RSA type")
Expand All @@ -487,7 +477,7 @@ func (k *JWTMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Request, _
})

if err == nil && token.Valid {
if jwtErr := k.validateJWTClaims(token.Claims.(jwt.MapClaims)); jwtErr != nil {
if jwtErr := k.timeValidateJWTClaims(token.Claims.(jwt.MapClaims)); jwtErr != nil {
return errors.New("Key not authorized: " + jwtErr.Error()), http.StatusUnauthorized
}

Expand All @@ -503,17 +493,15 @@ func (k *JWTMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Request, _
}
logEntry := getLogEntryForRequest(r, "", nil)
logEntry.Info("Attempted JWT access with non-existent key.")

k.reportLoginFailure(tykId, r)

if err != nil {
logEntry.Error("JWT validation error: ", err)
return errors.New("Key not authorized:" + err.Error()), http.StatusForbidden
}
return errors.New("Key not authorized"), 403
}
return errors.New("Key not authorized"), http.StatusForbidden
}

func (k *JWTMiddleware) validateJWTClaims(c jwt.MapClaims) *jwt.ValidationError {
func (k *JWTMiddleware) timeValidateJWTClaims(c jwt.MapClaims) *jwt.ValidationError {
vErr := new(jwt.ValidationError)
now := time.Now().Unix()

Expand Down
Loading

0 comments on commit 401a697

Please sign in to comment.