Skip to content

Commit

Permalink
fix(pkce): session generated needlessly (#12)
Browse files Browse the repository at this point in the history
This fixes an issue where the PKCE session is generated when not required. This also avoids a particular error that can occur in some situations.
  • Loading branch information
james-d-elliott committed Dec 21, 2023
1 parent 327810e commit dbdadf5
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 128 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/format.yml
Expand Up @@ -2,7 +2,11 @@ name: Format

on:
pull_request:
branches:
- master
push:
branches:
- master

jobs:
format:
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/test.yml
Expand Up @@ -2,7 +2,11 @@ name: Unit tests

on:
pull_request:
branches:
- master
push:
branches:
- master

jobs:
test:
Expand Down
120 changes: 78 additions & 42 deletions handler/pkce/handler.go
Expand Up @@ -6,6 +6,7 @@ package pkce
import (
"context"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"regexp"

Expand Down Expand Up @@ -33,55 +34,52 @@ var (

var verifierWrongFormat = regexp.MustCompile(`[^\w.~-]`)

func (c *Handler) HandleAuthorizeEndpointRequest(ctx context.Context, ar oauth2.AuthorizeRequester, resp oauth2.AuthorizeResponder) error {
// This let's us define multiple response types, for example open id connect's id_token
if !ar.GetResponseTypes().Has(consts.ResponseTypeAuthorizationCodeFlow) {
func (c *Handler) HandleAuthorizeEndpointRequest(ctx context.Context, requester oauth2.AuthorizeRequester, responder oauth2.AuthorizeResponder) error {
// This let's us define multiple response types, for example the OpenID Connect 1.0 `id_token`.
if !requester.GetResponseTypes().Has(consts.ResponseTypeAuthorizationCodeFlow) {
return nil
}

challenge := ar.GetRequestForm().Get(consts.FormParameterCodeChallenge)
method := ar.GetRequestForm().Get(consts.FormParameterCodeChallengeMethod)
client := ar.GetClient()
challenge := requester.GetRequestForm().Get(consts.FormParameterCodeChallenge)
method := requester.GetRequestForm().Get(consts.FormParameterCodeChallengeMethod)
client := requester.GetClient()

if err := c.validate(ctx, challenge, method, client); err != nil {
return err
}

code := resp.GetCode()
// We don't need a session if it's not enforced and the PKCE parameters are not provided by the client.
if challenge == "" && method == "" {
return nil
}

code := responder.GetCode()

if len(code) == 0 {
return errorsx.WithStack(oauth2.ErrServerError.WithDebug("The PKCE handler must be loaded after the authorize code handler."))
}

signature := c.AuthorizeCodeStrategy.AuthorizeCodeSignature(ctx, code)
if err := c.Storage.CreatePKCERequestSession(ctx, signature, ar.Sanitize([]string{

if err := c.Storage.CreatePKCERequestSession(ctx, signature, requester.Sanitize([]string{
consts.FormParameterCodeChallenge,
consts.FormParameterCodeChallengeMethod,
})); err != nil {
return errorsx.WithStack(oauth2.ErrServerError.WithWrap(err).WithDebug(err.Error()))
return errorsx.WithStack(oauth2.ErrServerError.WithWrap(err).WithDebug(oauth2.ErrorToRFC6749Error(err).Error()))
}

return nil
}

func (c *Handler) validate(ctx context.Context, challenge, method string, client oauth2.Client) error {
if challenge == "" {
if len(challenge) == 0 {
// If the server requires Proof Key for Code Exchange (PKCE) by OAuth
// clients and the client does not send the "code_challenge" in
// the request, the authorization endpoint MUST return the authorization
// error response with the "error" value set to "invalid_request". The
// "error_description" or the response of "error_uri" SHOULD explain the
// nature of error, e.g., code challenge required.
if c.Config.GetEnforcePKCE(ctx) {
return errorsx.WithStack(oauth2.ErrInvalidRequest.
WithHint("Clients must include a code_challenge when performing the authorize code flow, but it is missing.").
WithDebug("The server is configured in a way that enforces PKCE for clients."))
}
if c.Config.GetEnforcePKCEForPublicClients(ctx) && client.IsPublic() {
return errorsx.WithStack(oauth2.ErrInvalidRequest.
WithHint("This client must include a code_challenge when performing the authorize code flow, but it is missing.").
WithDebug("The server is configured in a way that enforces PKCE for this client."))
}
return nil
return c.validateNoPKCE(ctx, client)
}

// If the server supporting PKCE does not support the requested
Expand All @@ -105,11 +103,12 @@ func (c *Handler) validate(ctx context.Context, challenge, method string, client
return errorsx.WithStack(oauth2.ErrInvalidRequest.
WithHint("The code_challenge_method is not supported, use S256 instead."))
}

return nil
}

func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, request oauth2.AccessRequester) error {
if !c.CanHandleTokenEndpointRequest(ctx, request) {
func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, requester oauth2.AccessRequester) error {
if !c.CanHandleTokenEndpointRequest(ctx, requester) {
return errorsx.WithStack(oauth2.ErrUnknownRequest)
}

Expand All @@ -119,29 +118,39 @@ func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, request oauth2
// The "code_challenge_method" is bound to the Authorization Code when
// the Authorization Code is issued. That is the method that the token
// endpoint MUST use to verify the "code_verifier".
verifier := request.GetRequestForm().Get(consts.FormParameterCodeVerifier)
verifier := requester.GetRequestForm().Get(consts.FormParameterCodeVerifier)

code := request.GetRequestForm().Get(consts.FormParameterAuthorizationCode)
code := requester.GetRequestForm().Get(consts.FormParameterAuthorizationCode)
signature := c.AuthorizeCodeStrategy.AuthorizeCodeSignature(ctx, code)
authorizeRequest, err := c.Storage.GetPKCERequestSession(ctx, signature, request.GetSession())
requesterPKCE, err := c.Storage.GetPKCERequestSession(ctx, signature, requester.GetSession())

nv := len(verifier)

if errors.Is(err, oauth2.ErrNotFound) {
return errorsx.WithStack(oauth2.ErrInvalidGrant.WithHint("Unable to find initial PKCE data tied to this request").WithWrap(err).WithDebug(err.Error()))
if nv == 0 {
return c.validateNoPKCE(ctx, requester.GetClient())
}

return errorsx.WithStack(oauth2.ErrInvalidGrant.WithHint("Unable to find initial PKCE data tied to this request.").WithWrap(err).WithDebug(oauth2.ErrorToDebugRFC6749Error(err).Error()))
} else if err != nil {
return errorsx.WithStack(oauth2.ErrServerError.WithWrap(err).WithDebug(err.Error()))
return errorsx.WithStack(oauth2.ErrServerError.WithWrap(err).WithDebug(oauth2.ErrorToDebugRFC6749Error(err).Error()))
}

if err = c.Storage.DeletePKCERequestSession(ctx, signature); err != nil {
return errorsx.WithStack(oauth2.ErrServerError.WithWrap(err).WithDebug(err.Error()))
return errorsx.WithStack(oauth2.ErrServerError.WithWrap(err).WithDebug(oauth2.ErrorToDebugRFC6749Error(err).Error()))
}

challenge := authorizeRequest.GetRequestForm().Get(consts.FormParameterCodeChallenge)
method := authorizeRequest.GetRequestForm().Get(consts.FormParameterCodeChallengeMethod)
client := authorizeRequest.GetClient()
challenge := requesterPKCE.GetRequestForm().Get(consts.FormParameterCodeChallenge)
method := requesterPKCE.GetRequestForm().Get(consts.FormParameterCodeChallengeMethod)
client := requesterPKCE.GetClient()

if err = c.validate(ctx, challenge, method, client); err != nil {
return err
}

if !c.Config.GetEnforcePKCE(ctx) && challenge == "" && verifier == "" {
nc := len(challenge)

if !c.Config.GetEnforcePKCE(ctx) && nc == 0 && nv == 0 {
return nil
}

Expand All @@ -152,13 +161,17 @@ func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, request oauth2
// 43-octet URL safe string to use as the code verifier.

// Validation
if len(verifier) < 43 {
switch {
case nv < 43:
return errorsx.WithStack(oauth2.ErrInvalidGrant.
WithHint("The PKCE code verifier must be at least 43 characters."))
} else if len(verifier) > 128 {
case nv > 128:
return errorsx.WithStack(oauth2.ErrInvalidGrant.
WithHint("The PKCE code verifier can not be longer than 128 characters."))
} else if verifierWrongFormat.MatchString(verifier) {
case nc == 0:
return errorsx.WithStack(oauth2.ErrInvalidGrant.
WithHint("The PKCE code verifier was provided but the code challenge was absent from the authorization request."))
case verifierWrongFormat.MatchString(verifier):
return errorsx.WithStack(oauth2.ErrInvalidGrant.
WithHint("The PKCE code verifier must only contain [a-Z], [0-9], '-', '.', '_', '~'."))
}
Expand Down Expand Up @@ -186,19 +199,20 @@ func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, request oauth2
// Section 5.2 of [RFC6749] MUST be returned.
switch method {
case consts.PKCEChallengeMethodSHA256:
hash := sha256.New()
if _, err = hash.Write([]byte(verifier)); err != nil {
return errorsx.WithStack(oauth2.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}
sum := sha256.Sum256([]byte(verifier))

expected := make([]byte, base64.RawURLEncoding.EncodedLen(len(sum)))

if base64.RawURLEncoding.EncodeToString(hash.Sum([]byte{})) != challenge {
base64.RawURLEncoding.Strict().Encode(expected, sum[:])

if subtle.ConstantTimeCompare(expected, []byte(challenge)) == 0 {
return errorsx.WithStack(oauth2.ErrInvalidGrant.
WithHint("The PKCE code challenge did not match the code verifier."))
}
case consts.PKCEChallengeMethodPlain:
fallthrough
default:
if verifier != challenge {
if subtle.ConstantTimeCompare([]byte(verifier), []byte(challenge)) == 0 {
return errorsx.WithStack(oauth2.ErrInvalidGrant.
WithHint("The PKCE code challenge did not match the code verifier."))
}
Expand All @@ -220,3 +234,25 @@ func (c *Handler) CanHandleTokenEndpointRequest(ctx context.Context, requester o
// Value MUST be set to "authorization_code"
return requester.GetGrantTypes().ExactOne(consts.GrantTypeAuthorizationCode)
}

func (c *Handler) validateNoPKCE(ctx context.Context, client oauth2.Client) error {
if c.Config.GetEnforcePKCE(ctx) {
return errorsx.WithStack(oauth2.ErrInvalidRequest.
WithHint("Clients must include a code_challenge when performing the authorize code flow, but it is missing.").
WithDebug("The server is configured in a way that enforces PKCE for clients."))
}

if c.Config.GetEnforcePKCEForPublicClients(ctx) {
if client == nil {
return errorsx.WithStack(oauth2.ErrServerError.WithDebug("The client for the request wasn't properly loaded."))
}

if client.IsPublic() {
return errorsx.WithStack(oauth2.ErrInvalidRequest.
WithHint("This client must include a code_challenge when performing the authorize code flow, but it is missing.").
WithDebug("The server is configured in a way that enforces PKCE for this client."))
}
}

return nil
}

0 comments on commit dbdadf5

Please sign in to comment.