diff --git a/examples/go.mod b/examples/go.mod index 4fe218e..2130016 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -1,14 +1,10 @@ module github.com/auth0/go-jwt-middleware/examples -go 1.14 +go 1.16 require ( github.com/auth0/go-jwt-middleware v0.0.0 - github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0 // indirect - github.com/form3tech-oss/jwt-go v3.2.2+incompatible - github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab - github.com/gorilla/mux v1.7.4 - github.com/urfave/negroni v1.0.0 + gopkg.in/square/go-jose.v2 v2.5.1 ) -replace github.com/auth0/go-jwt-middleware => ../ +replace github.com/auth0/go-jwt-middleware => ./../ diff --git a/examples/go.sum b/examples/go.sum index 8df664e..785c4eb 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -1,25 +1,26 @@ -github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0 h1:sDMmm+q/3+BukdIpxwO365v/Rbspp2Nt5XntgQRXq8Q= -github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0/go.mod h1:4Zcjuz89kmFXt9morQgcfYZAYZ5n8WHjt81YYWIwtTM= -github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= -github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= -github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab h1:xveKWz2iaueeTaUgdetzel+U7exyigDYBryyVfV/rZk= -github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AEU963A2AYjv4d1V5eVL1CQbEJq6aCNHDDjibzu8= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 h1:l5lAOZEym3oK3SQ2HBHWsJUfbNBiTXJDeW2QDxw9AQ0= -github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gorilla/mux v1.7.4 h1:VuZ8uybHlWmqV03+zRzdwKL4tUnIp1MAQtp1mIFE1bc= -github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= -github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= -github.com/smartystreets/assertions v1.1.0 h1:MkTeG1DMwsrdH7QtLXy5W+fUxWq+vmb6cLmyJ7aRtF0= -github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= -github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= -github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= -github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc= -github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v4 v4.1.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= +github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w= +gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/examples/http-jwks-example/main.go b/examples/http-jwks-example/main.go index 77ca707..b3386c8 100644 --- a/examples/http-jwks-example/main.go +++ b/examples/http-jwks-example/main.go @@ -32,7 +32,7 @@ func main() { log.Fatalf("failed to parse the issuer url: %v", err) } - provider := josev2.NewCachingJWKSProvider(*issuerURL, 5*time.Minute) + provider := josev2.NewCachingJWKSProvider(issuerURL, 5*time.Minute) // Set up the josev2 validator. validator, err := josev2.New( diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index aee4081..dc02f7c 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -15,16 +15,19 @@ type WellKnownEndpoints struct { } // GetWellKnownEndpointsFromIssuerURL gets the well known endpoints for the passed in issuer url. -func GetWellKnownEndpointsFromIssuerURL(ctx context.Context, issuerURL url.URL) (*WellKnownEndpoints, error) { +func GetWellKnownEndpointsFromIssuerURL( + ctx context.Context, + httpClient *http.Client, + issuerURL url.URL, +) (*WellKnownEndpoints, error) { issuerURL.Path = path.Join(issuerURL.Path, ".well-known/openid-configuration") - request, err := http.NewRequest(http.MethodGet, issuerURL.String(), nil) + request, err := http.NewRequestWithContext(ctx, http.MethodGet, issuerURL.String(), nil) if err != nil { return nil, fmt.Errorf("could not build request to get well known endpoints: %w", err) } - request = request.WithContext(ctx) - response, err := http.DefaultClient.Do(request) + response, err := httpClient.Do(request) if err != nil { return nil, fmt.Errorf("could not get well known endpoints from url %s: %w", issuerURL.String(), err) } diff --git a/validate/josev2/jwks_provider.go b/validate/josev2/jwks_provider.go index 22e10e6..e503dc4 100644 --- a/validate/josev2/jwks_provider.go +++ b/validate/josev2/jwks_provider.go @@ -20,35 +20,60 @@ import ( // getting and caching JWKS which can help reduce request time and potential // rate limiting from your provider. type JWKSProvider struct { - IssuerURL url.URL + IssuerURL *url.URL // Required. + CustomJWKSURI *url.URL // Optional. + Client *http.Client } +// ProviderOption is how options for the JWKSProvider are set up. +type ProviderOption func(*JWKSProvider) + // NewJWKSProvider builds and returns a new *JWKSProvider. -func NewJWKSProvider(issuerURL url.URL) *JWKSProvider { - return &JWKSProvider{IssuerURL: issuerURL} +func NewJWKSProvider(issuerURL *url.URL, opts ...ProviderOption) *JWKSProvider { + p := &JWKSProvider{ + IssuerURL: issuerURL, + Client: &http.Client{}, + } + + for _, opt := range opts { + opt(p) + } + + return p +} + +// WithCustomJWKSURI will set a custom JWKS URI on the *JWKSProvider and +// call this directly inside the keyFunc in order to fetch the JWKS, +// skipping the oidc.GetWellKnownEndpointsFromIssuerURL call. +func WithCustomJWKSURI(jwksURI *url.URL) ProviderOption { + return func(p *JWKSProvider) { + p.CustomJWKSURI = jwksURI + } } // KeyFunc adheres to the keyFunc signature that the Validator requires. // While it returns an interface to adhere to keyFunc, as long as the // error is nil the type will be *jose.JSONWebKeySet. func (p *JWKSProvider) KeyFunc(ctx context.Context) (interface{}, error) { - wkEndpoints, err := oidc.GetWellKnownEndpointsFromIssuerURL(ctx, p.IssuerURL) - if err != nil { - return nil, err - } + jwksURI := p.CustomJWKSURI + if jwksURI == nil { + wkEndpoints, err := oidc.GetWellKnownEndpointsFromIssuerURL(ctx, p.Client, *p.IssuerURL) + if err != nil { + return nil, err + } - jwksURI, err := url.Parse(wkEndpoints.JWKSURI) - if err != nil { - return nil, fmt.Errorf("could not parse JWKS URI from well known endpoints: %w", err) + jwksURI, err = url.Parse(wkEndpoints.JWKSURI) + if err != nil { + return nil, fmt.Errorf("could not parse JWKS URI from well known endpoints: %w", err) + } } - request, err := http.NewRequest(http.MethodGet, jwksURI.String(), nil) + request, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURI.String(), nil) if err != nil { return nil, fmt.Errorf("could not build request to get JWKS: %w", err) } - request = request.WithContext(ctx) - response, err := http.DefaultClient.Do(request) + response, err := p.Client.Do(request) if err != nil { return nil, err } @@ -66,10 +91,10 @@ func (p *JWKSProvider) KeyFunc(ctx context.Context) (interface{}, error) { // and caching them for CacheTTL time. It exposes KeyFunc which adheres // to the keyFunc signature that the Validator requires. type CachingJWKSProvider struct { - IssuerURL url.URL - CacheTTL time.Duration - mu sync.Mutex - cache map[string]cachedJWKS + *JWKSProvider + CacheTTL time.Duration + mu sync.Mutex + cache map[string]cachedJWKS } type cachedJWKS struct { @@ -79,15 +104,15 @@ type cachedJWKS struct { // NewCachingJWKSProvider builds and returns a new CachingJWKSProvider. // If cacheTTL is zero then a default value of 1 minute will be used. -func NewCachingJWKSProvider(issuerURL url.URL, cacheTTL time.Duration) *CachingJWKSProvider { +func NewCachingJWKSProvider(issuerURL *url.URL, cacheTTL time.Duration, opts ...ProviderOption) *CachingJWKSProvider { if cacheTTL == 0 { cacheTTL = 1 * time.Minute } return &CachingJWKSProvider{ - IssuerURL: issuerURL, - CacheTTL: cacheTTL, - cache: map[string]cachedJWKS{}, + JWKSProvider: NewJWKSProvider(issuerURL, opts...), + CacheTTL: cacheTTL, + cache: map[string]cachedJWKS{}, } } @@ -95,12 +120,10 @@ func NewCachingJWKSProvider(issuerURL url.URL, cacheTTL time.Duration) *CachingJ // While it returns an interface to adhere to keyFunc, as long as the // error is nil the type will be *jose.JSONWebKeySet. func (c *CachingJWKSProvider) KeyFunc(ctx context.Context) (interface{}, error) { - issuer := c.IssuerURL.Hostname() - c.mu.Lock() - defer func() { - c.mu.Unlock() - }() + defer c.mu.Unlock() + + issuer := c.IssuerURL.Hostname() if cached, ok := c.cache[issuer]; ok { if !time.Now().After(cached.expiresAt) { @@ -108,8 +131,7 @@ func (c *CachingJWKSProvider) KeyFunc(ctx context.Context) (interface{}, error) } } - provider := JWKSProvider{IssuerURL: c.IssuerURL} - jwks, err := provider.KeyFunc(ctx) + jwks, err := c.JWKSProvider.KeyFunc(ctx) if err != nil { return nil, err } diff --git a/validate/josev2/jwks_provider_test.go b/validate/josev2/jwks_provider_test.go index 5ce8951..d36aa4c 100644 --- a/validate/josev2/jwks_provider_test.go +++ b/validate/josev2/jwks_provider_test.go @@ -6,11 +6,14 @@ import ( "crypto/rsa" "crypto/x509" "encoding/json" + "fmt" "math/big" "net/http" "net/http/httptest" "net/url" + "strings" "sync" + "sync/atomic" "testing" "time" @@ -21,259 +24,174 @@ import ( ) func Test_JWKSProvider(t *testing.T) { - var ( - p CachingJWKSProvider - server *httptest.Server - responseBytes []byte - responseStatusCode, reqCount int - serverURL *url.URL - ) + var requestCount int32 - tests := []struct { - name string - main func(t *testing.T) - }{ - { - name: "calls out to well known endpoint", - main: func(t *testing.T) { - _, jwks := genValidRSAKeyAndJWKS(t) - var err error - responseBytes, err = json.Marshal(jwks) - if !equalErrors(err, "") { - t.Fatalf("did not want an error, but got %s", err) - } - - _, err = p.KeyFunc(context.TODO()) - if !equalErrors(err, "") { - t.Fatalf("did not want an error, but got %s", err) - } - }, - }, - { - name: "errors if it can't decode the jwks", - main: func(t *testing.T) { - responseBytes = []byte("<>") - _, err := p.KeyFunc(context.TODO()) - - wantErr := "could not decode jwks: invalid character '<' looking for beginning of value" - if !equalErrors(err, wantErr) { - t.Fatalf("wanted err:\n%s\ngot:\n%+v\n", wantErr, err) - } - }, - }, - { - name: "passes back the valid jwks", - main: func(t *testing.T) { - _, jwks := genValidRSAKeyAndJWKS(t) - var err error - responseBytes, err = json.Marshal(jwks) - if !equalErrors(err, "") { - t.Fatalf("did not want an error, but got %s", err) - } - - p.CacheTTL = time.Minute * 5 - actualJWKS, err := p.KeyFunc(context.TODO()) - if !equalErrors(err, "") { - t.Fatalf("did not want an error, but got %s", err) - } - - if want, got := &jwks, actualJWKS; !cmp.Equal(want, got) { - t.Fatalf("jwks did not match: %s", cmp.Diff(want, got)) - } - - if want, got := &jwks, p.cache[serverURL.Hostname()].jwks; !cmp.Equal(want, got) { - t.Fatalf("cached jwks did not match: %s", cmp.Diff(want, got)) - } - - expiresAt := p.cache[serverURL.Hostname()].expiresAt - if !time.Now().Before(expiresAt) { - t.Fatalf("wanted cache item expiration to be in the future but it was not: %s", expiresAt) - } - }, - }, - { - name: "returns the cached jwks when they are not expired", - main: func(t *testing.T) { - _, expectedCachedJWKS := genValidRSAKeyAndJWKS(t) - p.cache[serverURL.Hostname()] = cachedJWKS{ - jwks: &expectedCachedJWKS, - expiresAt: time.Now().Add(1 * time.Minute), - } - - actualJWKS, err := p.KeyFunc(context.TODO()) - if !equalErrors(err, "") { - t.Fatalf("did not want an error, but got %s", err) - } - - if want, got := &expectedCachedJWKS, actualJWKS; !cmp.Equal(want, got) { - t.Fatalf("cached jwks did not match: %s", cmp.Diff(want, got)) - } - - if reqCount > 0 { - t.Fatalf("did not want any requests since we should have read from the cache, but we got %d requests", reqCount) - } - }, - }, - { - name: "re-caches the jwks if they have expired", - main: func(t *testing.T) { - _, expiredCachedJWKS := genValidRSAKeyAndJWKS(t) - expiresAt := time.Now().Add(-10 * time.Minute) - p.cache[server.URL] = cachedJWKS{ - jwks: &expiredCachedJWKS, - expiresAt: expiresAt, - } - _, jwks := genValidRSAKeyAndJWKS(t) - var err error - responseBytes, err = json.Marshal(jwks) - if !equalErrors(err, "") { - t.Fatalf("did not want an error, but got %s", err) - } - - p.CacheTTL = time.Minute * 5 - actualJWKS, err := p.KeyFunc(context.TODO()) - if !equalErrors(err, "") { - t.Fatalf("did not want an error, but got %s", err) - } - - if want, got := &jwks, actualJWKS; !cmp.Equal(want, got) { - t.Fatalf("jwks did not match: %s", cmp.Diff(want, got)) - } - - if want, got := &jwks, p.cache[serverURL.Hostname()].jwks; !cmp.Equal(want, got) { - t.Fatalf("cached jwks did not match: %s", cmp.Diff(want, got)) - } - - cacheExpiresAt := p.cache[serverURL.Hostname()].expiresAt - if !time.Now().Before(cacheExpiresAt) { - t.Fatalf("wanted cache item expiration to be in the future but it was not: %s", cacheExpiresAt) - } - }, - }, - { - name: "only calls the API once when multiple requests come in", - main: func(t *testing.T) { - _, jwks := genValidRSAKeyAndJWKS(t) - var err error - responseBytes, err = json.Marshal(jwks) - if !equalErrors(err, "") { - t.Fatalf("did not want an error, but got %s", err) - } - - p.CacheTTL = time.Minute * 5 - - wg := sync.WaitGroup{} - for i := 0; i < 50; i++ { - wg.Add(1) - go func(t *testing.T) { - actualJWKS, err := p.KeyFunc(context.TODO()) - if !equalErrors(err, "") { - t.Errorf("did not want an error, but got %s", err) - } - - if want, got := &jwks, actualJWKS; !cmp.Equal(want, got) { - t.Errorf("jwks did not match: %s", cmp.Diff(want, got)) - } - - wg.Done() - }(t) - } - wg.Wait() - - actualJWKS, err := p.KeyFunc(context.TODO()) - if !equalErrors(err, "") { - t.Fatalf("did not want an error, but got %s", err) - } - - if want, got := &jwks, actualJWKS; !cmp.Equal(want, got) { - t.Fatalf("jwks did not match: %s", cmp.Diff(want, got)) - } - - if reqCount != 2 { - t.Fatalf("only wanted 2 requests (well known and jwks) , but we got %d requests", reqCount) - } - - if want, got := &jwks, p.cache[serverURL.Hostname()].jwks; !cmp.Equal(want, got) { - t.Fatalf("cached jwks did not match: %s", cmp.Diff(want, got)) - } - - cacheExpiresAt := p.cache[serverURL.Hostname()].expiresAt - if !time.Now().Before(cacheExpiresAt) { - t.Fatalf("wanted cache item expiration to be in the future but it was not: %s", cacheExpiresAt) - } - }, - }, + expectedJWKS, err := generateJWKS() + if err != nil { + t.Fatalf("did not expect an error but gone one: %v", err) } - for _, test := range tests { - var reqCallMutex sync.Mutex - - reqCount = 0 - responseBytes = []byte(`{"kid":""}`) - responseStatusCode = http.StatusOK - server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // handle mutex things - reqCallMutex.Lock() - defer reqCallMutex.Unlock() - reqCount++ - w.WriteHeader(responseStatusCode) - - switch r.URL.String() { - case "/.well-known/openid-configuration": - wk := oidc.WellKnownEndpoints{JWKSURI: server.URL + "/url_for_jwks"} - err := json.NewEncoder(w).Encode(wk) - if !equalErrors(err, "") { - t.Fatalf("did not want an error, but got %s", err) - } - case "/url_for_jwks": - _, err := w.Write(responseBytes) - if !equalErrors(err, "") { - t.Fatalf("did not want an error, but got %s", err) - } - default: - t.Fatalf("do not know how to handle url %s", r.URL.String()) - } - })) - defer server.Close() - serverURL = mustParseURL(server.URL) - - p = CachingJWKSProvider{ - IssuerURL: *serverURL, - CacheTTL: 0, - cache: map[string]cachedJWKS{}, - } - - t.Run(test.name, test.main) + expectedCustomJWKS, err := generateJWKS() + if err != nil { + t.Fatalf("did not expect an error but gone one: %v", err) } -} -func mustParseURL(toParse string) *url.URL { - parsed, err := url.Parse(toParse) + server := setupTestServer(t, expectedJWKS, expectedCustomJWKS, &requestCount) + defer server.Close() + + serverURL, err := url.Parse(server.URL) if err != nil { - panic(err) + t.Fatalf("did not want an error, but got %s", err) } - return parsed + t.Run("It correctly fetches the JWKS after calling the discovery endpoint", func(t *testing.T) { + provider := NewJWKSProvider(serverURL) + actualJWKS, err := provider.KeyFunc(context.Background()) + if err != nil { + t.Fatalf("did not want an error, but got %s", err) + } + + if !cmp.Equal(expectedJWKS, actualJWKS) { + t.Fatalf("jwks did not match: %s", cmp.Diff(expectedJWKS, actualJWKS)) + } + }) + + t.Run("It skips the discovery if a custom JWKS_URI is provided", func(t *testing.T) { + customJWKSURI, err := url.Parse(server.URL + "/custom/jwks.json") + if err != nil { + t.Fatalf("did not want an error, but got %s", err) + } + + provider := NewJWKSProvider(serverURL, WithCustomJWKSURI(customJWKSURI)) + actualJWKS, err := provider.KeyFunc(context.Background()) + if err != nil { + t.Fatalf("did not want an error, but got %s", err) + } + + if !cmp.Equal(expectedCustomJWKS, actualJWKS) { + t.Fatalf("jwks did not match: %s", cmp.Diff(expectedCustomJWKS, actualJWKS)) + } + }) + + t.Run("It tells the provider to cancel fetching the JWKS if request is cancelled", func(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 0) + defer cancel() + + provider := NewJWKSProvider(serverURL) + _, err := provider.KeyFunc(ctx) + if !strings.Contains(err.Error(), "context deadline exceeded") { + t.Fatalf("was expecting context deadline to exceed but error is: %v", err) + } + }) + + t.Run("It re-caches the JWKS if they have expired when using CachingJWKSProvider", func(t *testing.T) { + expiredCachedJWKS, err := generateJWKS() + if err != nil { + t.Fatalf("did not expect an error but gone one: %v", err) + } + + provider := NewCachingJWKSProvider(serverURL, 5*time.Minute) + provider.cache[serverURL.Hostname()] = cachedJWKS{ + jwks: expiredCachedJWKS, + expiresAt: time.Now().Add(-10 * time.Minute), + } + + actualJWKS, err := provider.KeyFunc(context.Background()) + if err != nil { + t.Fatalf("did not want an error, but got %s", err) + } + + if !cmp.Equal(expectedJWKS, actualJWKS) { + t.Fatalf("jwks did not match: %s", cmp.Diff(expectedJWKS, actualJWKS)) + } + + if !cmp.Equal(expectedJWKS, provider.cache[serverURL.Hostname()].jwks) { + t.Fatalf("cached jwks did not match: %s", cmp.Diff(expectedJWKS, provider.cache[serverURL.Hostname()].jwks)) + } + + cacheExpiresAt := provider.cache[serverURL.Hostname()].expiresAt + if !time.Now().Before(cacheExpiresAt) { + t.Fatalf("wanted cache item expiration to be in the future but it was not: %s", cacheExpiresAt) + } + }) + + t.Run( + "It only calls the API once when multiple requests come in when using the CachingJWKSProvider", + func(t *testing.T) { + requestCount = 0 + + provider := NewCachingJWKSProvider(serverURL, 5*time.Minute) + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + _, _ = provider.KeyFunc(context.Background()) + wg.Done() + }() + } + wg.Wait() + + if requestCount != 2 { + t.Fatalf("only wanted 2 requests (well known and jwks) , but we got %d requests", requestCount) + } + }, + ) + + t.Run("It sets the caching TTL to 1 if 0 is provided when using the CachingJWKSProvider", func(t *testing.T) { + provider := NewCachingJWKSProvider(serverURL, 0) + if provider.CacheTTL != time.Minute { + t.Fatalf("was expecting cache ttl to be 1 minute") + } + }) + + t.Run( + "It fails to parse the jwks uri after fetching it from the discovery endpoint if malformed", + func(t *testing.T) { + malformedURL, err := url.Parse(server.URL+"/malformed") + if err != nil { + t.Fatalf("did not want an error, but got %s", err) + } + + provider := NewJWKSProvider(malformedURL) + _, err = provider.KeyFunc(context.Background()) + if !strings.Contains(err.Error(), "could not parse JWKS URI from well known endpoints") { + t.Fatalf("wanted an error, but got %s", err) + } + }, + ) } -func genValidRSAKeyAndJWKS(t *testing.T) (*rsa.PrivateKey, jose.JSONWebKeySet) { - ca := &x509.Certificate{ +func generateJWKS() (*jose.JSONWebKeySet, error) { + certificate := &x509.Certificate{ SerialNumber: big.NewInt(1653), } - priv, _ := rsa.GenerateKey(rand.Reader, 2048) - rawCert, err := x509.CreateCertificate(rand.Reader, ca, ca, &priv.PublicKey, priv) - if !equalErrors(err, "") { - t.Fatalf("did not want an error, but got %s", err) + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, fmt.Errorf("failed to generate private key") + } + + rawCertificate, err := x509.CreateCertificate( + rand.Reader, + certificate, + certificate, + &privateKey.PublicKey, + privateKey, + ) + if err != nil { + return nil, fmt.Errorf("failed to create certificate") } jwks := jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { - Key: priv, + Key: privateKey, KeyID: "kid", Certificates: []*x509.Certificate{ { - Raw: rawCert, + Raw: rawCertificate, }, }, CertificateThumbprintSHA1: []uint8{}, @@ -281,5 +199,44 @@ func genValidRSAKeyAndJWKS(t *testing.T) (*rsa.PrivateKey, jose.JSONWebKeySet) { }, }, } - return priv, jwks + + return &jwks, nil +} + +func setupTestServer( + t *testing.T, + expectedJWKS *jose.JSONWebKeySet, + expectedCustomJWKS *jose.JSONWebKeySet, + requestCount *int32, +) (server *httptest.Server) { + t.Helper() + + var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(requestCount, 1) + + switch r.URL.String() { + case "/malformed/.well-known/openid-configuration": + wk := oidc.WellKnownEndpoints{JWKSURI: ":"} + if err := json.NewEncoder(w).Encode(wk); err != nil { + t.Fatalf("did not want an error, but got %s", err) + } + case "/.well-known/openid-configuration": + wk := oidc.WellKnownEndpoints{JWKSURI: server.URL + "/.well-known/jwks.json"} + if err := json.NewEncoder(w).Encode(wk); err != nil { + t.Fatalf("did not want an error, but got %s", err) + } + case "/.well-known/jwks.json": + if err := json.NewEncoder(w).Encode(expectedJWKS); err != nil { + t.Fatalf("did not want an error, but got %s", err) + } + case "/custom/jwks.json": + if err := json.NewEncoder(w).Encode(expectedCustomJWKS); err != nil { + t.Fatalf("did not want an error, but got %s", err) + } + default: + t.Fatalf("was not expecting to handle the following url: %s", r.URL.String()) + } + }) + + return httptest.NewServer(handler) }