Skip to content

Commit

Permalink
Replace *jwt.validateClaimsWithLeeway with custom validation func
Browse files Browse the repository at this point in the history
+ refactor validator pkg
  • Loading branch information
sergiught committed Oct 31, 2022
1 parent 6f70e49 commit 33c3d1d
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 42 deletions.
124 changes: 88 additions & 36 deletions validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,62 +99,114 @@ func (v *Validator) ValidateToken(ctx context.Context, tokenString string) (inte
return nil, fmt.Errorf("could not parse the token: %w", err)
}

if string(v.signatureAlgorithm) != token.Headers[0].Algorithm {
return nil, fmt.Errorf(
"expected %q signing algorithm but token specified %q",
v.signatureAlgorithm,
token.Headers[0].Algorithm,
)
if err = validateSigningMethod(string(v.signatureAlgorithm), token.Headers[0].Algorithm); err != nil {
return nil, fmt.Errorf("signing method is invalid: %w", err)
}

key, err := v.keyFunc(ctx)
registeredClaims, customClaims, err := v.deserializeClaims(ctx, token)
if err != nil {
return nil, fmt.Errorf("error getting the keys from the key func: %w", err)
}

claimDest := []interface{}{&jwt.Claims{}}
if v.customClaims != nil && v.customClaims() != nil {
claimDest = append(claimDest, v.customClaims())
return nil, fmt.Errorf("failed to deserialize token claims: %w", err)
}

if err = token.Claims(key, claimDest...); err != nil {
return nil, fmt.Errorf("could not get token claims: %w", err)
if err = validateClaimsWithLeeway(registeredClaims, v.expectedClaims, v.allowedClockSkew); err != nil {
return nil, fmt.Errorf("expected claims not validated: %w", err)
}

registeredClaims := *claimDest[0].(*jwt.Claims)
expectedClaims := v.expectedClaims
expectedClaims.Time = time.Now()
if err = registeredClaims.ValidateWithLeeway(expectedClaims, v.allowedClockSkew); err != nil {
return nil, fmt.Errorf("expected claims not validated: %w", err)
if customClaims != nil {
if err = customClaims.Validate(ctx); err != nil {
return nil, fmt.Errorf("custom claims not validated: %w", err)
}
}

validatedClaims := &ValidatedClaims{
RegisteredClaims: RegisteredClaims{
Issuer: registeredClaims.Issuer,
Subject: registeredClaims.Subject,
Audience: registeredClaims.Audience,
ID: registeredClaims.ID,
Issuer: registeredClaims.Issuer,
Subject: registeredClaims.Subject,
Audience: registeredClaims.Audience,
ID: registeredClaims.ID,
Expiry: numericDateToUnixTime(registeredClaims.Expiry),
NotBefore: numericDateToUnixTime(registeredClaims.NotBefore),
IssuedAt: numericDateToUnixTime(registeredClaims.IssuedAt),
},
CustomClaims: customClaims,
}

if registeredClaims.Expiry != nil {
validatedClaims.RegisteredClaims.Expiry = registeredClaims.Expiry.Time().Unix()
return validatedClaims, nil
}

func validateClaimsWithLeeway(actualClaims jwt.Claims, expected jwt.Expected, leeway time.Duration) error {
expectedClaims := expected
expectedClaims.Time = time.Now()

if actualClaims.Issuer != expectedClaims.Issuer {
return jwt.ErrInvalidIssuer
}

if registeredClaims.NotBefore != nil {
validatedClaims.RegisteredClaims.NotBefore = registeredClaims.NotBefore.Time().Unix()
foundAudience := false
for _, value := range expectedClaims.Audience {
if actualClaims.Audience.Contains(value) {
foundAudience = true
break
}
}
if !foundAudience {
return jwt.ErrInvalidAudience
}

if registeredClaims.IssuedAt != nil {
validatedClaims.RegisteredClaims.IssuedAt = registeredClaims.IssuedAt.Time().Unix()
if actualClaims.NotBefore != nil && expectedClaims.Time.Add(leeway).Before(actualClaims.NotBefore.Time()) {
return jwt.ErrNotValidYet
}

if v.customClaims != nil && v.customClaims() != nil {
validatedClaims.CustomClaims = claimDest[1].(CustomClaims)
if err = validatedClaims.CustomClaims.Validate(ctx); err != nil {
return nil, fmt.Errorf("custom claims not validated: %w", err)
}
if actualClaims.Expiry != nil && expectedClaims.Time.Add(-leeway).After(actualClaims.Expiry.Time()) {
return jwt.ErrExpired
}

return validatedClaims, nil
if actualClaims.IssuedAt != nil && expectedClaims.Time.Add(leeway).Before(actualClaims.IssuedAt.Time()) {
return jwt.ErrIssuedInTheFuture
}

return nil
}

func validateSigningMethod(validAlg, tokenAlg string) error {
if validAlg != tokenAlg {
return fmt.Errorf("expected %q signing algorithm but token specified %q", validAlg, tokenAlg)
}
return nil
}

func (v *Validator) customClaimsExist() bool {
return v.customClaims != nil && v.customClaims() != nil
}

func (v *Validator) deserializeClaims(ctx context.Context, token *jwt.JSONWebToken) (jwt.Claims, CustomClaims, error) {
key, err := v.keyFunc(ctx)
if err != nil {
return jwt.Claims{}, nil, fmt.Errorf("error getting the keys from the key func: %w", err)
}

claims := []interface{}{&jwt.Claims{}}
if v.customClaimsExist() {
claims = append(claims, v.customClaims())
}

if err = token.Claims(key, claims...); err != nil {
return jwt.Claims{}, nil, fmt.Errorf("could not get token claims: %w", err)
}

registeredClaims := *claims[0].(*jwt.Claims)

var customClaims CustomClaims
if len(claims) > 1 {
customClaims = claims[1].(CustomClaims)
}

return registeredClaims, customClaims, nil
}

func numericDateToUnixTime(date *jwt.NumericDate) int64 {
if date != nil {
return date.Time().Unix()
}
return 0
}
52 changes: 46 additions & 6 deletions validator/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package validator
import (
"context"
"errors"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2/jwt"
)

type testClaims struct {
Expand Down Expand Up @@ -77,7 +80,7 @@ func TestValidator_ValidateToken(t *testing.T) {
return []byte("secret"), nil
},
algorithm: RS256,
expectedError: errors.New(`expected "RS256" signing algorithm but token specified "HS256"`),
expectedError: errors.New(`signing method is invalid: expected "RS256" signing algorithm but token specified "HS256"`),
},
{
name: "it throws an error when it cannot parse the token",
Expand All @@ -95,7 +98,7 @@ func TestValidator_ValidateToken(t *testing.T) {
return nil, errors.New("key func error message")
},
algorithm: HS256,
expectedError: errors.New("error getting the keys from the key func: key func error message"),
expectedError: errors.New("failed to deserialize token claims: error getting the keys from the key func: key func error message"),
},
{
name: "it throws an error when it fails to deserialize the claims because the signature is invalid",
Expand All @@ -104,7 +107,7 @@ func TestValidator_ValidateToken(t *testing.T) {
return []byte("secret"), nil
},
algorithm: HS256,
expectedError: errors.New("could not get token claims: square/go-jose: error in cryptographic primitive"),
expectedError: errors.New("failed to deserialize token claims: could not get token claims: square/go-jose: error in cryptographic primitive"),
},
{
name: "it throws an error when it fails to validate the registered claims",
Expand Down Expand Up @@ -150,7 +153,7 @@ func TestValidator_ValidateToken(t *testing.T) {
},
{
name: "it successfully validates a token with exp, nbf and iat",
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6MTY2NjkzOTAwMCwiZXhwIjoxNjY3OTM3Njg2fQ.36iSr7w8Q6b9iJoJo-swmfgAfm23w8SlX92NHIHGX2s",
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6MTY2NjkzOTAwMCwiZXhwIjo5NjY3OTM3Njg2fQ.FKZogkm08gTfYfPU6eYu7OHCjJKnKGLiC0IfoIOPEhs",
keyFunc: func(context.Context) (interface{}, error) {
return []byte("secret"), nil
},
Expand All @@ -160,12 +163,48 @@ func TestValidator_ValidateToken(t *testing.T) {
Issuer: issuer,
Subject: subject,
Audience: []string{audience},
Expiry: 1667937686,
Expiry: 9667937686,
NotBefore: 1666939000,
IssuedAt: 1666937686,
},
},
},
{
name: "it throws an error when token is not valid yet",
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6OTY2NjkzOTAwMCwiZXhwIjoxNjY3OTM3Njg2fQ.yUizJ-zK_33tv1qBVvDKO0RuCWtvJ02UQKs8gBadgGY",
keyFunc: func(context.Context) (interface{}, error) {
return []byte("secret"), nil
},
algorithm: HS256,
expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrNotValidYet),
},
{
name: "it throws an error when token is expired",
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6MTY2NjkzOTAwMCwiZXhwIjo2Njc5Mzc2ODZ9.SKvz82VOXRi_sjvZWIsPG9vSWAXKKgVS4DkGZcwFKL8",
keyFunc: func(context.Context) (interface{}, error) {
return []byte("secret"), nil
},
algorithm: HS256,
expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrExpired),
},
{
name: "it throws an error when token is issued in the future",
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjkxNjY2OTM3Njg2LCJuYmYiOjE2NjY5MzkwMDAsImV4cCI6ODY2NzkzNzY4Nn0.ieFV7XNJxiJyw8ARq9yHw-01Oi02e3P2skZO10ypxL8",
keyFunc: func(context.Context) (interface{}, error) {
return []byte("secret"), nil
},
algorithm: HS256,
expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrIssuedInTheFuture),
},
{
name: "it throws an error when token issuer is invalid",
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2hhY2tlZC1qd3QtbWlkZGxld2FyZS5ldS5hdXRoMC5jb20vIiwic3ViIjoiMTIzNDU2Nzg5MCIsImF1ZCI6WyJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLWFwaS8iXSwiaWF0Ijo5MTY2NjkzNzY4NiwibmJmIjoxNjY2OTM5MDAwLCJleHAiOjg2Njc5Mzc2ODZ9.b5gXNrUNfd_jyCWZF-6IPK_UFfvTr9wBQk9_QgRQ8rA",
keyFunc: func(context.Context) (interface{}, error) {
return []byte("secret"), nil
},
algorithm: HS256,
expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrInvalidIssuer),
},
}

for _, testCase := range testCases {
Expand All @@ -177,8 +216,9 @@ func TestValidator_ValidateToken(t *testing.T) {
testCase.keyFunc,
testCase.algorithm,
issuer,
[]string{audience},
[]string{audience, "another-audience"},
WithCustomClaims(testCase.customClaims),
WithAllowedClockSkew(time.Second),
)
require.NoError(t, err)

Expand Down

0 comments on commit 33c3d1d

Please sign in to comment.