Skip to content

Commit

Permalink
Merge branch 'master' into multiple_jwks
Browse files Browse the repository at this point in the history
# Conflicts:
#	README.md
#	options.go
  • Loading branch information
MicahParks committed Dec 20, 2022
2 parents 6739ca5 + fb3c60d commit 9ce014e
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 24 deletions.
2 changes: 2 additions & 0 deletions README.md
Expand Up @@ -170,6 +170,8 @@ These features can be configured by populating fields in the
* Custom cryptographic algorithms can be used. Make sure to
use [`jwt.RegisterSigningMethod`](https://pkg.go.dev/github.com/golang-jwt/jwt/v4#RegisterSigningMethod) before
parsing JWTs. For an example, see the `examples/custom` directory.
* The remote JWKS resource can be refreshed manually using the `.Refresh` method. This can bypass the rate limit, if the
option is set.
* There is support for creating one `jwt.Keyfunc` from multiple JWK Sets through the use of the `keyfunc.GetMultiple`.

## Notes
Expand Down
86 changes: 64 additions & 22 deletions get.go
Expand Up @@ -3,13 +3,18 @@ package keyfunc
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"sync"
"time"
)

var (
// ErrRefreshImpossible is returned when a refresh is attempted on a JWKS that was not created from a remote
// resource.
ErrRefreshImpossible = errors.New("refresh impossible: JWKS was not created from a remote resource")

// defaultRefreshTimeout is the default duration for the context used to create the HTTP request for a refresh of
// the JWKS.
defaultRefreshTimeout = time.Minute
Expand Down Expand Up @@ -49,13 +54,53 @@ func Get(jwksURL string, options Options) (jwks *JWKS, err error) {

if jwks.refreshInterval != 0 || jwks.refreshUnknownKID {
jwks.ctx, jwks.cancel = context.WithCancel(context.Background())
jwks.refreshRequests = make(chan context.CancelFunc, 1)
jwks.refreshRequests = make(chan refreshRequest, 1)
go jwks.backgroundRefresh()
}

return jwks, nil
}

// Refresh manually refreshes the JWKS with the remote resource. It can bypass the rate limit if configured to do so.
// This function will return an ErrRefreshImpossible if the JWKS was created from a static source like given keys or raw
// JSON, because there is no remote resource to refresh from.
//
// This function will block until the refresh is finished or an error occurs.
func (j *JWKS) Refresh(ctx context.Context, options RefreshOptions) error {
if j.jwksURL == "" {
return ErrRefreshImpossible
}

// Check if the background goroutine was launched.
if j.refreshInterval != 0 || j.refreshUnknownKID {
ctx, cancel := context.WithCancel(ctx)

req := refreshRequest{
cancel: cancel,
ignoreRateLimit: options.IgnoreRateLimit,
}

select {
case <-ctx.Done():
return fmt.Errorf("failed to send request refresh to background goroutine: %w", j.ctx.Err())
case j.refreshRequests <- req:
}

<-ctx.Done()

if !errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("unexpected keyfunc background refresh context error: %w", ctx.Err())
}
} else {
err := j.refresh()
if err != nil {
return fmt.Errorf("failed to refresh JWKS: %w", err)
}
}

return nil
}

