diff --git a/error_handler.go b/error_handler.go index 816387f..1360b3c 100644 --- a/error_handler.go +++ b/error_handler.go @@ -1,46 +1,174 @@ package jwtmiddleware import ( + "encoding/json" "errors" "fmt" "net/http" + + "github.com/auth0/go-jwt-middleware/v3/core" ) var ( // ErrJWTMissing is returned when the JWT is missing. - ErrJWTMissing = errors.New("jwt missing") + // This is the same as core.ErrJWTMissing for consistency. + ErrJWTMissing = core.ErrJWTMissing // ErrJWTInvalid is returned when the JWT is invalid. - ErrJWTInvalid = errors.New("jwt invalid") + // This is the same as core.ErrJWTInvalid for consistency. + ErrJWTInvalid = core.ErrJWTInvalid ) // ErrorHandler is a handler which is called when an error occurs in the -// JWTMiddleware. Among some general errors, this handler also determines the -// response of the JWTMiddleware when a token is not found or is invalid. The -// err can be checked to be ErrJWTMissing or ErrJWTInvalid for specific cases. -// The default handler will return a status code of 400 for ErrJWTMissing, -// 401 for ErrJWTInvalid, and 500 for all other errors. If you implement your -// own ErrorHandler you MUST take into consideration the error types as not -// properly responding to them or having a poorly implemented handler could -// result in the JWTMiddleware not functioning as intended. +// JWTMiddleware. The handler determines the HTTP response when a token is +// not found, is invalid, or other errors occur. +// +// The default handler (DefaultErrorHandler) provides: +// - Structured JSON error responses with error codes +// - RFC 6750 compliant WWW-Authenticate headers (Bearer tokens) +// - Appropriate HTTP status codes based on error type +// - Security-conscious error messages (no sensitive details by default) +// - Extensible architecture for future authentication schemes (e.g., DPoP per RFC 9449) +// +// Custom error handlers should check for ErrJWTMissing and ErrJWTInvalid +// sentinel errors, as well as core.ValidationError for detailed error codes. +// +// Future extensions (e.g., DPoP support) can use the same pattern: +// - Add DPoP-specific error codes to core.ValidationError +// - Update mapValidationError to handle DPoP errors +// - Return appropriate WWW-Authenticate headers with DPoP scheme type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error) -// DefaultErrorHandler is the default error handler implementation for the -// JWTMiddleware. If an error handler is not provided via the WithErrorHandler -// option this will be used. +// ErrorResponse represents a structured error response. +type ErrorResponse struct { + // Error is the main error message + Error string `json:"error"` + + // ErrorDescription provides additional context (optional) + ErrorDescription string `json:"error_description,omitempty"` + + // ErrorCode is a machine-readable error code (optional) + ErrorCode string `json:"error_code,omitempty"` +} + +// DefaultErrorHandler is the default error handler implementation. +// It provides structured error responses with appropriate HTTP status codes +// and RFC 6750 compliant WWW-Authenticate headers. func DefaultErrorHandler(w http.ResponseWriter, _ *http.Request, err error) { + // Extract error details + statusCode, errorResp, wwwAuthenticate := mapErrorToResponse(err) + + // Set headers w.Header().Set("Content-Type", "application/json") + if wwwAuthenticate != "" { + w.Header().Set("WWW-Authenticate", wwwAuthenticate) + } + + // Write response + w.WriteHeader(statusCode) + _ = json.NewEncoder(w).Encode(errorResp) +} + +// mapErrorToResponse maps errors to appropriate HTTP responses +func mapErrorToResponse(err error) (statusCode int, resp ErrorResponse, wwwAuthenticate string) { + // Check for JWT missing error + if errors.Is(err, ErrJWTMissing) { + return http.StatusUnauthorized, ErrorResponse{ + Error: "invalid_token", + ErrorDescription: "JWT is missing", + }, `Bearer error="invalid_token", error_description="JWT is missing"` + } + + // Check for validation error with specific code + var validationErr *core.ValidationError + if errors.As(err, &validationErr) { + return mapValidationError(validationErr) + } + + // Check for general JWT invalid error + if errors.Is(err, ErrJWTInvalid) { + return http.StatusUnauthorized, ErrorResponse{ + Error: "invalid_token", + ErrorDescription: "JWT is invalid", + }, `Bearer error="invalid_token", error_description="JWT is invalid"` + } + + // Default to internal server error for unexpected errors + return http.StatusInternalServerError, ErrorResponse{ + Error: "server_error", + ErrorDescription: "An internal error occurred while processing the request", + }, "" +} + +// mapValidationError maps core.ValidationError codes to HTTP responses +// This function is extensible to support future authentication schemes like DPoP (RFC 9449) +func mapValidationError(err *core.ValidationError) (statusCode int, resp ErrorResponse, wwwAuthenticate string) { + // Map error codes to HTTP status codes and RFC 6750 Bearer token error types + // Future: Add DPoP-specific error codes and return appropriate DPoP challenge headers + switch err.Code { + case core.ErrorCodeTokenExpired: + return http.StatusUnauthorized, ErrorResponse{ + Error: "invalid_token", + ErrorDescription: "The access token expired", + ErrorCode: string(err.Code), + }, `Bearer error="invalid_token", error_description="The access token expired"` + + case core.ErrorCodeTokenNotYetValid: + return http.StatusUnauthorized, ErrorResponse{ + Error: "invalid_token", + ErrorDescription: "The access token is not yet valid", + ErrorCode: string(err.Code), + }, `Bearer error="invalid_token", error_description="The access token is not yet valid"` + + case core.ErrorCodeInvalidSignature: + return http.StatusUnauthorized, ErrorResponse{ + Error: "invalid_token", + ErrorDescription: "The access token signature is invalid", + ErrorCode: string(err.Code), + }, `Bearer error="invalid_token", error_description="The access token signature is invalid"` + + case core.ErrorCodeTokenMalformed: + return http.StatusBadRequest, ErrorResponse{ + Error: "invalid_request", + ErrorDescription: "The access token is malformed", + ErrorCode: string(err.Code), + }, `Bearer error="invalid_request", error_description="The access token is malformed"` + + case core.ErrorCodeInvalidIssuer: + return http.StatusForbidden, ErrorResponse{ + Error: "insufficient_scope", + ErrorDescription: "The access token was issued by an untrusted issuer", + ErrorCode: string(err.Code), + }, `Bearer error="insufficient_scope", error_description="The access token was issued by an untrusted issuer"` + + case core.ErrorCodeInvalidAudience: + return http.StatusForbidden, ErrorResponse{ + Error: "insufficient_scope", + ErrorDescription: "The access token audience does not match", + ErrorCode: string(err.Code), + }, `Bearer error="insufficient_scope", error_description="The access token audience does not match"` + + case core.ErrorCodeInvalidAlgorithm: + return http.StatusUnauthorized, ErrorResponse{ + Error: "invalid_token", + ErrorDescription: "The access token uses an unsupported algorithm", + ErrorCode: string(err.Code), + }, `Bearer error="invalid_token", error_description="The access token uses an unsupported algorithm"` + + case core.ErrorCodeJWKSFetchFailed, core.ErrorCodeJWKSKeyNotFound: + return http.StatusUnauthorized, ErrorResponse{ + Error: "invalid_token", + ErrorDescription: "Unable to verify the access token", + ErrorCode: string(err.Code), + }, `Bearer error="invalid_token", error_description="Unable to verify the access token"` - switch { - case errors.Is(err, ErrJWTMissing): - w.WriteHeader(http.StatusBadRequest) - _, _ = w.Write([]byte(`{"message":"JWT is missing."}`)) - case errors.Is(err, ErrJWTInvalid): - w.WriteHeader(http.StatusUnauthorized) - _, _ = w.Write([]byte(`{"message":"JWT is invalid."}`)) default: - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte(`{"message":"Something went wrong while checking the JWT."}`)) + // Generic invalid token error for other cases + return http.StatusUnauthorized, ErrorResponse{ + Error: "invalid_token", + ErrorDescription: "The access token is invalid", + ErrorCode: string(err.Code), + }, `Bearer error="invalid_token", error_description="The access token is invalid"` } } diff --git a/error_handler_test.go b/error_handler_test.go index 4bf70d1..32f0942 100644 --- a/error_handler_test.go +++ b/error_handler_test.go @@ -1,34 +1,214 @@ package jwtmiddleware import ( - "errors" + "encoding/json" + "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/auth0/go-jwt-middleware/v3/core" ) -func Test_invalidError(t *testing.T) { - t.Run("Is", func(t *testing.T) { - err := invalidError{details: errors.New("error details")} +func TestDefaultErrorHandler(t *testing.T) { + tests := []struct { + name string + err error + wantStatus int + wantError string + wantErrorDescription string + wantErrorCode string + wantWWWAuthenticate string + }{ + { + name: "ErrJWTMissing", + err: ErrJWTMissing, + wantStatus: http.StatusUnauthorized, + wantError: "invalid_token", + wantErrorDescription: "JWT is missing", + wantWWWAuthenticate: `Bearer error="invalid_token", error_description="JWT is missing"`, + }, + { + name: "ErrJWTInvalid", + err: ErrJWTInvalid, + wantStatus: http.StatusUnauthorized, + wantError: "invalid_token", + wantErrorDescription: "JWT is invalid", + wantWWWAuthenticate: `Bearer error="invalid_token", error_description="JWT is invalid"`, + }, + { + name: "token expired", + err: core.NewValidationError(core.ErrorCodeTokenExpired, "token expired", nil), + wantStatus: http.StatusUnauthorized, + wantError: "invalid_token", + wantErrorDescription: "The access token expired", + wantErrorCode: "token_expired", + wantWWWAuthenticate: `Bearer error="invalid_token", error_description="The access token expired"`, + }, + { + name: "token not yet valid", + err: core.NewValidationError(core.ErrorCodeTokenNotYetValid, "token not yet valid", nil), + wantStatus: http.StatusUnauthorized, + wantError: "invalid_token", + wantErrorDescription: "The access token is not yet valid", + wantErrorCode: "token_not_yet_valid", + wantWWWAuthenticate: `Bearer error="invalid_token", error_description="The access token is not yet valid"`, + }, + { + name: "invalid signature", + err: core.NewValidationError(core.ErrorCodeInvalidSignature, "invalid signature", nil), + wantStatus: http.StatusUnauthorized, + wantError: "invalid_token", + wantErrorDescription: "The access token signature is invalid", + wantErrorCode: "invalid_signature", + wantWWWAuthenticate: `Bearer error="invalid_token", error_description="The access token signature is invalid"`, + }, + { + name: "token malformed", + err: core.NewValidationError(core.ErrorCodeTokenMalformed, "malformed token", nil), + wantStatus: http.StatusBadRequest, + wantError: "invalid_request", + wantErrorDescription: "The access token is malformed", + wantErrorCode: "token_malformed", + wantWWWAuthenticate: `Bearer error="invalid_request", error_description="The access token is malformed"`, + }, + { + name: "invalid issuer", + err: core.NewValidationError(core.ErrorCodeInvalidIssuer, "invalid issuer", nil), + wantStatus: http.StatusForbidden, + wantError: "insufficient_scope", + wantErrorDescription: "The access token was issued by an untrusted issuer", + wantErrorCode: "invalid_issuer", + wantWWWAuthenticate: `Bearer error="insufficient_scope", error_description="The access token was issued by an untrusted issuer"`, + }, + { + name: "invalid audience", + err: core.NewValidationError(core.ErrorCodeInvalidAudience, "invalid audience", nil), + wantStatus: http.StatusForbidden, + wantError: "insufficient_scope", + wantErrorDescription: "The access token audience does not match", + wantErrorCode: "invalid_audience", + wantWWWAuthenticate: `Bearer error="insufficient_scope", error_description="The access token audience does not match"`, + }, + { + name: "invalid algorithm", + err: core.NewValidationError(core.ErrorCodeInvalidAlgorithm, "invalid algorithm", nil), + wantStatus: http.StatusUnauthorized, + wantError: "invalid_token", + wantErrorDescription: "The access token uses an unsupported algorithm", + wantErrorCode: "invalid_algorithm", + wantWWWAuthenticate: `Bearer error="invalid_token", error_description="The access token uses an unsupported algorithm"`, + }, + { + name: "JWKS fetch failed", + err: core.NewValidationError(core.ErrorCodeJWKSFetchFailed, "jwks fetch failed", nil), + wantStatus: http.StatusUnauthorized, + wantError: "invalid_token", + wantErrorDescription: "Unable to verify the access token", + wantErrorCode: "jwks_fetch_failed", + wantWWWAuthenticate: `Bearer error="invalid_token", error_description="Unable to verify the access token"`, + }, + { + name: "JWKS key not found", + err: core.NewValidationError(core.ErrorCodeJWKSKeyNotFound, "key not found", nil), + wantStatus: http.StatusUnauthorized, + wantError: "invalid_token", + wantErrorDescription: "Unable to verify the access token", + wantErrorCode: "jwks_key_not_found", + wantWWWAuthenticate: `Bearer error="invalid_token", error_description="Unable to verify the access token"`, + }, + { + name: "unknown validation error", + err: core.NewValidationError("unknown_code", "unknown error", nil), + wantStatus: http.StatusUnauthorized, + wantError: "invalid_token", + wantErrorDescription: "The access token is invalid", + wantErrorCode: "unknown_code", + wantWWWAuthenticate: `Bearer error="invalid_token", error_description="The access token is invalid"`, + }, + { + name: "generic error", + err: assert.AnError, + wantStatus: http.StatusInternalServerError, + wantError: "server_error", + wantErrorDescription: "An internal error occurred while processing the request", + wantWWWAuthenticate: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/test", nil) + + DefaultErrorHandler(w, r, tt.err) - if !errors.Is(err, ErrJWTInvalid) { - t.Fatal("expected invalidError to be ErrJWTInvalid via errors.Is, but it was not") - } - }) + // Check status code + assert.Equal(t, tt.wantStatus, w.Code) - t.Run("Error", func(t *testing.T) { - err := invalidError{details: errors.New("error details")} - expectedErrMsg := "jwt invalid: error details" + // Check Content-Type + assert.Equal(t, "application/json", w.Header().Get("Content-Type")) - assert.EqualError(t, err, expectedErrMsg) - }) + // Check WWW-Authenticate header + if tt.wantWWWAuthenticate != "" { + assert.Equal(t, tt.wantWWWAuthenticate, w.Header().Get("WWW-Authenticate")) + } else { + assert.Empty(t, w.Header().Get("WWW-Authenticate")) + } + + // Check response body + var resp ErrorResponse + err := json.NewDecoder(w.Body).Decode(&resp) + require.NoError(t, err) + + assert.Equal(t, tt.wantError, resp.Error) + assert.Equal(t, tt.wantErrorDescription, resp.ErrorDescription) + if tt.wantErrorCode != "" { + assert.Equal(t, tt.wantErrorCode, resp.ErrorCode) + } + }) + } +} - t.Run("Unwrap", func(t *testing.T) { - expectedErr := errors.New("expected err") - err := invalidError{details: expectedErr} +func TestErrorResponse_JSON(t *testing.T) { + tests := []struct { + name string + response ErrorResponse + wantJSON string + }{ + { + name: "all fields", + response: ErrorResponse{ + Error: "invalid_token", + ErrorDescription: "The token expired", + ErrorCode: "token_expired", + }, + wantJSON: `{"error":"invalid_token","error_description":"The token expired","error_code":"token_expired"}`, + }, + { + name: "without error code", + response: ErrorResponse{ + Error: "invalid_token", + ErrorDescription: "JWT is invalid", + }, + wantJSON: `{"error":"invalid_token","error_description":"JWT is invalid"}`, + }, + { + name: "without description", + response: ErrorResponse{ + Error: "server_error", + }, + wantJSON: `{"error":"server_error"}`, + }, + } - if !errors.Is(err, expectedErr) { - t.Fatal("expected invalidError to be expectedErr via errors.Is, but it was not") - } - }) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.response) + require.NoError(t, err) + assert.JSONEq(t, tt.wantJSON, string(data)) + }) + } } diff --git a/examples/echo-example/go.mod b/examples/echo-example/go.mod index 07da922..54c3012 100644 --- a/examples/echo-example/go.mod +++ b/examples/echo-example/go.mod @@ -7,11 +7,13 @@ toolchain go1.24.8 require ( github.com/auth0/go-jwt-middleware/v3 v3.0.0 github.com/labstack/echo/v4 v4.13.4 + github.com/stretchr/testify v1.11.1 ) replace github.com/auth0/go-jwt-middleware/v3 => ./../../ require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect github.com/goccy/go-json v0.10.3 // indirect github.com/labstack/gommon v0.4.2 // indirect @@ -25,6 +27,7 @@ require ( github.com/lestrrat-go/option/v2 v2.0.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/segmentio/asm v1.2.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fastjson v1.6.4 // indirect @@ -33,4 +36,5 @@ require ( golang.org/x/net v0.47.0 // indirect golang.org/x/sys v0.38.0 // indirect golang.org/x/text v0.31.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/examples/echo-example/go.sum b/examples/echo-example/go.sum index c68eeff..feccc72 100644 --- a/examples/echo-example/go.sum +++ b/examples/echo-example/go.sum @@ -55,6 +55,7 @@ golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/examples/echo-example/main.go b/examples/echo-example/main.go index 41b2a01..c867363 100644 --- a/examples/echo-example/main.go +++ b/examples/echo-example/main.go @@ -40,12 +40,17 @@ import ( // "shouldReject": true // } -func main() { +func setupRouter() *echo.Echo { app := echo.New() - app.GET("/", func(ctx echo.Context) error { - claims, ok := ctx.Request().Context().Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims) - if !ok { + app.GET("/api/public", func(ctx echo.Context) error { + return ctx.JSON(http.StatusOK, map[string]string{"message": "Hello from a public endpoint!"}) + }) + + app.GET("/api/private", func(ctx echo.Context) error { + // Modern type-safe claims retrieval using generics + claims, err := jwtmiddleware.GetClaims[*validator.ValidatedClaims](ctx.Request().Context()) + if err != nil { ctx.JSON( http.StatusInternalServerError, map[string]string{"message": "Failed to get validated JWT claims."}, @@ -74,6 +79,12 @@ func main() { return nil }, checkJWT) + return app +} + +func main() { + app := setupRouter() + log.Print("Server listening on http://localhost:3000") err := app.Start(":3000") if err != nil { diff --git a/examples/echo-example/main_integration_test.go b/examples/echo-example/main_integration_test.go new file mode 100644 index 0000000..776b2e5 --- /dev/null +++ b/examples/echo-example/main_integration_test.go @@ -0,0 +1,80 @@ +package main + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEchoExample_ValidToken(t *testing.T) { + e := setupRouter() + server := httptest.NewServer(e) + defer server.Close() + + // Valid token from the example + validToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnby1qd3QtbWlkZGxld2FyZS1leGFtcGxlIiwiYXVkIjoiYXVkaWVuY2UtZXhhbXBsZSIsInN1YiI6IjEyMzQ1Njc4OTAiLCJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE1MTYyMzkwMjIsInVzZXJuYW1lIjoidXNlcjEyMyJ9.XFhrzWzntyINkgoRt2mb8dES84dJcuOoORdzKfwUX70" + + req, err := http.NewRequest(http.MethodGet, server.URL+"/api/public", nil) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Contains(t, string(body), "message") + + // Test protected endpoint + req, err = http.NewRequest(http.MethodGet, server.URL+"/api/private", nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+validToken) + + resp, err = http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + body, err = io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Contains(t, string(body), "John Doe") + assert.Contains(t, string(body), "user123") +} + +func TestEchoExample_MissingToken(t *testing.T) { + e := setupRouter() + server := httptest.NewServer(e) + defer server.Close() + + req, err := http.NewRequest(http.MethodGet, server.URL+"/api/private", nil) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) +} + +func TestEchoExample_InvalidToken(t *testing.T) { + e := setupRouter() + server := httptest.NewServer(e) + defer server.Close() + + req, err := http.NewRequest(http.MethodGet, server.URL+"/api/private", nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer invalid.token.here") + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) +} diff --git a/examples/echo-example/middleware.go b/examples/echo-example/middleware.go index 5da2209..77a209e 100644 --- a/examples/echo-example/middleware.go +++ b/examples/echo-example/middleware.go @@ -2,11 +2,12 @@ package main import ( "context" - "github.com/labstack/echo/v4" "log" "net/http" "time" + "github.com/labstack/echo/v4" + jwtmiddleware "github.com/auth0/go-jwt-middleware/v3" "github.com/auth0/go-jwt-middleware/v3/validator" ) @@ -25,7 +26,6 @@ var ( keyFunc = func(ctx context.Context) (interface{}, error) { return signingKey, nil } - ) // checkJWT is an echo.HandlerFunc middleware @@ -51,10 +51,14 @@ func checkJWT(next echo.HandlerFunc) echo.HandlerFunc { log.Printf("Encountered error while validating JWT: %v", err) } - middleware := jwtmiddleware.New( - jwtValidator.ValidateToken, + // Set up the middleware using pure options pattern + middleware, err := jwtmiddleware.New( + jwtmiddleware.WithValidator(jwtValidator), jwtmiddleware.WithErrorHandler(errorHandler), ) + if err != nil { + log.Fatalf("failed to set up the middleware: %v", err) + } return func(ctx echo.Context) error { encounteredError := true diff --git a/examples/gin-example/go.mod b/examples/gin-example/go.mod index ec8afe4..0e486d1 100644 --- a/examples/gin-example/go.mod +++ b/examples/gin-example/go.mod @@ -7,6 +7,7 @@ toolchain go1.24.8 require ( github.com/auth0/go-jwt-middleware/v3 v3.0.0 github.com/gin-gonic/gin v1.10.1 + github.com/stretchr/testify v1.11.1 ) replace github.com/auth0/go-jwt-middleware/v3 => ./../../ @@ -16,6 +17,7 @@ require ( github.com/bytedance/sonic v1.14.2 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect github.com/gabriel-vasile/mimetype v1.4.11 // indirect github.com/gin-contrib/sse v1.1.0 // indirect @@ -38,6 +40,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/segmentio/asm v1.2.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.3.1 // indirect diff --git a/examples/gin-example/main.go b/examples/gin-example/main.go index b280e23..2b6787b 100644 --- a/examples/gin-example/main.go +++ b/examples/gin-example/main.go @@ -40,11 +40,18 @@ import ( // "shouldReject": true // } -func main() { +func setupRouter() *gin.Engine { router := gin.Default() - router.GET("/", checkJWT(), func(ctx *gin.Context) { - claims, ok := ctx.Request.Context().Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims) - if !ok { + + api := router.Group("/api") + api.GET("/public", func(ctx *gin.Context) { + ctx.JSON(http.StatusOK, map[string]string{"message": "Hello from a public endpoint!"}) + }) + + api.GET("/private", checkJWT(), func(ctx *gin.Context) { + // Modern type-safe claims retrieval using generics + claims, err := jwtmiddleware.GetClaims[*validator.ValidatedClaims](ctx.Request.Context()) + if err != nil { ctx.AbortWithStatusJSON( http.StatusInternalServerError, map[string]string{"message": "Failed to get validated JWT claims."}, @@ -72,6 +79,12 @@ func main() { ctx.JSON(http.StatusOK, claims) }) + return router +} + +func main() { + router := setupRouter() + log.Print("Server listening on http://localhost:3000") if err := http.ListenAndServe("0.0.0.0:3000", router); err != nil { log.Fatalf("There was an error with the http server: %v", err) diff --git a/examples/gin-example/main_integration_test.go b/examples/gin-example/main_integration_test.go new file mode 100644 index 0000000..9feda00 --- /dev/null +++ b/examples/gin-example/main_integration_test.go @@ -0,0 +1,87 @@ +package main + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGinExample_ValidToken(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := setupRouter() + server := httptest.NewServer(router) + defer server.Close() + + // Valid token from the example + validToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnby1qd3QtbWlkZGxld2FyZS1leGFtcGxlIiwiYXVkIjoiYXVkaWVuY2UtZXhhbXBsZSIsInN1YiI6IjEyMzQ1Njc4OTAiLCJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE1MTYyMzkwMjIsInVzZXJuYW1lIjoidXNlcjEyMyJ9.XFhrzWzntyINkgoRt2mb8dES84dJcuOoORdzKfwUX70" + + req, err := http.NewRequest(http.MethodGet, server.URL+"/api/public", nil) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Contains(t, string(body), "message") + + // Test protected endpoint + req, err = http.NewRequest(http.MethodGet, server.URL+"/api/private", nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+validToken) + + resp, err = http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + body, err = io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Contains(t, string(body), "John Doe") + assert.Contains(t, string(body), "user123") +} + +func TestGinExample_MissingToken(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := setupRouter() + server := httptest.NewServer(router) + defer server.Close() + + req, err := http.NewRequest(http.MethodGet, server.URL+"/api/private", nil) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) +} + +func TestGinExample_InvalidToken(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := setupRouter() + server := httptest.NewServer(router) + defer server.Close() + + req, err := http.NewRequest(http.MethodGet, server.URL+"/api/private", nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer invalid.token.here") + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) +} diff --git a/examples/gin-example/middleware.go b/examples/gin-example/middleware.go index a02758c..5267ba3 100644 --- a/examples/gin-example/middleware.go +++ b/examples/gin-example/middleware.go @@ -25,7 +25,6 @@ var ( keyFunc = func(ctx context.Context) (interface{}, error) { return signingKey, nil } - ) // checkJWT is a gin.HandlerFunc middleware @@ -51,10 +50,14 @@ func checkJWT() gin.HandlerFunc { log.Printf("Encountered error while validating JWT: %v", err) } - middleware := jwtmiddleware.New( - jwtValidator.ValidateToken, + // Set up the middleware using pure options pattern + middleware, err := jwtmiddleware.New( + jwtmiddleware.WithValidator(jwtValidator), jwtmiddleware.WithErrorHandler(errorHandler), ) + if err != nil { + log.Fatalf("failed to set up the middleware: %v", err) + } return func(ctx *gin.Context) { encounteredError := true diff --git a/examples/http-example/go.mod b/examples/http-example/go.mod index 155bc28..2de4730 100644 --- a/examples/http-example/go.mod +++ b/examples/http-example/go.mod @@ -6,12 +6,14 @@ toolchain go1.24.8 require ( github.com/auth0/go-jwt-middleware/v3 v3.0.0 + github.com/stretchr/testify v1.11.1 gopkg.in/go-jose/go-jose.v2 v2.6.3 ) replace github.com/auth0/go-jwt-middleware/v3 => ./../../ require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect github.com/goccy/go-json v0.10.3 // indirect github.com/lestrrat-go/blackmagic v1.0.4 // indirect @@ -22,8 +24,10 @@ require ( github.com/lestrrat-go/jwx/v3 v3.0.12 // indirect github.com/lestrrat-go/option v1.0.1 // indirect github.com/lestrrat-go/option/v2 v2.0.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/segmentio/asm v1.2.1 // indirect github.com/valyala/fastjson v1.6.4 // indirect golang.org/x/crypto v0.45.0 // indirect golang.org/x/sys v0.38.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/examples/http-example/go.sum b/examples/http-example/go.sum index 4a9d2db..2bdeab4 100644 --- a/examples/http-example/go.sum +++ b/examples/http-example/go.sum @@ -38,6 +38,7 @@ golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/go-jose/go-jose.v2 v2.6.3 h1:nt80fvSDlhKWQgSWyHyy5CfmlQr+asih51R8PTWNKKs= gopkg.in/go-jose/go-jose.v2 v2.6.3/go.mod h1:zzZDPkNNw/c9IE7Z9jr11mBZQhKQTMzoEEIoEdZlFBI= diff --git a/examples/http-example/main.go b/examples/http-example/main.go index caa866a..7ead1a0 100644 --- a/examples/http-example/main.go +++ b/examples/http-example/main.go @@ -34,8 +34,9 @@ func (c *CustomClaimsExample) Validate(ctx context.Context) error { } var handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - claims, ok := r.Context().Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims) - if !ok { + // Modern type-safe claims retrieval using generics + claims, err := jwtmiddleware.GetClaims[*validator.ValidatedClaims](r.Context()) + if err != nil { http.Error(w, "failed to get validated claims", http.StatusInternalServerError) return } @@ -43,10 +44,12 @@ var handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { customClaims, ok := claims.CustomClaims.(*CustomClaimsExample) if !ok { http.Error(w, "could not cast custom claims to specific type", http.StatusInternalServerError) + return } if len(customClaims.Username) == 0 { http.Error(w, "username in JWT claims was empty", http.StatusBadRequest) + return } payload, err := json.Marshal(claims) @@ -81,7 +84,17 @@ func setupHandler() http.Handler { log.Fatalf("failed to set up the validator: %v", err) } - return jwtmiddleware.New(jwtValidator.ValidateToken).CheckJWT(handler) + // Set up the middleware using pure options pattern + middleware, err := jwtmiddleware.New( + jwtmiddleware.WithValidator(jwtValidator), + // Optional: Add a logger for debugging JWT validation flow + // jwtmiddleware.WithLogger(slog.Default()), + ) + if err != nil { + log.Fatalf("failed to set up the middleware: %v", err) + } + + return middleware.CheckJWT(handler) } func main() { diff --git a/examples/http-example/main_integration_test.go b/examples/http-example/main_integration_test.go new file mode 100644 index 0000000..68c4e1f --- /dev/null +++ b/examples/http-example/main_integration_test.go @@ -0,0 +1,107 @@ +package main + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHTTPExample_ValidToken(t *testing.T) { + handler := setupHandler() + server := httptest.NewServer(handler) + defer server.Close() + + // Valid token from the example + validToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnby1qd3QtbWlkZGxld2FyZS1leGFtcGxlIiwiYXVkIjoiYXVkaWVuY2UtZXhhbXBsZSIsInN1YiI6IjEyMzQ1Njc4OTAiLCJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE1MTYyMzkwMjIsInVzZXJuYW1lIjoidXNlcjEyMyJ9.XFhrzWzntyINkgoRt2mb8dES84dJcuOoORdzKfwUX70" + + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+validToken) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + // Verify response contains the custom claims + assert.Contains(t, string(body), "John Doe") + assert.Contains(t, string(body), "user123") +} + +func TestHTTPExample_TokenWithShouldReject(t *testing.T) { + handler := setupHandler() + server := httptest.NewServer(handler) + defer server.Close() + + // Token with shouldReject: true + rejectToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnby1qd3QtbWlkZGxld2FyZS1leGFtcGxlIiwiYXVkIjoiYXVkaWVuY2UtZXhhbXBsZSIsInN1YiI6IjEyMzQ1Njc4OTAiLCJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE1MTYyMzkwMjIsInVzZXJuYW1lIjoidXNlcjEyMyIsInNob3VsZFJlamVjdCI6dHJ1ZX0.Jf13PY_Oyu2x3Gx1JQ0jXRiWaCOb5T2RbKOrTPBNHJA" + + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+rejectToken) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Should be rejected due to custom validation + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) +} + +func TestHTTPExample_MissingToken(t *testing.T) { + handler := setupHandler() + server := httptest.NewServer(handler) + defer server.Close() + + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) +} + +func TestHTTPExample_InvalidToken(t *testing.T) { + handler := setupHandler() + server := httptest.NewServer(handler) + defer server.Close() + + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer invalid.token.here") + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) +} + +func TestHTTPExample_WrongIssuer(t *testing.T) { + handler := setupHandler() + server := httptest.NewServer(handler) + defer server.Close() + + // Token with wrong issuer + wrongIssuerToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ3cm9uZy1pc3N1ZXIiLCJhdWQiOiJhdWRpZW5jZS1leGFtcGxlIiwic3ViIjoiMTIzNDU2Nzg5MCIsIm5hbWUiOiJKb2huIERvZSIsImlhdCI6MTUxNjIzOTAyMiwidXNlcm5hbWUiOiJ1c2VyMTIzIn0.8m4cV8KJFmKnHvY4I0F4Y9L8x-vH7RxQ1qvQzc6YZ8M" + + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+wrongIssuerToken) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) +} diff --git a/examples/http-jwks-example/main.go b/examples/http-jwks-example/main.go index f81aff9..9743715 100644 --- a/examples/http-jwks-example/main.go +++ b/examples/http-jwks-example/main.go @@ -13,14 +13,16 @@ import ( ) var handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - claims, ok := r.Context().Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims) - if !ok { + // Modern type-safe claims retrieval using generics + claims, err := jwtmiddleware.GetClaims[*validator.ValidatedClaims](r.Context()) + if err != nil { http.Error(w, "failed to get validated claims", http.StatusInternalServerError) return } if len(claims.RegisteredClaims.Subject) == 0 { http.Error(w, "subject in JWT claims was empty", http.StatusBadRequest) + return } payload, err := json.Marshal(claims) @@ -58,7 +60,15 @@ func setupHandler(issuer string, audience []string) http.Handler { log.Fatalf("failed to set up the validator: %v", err) } - return jwtmiddleware.New(jwtValidator.ValidateToken).CheckJWT(handler) + // Set up the middleware using pure options pattern + middleware, err := jwtmiddleware.New( + jwtmiddleware.WithValidator(jwtValidator), + ) + if err != nil { + log.Fatalf("failed to set up the middleware: %v", err) + } + + return middleware.CheckJWT(handler) } func main() { diff --git a/examples/iris-example/go.mod b/examples/iris-example/go.mod index f089e74..bc14f1f 100644 --- a/examples/iris-example/go.mod +++ b/examples/iris-example/go.mod @@ -17,16 +17,24 @@ require ( github.com/CloudyKit/jet/v6 v6.2.0 // indirect github.com/Joker/jade v1.1.3 // indirect github.com/Shopify/goreferrer v0.0.0-20240724165105-aceaa0259138 // indirect + github.com/ajg/form v1.5.1 // indirect github.com/andybalholm/brotli v1.1.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect + github.com/fatih/color v1.15.0 // indirect github.com/fatih/structs v1.1.0 // indirect github.com/flosch/pongo2/v4 v4.0.2 // indirect + github.com/gobwas/glob v0.2.3 // indirect github.com/goccy/go-json v0.10.3 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/gomarkdown/markdown v0.0.0-20250207164621-7a1f277a159e // indirect + github.com/google/go-querystring v1.1.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/gorilla/css v1.0.1 // indirect + github.com/gorilla/websocket v1.5.1 // indirect + github.com/imkira/go-interpol v1.1.0 // indirect + github.com/iris-contrib/httpexpect/v2 v2.15.2 // indirect github.com/iris-contrib/schema v0.0.6 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/kataras/blocks v0.0.11 // indirect @@ -45,18 +53,31 @@ require ( github.com/lestrrat-go/option/v2 v2.0.0 // indirect github.com/mailgun/raymond/v2 v2.0.48 // indirect github.com/mailru/easyjson v0.7.7 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect github.com/microcosm-cc/bluemonday v1.0.27 // indirect + github.com/mitchellh/go-wordwrap v1.0.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/sanity-io/litter v1.5.5 // indirect github.com/schollz/closestmatch v2.1.0+incompatible // indirect github.com/segmentio/asm v1.2.1 // indirect + github.com/sergi/go-diff v1.0.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect + github.com/stretchr/testify v1.11.1 // indirect github.com/tdewolff/minify/v2 v2.20.37 // indirect github.com/tdewolff/parse/v2 v2.7.20 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fastjson v1.6.4 // indirect github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect + github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect + github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect + github.com/xeipuuv/gojsonschema v1.2.0 // indirect + github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 // indirect github.com/yosssi/ace v0.0.5 // indirect + github.com/yudai/gojsondiff v1.0.0 // indirect + github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect golang.org/x/crypto v0.45.0 // indirect golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect golang.org/x/net v0.47.0 // indirect @@ -65,5 +86,7 @@ require ( golang.org/x/time v0.5.0 // indirect google.golang.org/protobuf v1.33.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + moul.io/http2curl/v2 v2.3.0 // indirect ) diff --git a/examples/iris-example/go.sum b/examples/iris-example/go.sum index 004d3a2..22feae6 100644 --- a/examples/iris-example/go.sum +++ b/examples/iris-example/go.sum @@ -16,6 +16,7 @@ github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7X github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= +github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -27,6 +28,8 @@ github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/flosch/pongo2/v4 v4.0.2 h1:gv+5Pe3vaSVmiJvh/BZa82b7/00YUGm0PIyVVLop0Hw= github.com/flosch/pongo2/v4 v4.0.2/go.mod h1:B5ObFANs/36VwxxlgKpdchIJHMvHB562PW+BWPhwZD8= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= @@ -35,6 +38,7 @@ github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gomarkdown/markdown v0.0.0-20250207164621-7a1f277a159e h1:ESHlT0RVZphh4JGBz49I5R6nTdC8Qyc08vU25GQHzzQ= github.com/gomarkdown/markdown v0.0.0-20250207164621-7a1f277a159e/go.mod h1:JDGcbDT52eL4fju3sZ4TeHGsQwhG9nbDV21aMyhwPoA= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= @@ -92,6 +96,7 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0 github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk= @@ -100,6 +105,14 @@ github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQ github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/nxadm/tail v1.4.11 h1:8feyoE3OzPrcshW5/MJ4sGESc5cqmGkGCWlco4l0bqY= +github.com/nxadm/tail v1.4.11/go.mod h1:OTaG3NK980DZzxbRq6lEuzgU+mug70nY11sMd4JXXHc= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= +github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI= +github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M= +github.com/pkg/diff v0.0.0-20200914180035-5b29258ca4f7/go.mod h1:zO8QMzTeZd5cpnIkz/Gn6iK0jDfGicM1nynOkkPIl28= +github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= @@ -116,12 +129,16 @@ github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v0.0.0-20161117074351-18a02ba4a312/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tailscale/depaware v0.0.0-20210622194025-720c4b409502/go.mod h1:p9lPsd+cx33L3H9nNoecRRxPssFKUwwI50I3pZ0yT+8= github.com/tdewolff/minify/v2 v2.20.37 h1:Q97cx4STXCh1dlWDlNHZniE8BJ2EBL0+2b0n92BJQhw= github.com/tdewolff/minify/v2 v2.20.37/go.mod h1:L1VYef/jwKw6Wwyk5A+T0mBjjn3mMPgmjjA688RNsxU= github.com/tdewolff/parse/v2 v2.7.20 h1:Y33JmRLjyGhX5JRvYh+CO6Sk6pGMw3iO5eKGhUhx8JE= @@ -153,33 +170,45 @@ github.com/yudai/gojsondiff v1.0.0 h1:27cbfqXLVEJ1o8I6v3y9lg8Ydm53EKqHXAOMxEGlCO github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg= github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 h1:BHyfKlQyqbsFN5p3IfnEUduWvb9is428/nNb5L3U01M= github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM= +github.com/yudai/pp v2.0.1+incompatible h1:Q4//iY4pNF6yPLZIigmvcl7k/bPgrcTPIFIcmawg5bI= +github.com/yudai/pp v2.0.1+incompatible/go.mod h1:PuxR/8QJ7cyCkFp/aUDS+JY727OFEZkTdatxwunjIkc= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0= golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= golang.org/x/net v0.0.0-20190327091125-710a502c58a2/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= @@ -188,9 +217,11 @@ golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20201211185031-d93e913c1a58/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.9/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= @@ -199,6 +230,9 @@ gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b h1:QRR6H1YWRnHb4Y/HeNFCTJLF gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/examples/iris-example/main.go b/examples/iris-example/main.go index 6f2e27f..b397adc 100644 --- a/examples/iris-example/main.go +++ b/examples/iris-example/main.go @@ -1,11 +1,12 @@ package main import ( + "log" + "net/http" + jwtmiddleware "github.com/auth0/go-jwt-middleware/v3" "github.com/auth0/go-jwt-middleware/v3/validator" "github.com/kataras/iris/v12" - "log" - "net/http" ) // Try it out with: @@ -39,12 +40,17 @@ import ( // "shouldReject": true // } -func main() { +func setupApp() *iris.Application { app := iris.New() - app.Get("/", checkJWT(), func(ctx iris.Context) { - claims, ok := ctx.Request().Context().Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims) - if !ok { + app.Get("/api/public", func(ctx iris.Context) { + ctx.JSON(map[string]string{"message": "Hello from a public endpoint!"}) + }) + + app.Get("/api/private", checkJWT(), func(ctx iris.Context) { + // Modern type-safe claims retrieval using generics + claims, err := jwtmiddleware.GetClaims[*validator.ValidatedClaims](ctx.Request().Context()) + if err != nil { ctx.StopWithJSON( http.StatusInternalServerError, map[string]string{"message": "Failed to get validated JWT claims."}, @@ -72,6 +78,12 @@ func main() { ctx.JSON(claims) }) + return app +} + +func main() { + app := setupApp() + log.Print("Server listening on http://localhost:3000") if err := app.Listen(":3000"); err != nil { log.Fatalf("There was an error with the http server: %v", err) diff --git a/examples/iris-example/main_integration_test.go b/examples/iris-example/main_integration_test.go new file mode 100644 index 0000000..47050e4 --- /dev/null +++ b/examples/iris-example/main_integration_test.go @@ -0,0 +1,76 @@ +package main + +import ( + "testing" + + "github.com/kataras/iris/v12/httptest" +) + +func TestIrisExample_PublicEndpoint(t *testing.T) { + app := setupApp() + e := httptest.New(t, app) + + e.GET("/api/public"). + Expect(). + Status(httptest.StatusOK). + JSON().Object(). + ContainsKey("message"). + ValueEqual("message", "Hello from a public endpoint!") +} + +func TestIrisExample_ValidToken(t *testing.T) { + app := setupApp() + e := httptest.New(t, app) + + // Valid token from the example + validToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnby1qd3QtbWlkZGxld2FyZS1leGFtcGxlIiwiYXVkIjoiYXVkaWVuY2UtZXhhbXBsZSIsInN1YiI6IjEyMzQ1Njc4OTAiLCJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE1MTYyMzkwMjIsInVzZXJuYW1lIjoidXNlcjEyMyJ9.XFhrzWzntyINkgoRt2mb8dES84dJcuOoORdzKfwUX70" + + e.GET("/api/private"). + WithHeader("Authorization", "Bearer "+validToken). + Expect(). + Status(httptest.StatusOK). + JSON().Object(). + ContainsKey("RegisteredClaims"). + ContainsKey("CustomClaims") +} + +func TestIrisExample_MissingToken(t *testing.T) { + app := setupApp() + e := httptest.New(t, app) + + e.GET("/api/private"). + Expect(). + Status(httptest.StatusUnauthorized). + JSON().Object(). + ContainsKey("message"). + ValueEqual("message", "JWT is invalid.") +} + +func TestIrisExample_InvalidToken(t *testing.T) { + app := setupApp() + e := httptest.New(t, app) + + e.GET("/api/private"). + WithHeader("Authorization", "Bearer invalid.token.here"). + Expect(). + Status(httptest.StatusUnauthorized). + JSON().Object(). + ContainsKey("message"). + ValueEqual("message", "JWT is invalid.") +} + +func TestIrisExample_WrongIssuer(t *testing.T) { + app := setupApp() + e := httptest.New(t, app) + + // Token with wrong issuer + wrongIssuerToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ3cm9uZy1pc3N1ZXIiLCJhdWQiOiJhdWRpZW5jZS1leGFtcGxlIiwic3ViIjoiMTIzNDU2Nzg5MCIsIm5hbWUiOiJKb2huIERvZSIsImlhdCI6MTUxNjIzOTAyMiwidXNlcm5hbWUiOiJ1c2VyMTIzIn0.8m4cV8KJFmKnHvY4I0F4Y9L8x-vH7RxQ1qvQzc6YZ8M" + + e.GET("/api/private"). + WithHeader("Authorization", "Bearer "+wrongIssuerToken). + Expect(). + Status(httptest.StatusUnauthorized). + JSON().Object(). + ContainsKey("message"). + ValueEqual("message", "JWT is invalid.") +} diff --git a/examples/iris-example/middleware.go b/examples/iris-example/middleware.go index 67fc295..9663538 100644 --- a/examples/iris-example/middleware.go +++ b/examples/iris-example/middleware.go @@ -2,11 +2,12 @@ package main import ( "context" - "github.com/kataras/iris/v12" "log" "net/http" "time" + "github.com/kataras/iris/v12" + jwtmiddleware "github.com/auth0/go-jwt-middleware/v3" "github.com/auth0/go-jwt-middleware/v3/validator" ) @@ -50,10 +51,14 @@ func checkJWT() iris.Handler { log.Printf("Encountered error while validating JWT: %v", err) } - middleware := jwtmiddleware.New( - jwtValidator.ValidateToken, + // Set up the middleware using pure options pattern + middleware, err := jwtmiddleware.New( + jwtmiddleware.WithValidator(jwtValidator), jwtmiddleware.WithErrorHandler(errorHandler), ) + if err != nil { + log.Fatalf("failed to set up the middleware: %v", err) + } return func(ctx iris.Context) { encounteredError := true diff --git a/extractor.go b/extractor.go index 376e513..d74a839 100644 --- a/extractor.go +++ b/extractor.go @@ -33,10 +33,20 @@ func AuthHeaderTokenExtractor(r *http.Request) (string, error) { // extracts the token from the cookie using the passed in cookieName. func CookieTokenExtractor(cookieName string) TokenExtractor { return func(r *http.Request) (string, error) { + if cookieName == "" { + return "", errors.New("cookie name cannot be empty") + } + cookie, err := r.Cookie(cookieName) if err == http.ErrNoCookie { return "", nil // No cookie, then no JWT, so no error. } + if err != nil { + // Defensive: r.Cookie() rarely returns non-ErrNoCookie errors in practice, + // but we handle them properly for robustness. The http package's cookie + // parsing is very lenient and typically only returns ErrNoCookie. + return "", err + } return cookie.Value, nil } @@ -46,6 +56,9 @@ func CookieTokenExtractor(cookieName string) TokenExtractor { // the token from the specified query string parameter. func ParameterTokenExtractor(param string) TokenExtractor { return func(r *http.Request) (string, error) { + if param == "" { + return "", errors.New("parameter name cannot be empty") + } return r.URL.Query().Get(param), nil } } diff --git a/extractor_test.go b/extractor_test.go index 3101847..86d839c 100644 --- a/extractor_test.go +++ b/extractor_test.go @@ -40,6 +40,42 @@ func Test_AuthHeaderTokenExtractor(t *testing.T) { }, wantError: "Authorization header format must be Bearer {token}", }, + { + name: "bearer with uppercase", + request: &http.Request{ + Header: http.Header{ + "Authorization": []string{"BEARER i-am-a-token"}, + }, + }, + wantToken: "i-am-a-token", + }, + { + name: "bearer with mixed case", + request: &http.Request{ + Header: http.Header{ + "Authorization": []string{"BeArEr i-am-a-token"}, + }, + }, + wantToken: "i-am-a-token", + }, + { + name: "multiple spaces between bearer and token", + request: &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer i-am-a-token"}, + }, + }, + wantToken: "i-am-a-token", + }, + { + name: "extra parts after token", + request: &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer token extra-part"}, + }, + }, + wantError: "Authorization header format must be Bearer {token}", + }, } for _, testCase := range testCases { @@ -60,19 +96,33 @@ func Test_AuthHeaderTokenExtractor(t *testing.T) { } func Test_ParameterTokenExtractor(t *testing.T) { - wantToken := "i am a token" - param := "i-am-param" + t.Run("extracts token from query parameter", func(t *testing.T) { + wantToken := "i am a token" + param := "i-am-param" + + testURL, err := url.Parse(fmt.Sprintf("http://localhost?%s=%s", param, wantToken)) + require.NoError(t, err) + + request := &http.Request{URL: testURL} + tokenExtractor := ParameterTokenExtractor(param) + + gotToken, err := tokenExtractor(request) + require.NoError(t, err) - testURL, err := url.Parse(fmt.Sprintf("http://localhost?%s=%s", param, wantToken)) - require.NoError(t, err) + assert.Equal(t, wantToken, gotToken) + }) - request := &http.Request{URL: testURL} - tokenExtractor := ParameterTokenExtractor(param) + t.Run("returns error for empty parameter name", func(t *testing.T) { + testURL, err := url.Parse("http://localhost?token=abc") + require.NoError(t, err) - gotToken, err := tokenExtractor(request) - require.NoError(t, err) + request := &http.Request{URL: testURL} + tokenExtractor := ParameterTokenExtractor("") - assert.Equal(t, wantToken, gotToken) + gotToken, err := tokenExtractor(request) + assert.EqualError(t, err, "parameter name cannot be empty") + assert.Empty(t, gotToken) + }) } func Test_CookieTokenExtractor(t *testing.T) { @@ -121,6 +171,15 @@ func Test_CookieTokenExtractor(t *testing.T) { assert.Equal(t, testCase.wantToken, gotToken) }) } + + t.Run("returns error for empty cookie name", func(t *testing.T) { + request, err := http.NewRequest(http.MethodGet, "https://example.com", nil) + require.NoError(t, err) + + gotToken, err := CookieTokenExtractor("")(request) + assert.EqualError(t, err, "cookie name cannot be empty") + assert.Empty(t, gotToken) + }) } func Test_MultiTokenExtractor(t *testing.T) { diff --git a/go.mod b/go.mod index 41913ac..349d857 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,6 @@ require ( github.com/google/go-cmp v0.7.0 github.com/lestrrat-go/jwx/v3 v3.0.12 github.com/stretchr/testify v1.11.1 - golang.org/x/sync v0.18.0 - gopkg.in/go-jose/go-jose.v2 v2.6.3 ) require ( diff --git a/go.sum b/go.sum index ed3a1d2..e33c5bc 100644 --- a/go.sum +++ b/go.sum @@ -36,14 +36,10 @@ github.com/valyala/fastjson v1.6.4 h1:uAUNq9Z6ymTgGhcm0UynUAB6tlbakBrz6CQFax3BXV github.com/valyala/fastjson v1.6.4/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= -golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= -golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/go-jose/go-jose.v2 v2.6.3 h1:nt80fvSDlhKWQgSWyHyy5CfmlQr+asih51R8PTWNKKs= -gopkg.in/go-jose/go-jose.v2 v2.6.3/go.mod h1:zzZDPkNNw/c9IE7Z9jr11mBZQhKQTMzoEEIoEdZlFBI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/middleware.go b/middleware.go index 2f82076..407802e 100644 --- a/middleware.go +++ b/middleware.go @@ -4,20 +4,36 @@ import ( "context" "fmt" "net/http" -) -// ContextKey is the key used in the request -// context where the information from a -// validated JWT will be stored. -type ContextKey struct{} + "github.com/auth0/go-jwt-middleware/v3/core" + "github.com/auth0/go-jwt-middleware/v3/validator" +) +// JWTMiddleware is a middleware that validates JWTs and makes claims available in the request context. +// It wraps the core validation engine and provides HTTP-specific functionality like token extraction +// and error handling. +// +// Claims are stored in the context using core.SetClaims() and can be retrieved using core.GetClaims[T](). type JWTMiddleware struct { - validateToken ValidateToken + core *core.Core errorHandler ErrorHandler tokenExtractor TokenExtractor - credentialsOptional bool validateOnOptions bool exclusionUrlHandler ExclusionUrlHandler + logger Logger + + // Temporary fields used during construction + validator *validator.Validator + credentialsOptional bool +} + +// Logger defines an optional logging interface compatible with log/slog. +// This is the same interface used by core for consistent logging across the stack. +type Logger interface { + Debug(msg string, args ...any) + Info(msg string, args ...any) + Warn(msg string, args ...any) + Error(msg string, args ...any) } // ValidateToken takes in a string JWT and makes sure it is valid and @@ -25,29 +41,148 @@ type JWTMiddleware struct { // an error message describing why validation failed. // Inside ValidateToken things like key and alg checking can happen. // In the default implementation we can add safe defaults for those. -type ValidateToken func(context.Context, string) (interface{}, error) +type ValidateToken func(context.Context, string) (any, error) // ExclusionUrlHandler is a function that takes in a http.Request and returns // true if the request should be excluded from JWT validation. type ExclusionUrlHandler func(r *http.Request) bool // New constructs a new JWTMiddleware instance with the supplied options. -// It requires a ValidateToken function to be passed in, so it can -// properly validate tokens. -func New(validateToken ValidateToken, opts ...Option) *JWTMiddleware { +// All parameters are passed via options (pure options pattern). +// +// Required options: +// - WithValidator: A configured validator instance +// +// Example: +// +// v, err := validator.New( +// validator.WithKeyFunc(keyFunc), +// validator.WithAlgorithm(validator.RS256), +// validator.WithIssuer("https://issuer.example.com/"), +// validator.WithAudience("my-api"), +// ) +// if err != nil { +// log.Fatal(err) +// } +// +// middleware, err := jwtmiddleware.New( +// jwtmiddleware.WithValidator(v), +// jwtmiddleware.WithCredentialsOptional(false), +// ) +// if err != nil { +// log.Fatalf("failed to create middleware: %v", err) +// } +func New(opts ...Option) (*JWTMiddleware, error) { m := &JWTMiddleware{ - validateToken: validateToken, - errorHandler: DefaultErrorHandler, - credentialsOptional: false, - tokenExtractor: AuthHeaderTokenExtractor, - validateOnOptions: true, + // Set secure defaults before applying options + validateOnOptions: true, // Validate OPTIONS by default + credentialsOptional: false, // Credentials required by default } + // Apply all options for _, opt := range opts { - opt(m) + if err := opt(m); err != nil { + return nil, fmt.Errorf("invalid option: %w", err) + } + } + + // Validate required configuration + if err := m.validate(); err != nil { + return nil, fmt.Errorf("invalid middleware configuration: %w", err) + } + + // Apply defaults for optional fields not set by options + m.applyDefaults() + + // Create the core with the configured validator and options + if err := m.createCore(); err != nil { + return nil, fmt.Errorf("failed to create core: %w", err) + } + + return m, nil +} + +// validate ensures all required fields are set +func (m *JWTMiddleware) validate() error { + if m.validator == nil { + return ErrValidatorNil + } + return nil +} + +// createCore creates the core.Core instance with the configured options +func (m *JWTMiddleware) createCore() error { + adapter := &validatorAdapter{validator: m.validator} + + // Build core options + coreOpts := []core.Option{ + core.WithValidator(adapter), + core.WithCredentialsOptional(m.credentialsOptional), + } + + // Add logger if configured + if m.logger != nil { + coreOpts = append(coreOpts, core.WithLogger(m.logger)) + } + + coreInstance, err := core.New(coreOpts...) + if err != nil { + return err } + m.core = coreInstance + return nil +} - return m +// applyDefaults sets secure default values for optional fields +func (m *JWTMiddleware) applyDefaults() { + if m.errorHandler == nil { + m.errorHandler = DefaultErrorHandler + } + if m.tokenExtractor == nil { + m.tokenExtractor = AuthHeaderTokenExtractor + } +} + +// GetClaims retrieves claims from the context with type safety using generics. +// This provides compile-time type checking and eliminates the need for manual type assertions. +// +// Example: +// +// claims, err := jwtmiddleware.GetClaims[*validator.ValidatedClaims](r.Context()) +// if err != nil { +// http.Error(w, "failed to get claims", http.StatusInternalServerError) +// return +// } +// fmt.Println(claims.RegisteredClaims.Subject) +func GetClaims[T any](ctx context.Context) (T, error) { + return core.GetClaims[T](ctx) +} + +// MustGetClaims retrieves claims from the context or panics. +// Use only when you are certain claims exist (e.g., after middleware has run). +// +// Example: +// +// claims := jwtmiddleware.MustGetClaims[*validator.ValidatedClaims](r.Context()) +// fmt.Println(claims.RegisteredClaims.Subject) +func MustGetClaims[T any](ctx context.Context) T { + claims, err := core.GetClaims[T](ctx) + if err != nil { + panic(err) + } + return claims +} + +// HasClaims checks if claims exist in the context. +// +// Example: +// +// if jwtmiddleware.HasClaims(r.Context()) { +// claims, _ := jwtmiddleware.GetClaims[*validator.ValidatedClaims](r.Context()) +// // Use claims... +// } +func HasClaims(ctx context.Context) bool { + return core.HasClaims(ctx) } // CheckJWT is the main JWTMiddleware function which performs the main logic. It @@ -56,47 +191,78 @@ func (m *JWTMiddleware) CheckJWT(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // If there's an exclusion handler and the URL matches, skip JWT validation if m.exclusionUrlHandler != nil && m.exclusionUrlHandler(r) { + if m.logger != nil { + m.logger.Debug("skipping JWT validation for excluded URL", + "method", r.Method, + "path", r.URL.Path) + } next.ServeHTTP(w, r) return } // If we don't validate on OPTIONS and this is OPTIONS // then continue onto next without validating. if !m.validateOnOptions && r.Method == http.MethodOptions { + if m.logger != nil { + m.logger.Debug("skipping JWT validation for OPTIONS request") + } next.ServeHTTP(w, r) return } + if m.logger != nil { + m.logger.Debug("extracting JWT from request", + "method", r.Method, + "path", r.URL.Path) + } + token, err := m.tokenExtractor(r) if err != nil { // This is not ErrJWTMissing because an error here means that the // tokenExtractor had an error and _not_ that the token was missing. + if m.logger != nil { + m.logger.Error("failed to extract token from request", + "error", err, + "method", r.Method, + "path", r.URL.Path) + } m.errorHandler(w, r, fmt.Errorf("error extracting token: %w", err)) return } - if token == "" { - // If credentials are optional continue - // onto next without validating. - if m.credentialsOptional { - next.ServeHTTP(w, r) - return - } - - // Credentials were not optional so we error. - m.errorHandler(w, r, ErrJWTMissing) - return + if m.logger != nil { + m.logger.Debug("validating JWT") } - // Validate the token using the token validator. - validToken, err := m.validateToken(r.Context(), token) + // Validate the token using the core validator. + // Core handles empty token logic based on credentialsOptional setting. + validToken, err := m.core.CheckToken(r.Context(), token) if err != nil { + if m.logger != nil { + m.logger.Warn("JWT validation failed", + "error", err, + "method", r.Method, + "path", r.URL.Path) + } m.errorHandler(w, r, &invalidError{details: err}) return } + // If credentials are optional and no token was provided, + // core.CheckToken returns (nil, nil), so we continue without setting claims + if validToken == nil { + if m.logger != nil { + m.logger.Debug("no credentials provided, continuing without claims (credentials optional)") + } + next.ServeHTTP(w, r) + return + } + // No err means we have a valid token, so set // it into the context and continue onto next. - r = r.Clone(context.WithValue(r.Context(), ContextKey{}, validToken)) + if m.logger != nil { + m.logger.Debug("JWT validation successful, setting claims in context") + } + r = r.Clone(core.SetClaims(r.Context(), validToken)) next.ServeHTTP(w, r) }) } diff --git a/middleware_test.go b/middleware_test.go index a05b604..2ec3fc9 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -44,7 +44,7 @@ func Test_CheckJWT(t *testing.T) { testCases := []struct { name string - validateToken ValidateToken + validator *validator.Validator // Changed from validateToken options []Option method string token string @@ -55,7 +55,7 @@ func Test_CheckJWT(t *testing.T) { }{ { name: "it can successfully validate a token", - validateToken: jwtValidator.ValidateToken, + validator: jwtValidator, token: validToken, method: http.MethodGet, wantToken: tokenClaims, @@ -64,7 +64,7 @@ func Test_CheckJWT(t *testing.T) { }, { name: "it can validate on options", - validateToken: jwtValidator.ValidateToken, + validator: jwtValidator, method: http.MethodOptions, token: validToken, wantToken: tokenClaims, @@ -76,22 +76,22 @@ func Test_CheckJWT(t *testing.T) { token: "bad", method: http.MethodGet, wantStatusCode: http.StatusInternalServerError, - wantBody: `{"message":"Something went wrong while checking the JWT."}`, + wantBody: `{"error":"server_error","error_description":"An internal error occurred while processing the request"}`, }, { name: "it fails to validate if token is missing and credentials are not optional", token: "", method: http.MethodGet, - wantStatusCode: http.StatusBadRequest, - wantBody: `{"message":"JWT is missing."}`, + wantStatusCode: http.StatusUnauthorized, + wantBody: `{"error":"invalid_token","error_description":"JWT is missing"}`, }, { name: "it fails to validate an invalid token", - validateToken: jwtValidator.ValidateToken, + validator: jwtValidator, token: invalidToken, method: http.MethodGet, wantStatusCode: http.StatusUnauthorized, - wantBody: `{"message":"JWT is invalid."}`, + wantBody: `{"error":"invalid_token","error_description":"JWT is invalid"}`, }, { name: "it skips validation on OPTIONS if validateOnOptions is set to false", @@ -112,7 +112,7 @@ func Test_CheckJWT(t *testing.T) { }, method: http.MethodGet, wantStatusCode: http.StatusInternalServerError, - wantBody: `{"message":"Something went wrong while checking the JWT."}`, + wantBody: `{"error":"server_error","error_description":"An internal error occurred while processing the request"}`, }, { name: "credentialsOptional true", @@ -136,8 +136,8 @@ func Test_CheckJWT(t *testing.T) { }), }, method: http.MethodGet, - wantStatusCode: http.StatusBadRequest, - wantBody: `{"message":"JWT is missing."}`, + wantStatusCode: http.StatusUnauthorized, + wantBody: `{"error":"invalid_token","error_description":"JWT is missing"}`, }, { name: "JWT not required for /public", @@ -180,8 +180,8 @@ func Test_CheckJWT(t *testing.T) { method: http.MethodGet, path: "/secure", token: "", - wantStatusCode: http.StatusBadRequest, - wantBody: `{"message":"JWT is missing."}`, + wantStatusCode: http.StatusUnauthorized, + wantBody: `{"error":"invalid_token","error_description":"JWT is missing"}`, }, } @@ -190,11 +190,32 @@ func Test_CheckJWT(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { t.Parallel() - middleware := New(testCase.validateToken, testCase.options...) + // Use the test's validator if specified, otherwise create a default failing validator + v := testCase.validator + if v == nil { + // Create a validator that always fails + keyFunc := func(context.Context) (interface{}, error) { + return nil, errors.New("no key") + } + v, _ = validator.New( + validator.WithKeyFunc(keyFunc), + validator.WithAlgorithm(validator.HS256), + validator.WithIssuer("fail"), + validator.WithAudience("fail"), + ) + } + + opts := append([]Option{WithValidator(v)}, testCase.options...) + middleware, err := New(opts...) + require.NoError(t, err) - var actualValidatedClaims interface{} + var actualValidatedClaims any var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - actualValidatedClaims = r.Context().Value(ContextKey{}) + // Use the public API to get claims + if HasClaims(r.Context()) { + claims, _ := GetClaims[any](r.Context()) + actualValidatedClaims = claims + } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) @@ -221,7 +242,11 @@ func Test_CheckJWT(t *testing.T) { assert.Equal(t, testCase.wantStatusCode, response.StatusCode) assert.Equal(t, "application/json", response.Header.Get("Content-Type")) - assert.Equal(t, testCase.wantBody, string(body)) + + // Compare JSON responses (ignoring formatting differences like newlines) + if testCase.wantBody != "" { + assert.JSONEq(t, testCase.wantBody, string(body)) + } if want, got := testCase.wantToken, actualValidatedClaims; !cmp.Equal(want, got) { t.Fatal(cmp.Diff(want, got)) diff --git a/option.go b/option.go index bb49c8a..5a09dbc 100644 --- a/option.go +++ b/option.go @@ -1,58 +1,118 @@ package jwtmiddleware import ( + "context" + "errors" "net/http" + + "github.com/auth0/go-jwt-middleware/v3/validator" ) -// Option is how options for the JWTMiddleware are set up. -type Option func(*JWTMiddleware) +// Option configures the JWTMiddleware. +// Returns error for validation failures. +type Option func(*JWTMiddleware) error + +// TokenValidator defines the interface for token validation. +// This interface is satisfied by *validator.Validator and allows +// explicit passing of validation methods. +type TokenValidator interface { + ValidateToken(ctx context.Context, token string) (any, error) +} + +// validatorAdapter adapts the TokenValidator to the core.TokenValidator interface +type validatorAdapter struct { + validator TokenValidator +} + +func (v *validatorAdapter) ValidateToken(ctx context.Context, token string) (any, error) { + return v.validator.ValidateToken(ctx, token) +} + +// WithValidator sets the validator instance to validate tokens (REQUIRED). +// The validator must be a *validator.Validator instance. +// This approach allows explicit passing of validation methods and future +// extensibility for methods like ValidateDPoP. +// +// Example: +// +// v, err := validator.New( +// validator.WithKeyFunc(keyFunc), +// validator.WithAlgorithm(validator.RS256), +// validator.WithIssuer("https://issuer.example.com/"), +// validator.WithAudience("my-api"), +// ) +// if err != nil { +// log.Fatal(err) +// } +// +// middleware, err := jwtmiddleware.New( +// jwtmiddleware.WithValidator(v), +// ) +func WithValidator(v *validator.Validator) Option { + return func(m *JWTMiddleware) error { + if v == nil { + return ErrValidatorNil + } + m.validator = v + return nil + } +} -// WithCredentialsOptional sets up if credentials are -// optional or not. If set to true then an empty token -// will be considered valid. +// WithCredentialsOptional sets whether credentials are optional. +// If set to true, an empty token will be considered valid. // -// Default value: false. +// Default: false (credentials required) func WithCredentialsOptional(value bool) Option { - return func(m *JWTMiddleware) { + return func(m *JWTMiddleware) error { m.credentialsOptional = value + return nil } } -// WithValidateOnOptions sets up if OPTIONS requests -// should have their JWT validated or not. +// WithValidateOnOptions sets whether OPTIONS requests should have their JWT validated. // -// Default value: true. +// Default: true (OPTIONS requests are validated) func WithValidateOnOptions(value bool) Option { - return func(m *JWTMiddleware) { + return func(m *JWTMiddleware) error { m.validateOnOptions = value + return nil } } -// WithErrorHandler sets the handler which is called -// when we encounter errors in the JWTMiddleware. +// WithErrorHandler sets the handler called when errors occur during JWT validation. // See the ErrorHandler type for more information. // -// Default value: DefaultErrorHandler. +// Default: DefaultErrorHandler func WithErrorHandler(h ErrorHandler) Option { - return func(m *JWTMiddleware) { + return func(m *JWTMiddleware) error { + if h == nil { + return ErrErrorHandlerNil + } m.errorHandler = h + return nil } } -// WithTokenExtractor sets up the function which extracts -// the JWT to be validated from the request. +// WithTokenExtractor sets the function to extract the JWT from the request. // -// Default value: AuthHeaderTokenExtractor. +// Default: AuthHeaderTokenExtractor func WithTokenExtractor(e TokenExtractor) Option { - return func(m *JWTMiddleware) { + return func(m *JWTMiddleware) error { + if e == nil { + return ErrTokenExtractorNil + } m.tokenExtractor = e + return nil } } -// WithExclusionUrls allows configuring the exclusion URL handler with multiple URLs -// that should be excluded from JWT validation. +// WithExclusionUrls configures URL patterns to exclude from JWT validation. +// URLs can be full URLs or just paths. func WithExclusionUrls(exclusions []string) Option { - return func(m *JWTMiddleware) { + return func(m *JWTMiddleware) error { + if len(exclusions) == 0 { + return ErrExclusionUrlsEmpty + } m.exclusionUrlHandler = func(r *http.Request) bool { requestFullURL := r.URL.String() requestPath := r.URL.Path @@ -64,5 +124,36 @@ func WithExclusionUrls(exclusions []string) Option { } return false } + return nil } } + +// WithLogger sets an optional logger for the middleware. +// The logger will be used throughout the validation flow in both middleware and core. +// +// The logger interface is compatible with log/slog.Logger and similar loggers. +// +// Example: +// +// middleware, err := jwtmiddleware.New( +// jwtmiddleware.WithValidateToken(validator.ValidateToken), +// jwtmiddleware.WithLogger(slog.Default()), +// ) +func WithLogger(logger Logger) Option { + return func(m *JWTMiddleware) error { + if logger == nil { + return ErrLoggerNil + } + m.logger = logger + return nil + } +} + +// Sentinel errors for configuration validation +var ( + ErrValidatorNil = errors.New("validator cannot be nil (use WithValidator)") + ErrErrorHandlerNil = errors.New("errorHandler cannot be nil") + ErrTokenExtractorNil = errors.New("tokenExtractor cannot be nil") + ErrExclusionUrlsEmpty = errors.New("exclusion URLs list cannot be empty") + ErrLoggerNil = errors.New("logger cannot be nil") +) diff --git a/option_test.go b/option_test.go new file mode 100644 index 0000000..62f392c --- /dev/null +++ b/option_test.go @@ -0,0 +1,795 @@ +package jwtmiddleware + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/auth0/go-jwt-middleware/v3/core" + "github.com/auth0/go-jwt-middleware/v3/validator" +) + +// Test token with issuer="test-issuer" and audience="test-audience", signed with HS256 and secret="secret" +// Expires in year 2099 to ensure it works in CI for a long time +const testToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOlsidGVzdC1hdWRpZW5jZSJdLCJleHAiOjQxMDI0NDQ3OTksImlhdCI6MTU3NzgzNjgwMCwiaXNzIjoidGVzdC1pc3N1ZXIifQ.k34FmdKsA_3XaOhXsEihRUaAKk-4l4wbLRw7UCYNE2o" + +// createTestValidator creates a basic validator for testing +func createTestValidator(t *testing.T) *validator.Validator { + t.Helper() + keyFunc := func(context.Context) (interface{}, error) { + return []byte("secret"), nil + } + v, err := validator.New( + validator.WithKeyFunc(keyFunc), + validator.WithAlgorithm(validator.HS256), + validator.WithIssuer("test-issuer"), + validator.WithAudience("test-audience"), + ) + require.NoError(t, err) + return v +} + +func Test_New_OptionsValidation(t *testing.T) { + validValidator := createTestValidator(t) + + tests := []struct { + name string + opts []Option + wantErr bool + errMsg string + }{ + { + name: "missing validator", + opts: []Option{}, + wantErr: true, + errMsg: "validator cannot be nil", + }, + { + name: "nil validator", + opts: []Option{ + WithValidator(nil), + }, + wantErr: true, + errMsg: "validator cannot be nil", + }, + { + name: "valid minimal configuration", + opts: []Option{ + WithValidator(validValidator), + }, + wantErr: false, + }, + { + name: "nil error handler", + opts: []Option{ + WithValidator(validValidator), + WithErrorHandler(nil), + }, + wantErr: true, + errMsg: "errorHandler cannot be nil", + }, + { + name: "nil token extractor", + opts: []Option{ + WithValidator(validValidator), + WithTokenExtractor(nil), + }, + wantErr: true, + errMsg: "tokenExtractor cannot be nil", + }, + { + name: "empty exclusion URLs", + opts: []Option{ + WithValidator(validValidator), + WithExclusionUrls([]string{}), + }, + wantErr: true, + errMsg: "exclusion URLs list cannot be empty", + }, + { + name: "valid exclusion URLs", + opts: []Option{ + WithValidator(validValidator), + WithExclusionUrls([]string{"/health", "/metrics"}), + }, + wantErr: false, + }, + { + name: "nil logger", + opts: []Option{ + WithValidator(validValidator), + WithLogger(nil), + }, + wantErr: true, + errMsg: "logger cannot be nil", + }, + { + name: "valid logger", + opts: []Option{ + WithValidator(validValidator), + WithLogger(&mockLogger{}), + }, + wantErr: false, + }, + { + name: "valid configuration with all options", + opts: []Option{ + WithValidator(validValidator), + WithCredentialsOptional(true), + WithValidateOnOptions(false), + WithErrorHandler(DefaultErrorHandler), + WithTokenExtractor(AuthHeaderTokenExtractor), + WithExclusionUrls([]string{"/public"}), + WithLogger(&mockLogger{}), + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + middleware, err := New(tt.opts...) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + assert.Nil(t, middleware) + } else { + require.NoError(t, err) + assert.NotNil(t, middleware) + assert.NotNil(t, middleware.validator) + assert.NotNil(t, middleware.errorHandler) + assert.NotNil(t, middleware.tokenExtractor) + } + }) + } +} + +func Test_New_Defaults(t *testing.T) { + validValidator := createTestValidator(t) + + middleware, err := New( + WithValidator(validValidator), + ) + require.NoError(t, err) + + // Check defaults + assert.NotNil(t, middleware.errorHandler) + assert.NotNil(t, middleware.tokenExtractor) + assert.False(t, middleware.credentialsOptional) + assert.True(t, middleware.validateOnOptions) + assert.Nil(t, middleware.exclusionUrlHandler) +} + +func Test_WithCredentialsOptional(t *testing.T) { + validValidator := createTestValidator(t) + + tests := []struct { + name string + value bool + }{ + { + name: "credentials optional true", + value: true, + }, + { + name: "credentials optional false", + value: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + middleware, err := New( + WithValidator(validValidator), + WithCredentialsOptional(tt.value), + ) + require.NoError(t, err) + assert.Equal(t, tt.value, middleware.credentialsOptional) + }) + } +} + +func Test_WithValidateOnOptions(t *testing.T) { + validValidator := createTestValidator(t) + + tests := []struct { + name string + value bool + }{ + { + name: "validate on OPTIONS true", + value: true, + }, + { + name: "validate on OPTIONS false", + value: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + middleware, err := New( + WithValidator(validValidator), + WithValidateOnOptions(tt.value), + ) + require.NoError(t, err) + assert.Equal(t, tt.value, middleware.validateOnOptions) + }) + } +} + +func Test_WithErrorHandler(t *testing.T) { + validValidator := createTestValidator(t) + + customHandler := func(w http.ResponseWriter, r *http.Request, err error) { + w.WriteHeader(http.StatusTeapot) + } + + middleware, err := New( + WithValidator(validValidator), + WithErrorHandler(customHandler), + ) + require.NoError(t, err) + assert.NotNil(t, middleware.errorHandler) +} + +func Test_WithTokenExtractor(t *testing.T) { + validValidator := createTestValidator(t) + + customExtractor := func(r *http.Request) (string, error) { + return "custom-token", nil + } + + middleware, err := New( + WithValidator(validValidator), + WithTokenExtractor(customExtractor), + ) + require.NoError(t, err) + assert.NotNil(t, middleware.tokenExtractor) +} + +func Test_WithExclusionUrls(t *testing.T) { + validValidator := createTestValidator(t) + + exclusions := []string{"/health", "/metrics", "/public"} + + middleware, err := New( + WithValidator(validValidator), + WithExclusionUrls(exclusions), + ) + require.NoError(t, err) + assert.NotNil(t, middleware.exclusionUrlHandler) + + // Test the exclusion handler + testCases := []struct { + name string + path string + excluded bool + }{ + {"health endpoint", "/health", true}, + {"metrics endpoint", "/metrics", true}, + {"public endpoint", "/public", true}, + {"secure endpoint", "/secure", false}, + {"api endpoint", "/api/users", false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "http://example.com"+tc.path, nil) + require.NoError(t, err) + + result := middleware.exclusionUrlHandler(req) + assert.Equal(t, tc.excluded, result) + }) + } +} + +func Test_WithLogger(t *testing.T) { + t.Run("credentials optional with no token and logging", func(t *testing.T) { + logger := &mockLogger{} + validator := createTestValidator(t) + + middleware, err := New( + WithValidator(validator), + WithLogger(logger), + WithCredentialsOptional(true), + WithTokenExtractor(func(r *http.Request) (string, error) { + return "", nil // No token + }), + ) + require.NoError(t, err) + + // Create a test server with the middleware + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + }) + + testServer := httptest.NewServer(middleware.CheckJWT(handler)) + defer testServer.Close() + + // Make a request without token but credentials optional + req, err := http.NewRequest(http.MethodGet, testServer.URL, nil) + require.NoError(t, err) + + resp, err := testServer.Client().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Verify logging occurred for optional credentials + assert.Greater(t, len(logger.debugCalls), 0, "expected debug logs") + // Should have log about continuing without claims + foundOptionalLog := false + for _, call := range logger.debugCalls { + if len(call) > 0 { + if msg, ok := call[0].(string); ok && msg == "no credentials provided, continuing without claims (credentials optional)" { + foundOptionalLog = true + break + } + } + } + assert.True(t, foundOptionalLog, "expected log about continuing without claims") + }) + + t.Run("successful validation with logging", func(t *testing.T) { + logger := &mockLogger{} + validator := createTestValidator(t) + + middleware, err := New( + WithValidator(validator), + WithLogger(logger), + ) + require.NoError(t, err) + assert.NotNil(t, middleware.logger) + + // Create a test server with the middleware + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + }) + + testServer := httptest.NewServer(middleware.CheckJWT(handler)) + defer testServer.Close() + + // Make a request with a valid token (matching the test validator) + validToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ0ZXN0LWlzc3VlciIsImF1ZCI6InRlc3QtYXVkaWVuY2UifQ.4Adcj0cmV2bkeH_6hFM8pE6yx_WJ6TqXn5n4F7l_AhI" + req, err := http.NewRequest(http.MethodGet, testServer.URL, nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+validToken) + + resp, err := testServer.Client().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Verify logging occurred + assert.Greater(t, len(logger.debugCalls), 0, "expected debug logs") + // Should have logs for: extracting JWT, validating JWT, validation successful (at least 2) + assert.GreaterOrEqual(t, len(logger.debugCalls), 2) + }) + + t.Run("validation failure with logging", func(t *testing.T) { + logger := &mockLogger{} + validator := createTestValidator(t) + + middleware, err := New( + WithValidator(validator), + WithLogger(logger), + ) + require.NoError(t, err) + + // Create a test server with the middleware + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + }) + + testServer := httptest.NewServer(middleware.CheckJWT(handler)) + defer testServer.Close() + + // Make a request with an invalid token + req, err := http.NewRequest(http.MethodGet, testServer.URL, nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer bad-token") + + resp, err := testServer.Client().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Verify logging occurred + assert.Greater(t, len(logger.debugCalls), 0, "expected debug logs") + assert.Greater(t, len(logger.warnCalls), 0, "expected warn logs for validation failure") + }) + + t.Run("excluded URL with logging", func(t *testing.T) { + logger := &mockLogger{} + validator := createTestValidator(t) + + middleware, err := New( + WithValidator(validator), + WithLogger(logger), + WithExclusionUrls([]string{"/health"}), + ) + require.NoError(t, err) + + // Create a test server with the middleware + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + }) + + testServer := httptest.NewServer(middleware.CheckJWT(handler)) + defer testServer.Close() + + // Make a request to excluded URL without token + req, err := http.NewRequest(http.MethodGet, testServer.URL+"/health", nil) + require.NoError(t, err) + + resp, err := testServer.Client().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Verify logging occurred for excluded URL + assert.Greater(t, len(logger.debugCalls), 0, "expected debug logs") + // Should have log about skipping validation + foundSkipLog := false + for _, call := range logger.debugCalls { + if len(call) > 0 { + if msg, ok := call[0].(string); ok && msg == "skipping JWT validation for excluded URL" { + foundSkipLog = true + break + } + } + } + assert.True(t, foundSkipLog, "expected log about skipping validation for excluded URL") + }) + + t.Run("OPTIONS request with logging", func(t *testing.T) { + logger := &mockLogger{} + validator := createTestValidator(t) + + middleware, err := New( + WithValidator(validator), + WithLogger(logger), + WithValidateOnOptions(false), + ) + require.NoError(t, err) + + // Create a test server with the middleware + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + }) + + testServer := httptest.NewServer(middleware.CheckJWT(handler)) + defer testServer.Close() + + // Make an OPTIONS request without token + req, err := http.NewRequest(http.MethodOptions, testServer.URL, nil) + require.NoError(t, err) + + resp, err := testServer.Client().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Verify logging occurred for OPTIONS request + assert.Greater(t, len(logger.debugCalls), 0, "expected debug logs") + // Should have log about skipping validation for OPTIONS + foundSkipLog := false + for _, call := range logger.debugCalls { + if len(call) > 0 { + if msg, ok := call[0].(string); ok && msg == "skipping JWT validation for OPTIONS request" { + foundSkipLog = true + break + } + } + } + assert.True(t, foundSkipLog, "expected log about skipping validation for OPTIONS request") + }) + + t.Run("token extraction error with logging", func(t *testing.T) { + logger := &mockLogger{} + validator := createTestValidator(t) + + customExtractor := func(r *http.Request) (string, error) { + return "", errors.New("extraction failed") + } + + middleware, err := New( + WithValidator(validator), + WithLogger(logger), + WithTokenExtractor(customExtractor), + ) + require.NoError(t, err) + + // Create a test server with the middleware + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + }) + + testServer := httptest.NewServer(middleware.CheckJWT(handler)) + defer testServer.Close() + + // Make a request + req, err := http.NewRequest(http.MethodGet, testServer.URL, nil) + require.NoError(t, err) + + resp, err := testServer.Client().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Verify error logging occurred + assert.Greater(t, len(logger.errorCalls), 0, "expected error logs for extraction failure") + }) +} + +func Test_GetClaims(t *testing.T) { + tests := []struct { + name string + setupCtx func() context.Context + wantErr bool + errMsg string + }{ + { + name: "valid claims from middleware", + setupCtx: func() context.Context { + // Create a validator that matches the token we'll use + keyFunc := func(context.Context) (interface{}, error) { + return []byte("secret"), nil + } + v, err := validator.New( + validator.WithKeyFunc(keyFunc), + validator.WithAlgorithm(validator.HS256), + validator.WithIssuer("test-issuer"), + validator.WithAudience("test-audience"), + ) + require.NoError(t, err) + + middleware, err := New(WithValidator(v)) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Authorization", "Bearer "+testToken) + + var resultCtx context.Context + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resultCtx = r.Context() + w.WriteHeader(http.StatusOK) + }) + + rr := httptest.NewRecorder() + middleware.CheckJWT(handler).ServeHTTP(rr, req) + + // Verify the handler was called + require.NotNil(t, resultCtx, "Handler should have been called") + require.Equal(t, http.StatusOK, rr.Code, "Expected successful validation") + + return resultCtx + }, + wantErr: false, + }, + { + name: "claims not found", + setupCtx: func() context.Context { + return context.Background() + }, + wantErr: true, + errMsg: "claims not found", + }, + { + name: "claims wrong type", + setupCtx: func() context.Context { + // Use core.SetClaims to set wrong type + ctx := context.Background() + wrongClaims := map[string]any{"sub": "user-123"} + return core.SetClaims(ctx, wrongClaims) + }, + wantErr: true, + errMsg: "claims type assertion failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := tt.setupCtx() + claims, err := GetClaims[*validator.ValidatedClaims](ctx) + + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + } else { + require.NoError(t, err) + assert.NotNil(t, claims) + } + }) + } +} + +func Test_MustGetClaims(t *testing.T) { + // Helper to create valid context with claims through middleware + createValidContext := func() context.Context { + v := createTestValidator(t) + + middleware, err := New(WithValidator(v)) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Authorization", "Bearer "+testToken) + + var resultCtx context.Context + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resultCtx = r.Context() + }) + + rr := httptest.NewRecorder() + middleware.CheckJWT(handler).ServeHTTP(rr, req) + require.NotNil(t, resultCtx) + return resultCtx + } + + t.Run("valid claims", func(t *testing.T) { + ctx := createValidContext() + + result := MustGetClaims[*validator.ValidatedClaims](ctx) + assert.NotNil(t, result) + }) + + t.Run("panics on missing claims", func(t *testing.T) { + ctx := context.Background() + + assert.Panics(t, func() { + MustGetClaims[*validator.ValidatedClaims](ctx) + }) + }) + + t.Run("panics on wrong type", func(t *testing.T) { + wrongClaims := map[string]any{"sub": "user-123"} + ctx := core.SetClaims(context.Background(), wrongClaims) + + assert.Panics(t, func() { + MustGetClaims[*validator.ValidatedClaims](ctx) + }) + }) +} + +func Test_HasClaims(t *testing.T) { + // Helper to create context with claims through middleware + createContextWithClaims := func() context.Context { + validator := createTestValidator(t) + + middleware, _ := New(WithValidator(validator)) + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Authorization", "Bearer "+testToken) + + var resultCtx context.Context + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resultCtx = r.Context() + }) + + rr := httptest.NewRecorder() + middleware.CheckJWT(handler).ServeHTTP(rr, req) + return resultCtx + } + + tests := []struct { + name string + setupCtx func() context.Context + want bool + }{ + { + name: "has claims", + setupCtx: func() context.Context { + return createContextWithClaims() + }, + want: true, + }, + { + name: "no claims", + setupCtx: func() context.Context { + return context.Background() + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := tt.setupCtx() + result := HasClaims(ctx) + assert.Equal(t, tt.want, result) + }) + } +} + +func Test_SentinelErrors(t *testing.T) { + t.Run("ErrValidatorNil", func(t *testing.T) { + assert.True(t, errors.Is(ErrValidatorNil, ErrValidatorNil)) + assert.Contains(t, ErrValidatorNil.Error(), "validator cannot be nil") + }) + + t.Run("ErrErrorHandlerNil", func(t *testing.T) { + assert.True(t, errors.Is(ErrErrorHandlerNil, ErrErrorHandlerNil)) + assert.Contains(t, ErrErrorHandlerNil.Error(), "errorHandler cannot be nil") + }) + + t.Run("ErrTokenExtractorNil", func(t *testing.T) { + assert.True(t, errors.Is(ErrTokenExtractorNil, ErrTokenExtractorNil)) + assert.Contains(t, ErrTokenExtractorNil.Error(), "tokenExtractor cannot be nil") + }) + + t.Run("ErrExclusionUrlsEmpty", func(t *testing.T) { + assert.True(t, errors.Is(ErrExclusionUrlsEmpty, ErrExclusionUrlsEmpty)) + assert.Contains(t, ErrExclusionUrlsEmpty.Error(), "exclusion URLs list cannot be empty") + }) +} + +func Test_validatorAdapter(t *testing.T) { + testValidator := createTestValidator(t) + adapter := &validatorAdapter{validator: testValidator} + + t.Run("successful validation", func(t *testing.T) { + result, err := adapter.ValidateToken(context.Background(), testToken) + require.NoError(t, err) + assert.NotNil(t, result) + }) + + t.Run("validation error with invalid token", func(t *testing.T) { + result, err := adapter.ValidateToken(context.Background(), "invalid-token") + assert.Error(t, err) + assert.Nil(t, result) + }) +} + +func Test_invalidError(t *testing.T) { + t.Run("Error method returns formatted message", func(t *testing.T) { + detailErr := errors.New("token signature is invalid") + invErr := &invalidError{details: detailErr} + + errMsg := invErr.Error() + assert.Contains(t, errMsg, "jwt invalid") + assert.Contains(t, errMsg, "token signature is invalid") + }) + + t.Run("Is method works with ErrJWTInvalid", func(t *testing.T) { + detailErr := errors.New("some validation error") + invErr := &invalidError{details: detailErr} + + assert.True(t, errors.Is(invErr, ErrJWTInvalid)) + }) + + t.Run("Unwrap returns the details error", func(t *testing.T) { + detailErr := errors.New("specific error details") + invErr := &invalidError{details: detailErr} + + assert.Equal(t, detailErr, errors.Unwrap(invErr)) + }) +} + +// mockLogger is a test implementation of the Logger interface +type mockLogger struct { + debugCalls [][]any + infoCalls [][]any + warnCalls [][]any + errorCalls [][]any +} + +func (m *mockLogger) Debug(msg string, args ...any) { + m.debugCalls = append(m.debugCalls, append([]any{msg}, args...)) +} + +func (m *mockLogger) Info(msg string, args ...any) { + m.infoCalls = append(m.infoCalls, append([]any{msg}, args...)) +} + +func (m *mockLogger) Warn(msg string, args ...any) { + m.warnCalls = append(m.warnCalls, append([]any{msg}, args...)) +} + +func (m *mockLogger) Error(msg string, args ...any) { + m.errorCalls = append(m.errorCalls, append([]any{msg}, args...)) +}