Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace *jwt.validateClaimsWithLeeway with custom validation func #176

Merged
merged 3 commits into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ jobs:
uses: codecov/codecov-action@v3
with:
files: coverage.out
fail_ci_if_error: true
fail_ci_if_error: false
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick toggle to not fail the CI build if we fail to upload codecov.

verbose: true
4 changes: 2 additions & 2 deletions examples/http-jwks-example/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ func TestHandler(t *testing.T) {
t.Fatal(err)
}

token := buildJWTForTesting(t, jwk, testServer.URL, test.subject, []string{})
token := buildJWTForTesting(t, jwk, testServer.URL, test.subject, []string{"my-audience"})
req.Header.Set("Authorization", "Bearer "+token)

rr := httptest.NewRecorder()

mainHandler := setupHandler(testServer.URL, []string{})
mainHandler := setupHandler(testServer.URL, []string{"my-audience"})
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was incorrectly set to validate no audience.

mainHandler.ServeHTTP(rr, req)

if want, got := test.wantStatusCode, rr.Code; want != got {
Expand Down
124 changes: 88 additions & 36 deletions validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,62 +99,114 @@ func (v *Validator) ValidateToken(ctx context.Context, tokenString string) (inte
return nil, fmt.Errorf("could not parse the token: %w", err)
}

if string(v.signatureAlgorithm) != token.Headers[0].Algorithm {
return nil, fmt.Errorf(
"expected %q signing algorithm but token specified %q",
v.signatureAlgorithm,
token.Headers[0].Algorithm,
)
if err = validateSigningMethod(string(v.signatureAlgorithm), token.Headers[0].Algorithm); err != nil {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to refactor this func so that the steps read better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👏🏼

return nil, fmt.Errorf("signing method is invalid: %w", err)
}

key, err := v.keyFunc(ctx)
registeredClaims, customClaims, err := v.deserializeClaims(ctx, token)
if err != nil {
return nil, fmt.Errorf("error getting the keys from the key func: %w", err)
}

claimDest := []interface{}{&jwt.Claims{}}
if v.customClaims != nil && v.customClaims() != nil {
claimDest = append(claimDest, v.customClaims())
return nil, fmt.Errorf("failed to deserialize token claims: %w", err)
}

if err = token.Claims(key, claimDest...); err != nil {
return nil, fmt.Errorf("could not get token claims: %w", err)
if err = validateClaimsWithLeeway(registeredClaims, v.expectedClaims, v.allowedClockSkew); err != nil {
return nil, fmt.Errorf("expected claims not validated: %w", err)
}

registeredClaims := *claimDest[0].(*jwt.Claims)
expectedClaims := v.expectedClaims
expectedClaims.Time = time.Now()
if err = registeredClaims.ValidateWithLeeway(expectedClaims, v.allowedClockSkew); err != nil {
return nil, fmt.Errorf("expected claims not validated: %w", err)
if customClaims != nil {
if err = customClaims.Validate(ctx); err != nil {
return nil, fmt.Errorf("custom claims not validated: %w", err)
}
}

validatedClaims := &ValidatedClaims{
RegisteredClaims: RegisteredClaims{
Issuer: registeredClaims.Issuer,
Subject: registeredClaims.Subject,
Audience: registeredClaims.Audience,
ID: registeredClaims.ID,
Issuer: registeredClaims.Issuer,
Subject: registeredClaims.Subject,
Audience: registeredClaims.Audience,
ID: registeredClaims.ID,
Expiry: numericDateToUnixTime(registeredClaims.Expiry),
NotBefore: numericDateToUnixTime(registeredClaims.NotBefore),
IssuedAt: numericDateToUnixTime(registeredClaims.IssuedAt),
},
CustomClaims: customClaims,
}

if registeredClaims.Expiry != nil {
validatedClaims.RegisteredClaims.Expiry = registeredClaims.Expiry.Time().Unix()
return validatedClaims, nil
}

func validateClaimsWithLeeway(actualClaims jwt.Claims, expected jwt.Expected, leeway time.Duration) error {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@Widcket Widcket Oct 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also include the sub and jti validation?

Copy link
Contributor Author

@sergiught sergiught Nov 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although that original func does include them, we were never actually using them for the validation. We're only setting as expected the audience, issuer and the time based claims.

Re:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that in other SDKs we do validate the sub claim: https://github.com/auth0/Auth0.swift/blob/master/Auth0/IDTokenValidator.swift#L56

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can add that validation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(if applicable, of course).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately we can't introduce that right now without a breaking change. The expectedClaims are a private internal property inside the validator and we set it with the passed in issuer and audience when constructing the validator. We'll have to change this func https://github.com/auth0/go-jwt-middleware/blob/master/validator/validator.go#L59 to allow for the subject as well. I'll note this down for a potential upcoming v3 perhaps, wdyt?

expectedClaims := expected
expectedClaims.Time = time.Now()

if actualClaims.Issuer != expectedClaims.Issuer {
return jwt.ErrInvalidIssuer
}

if registeredClaims.NotBefore != nil {
validatedClaims.RegisteredClaims.NotBefore = registeredClaims.NotBefore.Time().Unix()
foundAudience := false
for _, value := range expectedClaims.Audience {
if actualClaims.Audience.Contains(value) {
foundAudience = true
break
}
}
if !foundAudience {
return jwt.ErrInvalidAudience
Comment on lines +145 to +153
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mainly the fix for the multiple audiences described in #148

}

if registeredClaims.IssuedAt != nil {
validatedClaims.RegisteredClaims.IssuedAt = registeredClaims.IssuedAt.Time().Unix()
if actualClaims.NotBefore != nil && expectedClaims.Time.Add(leeway).Before(actualClaims.NotBefore.Time()) {
return jwt.ErrNotValidYet
}

if v.customClaims != nil && v.customClaims() != nil {
validatedClaims.CustomClaims = claimDest[1].(CustomClaims)
if err = validatedClaims.CustomClaims.Validate(ctx); err != nil {
return nil, fmt.Errorf("custom claims not validated: %w", err)
}
if actualClaims.Expiry != nil && expectedClaims.Time.Add(-leeway).After(actualClaims.Expiry.Time()) {
return jwt.ErrExpired
}

return validatedClaims, nil
if actualClaims.IssuedAt != nil && expectedClaims.Time.Add(leeway).Before(actualClaims.IssuedAt.Time()) {
return jwt.ErrIssuedInTheFuture
}

return nil
}

func validateSigningMethod(validAlg, tokenAlg string) error {
if validAlg != tokenAlg {
return fmt.Errorf("expected %q signing algorithm but token specified %q", validAlg, tokenAlg)
}
return nil
}

func (v *Validator) customClaimsExist() bool {
return v.customClaims != nil && v.customClaims() != nil
}

func (v *Validator) deserializeClaims(ctx context.Context, token *jwt.JSONWebToken) (jwt.Claims, CustomClaims, error) {
key, err := v.keyFunc(ctx)
if err != nil {
return jwt.Claims{}, nil, fmt.Errorf("error getting the keys from the key func: %w", err)
}

claims := []interface{}{&jwt.Claims{}}
if v.customClaimsExist() {
claims = append(claims, v.customClaims())
}

if err = token.Claims(key, claims...); err != nil {
return jwt.Claims{}, nil, fmt.Errorf("could not get token claims: %w", err)
}

registeredClaims := *claims[0].(*jwt.Claims)

var customClaims CustomClaims
if len(claims) > 1 {
customClaims = claims[1].(CustomClaims)
}

return registeredClaims, customClaims, nil
}

func numericDateToUnixTime(date *jwt.NumericDate) int64 {
if date != nil {
return date.Time().Unix()
}
return 0
}
52 changes: 46 additions & 6 deletions validator/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package validator
import (
"context"
"errors"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2/jwt"
)

type testClaims struct {
Expand Down Expand Up @@ -77,7 +80,7 @@ func TestValidator_ValidateToken(t *testing.T) {
return []byte("secret"), nil
},
algorithm: RS256,
expectedError: errors.New(`expected "RS256" signing algorithm but token specified "HS256"`),
expectedError: errors.New(`signing method is invalid: expected "RS256" signing algorithm but token specified "HS256"`),
},
{
name: "it throws an error when it cannot parse the token",
Expand All @@ -95,7 +98,7 @@ func TestValidator_ValidateToken(t *testing.T) {
return nil, errors.New("key func error message")
},
algorithm: HS256,
expectedError: errors.New("error getting the keys from the key func: key func error message"),
expectedError: errors.New("failed to deserialize token claims: error getting the keys from the key func: key func error message"),
},
{
name: "it throws an error when it fails to deserialize the claims because the signature is invalid",
Expand All @@ -104,7 +107,7 @@ func TestValidator_ValidateToken(t *testing.T) {
return []byte("secret"), nil
},
algorithm: HS256,
expectedError: errors.New("could not get token claims: square/go-jose: error in cryptographic primitive"),
expectedError: errors.New("failed to deserialize token claims: could not get token claims: square/go-jose: error in cryptographic primitive"),
},
{
name: "it throws an error when it fails to validate the registered claims",
Expand Down Expand Up @@ -150,7 +153,7 @@ func TestValidator_ValidateToken(t *testing.T) {
},
{
name: "it successfully validates a token with exp, nbf and iat",
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6MTY2NjkzOTAwMCwiZXhwIjoxNjY3OTM3Njg2fQ.36iSr7w8Q6b9iJoJo-swmfgAfm23w8SlX92NHIHGX2s",
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6MTY2NjkzOTAwMCwiZXhwIjo5NjY3OTM3Njg2fQ.FKZogkm08gTfYfPU6eYu7OHCjJKnKGLiC0IfoIOPEhs",
keyFunc: func(context.Context) (interface{}, error) {
return []byte("secret"), nil
},
Expand All @@ -160,12 +163,48 @@ func TestValidator_ValidateToken(t *testing.T) {
Issuer: issuer,
Subject: subject,
Audience: []string{audience},
Expiry: 1667937686,
Expiry: 9667937686,
NotBefore: 1666939000,
IssuedAt: 1666937686,
},
},
},
{
name: "it throws an error when token is not valid yet",
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6OTY2NjkzOTAwMCwiZXhwIjoxNjY3OTM3Njg2fQ.yUizJ-zK_33tv1qBVvDKO0RuCWtvJ02UQKs8gBadgGY",
keyFunc: func(context.Context) (interface{}, error) {
return []byte("secret"), nil
},
algorithm: HS256,
expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrNotValidYet),
},
{
name: "it throws an error when token is expired",
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6MTY2NjkzOTAwMCwiZXhwIjo2Njc5Mzc2ODZ9.SKvz82VOXRi_sjvZWIsPG9vSWAXKKgVS4DkGZcwFKL8",
keyFunc: func(context.Context) (interface{}, error) {
return []byte("secret"), nil
},
algorithm: HS256,
expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrExpired),
},
{
name: "it throws an error when token is issued in the future",
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjkxNjY2OTM3Njg2LCJuYmYiOjE2NjY5MzkwMDAsImV4cCI6ODY2NzkzNzY4Nn0.ieFV7XNJxiJyw8ARq9yHw-01Oi02e3P2skZO10ypxL8",
keyFunc: func(context.Context) (interface{}, error) {
return []byte("secret"), nil
},
algorithm: HS256,
expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrIssuedInTheFuture),
},
{
name: "it throws an error when token issuer is invalid",
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2hhY2tlZC1qd3QtbWlkZGxld2FyZS5ldS5hdXRoMC5jb20vIiwic3ViIjoiMTIzNDU2Nzg5MCIsImF1ZCI6WyJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLWFwaS8iXSwiaWF0Ijo5MTY2NjkzNzY4NiwibmJmIjoxNjY2OTM5MDAwLCJleHAiOjg2Njc5Mzc2ODZ9.b5gXNrUNfd_jyCWZF-6IPK_UFfvTr9wBQk9_QgRQ8rA",
keyFunc: func(context.Context) (interface{}, error) {
return []byte("secret"), nil
},
algorithm: HS256,
expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrInvalidIssuer),
},
}

for _, testCase := range testCases {
Expand All @@ -177,8 +216,9 @@ func TestValidator_ValidateToken(t *testing.T) {
testCase.keyFunc,
testCase.algorithm,
issuer,
[]string{audience},
[]string{audience, "another-audience"},
WithCustomClaims(testCase.customClaims),
WithAllowedClockSkew(time.Second),
)
require.NoError(t, err)

Expand Down