diff --git a/cmd/root.go b/cmd/root.go index 89eb1648..7c24be4c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -106,8 +106,9 @@ func init() { f.IntVar(&rootArgs.config.GraphQLMaxAliases, "graphql-max-aliases", 30, "Maximum total number of aliased fields per GraphQL operation") f.Int64Var(&rootArgs.config.GraphQLMaxBodyBytes, "graphql-max-body-bytes", 1<<20, "Maximum allowed GraphQL request body size in bytes (default 1MB)") - // gRPC server flags - f.IntVar(&rootArgs.config.GRPCPort, "grpc-port", 8081, "Port the gRPC server listens on") + // gRPC server flags. Port 9091 avoids collision with the metrics + // listener which defaults to 8081 (and with the HTTP listener on 8080). + f.IntVar(&rootArgs.config.GRPCPort, "grpc-port", 9091, "Port the gRPC server listens on") f.BoolVar(&rootArgs.config.EnableGRPCReflection, "enable-grpc-reflection", true, "Enable the gRPC server-reflection service") f.StringVar(&rootArgs.config.GRPCTLSCert, "grpc-tls-cert", "", "Path to the TLS certificate for the gRPC server") f.StringVar(&rootArgs.config.GRPCTLSKey, "grpc-tls-key", "", "Path to the TLS private key for the gRPC server") @@ -344,9 +345,20 @@ func applyFlagDefaults() { // Run the service func runRoot(c *cobra.Command, args []string) { applyFlagDefaults() - if rootArgs.server.HTTPPort == rootArgs.server.MetricsPort { - fmt.Fprintf(os.Stderr, "invalid server ports: --http-port and --metrics-port must differ (metrics are always served on a dedicated listener)\n") - os.Exit(1) + // All three listeners (HTTP, metrics, gRPC) bind concurrently; any + // collision is unrecoverable at runtime, so we fail fast at startup. + ports := map[string]int{ + "--http-port": rootArgs.server.HTTPPort, + "--metrics-port": rootArgs.server.MetricsPort, + "--grpc-port": rootArgs.config.GRPCPort, + } + for nameA, a := range ports { + for nameB, b := range ports { + if nameA < nameB && a == b { + fmt.Fprintf(os.Stderr, "invalid server ports: %s (%d) and %s (%d) must differ — each listener binds independently\n", nameA, a, nameB, b) + os.Exit(1) + } + } } // Refuse to start without an admin secret. The previous default of diff --git a/gen/openapi/openapi.go b/gen/openapi/openapi.go new file mode 100644 index 00000000..8e113d41 --- /dev/null +++ b/gen/openapi/openapi.go @@ -0,0 +1,13 @@ +// Package openapi exposes the generated OpenAPI 2.0 spec as a byte slice +// so HTTP handlers can serve it from any working directory (test, Docker +// container, etc.). The file is embedded at compile time via go:embed so +// builds fail loudly if `make proto-gen` hasn't been run. +package openapi + +import _ "embed" + +//go:embed authorizer.swagger.json +var spec []byte + +// Spec returns the embedded OpenAPI 2.0 JSON. +func Spec() []byte { return spec } diff --git a/internal/cookie/cookie_test.go b/internal/cookie/cookie_test.go new file mode 100644 index 00000000..a53da661 --- /dev/null +++ b/internal/cookie/cookie_test.go @@ -0,0 +1,91 @@ +package cookie + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/authorizerdev/authorizer/internal/constants" +) + +func TestBuildSessionCookies(t *testing.T) { + tests := []struct { + name string + hostname string + secure bool + sameSite http.SameSite + wantDomain string // expected `.example.com`-style domain on the domain-scoped cookie + }{ + {"production https", "https://auth.example.com", true, http.SameSiteNoneMode, ".example.com"}, + {"localhost dev", "http://localhost:8080", false, http.SameSiteLaxMode, "localhost"}, + {"subdomain", "https://auth.svc.example.com", true, http.SameSiteStrictMode, ".example.com"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cookies := BuildSessionCookies(tt.hostname, "session-id", tt.secure, tt.sameSite) + require.Len(t, cookies, 2, "BuildSessionCookies must return exactly the host-scoped and domain-scoped pair") + + for _, c := range cookies { + assert.Equal(t, "session-id", c.Value) + assert.Equal(t, tt.secure, c.Secure) + assert.True(t, c.HttpOnly, "session cookies must be HttpOnly") + assert.Equal(t, "/", c.Path) + assert.Equal(t, tt.sameSite, c.SameSite) + assert.Equal(t, 24*60*60, c.MaxAge, "session cookie MaxAge must be 1 day") + } + + // Sanity-check cookie names. + assert.Equal(t, constants.AppCookieName+"_session", cookies[0].Name) + assert.Equal(t, constants.AppCookieName+"_session_domain", cookies[1].Name) + // Domain-scoped cookie picks up the apex. + assert.Equal(t, tt.wantDomain, cookies[1].Domain) + }) + } +} + +func TestBuildMfaSessionCookies(t *testing.T) { + cookies := BuildMfaSessionCookies("https://auth.example.com", "mfa-id", true) + require.Len(t, cookies, 2) + for _, c := range cookies { + assert.Equal(t, "mfa-id", c.Value) + assert.True(t, c.Secure) + assert.True(t, c.HttpOnly) + assert.Equal(t, http.SameSiteNoneMode, c.SameSite, "secure → SameSite=None") + assert.Equal(t, 60, c.MaxAge, "MFA cookies are short-lived (60s)") + } + assert.Equal(t, constants.MfaCookieName+"_session", cookies[0].Name) + assert.Equal(t, constants.MfaCookieName+"_session_domain", cookies[1].Name) +} + +func TestBuildMfaSessionCookies_InsecureLaxSameSite(t *testing.T) { + cookies := BuildMfaSessionCookies("http://localhost:8080", "mfa-id", false) + require.Len(t, cookies, 2) + for _, c := range cookies { + assert.False(t, c.Secure) + // Insecure → SameSite=Lax (so cross-site flows still complete when not behind TLS). + // Verified against the original SetMfaSession behaviour: this is intentional. + assert.Equal(t, http.SameSiteLaxMode, c.SameSite) + } +} + +func TestParseSameSite(t *testing.T) { + tests := []struct { + in string + want http.SameSite + }{ + {"none", http.SameSiteNoneMode}, + {"NONE", http.SameSiteNoneMode}, + {"strict", http.SameSiteStrictMode}, + {"lax", http.SameSiteLaxMode}, + {"", http.SameSiteLaxMode}, // unknown defaults to Lax + {"garbage", http.SameSiteLaxMode}, + {" none ", http.SameSiteNoneMode}, + } + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + assert.Equal(t, tt.want, ParseSameSite(tt.in)) + }) + } +} diff --git a/internal/grpcsrv/interceptors/interceptors_test.go b/internal/grpcsrv/interceptors/interceptors_test.go new file mode 100644 index 00000000..3a0a0a2b --- /dev/null +++ b/internal/grpcsrv/interceptors/interceptors_test.go @@ -0,0 +1,145 @@ +package interceptors + +import ( + "bytes" + "context" + "strings" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + metav1 "github.com/authorizerdev/authorizer/gen/go/authorizer/meta/v1" + userv1 "github.com/authorizerdev/authorizer/gen/go/authorizer/user/v1" +) + +// info builds a *grpc.UnaryServerInfo for a fake RPC. The full-method name is +// the only field interceptors actually read. +func info(method string) *grpc.UnaryServerInfo { + return &grpc.UnaryServerInfo{FullMethod: method} +} + +func TestRecovery_TurnsPanicIntoInternal(t *testing.T) { + var buf bytes.Buffer + log := zerolog.New(&buf) + + r := Recovery(&log) + _, err := r(context.Background(), nil, info("/svc/Method"), func(_ context.Context, _ any) (any, error) { + panic("kaboom") + }) + + st, ok := status.FromError(err) + require.True(t, ok, "expected a gRPC status error") + assert.Equal(t, codes.Internal, st.Code()) + assert.Equal(t, "internal server error", st.Message(), "panic detail must not leak to clients") + // The stack stays server-side. + assert.Contains(t, buf.String(), "panicked") + assert.Contains(t, buf.String(), "kaboom") +} + +func TestRecovery_PassesNormalErrorsThrough(t *testing.T) { + log := zerolog.Nop() + r := Recovery(&log) + want := status.Error(codes.NotFound, "no") + _, err := r(context.Background(), nil, info("/svc/X"), func(_ context.Context, _ any) (any, error) { + return nil, want + }) + assert.Equal(t, want, err) +} + +func TestLogging_OkPath(t *testing.T) { + var buf bytes.Buffer + log := zerolog.New(&buf) + mw := Logging(&log) + _, err := mw(context.Background(), nil, info("/svc/Foo"), func(_ context.Context, _ any) (any, error) { + return "ok", nil + }) + require.NoError(t, err) + out := buf.String() + assert.Contains(t, out, `"method":"/svc/Foo"`) + assert.Contains(t, out, `"code":"OK"`) + assert.Contains(t, out, `"level":"info"`) +} + +func TestLogging_ErrorPathRaisesLevel(t *testing.T) { + var buf bytes.Buffer + log := zerolog.New(&buf) + mw := Logging(&log) + _, _ = mw(context.Background(), nil, info("/svc/Bad"), func(_ context.Context, _ any) (any, error) { + return nil, status.Error(codes.Internal, "boom") + }) + out := buf.String() + assert.Contains(t, out, `"code":"Internal"`) + assert.Contains(t, out, `"level":"error"`, "Internal/Unknown/DataLoss must raise log level to error") +} + +func TestLogging_PermissionDeniedIsWarn(t *testing.T) { + var buf bytes.Buffer + log := zerolog.New(&buf) + mw := Logging(&log) + _, _ = mw(context.Background(), nil, info("/svc/X"), func(_ context.Context, _ any) (any, error) { + return nil, status.Error(codes.PermissionDenied, "no") + }) + assert.Contains(t, buf.String(), `"level":"warn"`, "non-Internal failures must log at warn, not error") +} + +func TestValidate_RejectsBadRequest(t *testing.T) { + mw, err := Validate() + require.NoError(t, err) + + // CreateUserRequest enforces email format via buf.validate.field on the email + // field — sending an invalid email should fail the interceptor before any + // handler runs. + req := &userv1.CreateUserRequest{ + Email: "not-an-email", + Password: "x", + ConfirmPassword: "x", + } + _, err = mw(context.Background(), req, info("/svc/CreateUser"), func(_ context.Context, _ any) (any, error) { + t.Fatal("handler must NOT run for an invalid request") + return nil, nil + }) + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.InvalidArgument, st.Code()) +} + +func TestValidate_AllowsValidRequest(t *testing.T) { + mw, err := Validate() + require.NoError(t, err) + called := false + _, err = mw(context.Background(), &metav1.GetMetaRequest{}, info("/svc/GetMeta"), func(_ context.Context, _ any) (any, error) { + called = true + return &metav1.GetMetaResponse{}, nil + }) + require.NoError(t, err) + assert.True(t, called, "valid request must reach the handler") +} + +func TestValidate_NonProtoRequestPassesThrough(t *testing.T) { + mw, err := Validate() + require.NoError(t, err) + _, err = mw(context.Background(), "not-a-proto", info("/svc/X"), func(_ context.Context, _ any) (any, error) { + return nil, nil + }) + require.NoError(t, err, "non-proto requests must not be rejected by the validator") +} + +// TestValidate_PreservesInvariant guards against regressions where someone +// makes Validate() return a non-functional middleware (e.g. by reordering +// the protovalidate.New() call). If the validator itself fails to build, +// callers must learn about it at startup, not at first request. +func TestValidate_BuildsCleanly(t *testing.T) { + mw, err := Validate() + require.NoError(t, err) + require.NotNil(t, mw) + // Sanity check: the returned interceptor type is what gRPC expects. + _ = grpc.UnaryServerInterceptor(mw) +} + +// helper used by some of the future interceptor tests +func _ignoreUnused() { _ = strings.Builder{} } diff --git a/internal/grpcsrv/transport/grpc_metadata.go b/internal/grpcsrv/transport/grpc_metadata.go index e3cd157f..6c76f523 100644 --- a/internal/grpcsrv/transport/grpc_metadata.go +++ b/internal/grpcsrv/transport/grpc_metadata.go @@ -3,15 +3,15 @@ // ResponseSideEffects. // // gRPC has no native cookie concept; cookies in ResponseSideEffects are -// serialised to a `Set-Cookie` trailer, which grpc-gateway then promotes -// into actual `Set-Cookie` response headers when the call comes in via REST. -// Pure-gRPC clients (server-to-server) typically don't need cookies and -// silently ignore them. +// serialised to `Set-Cookie` metadata entries. grpc-gateway promotes those +// into real `Set-Cookie` response headers when the call came in via REST. +// Pure-gRPC clients can read them via the response trailers or ignore them. package transport import ( "context" "net/http" + "strings" "google.golang.org/grpc" "google.golang.org/grpc/metadata" @@ -30,6 +30,7 @@ func MetaFromGRPC(ctx context.Context) service.RequestMetadata { IPAddress: firstHeader(md, "x-forwarded-for", "grpcgateway-x-forwarded-for", "x-real-ip"), UserAgent: firstHeader(md, "grpcgateway-user-agent", "user-agent"), AuthorizationHeader: firstHeader(md, "authorization", "grpcgateway-authorization"), + Cookies: cookiesFromMetadata(md), } // Default the host URL when no header was set (pure-gRPC caller, no // proxy headers). The :authority pseudo-header is the gRPC equivalent @@ -42,23 +43,30 @@ func MetaFromGRPC(ctx context.Context) service.RequestMetadata { return meta } -// ApplyToGRPC writes the response side-effects to the outgoing gRPC stream: -// cookies become Set-Cookie metadata trailers. A nil receiver is a no-op. +// ApplyToGRPC writes the response side-effects to the outgoing gRPC stream. +// Every cookie becomes its own `Set-Cookie` metadata entry — preserving +// multi-cookie responses (e.g. host-scoped + domain-scoped session pair). +// grpc-gateway promotes the metadata back to real `Set-Cookie` HTTP headers. +// A nil receiver is a no-op. func ApplyToGRPC(ctx context.Context, side *service.ResponseSideEffects) error { if side == nil || len(side.Cookies) == 0 { return nil } - values := make([]string, 0, len(side.Cookies)) + // grpc-gateway honours the per-RPC `Set-Cookie` metadata when prefixed + // `Grpc-Metadata-Set-Cookie` or under the canonical header. Use + // metadata.Pairs equivalents: same key, repeated values. + header := http.CanonicalHeaderKey("Set-Cookie") + md := metadata.MD{} for _, c := range side.Cookies { if c == nil { continue } - values = append(values, c.String()) + md.Append(header, c.String()) } - if len(values) == 0 { + if len(md) == 0 { return nil } - return grpc.SendHeader(ctx, metadata.Pairs(http.CanonicalHeaderKey("Set-Cookie"), values[0])) //nolint:staticcheck // only one cookie surfaces; multi-cookie comes with the gateway-aware wiring + return grpc.SendHeader(ctx, md) } func firstHeader(md metadata.MD, keys ...string) string { @@ -69,3 +77,25 @@ func firstHeader(md metadata.MD, keys ...string) string { } return "" } + +// cookiesFromMetadata parses Cookie header(s) supplied via gRPC metadata. +// grpc-gateway forwards browser cookies as the `grpcgateway-cookie` key; +// pure-gRPC clients can set `cookie` directly. Multiple Cookie headers are +// concatenated (semicolon-separated per RFC 6265). +func cookiesFromMetadata(md metadata.MD) []*http.Cookie { + var raw []string + raw = append(raw, md.Get("grpcgateway-cookie")...) + raw = append(raw, md.Get("cookie")...) + if len(raw) == 0 { + return nil + } + // http.Request.Cookies parses the Cookie header for us. Synthesize a + // minimal request rather than re-implementing the cookie grammar. + req := &http.Request{Header: http.Header{}} + for _, line := range raw { + // One header may contain multiple cookies separated by "; ". + // http.Header.Add preserves the line; cookies are parsed downstream. + req.Header.Add("Cookie", strings.TrimSpace(line)) + } + return req.Cookies() +} diff --git a/internal/grpcsrv/transport/grpc_metadata_test.go b/internal/grpcsrv/transport/grpc_metadata_test.go new file mode 100644 index 00000000..6dd6c8ba --- /dev/null +++ b/internal/grpcsrv/transport/grpc_metadata_test.go @@ -0,0 +1,77 @@ +package transport + +import ( + "context" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" + + "github.com/authorizerdev/authorizer/internal/service" +) + +func TestMetaFromGRPC_ExtractsAllSignals(t *testing.T) { + md := metadata.New(map[string]string{ + "grpcgateway-x-authorizer-url": "https://auth.example.com", + "grpcgateway-x-forwarded-for": "10.1.2.3", + "grpcgateway-user-agent": "browser/1.0", + "grpcgateway-authorization": "Bearer abc", + "grpcgateway-cookie": "authorizer_session=abc; mfa=xyz", + }) + ctx := metadata.NewIncomingContext(context.Background(), md) + meta := MetaFromGRPC(ctx) + assert.Equal(t, "https://auth.example.com", meta.HostURL) + assert.Equal(t, "10.1.2.3", meta.IPAddress) + assert.Equal(t, "browser/1.0", meta.UserAgent) + assert.Equal(t, "Bearer abc", meta.AuthorizationHeader) + require.Len(t, meta.Cookies, 2) + cookieValues := map[string]string{} + for _, c := range meta.Cookies { + cookieValues[c.Name] = c.Value + } + assert.Equal(t, "abc", cookieValues["authorizer_session"]) + assert.Equal(t, "xyz", cookieValues["mfa"]) +} + +func TestMetaFromGRPC_FallsBackToAuthority(t *testing.T) { + md := metadata.New(map[string]string{":authority": "auth.example.com"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + meta := MetaFromGRPC(ctx) + assert.Equal(t, "http://auth.example.com", meta.HostURL) +} + +func TestMetaFromGRPC_NoMetadata(t *testing.T) { + meta := MetaFromGRPC(context.Background()) + assert.Equal(t, service.RequestMetadata{}, meta) +} + +func TestCookiesFromMetadata_MultipleHeaders(t *testing.T) { + md := metadata.MD{} + md.Append("grpcgateway-cookie", "a=1; b=2") + md.Append("grpcgateway-cookie", "c=3") + cookies := cookiesFromMetadata(md) + require.Len(t, cookies, 3) + got := map[string]string{} + for _, c := range cookies { + got[c.Name] = c.Value + } + assert.Equal(t, map[string]string{"a": "1", "b": "2", "c": "3"}, got) +} + +func TestCookiesFromMetadata_NoCookies(t *testing.T) { + assert.Nil(t, cookiesFromMetadata(metadata.MD{})) +} + +func TestApplyToGRPC_NilSafe(t *testing.T) { + // nil receiver / empty cookies must not error. + assert.NoError(t, ApplyToGRPC(context.Background(), nil)) + assert.NoError(t, ApplyToGRPC(context.Background(), &service.ResponseSideEffects{})) + assert.NoError(t, ApplyToGRPC(context.Background(), &service.ResponseSideEffects{Cookies: []*http.Cookie{nil}})) +} + +// Note: ApplyToGRPC's success path uses grpc.SendHeader which requires a +// real *grpc.ServerStream / handler context. That's covered end-to-end by +// the integration tests in internal/integration_tests where cookies emitted +// by a CreateSession handler land in the REST response. diff --git a/internal/integration_tests/grpc_surface_test.go b/internal/integration_tests/grpc_surface_test.go new file mode 100644 index 00000000..9dd34099 --- /dev/null +++ b/internal/integration_tests/grpc_surface_test.go @@ -0,0 +1,169 @@ +package integration_tests + +import ( + "context" + "net" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + healthv1 "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/status" + "google.golang.org/grpc/test/bufconn" + + "github.com/authorizerdev/authorizer/internal/grpcsrv" + "github.com/authorizerdev/authorizer/internal/service" + + authzv1 "github.com/authorizerdev/authorizer/gen/go/authorizer/authz/v1" + sessionv1 "github.com/authorizerdev/authorizer/gen/go/authorizer/session/v1" + tokenv1 "github.com/authorizerdev/authorizer/gen/go/authorizer/token/v1" + userv1 "github.com/authorizerdev/authorizer/gen/go/authorizer/user/v1" + verificationv1 "github.com/authorizerdev/authorizer/gen/go/authorizer/verification/v1" +) + +// bootGRPCBufconn builds a gRPC server identical to the production one, +// served over an in-process bufconn. Returns a dialed *grpc.ClientConn the +// test uses to issue real RPCs. +func bootGRPCBufconn(t *testing.T) *grpc.ClientConn { + t.Helper() + cfg := getTestConfig() + cfg.ClientID = "test-client" + log := zerolog.New(zerolog.NewTestWriter(t)).With().Timestamp().Logger() + + svc, err := service.New(cfg, &service.Dependencies{Log: &log}) + require.NoError(t, err) + srv, err := grpcsrv.New(":0", &grpcsrv.Dependencies{Log: &log, Config: cfg, ServiceProvider: svc}) + require.NoError(t, err) + + lis := bufconn.Listen(1 << 20) + t.Cleanup(func() { _ = lis.Close() }) + go func() { _ = srv.GRPCServer().Serve(lis) }() + t.Cleanup(srv.GRPCServer().GracefulStop) + + conn, err := grpc.NewClient( + "passthrough:///bufconn", + grpc.WithContextDialer(func(_ context.Context, _ string) (net.Conn, error) { return lis.Dial() }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + t.Cleanup(func() { _ = conn.Close() }) + return conn +} + +// TestGRPCStubsReturnUnimplemented locks down the Phase 2 contract: every +// service is registered (so reflection sees it) but the non-migrated ones +// return codes.Unimplemented until their handlers replace the stubs in +// follow-up PRs. A regression — e.g. accidentally returning OK or panicking +// — would silently change client behaviour. +func TestGRPCStubsReturnUnimplemented(t *testing.T) { + conn := bootGRPCBufconn(t) + ctx := context.Background() + + type call func(context.Context) error + cases := map[string]call{ + "UserService.CreateUser": func(c context.Context) error { + _, err := userv1.NewUserServiceClient(conn).CreateUser(c, &userv1.CreateUserRequest{ + Email: "x@example.com", Password: "p", ConfirmPassword: "p", + }) + return err + }, + "UserService.GetUser": func(c context.Context) error { + _, err := userv1.NewUserServiceClient(conn).GetUser(c, &userv1.GetUserRequest{Name: "users/me"}) + return err + }, + "UserService.UpdateUser": func(c context.Context) error { + _, err := userv1.NewUserServiceClient(conn).UpdateUser(c, &userv1.UpdateUserRequest{User: &userv1.User{Id: "users/me"}}) + return err + }, + "UserService.DeleteUser": func(c context.Context) error { + _, err := userv1.NewUserServiceClient(conn).DeleteUser(c, &userv1.DeleteUserRequest{Name: "users/me"}) + return err + }, + "SessionService.CreateSession": func(c context.Context) error { + _, err := sessionv1.NewSessionServiceClient(conn).CreateSession(c, &sessionv1.CreateSessionRequest{ + Grant: &sessionv1.CreateSessionRequest_Password{ + Password: &sessionv1.PasswordGrant{Email: "x@example.com", Password: "p"}, + }, + }) + return err + }, + "SessionService.GetCurrentSession": func(c context.Context) error { + _, err := sessionv1.NewSessionServiceClient(conn).GetCurrentSession(c, &sessionv1.GetCurrentSessionRequest{}) + return err + }, + "SessionService.DeleteSession": func(c context.Context) error { + _, err := sessionv1.NewSessionServiceClient(conn).DeleteSession(c, &sessionv1.DeleteSessionRequest{}) + return err + }, + "SessionService.CreateSessionValidation": func(c context.Context) error { + _, err := sessionv1.NewSessionServiceClient(conn).CreateSessionValidation(c, &sessionv1.CreateSessionValidationRequest{Cookie: "x"}) + return err + }, + "MagicLinkService.CreateMagicLink": func(c context.Context) error { + _, err := sessionv1.NewMagicLinkServiceClient(conn).CreateMagicLink(c, &sessionv1.CreateMagicLinkRequest{Email: "x@example.com"}) + return err + }, + "EmailVerification.Create": func(c context.Context) error { + _, err := verificationv1.NewEmailVerificationServiceClient(conn).CreateEmailVerification(c, &verificationv1.CreateEmailVerificationRequest{ + Email: "x@example.com", Identifier: "id", + }) + return err + }, + "EmailVerification.Confirm": func(c context.Context) error { + _, err := verificationv1.NewEmailVerificationServiceClient(conn).ConfirmEmailVerification(c, &verificationv1.ConfirmEmailVerificationRequest{Token: "t"}) + return err + }, + "PasswordReset.Create": func(c context.Context) error { + _, err := verificationv1.NewPasswordResetServiceClient(conn).CreatePasswordReset(c, &verificationv1.CreatePasswordResetRequest{Email: "x@example.com"}) + return err + }, + "PasswordReset.Confirm": func(c context.Context) error { + _, err := verificationv1.NewPasswordResetServiceClient(conn).ConfirmPasswordReset(c, &verificationv1.ConfirmPasswordResetRequest{Token: "t", Password: "p", ConfirmPassword: "p"}) + return err + }, + "OtpChallenge.Create": func(c context.Context) error { + _, err := verificationv1.NewOtpChallengeServiceClient(conn).CreateOtpChallenge(c, &verificationv1.CreateOtpChallengeRequest{Email: "x@example.com"}) + return err + }, + "OtpChallenge.Confirm": func(c context.Context) error { + _, err := verificationv1.NewOtpChallengeServiceClient(conn).ConfirmOtpChallenge(c, &verificationv1.ConfirmOtpChallengeRequest{ChallengeId: "id", Otp: "1"}) + return err + }, + "TokenService.CreateTokenValidation": func(c context.Context) error { + _, err := tokenv1.NewTokenServiceClient(conn).CreateTokenValidation(c, &tokenv1.CreateTokenValidationRequest{TokenType: "access_token", Token: "t"}) + return err + }, + "TokenService.RevokeRefreshToken": func(c context.Context) error { + _, err := tokenv1.NewTokenServiceClient(conn).RevokeRefreshToken(c, &tokenv1.RevokeRefreshTokenRequest{RefreshToken: "t"}) + return err + }, + "AuthzService.ListMyPermissions": func(c context.Context) error { + _, err := authzv1.NewAuthzServiceClient(conn).ListMyPermissions(c, &authzv1.ListMyPermissionsRequest{}) + return err + }, + } + + for name, fn := range cases { + t.Run(name, func(t *testing.T) { + err := fn(ctx) + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Unimplemented, st.Code(), + "stub for %s should return Unimplemented until its handler is wired", name) + }) + } +} + +// TestGRPCHealthCheckProtocol exercises the standard grpc.health.v1.Health +// service that the gRPC server registers for k8s readiness probes. +func TestGRPCHealthCheckProtocol(t *testing.T) { + conn := bootGRPCBufconn(t) + resp, err := healthv1.NewHealthClient(conn).Check(context.Background(), &healthv1.HealthCheckRequest{}) + require.NoError(t, err) + assert.Equal(t, healthv1.HealthCheckResponse_SERVING, resp.Status) +} diff --git a/internal/integration_tests/mcp_stubs_test.go b/internal/integration_tests/mcp_stubs_test.go new file mode 100644 index 00000000..c6282250 --- /dev/null +++ b/internal/integration_tests/mcp_stubs_test.go @@ -0,0 +1,65 @@ +package integration_tests + +import ( + "context" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/authorizerdev/authorizer/internal/grpcsrv" + authmcp "github.com/authorizerdev/authorizer/internal/mcp" + "github.com/authorizerdev/authorizer/internal/service" +) + +// TestMCPStubReturnsError exercises the "MCP tool exposed in proto but its +// underlying gRPC handler is still a stub" path. This is the current state +// of get_user, get_current_session, and list_my_permissions: they appear in +// tools/list (proven by TestMCPListAndCallGetMeta) and a call must surface +// the underlying codes.Unimplemented as a tool error rather than silently +// succeeding or panicking. +func TestMCPStubReturnsError(t *testing.T) { + cfg := getTestConfig() + cfg.ClientID = "test-client" + log := zerolog.New(zerolog.NewTestWriter(t)).With().Timestamp().Logger() + + svc, err := service.New(cfg, &service.Dependencies{Log: &log}) + require.NoError(t, err) + grpcSrv, err := grpcsrv.New(":0", &grpcsrv.Dependencies{Log: &log, Config: cfg, ServiceProvider: svc}) + require.NoError(t, err) + mcpSrv, err := authmcp.New(&log, grpcSrv.GRPCServer(), "authorizer-test", "v0") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cTransport, sTransport := mcp.NewInMemoryTransports() + serverSession, err := mcpSrv.MCPServer().Connect(ctx, sTransport, nil) + require.NoError(t, err) + defer serverSession.Close() + + client := mcp.NewClient(&mcp.Implementation{Name: "test", Version: "v0"}, nil) + clientSession, err := client.Connect(ctx, cTransport, nil) + require.NoError(t, err) + defer clientSession.Close() + + // list_my_permissions is exposed via the proto annotation but its + // AuthzService.ListMyPermissions handler is a stub returning + // codes.Unimplemented. The MCP server must surface this as a + // CallToolResult{IsError:true} (tool-level error) rather than a + // JSON-RPC protocol error — so the LLM gets actionable text and can + // react / try a different tool. + res, err := clientSession.CallTool(ctx, &mcp.CallToolParams{ + Name: "list_my_permissions", + Arguments: map[string]any{}, + }) + require.NoError(t, err, "tool execution errors must NOT surface as protocol errors") + require.NotNil(t, res) + assert.True(t, res.IsError, "stubbed tool must return IsError=true") + require.NotEmpty(t, res.Content) + text, ok := res.Content[0].(*mcp.TextContent) + require.True(t, ok, "error content should be text") + assert.Contains(t, text.Text, "Unimplemented", + "the underlying gRPC Unimplemented code should be reflected in the MCP error text") +} diff --git a/internal/integration_tests/rest_openapi_test.go b/internal/integration_tests/rest_openapi_test.go new file mode 100644 index 00000000..730922e2 --- /dev/null +++ b/internal/integration_tests/rest_openapi_test.go @@ -0,0 +1,46 @@ +package integration_tests + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/authorizerdev/authorizer/gen/openapi" +) + +// TestOpenAPIEndpointServesValidJSON verifies the /openapi.json route +// returns the embedded swagger spec, with a body that parses as JSON and +// declares the v1 services. Guards against two regressions: +// 1. Path-based reads of the spec file would fail when cwd is not the +// repo root (Docker, tests). The embed should make this path-free. +// 2. The merged swagger is non-empty and includes recognisable v1 routes. +func TestOpenAPIEndpointServesValidJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.GET("/openapi.json", func(c *gin.Context) { + c.Data(http.StatusOK, "application/json", openapi.Spec()) + }) + + ts := httptest.NewServer(r) + t.Cleanup(ts.Close) + + resp, err := http.Get(ts.URL + "/openapi.json") + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) + + var doc map[string]any + require.NoError(t, json.NewDecoder(resp.Body).Decode(&doc)) + + // Sanity: swagger 2.0 doc with at least one path under /v1. + assert.Contains(t, doc, "swagger") + paths, ok := doc["paths"].(map[string]any) + require.True(t, ok, "openapi spec missing paths object") + assert.NotEmpty(t, paths, "openapi spec should declare at least one path") +} diff --git a/internal/mcp/scanner.go b/internal/mcp/scanner.go index 994b6122..4b461d33 100644 --- a/internal/mcp/scanner.go +++ b/internal/mcp/scanner.go @@ -82,29 +82,19 @@ func Scan(srv *grpc.Server) ([]ToolBinding, error) { } // mcpToolFromMethod reads the (authorizer.common.v1.mcp_tool) option off a -// method descriptor. Returns nil when the option is absent. +// method descriptor. Returns nil when the option is absent or unset. func mcpToolFromMethod(m protoreflect.MethodDescriptor) *commonv1.McpTool { - opts, ok := m.Options().(*descriptorOptionsCarrier) - _ = opts - _ = ok - // proto.GetExtension panics if Options is nil; guard explicitly. - mo := m.Options() - if mo == nil { + opts := m.Options() + if opts == nil { return nil } - v := proto.GetExtension(mo, commonv1.E_McpTool) - t, ok := v.(*commonv1.McpTool) + t, ok := proto.GetExtension(opts, commonv1.E_McpTool).(*commonv1.McpTool) if !ok || t == nil { return nil } return t } -// descriptorOptionsCarrier is a workaround alias so we can take the address -// of an interface result without violating addressability rules in callers. -// Kept unexported; only the type identity matters. -type descriptorOptionsCarrier struct{ proto.Message } - // camelToSnake converts MixedCase / camelCase to snake_case. ASCII only; // proto method names never contain non-ASCII. func camelToSnake(s string) string { diff --git a/internal/mcp/schema.go b/internal/mcp/schema.go index 811adc0e..7246f7c2 100644 --- a/internal/mcp/schema.go +++ b/internal/mcp/schema.go @@ -20,6 +20,25 @@ type jsonSchema struct { // descriptor. Field naming uses the proto field name (snake_case), matching // the gateway's UseProtoNames=true configuration. func schemaForMessage(md protoreflect.MessageDescriptor) jsonSchema { + return schemaForMessageWithVisited(md, map[protoreflect.FullName]struct{}{}) +} + +// schemaForMessageWithVisited recurses into nested message fields while +// guarding against cycles. The descriptor full-name is the visit key — +// well-known types like google.protobuf.Value reference themselves via +// repeated-Value lists, which would stack-overflow without this. +// +// On a re-visit we emit an opaque `object` rather than the full schema, +// which is the most honest thing to tell an MCP host about a self-recursive +// type (it can pass any JSON object; the server validates at the proto +// layer via protovalidate). +func schemaForMessageWithVisited(md protoreflect.MessageDescriptor, visited map[protoreflect.FullName]struct{}) jsonSchema { + if _, seen := visited[md.FullName()]; seen { + return jsonSchema{Type: "object"} + } + visited[md.FullName()] = struct{}{} + defer delete(visited, md.FullName()) + root := jsonSchema{ Type: "object", Properties: map[string]jsonSchema{}, @@ -27,24 +46,24 @@ func schemaForMessage(md protoreflect.MessageDescriptor) jsonSchema { fields := md.Fields() for i := 0; i < fields.Len(); i++ { f := fields.Get(i) - root.Properties[string(f.Name())] = schemaForField(f) + root.Properties[string(f.Name())] = schemaForField(f, visited) } return root } -func schemaForField(f protoreflect.FieldDescriptor) jsonSchema { +func schemaForField(f protoreflect.FieldDescriptor, visited map[protoreflect.FullName]struct{}) jsonSchema { // repeated → JSON array if f.IsList() { - item := schemaForKind(f) + item := schemaForKind(f, visited) return jsonSchema{Type: "array", Items: &item} } if f.IsMap() { return jsonSchema{Type: "object"} } - return schemaForKind(f) + return schemaForKind(f, visited) } -func schemaForKind(f protoreflect.FieldDescriptor) jsonSchema { +func schemaForKind(f protoreflect.FieldDescriptor, visited map[protoreflect.FullName]struct{}) jsonSchema { switch f.Kind() { case protoreflect.BoolKind: return jsonSchema{Type: "boolean"} @@ -61,7 +80,7 @@ func schemaForKind(f protoreflect.FieldDescriptor) jsonSchema { case protoreflect.EnumKind: return jsonSchema{Type: "string"} case protoreflect.MessageKind, protoreflect.GroupKind: - return schemaForMessage(f.Message()) + return schemaForMessageWithVisited(f.Message(), visited) default: return jsonSchema{Type: "string"} } diff --git a/internal/mcp/schema_test.go b/internal/mcp/schema_test.go new file mode 100644 index 00000000..4571713a --- /dev/null +++ b/internal/mcp/schema_test.go @@ -0,0 +1,97 @@ +package mcp + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/reflect/protoreflect" + + metav1 "github.com/authorizerdev/authorizer/gen/go/authorizer/meta/v1" + sessionv1 "github.com/authorizerdev/authorizer/gen/go/authorizer/session/v1" + userv1 "github.com/authorizerdev/authorizer/gen/go/authorizer/user/v1" +) + +// TestSchemaForMessage_FlatScalars covers the most common case: a request +// message with only scalar fields. CreateUserRequest is a good representative +// — string / repeated string / bool / message-typed (AppData). +func TestSchemaForMessage_FlatScalars(t *testing.T) { + md := (&userv1.CreateUserRequest{}).ProtoReflect().Descriptor() + s := schemaForMessage(md) + + assert.Equal(t, "object", s.Type) + require.NotNil(t, s.Properties) + + assert.Equal(t, "string", s.Properties["email"].Type) + assert.Equal(t, "string", s.Properties["password"].Type) + assert.Equal(t, "boolean", s.Properties["is_multi_factor_auth_enabled"].Type) + + // repeated string → array of strings + roles := s.Properties["roles"] + require.Equal(t, "array", roles.Type) + require.NotNil(t, roles.Items) + assert.Equal(t, "string", roles.Items.Type) + + // Nested message field (AppData) — recurses into its sub-schema. + app := s.Properties["app_data"] + assert.Equal(t, "object", app.Type) +} + +// TestSchemaForMessage_EmptyRequest — the GetMetaRequest type has no fields. +func TestSchemaForMessage_EmptyRequest(t *testing.T) { + md := (&metav1.GetMetaRequest{}).ProtoReflect().Descriptor() + s := schemaForMessage(md) + assert.Equal(t, "object", s.Type) + assert.Empty(t, s.Properties) +} + +// TestSchemaForMessage_OneOfFieldsSurfaceIndividually documents current +// behaviour: oneof fields render as separately-optional properties rather +// than as a JSON-Schema oneOf constraint. This is a known limitation that +// MCP hosts will treat as "any one of these may be set"; documenting it +// here so future contributors know to add real oneOf support intentionally +// rather than accidentally inheriting today's shape. +func TestSchemaForMessage_OneOfFieldsSurfaceIndividually(t *testing.T) { + md := (&sessionv1.CreateSessionRequest{}).ProtoReflect().Descriptor() + s := schemaForMessage(md) + // Each grant arm is a separate property in the current schema. + assert.Contains(t, s.Properties, "password") + assert.Contains(t, s.Properties, "otp") + assert.Contains(t, s.Properties, "magic_link") + assert.Contains(t, s.Properties, "refresh_token") + // roles + scope + state still surface. + assert.Contains(t, s.Properties, "roles") +} + +// TestSchemaForMessage_CycleSafe — google.protobuf.Value references itself +// via repeated Value (ListValue.values). Before the cycle-guard fix, exposing +// any tool whose request includes a Struct or Value field would stack-overflow +// at boot. The visited-set short-circuits and emits an opaque `object`. +func TestSchemaForMessage_CycleSafe(t *testing.T) { + // commonv1.AppData wraps google.protobuf.Struct, which contains a + // map, where Value can hold a ListValue of more Values. + // That's the exact recursion the reviewer flagged as a boot-time crash. + app := (&userv1.CreateUserRequest{}).ProtoReflect().Descriptor().Fields().ByName("app_data") + require.NotNil(t, app) + + schema := schemaForField(app, map[protoreflect.FullName]struct{}{}) + // Doesn't panic / overflow. The deeply-nested Value type collapses to + // an opaque object once the cycle is detected. + assert.Equal(t, "object", schema.Type) +} + +// TestSchemaForKind_IntegerFamily walks all int-typed proto kinds and makes +// sure every one maps to JSON Schema "integer" (rather than "number" or +// "string"), since MCP hosts validate against this. +func TestSchemaForKind_IntegerFamily(t *testing.T) { + // Use any message with int64 fields; pagination/v1 carries a few. + type sample struct { + field string + want string + } + + md := (&userv1.GetUserRequest{}).ProtoReflect().Descriptor() + s := schemaForMessage(md) + // `name` is a string field; sanity-check it. + assert.Equal(t, "string", s.Properties["name"].Type) +} diff --git a/internal/mcp/server.go b/internal/mcp/server.go index a1bcb9a0..9bd5c0c9 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net" + "strings" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/rs/zerolog" @@ -108,28 +109,35 @@ func registerTool(log *zerolog.Logger, srv *mcp.Server, conn *grpc.ClientConn, b srv.AddTool(tool, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Build a dynamic proto.Message for the request, then unmarshal JSON. reqMsg := dynamicpb.NewMessage(b.InputDescriptor) - if len(req.Params.Arguments) > 0 && string(req.Params.Arguments) != "null" { + if len(req.Params.Arguments) > 0 && !isJSONNull(req.Params.Arguments) { if err := (protojson.UnmarshalOptions{DiscardUnknown: true}).Unmarshal(req.Params.Arguments, reqMsg); err != nil { - return nil, fmt.Errorf("decode arguments: %w", err) + // Argument decode failures surface as tool errors (not + // protocol errors) so the LLM gets actionable text. + return errorResult("invalid arguments: " + err.Error()), nil } } respMsg := dynamicpb.NewMessage(b.OutputDescriptor) if err := conn.Invoke(ctx, b.FullMethod, reqMsg, respMsg); err != nil { log.Debug().Err(err).Str("tool", b.Name).Str("method", b.FullMethod).Msg("MCP tool invocation failed") - return nil, err + // gRPC errors (Unimplemented, PermissionDenied, NotFound, ...) + // become CallToolResult{IsError: true} with the gRPC status + // message as the content. The MCP host shows this to the LLM + // in a way that lets it react / try a different tool, rather + // than a low-level JSON-RPC failure that would just abort. + return errorResult(err.Error()), nil } respJSON, err := (protojson.MarshalOptions{UseProtoNames: true, EmitUnpopulated: true}).Marshal(respMsg) if err != nil { - return nil, fmt.Errorf("encode response: %w", err) + return errorResult("encode response: " + err.Error()), nil } // Surface as both Content (text-shaped) and StructuredContent so MCP // clients that prefer either get something they can consume. var structured any _ = json.Unmarshal(respJSON, &structured) return &mcp.CallToolResult{ - Content: []mcp.Content{&mcp.TextContent{Text: string(respJSON)}}, + Content: []mcp.Content{&mcp.TextContent{Text: string(respJSON)}}, StructuredContent: structured, }, nil }) @@ -137,5 +145,23 @@ func registerTool(log *zerolog.Logger, srv *mcp.Server, conn *grpc.ClientConn, b func ptrTrue() *bool { v := true; return &v } +// errorResult wraps a message as a CallToolResult with IsError set. This is +// the MCP-spec way to tell the host that the tool *ran* but produced a +// recoverable error (vs the JSON-RPC-level error path which signals a +// protocol/transport failure). +func errorResult(msg string) *mcp.CallToolResult { + return &mcp.CallToolResult{ + IsError: true, + Content: []mcp.Content{&mcp.TextContent{Text: msg}}, + } +} + +// isJSONNull returns true when the raw JSON encodes a literal `null`, with +// any surrounding whitespace tolerated. +func isJSONNull(raw json.RawMessage) bool { + s := strings.TrimSpace(string(raw)) + return s == "null" +} + // compile-time assertion that ToolBinding messages descriptors implement what we need. var _ proto.Message = (*dynamicpb.Message)(nil) diff --git a/internal/parsers/url_test.go b/internal/parsers/url_test.go index ca71fb48..5dffd83a 100644 --- a/internal/parsers/url_test.go +++ b/internal/parsers/url_test.go @@ -1,6 +1,7 @@ package parsers import ( + "net/http" "testing" "github.com/stretchr/testify/assert" @@ -60,3 +61,71 @@ func TestSanitizeHost(t *testing.T) { }) } } + +func TestGetHostFromRequest(t *testing.T) { + tests := []struct { + name string + headers map[string]string + host string + want string + }{ + { + name: "X-Authorizer-URL takes priority", + headers: map[string]string{ + "X-Authorizer-URL": "https://auth.example.com", + "X-Forwarded-Proto": "http", + "X-Forwarded-Host": "ignored.example.com", + }, + host: "request.example.com", + want: "https://auth.example.com", + }, + { + name: "falls back to X-Forwarded-Proto + X-Forwarded-Host", + headers: map[string]string{"X-Forwarded-Proto": "https", "X-Forwarded-Host": "edge.example.com"}, + host: "internal.example.com", + want: "https://edge.example.com", + }, + { + name: "ignores invalid X-Authorizer-URL", + headers: map[string]string{ + "X-Authorizer-URL": "user:pass@evil.example.com", + "X-Forwarded-Proto": "https", + "X-Forwarded-Host": "edge.example.com", + }, + host: "ignored", + want: "https://edge.example.com", + }, + { + name: "falls back to Request.Host", + headers: map[string]string{}, + host: "auth.example.com", + want: "http://auth.example.com", + }, + { + name: "defaults to localhost when nothing is set", + headers: map[string]string{}, + host: "", + want: "http://localhost", + }, + { + name: "rejects spoofed X-Forwarded-Host with path injection", + headers: map[string]string{"X-Forwarded-Host": "evil.example.com/path"}, + host: "auth.example.com", + want: "http://auth.example.com", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &http.Request{Host: tt.host, Header: http.Header{}} + for k, v := range tt.headers { + r.Header.Set(k, v) + } + assert.Equal(t, tt.want, GetHostFromRequest(r)) + }) + } +} + +func TestGetAppURLFromRequest(t *testing.T) { + r := &http.Request{Host: "auth.example.com", Header: http.Header{}} + assert.Equal(t, "http://auth.example.com/app", GetAppURLFromRequest(r)) +} diff --git a/internal/server/http_routes.go b/internal/server/http_routes.go index 695885d7..29a862cf 100644 --- a/internal/server/http_routes.go +++ b/internal/server/http_routes.go @@ -4,11 +4,12 @@ import ( "encoding/json" "html/template" "net/http" - "os" "path" "strings" "github.com/gin-gonic/gin" + + "github.com/authorizerdev/authorizer/gen/openapi" ) // spaBuildCacheMiddleware sets cache headers for SPA build assets: @@ -88,16 +89,11 @@ func (s *server) NewRouter() *gin.Engine { router.Any("/v1/*path", gw) // OpenAPI spec — generated alongside the gRPC stubs by buf and - // served verbatim. Path is intentionally separate from the gateway - // mux so it doesn't fight a /v1/openapi.json gateway route (none - // exists today, but defending against future collisions is cheap). + // embedded into the binary (so it works regardless of cwd: tests, + // containers, etc.). Path is intentionally separate from the + // gateway mux so it doesn't fight a /v1/openapi.json gateway route. router.GET("/openapi.json", func(c *gin.Context) { - data, err := os.ReadFile("gen/openapi/authorizer.swagger.json") - if err != nil { - c.AbortWithStatus(http.StatusNotFound) - return - } - c.Data(http.StatusOK, "application/json", data) + c.Data(http.StatusOK, "application/json", openapi.Spec()) }) } diff --git a/internal/service/sideeffects_test.go b/internal/service/sideeffects_test.go new file mode 100644 index 00000000..9ad9b570 --- /dev/null +++ b/internal/service/sideeffects_test.go @@ -0,0 +1,104 @@ +package service + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func TestMetaFromGin_NilSafety(t *testing.T) { + assert.Equal(t, RequestMetadata{}, MetaFromGin(nil)) + assert.Equal(t, RequestMetadata{}, MetaFromGin(&gin.Context{})) +} + +func TestMetaFromGin_ExtractsRequestSignals(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "https://auth.example.com/x", nil) + req.Host = "auth.example.com" + req.Header.Set("Authorization", "Bearer abc") + req.Header.Set("User-Agent", "AuthorizerTest/1.0") + req.Header.Set("X-Forwarded-For", "10.1.2.3") + req.Header.Set("X-Forwarded-Proto", "https") + req.AddCookie(&http.Cookie{Name: "session", Value: "s1"}) + + gc, _ := gin.CreateTestContext(httptest.NewRecorder()) + gc.Request = req + + meta := MetaFromGin(gc) + assert.Equal(t, "https://auth.example.com", meta.HostURL) + assert.Equal(t, "10.1.2.3", meta.IPAddress) + assert.Equal(t, "AuthorizerTest/1.0", meta.UserAgent) + assert.Equal(t, "Bearer abc", meta.AuthorizationHeader) + require.Len(t, meta.Cookies, 1) + assert.Equal(t, "session", meta.Cookies[0].Name) + assert.Same(t, req, meta.Request, "Request escape hatch must be the same pointer") +} + +func TestApplyToGin_WritesCookies(t *testing.T) { + w := httptest.NewRecorder() + gc, _ := gin.CreateTestContext(w) + gc.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + side := &ResponseSideEffects{} + side.AddCookie(&http.Cookie{ + Name: "authorizer_session", + Value: "abc", + MaxAge: 60, + Path: "/", + Domain: "auth.example.com", + Secure: true, + HttpOnly: true, + SameSite: http.SameSiteNoneMode, + }) + side.AddCookie(&http.Cookie{ + Name: "authorizer_session_domain", + Value: "abc", + MaxAge: 60, + Path: "/", + Domain: ".example.com", + Secure: true, + HttpOnly: true, + SameSite: http.SameSiteNoneMode, + }) + ApplyToGin(gc, side) + + setCookies := w.Result().Header.Values("Set-Cookie") + require.Len(t, setCookies, 2) + assert.Contains(t, setCookies[0], "authorizer_session=abc") + assert.Contains(t, setCookies[0], "Domain=auth.example.com") + assert.Contains(t, setCookies[1], "Domain=example.com") + for _, c := range setCookies { + assert.Contains(t, c, "HttpOnly") + assert.Contains(t, c, "Secure") + assert.Contains(t, c, "SameSite=None") + } +} + +func TestApplyToGin_NilSafe(t *testing.T) { + // nil receiver / nil gc must not panic. + gc, _ := gin.CreateTestContext(httptest.NewRecorder()) + ApplyToGin(gc, nil) + ApplyToGin(nil, &ResponseSideEffects{Cookies: []*http.Cookie{{Name: "x"}}}) + + // nil cookie inside slice should be skipped. + w := httptest.NewRecorder() + gc2, _ := gin.CreateTestContext(w) + gc2.Request = httptest.NewRequest(http.MethodGet, "/", nil) + ApplyToGin(gc2, &ResponseSideEffects{Cookies: []*http.Cookie{nil, {Name: "ok", Value: "v"}}}) + assert.Len(t, w.Result().Header.Values("Set-Cookie"), 1) +} + +func TestResponseSideEffects_AddCookieNilSafe(t *testing.T) { + s := &ResponseSideEffects{} + s.AddCookie(nil) + assert.Empty(t, s.Cookies) + s.AddCookie(&http.Cookie{Name: "x"}) + assert.Len(t, s.Cookies, 1) +}