// backgroundRefresh is meant to be a separate goroutine that will update the keys in a JWKS over a given interval of
// time.
func (j *JWKS) backgroundRefresh() {
Expand All @@ -69,6 +114,14 @@ func (j *JWKS) backgroundRefresh() {
// Create a channel that will never send anything unless there is a refresh interval.
refreshInterval := make(<-chan time.Time)

refresh := func() {
err := j.refresh()
if err != nil && j.refreshErrorHandler != nil {
j.refreshErrorHandler(err)
}
lastRefresh = time.Now()
}

// Enter an infinite loop that ends when the background ends.
for {
if j.refreshInterval != 0 {
Expand All @@ -80,16 +133,15 @@ func (j *JWKS) backgroundRefresh() {
select {
case <-j.ctx.Done():
return
case j.refreshRequests <- func() {}:
case j.refreshRequests <- refreshRequest{}:
default: // If the j.refreshRequests channel is full, don't send another request.
}

case cancel := <-j.refreshRequests:
case req := <-j.refreshRequests:
refreshMux.Lock()
if j.refreshRateLimit != 0 && lastRefresh.Add(j.refreshRateLimit).After(time.Now()) {
// Don't make the JWT parsing goroutine wait for the JWKS to refresh.
cancel()

if req.ignoreRateLimit {
refresh()
} else if j.refreshRateLimit != 0 && lastRefresh.Add(j.refreshRateLimit).After(time.Now()) {
// Launch a goroutine that will get a reservation for a JWKS refresh or fail to and immediately return.
queueOnce.Do(func() {
go func() {
Expand All @@ -104,25 +156,15 @@ func (j *JWKS) backgroundRefresh() {

refreshMux.Lock()
defer refreshMux.Unlock()
err := j.refresh()
if err != nil && j.refreshErrorHandler != nil {
j.refreshErrorHandler(err)
}

lastRefresh = time.Now()
refresh()
queueOnce = sync.Once{}
}()
})
} else {
err := j.refresh()
if err != nil && j.refreshErrorHandler != nil {
j.refreshErrorHandler(err)
}

lastRefresh = time.Now()

// Allow the JWT parsing goroutine to continue with the refreshed JWKS.
cancel()
refresh()
}
if req.cancel != nil {
req.cancel()
}
refreshMux.Unlock()

Expand Down
82 changes: 82 additions & 0 deletions get_test.go
@@ -0,0 +1,82 @@
package keyfunc_test

import (
"context"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"

"github.com/MicahParks/keyfunc"
)

func TestJWKS_Refresh(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

var counter uint64
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddUint64(&counter, 1)
_, err := w.Write([]byte(jwksJSON))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
defer server.Close()

jwksURL := server.URL
opts := keyfunc.Options{
Ctx: ctx,
}
jwks, err := keyfunc.Get(jwksURL, opts)
if err != nil {
t.Fatalf(logFmt, "Failed to get JWKS from testing URL.", err)
}

err = jwks.Refresh(ctx, keyfunc.RefreshOptions{IgnoreRateLimit: true})
if err != nil {
t.Fatalf(logFmt, "Failed to refresh JWKS.", err)
}

count := atomic.LoadUint64(&counter)
if count != 2 {
t.Fatalf("Expected 2 refreshes, got %d.", count)
}
}

func TestJWKS_RefreshUsingBackgroundGoroutine(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

var counter uint64
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddUint64(&counter, 1)
_, err := w.Write([]byte(jwksJSON))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
defer server.Close()

jwksURL := server.URL
opts := keyfunc.Options{
Ctx: ctx,
RefreshInterval: time.Hour,
RefreshRateLimit: time.Hour,
}
jwks, err := keyfunc.Get(jwksURL, opts)
if err != nil {
t.Fatalf(logFmt, "Failed to get JWKS from testing URL.", err)
}

err = jwks.Refresh(ctx, keyfunc.RefreshOptions{IgnoreRateLimit: true})
if err != nil {
t.Fatalf(logFmt, "Failed to refresh JWKS.", err)
}

count := atomic.LoadUint64(&counter)
if count != 2 {
t.Fatalf("Expected 2 refreshes, got %d.", count)
}
}
7 changes: 5 additions & 2 deletions jwks.go
Expand Up @@ -77,7 +77,7 @@ type JWKS struct {
refreshErrorHandler ErrorHandler
refreshInterval time.Duration
refreshRateLimit time.Duration
refreshRequests chan context.CancelFunc
refreshRequests chan refreshRequest
refreshTimeout time.Duration
refreshUnknownKID bool
requestFactory func(ctx context.Context, url string) (*http.Request, error)
Expand Down Expand Up @@ -199,12 +199,15 @@ func (j *JWKS) getKey(alg, kid string) (jsonKey interface{}, err error) {
}

ctx, cancel := context.WithCancel(j.ctx)
req := refreshRequest{
cancel: cancel,
}

// Refresh the JWKS.
select {
case <-j.ctx.Done():
return
case j.refreshRequests <- cancel:
case j.refreshRequests <- req:
default:
// If the j.refreshRequests channel is full, return the error early.
return nil, ErrKIDNotFound
Expand Down
10 changes: 10 additions & 0 deletions options.go
Expand Up @@ -93,6 +93,16 @@ type MultipleOptions struct {
KeySelector func(multiJWKS *MultipleJWKS, token *jwt.Token) (key interface{}, err error)
}

// RefreshOptions are used to specify manual refresh behavior.
type RefreshOptions struct {
IgnoreRateLimit bool
}

type refreshRequest struct {
cancel context.CancelFunc
ignoreRateLimit bool
}

// ResponseExtractorStatusOK is meant to be used as the ResponseExtractor field for Options. It confirms that response
// status code is 200 OK and returns the raw JSON from the response body.
func ResponseExtractorStatusOK(ctx context.Context, resp *http.Response) (json.RawMessage, error) {
Expand Down

0 comments on commit 9ce014e

Please sign in to comment.