From 54ca1d530c0f66a604e78f34f6bf78e6461a5aad Mon Sep 17 00:00:00 2001 From: Rafael Dantas Justo Date: Thu, 21 Aug 2025 10:06:47 -0300 Subject: [PATCH] Enhancement: Allow some protocol methods to bypass authentication This is a requirement for the Docker MCP registry. Related to: https://github.com/docker/mcp-registry/pull/164 --- cmd/mcp-http/main.go | 33 +++++++++++++- cmd/mcp-stdio/main.go | 103 ++++++++++++++++++++++++++++++++---------- 2 files changed, 111 insertions(+), 25 deletions(-) diff --git a/cmd/mcp-http/main.go b/cmd/mcp-http/main.go index 0e9c7fd..99e8a8b 100644 --- a/cmd/mcp-http/main.go +++ b/cmd/mcp-http/main.go @@ -8,6 +8,7 @@ import ( "os" "os/signal" "regexp" + "slices" "strings" "syscall" "time" @@ -173,13 +174,41 @@ func tracerMiddleware(resources config.Resources, next http.Handler) http.Handle } func authMiddleware(resources config.Resources, next http.Handler) http.Handler { + whitelistEndpoints := map[string][]string{ + // health checks don't require authentication + "/api/health": {http.MethodGet, http.MethodOptions}, + + // allow some protocol methods to bypass authentication + // + // https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#listing-tools + // https://modelcontextprotocol.io/specification/2025-06-18/server/resources#listing-resources + // https://modelcontextprotocol.io/specification/2025-06-18/server/resources#resource-templates + // https://modelcontextprotocol.io/specification/2025-06-18/server/prompts#listing-prompts + "/": {http.MethodPost}, + "/tools/list": {http.MethodPost}, + "/resources/list": {http.MethodPost}, + "/resources/templates/list": {http.MethodPost}, + "/prompts/list": {http.MethodPost}, + } + + whitelistPrefixEndpoints := map[string][]string{ + // OAuth2 endpoints cannot require authentication + "/.well-known": {"GET", "OPTIONS"}, + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // some endpoints don't require auth - if (r.URL.Path == "/api/health" || strings.HasPrefix(r.URL.Path, "/.well-known")) && - (r.Method == http.MethodGet || r.Method == http.MethodOptions) { + if methods, ok := whitelistEndpoints[r.URL.Path]; ok && slices.Contains(methods, r.Method) { next.ServeHTTP(w, r) return } + for prefix, methods := range whitelistPrefixEndpoints { + if strings.HasPrefix(r.URL.Path, prefix) && slices.Contains(methods, r.Method) { + next.ServeHTTP(w, r) + return + } + } requestLogger := resources.Logger().With( slog.String("method", r.Method), diff --git a/cmd/mcp-stdio/main.go b/cmd/mcp-stdio/main.go index bfba1ea..bc22080 100644 --- a/cmd/mcp-stdio/main.go +++ b/cmd/mcp-stdio/main.go @@ -8,6 +8,7 @@ import ( "fmt" "log/slog" "os" + "slices" "strings" "github.com/mark3labs/mcp-go/mcp" @@ -20,7 +21,23 @@ import ( ) var ( - methods = methodsInput([]toolsets.Method{toolsets.MethodAll}) + methods = methodsInput([]toolsets.Method{toolsets.MethodAll}) + methodsWhitelist = []string{ + // allow some protocol methods to bypass authentication + // + // https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#listing-tools + // https://modelcontextprotocol.io/specification/2025-06-18/server/resources#listing-resources + // https://modelcontextprotocol.io/specification/2025-06-18/server/resources#resource-templates + // https://modelcontextprotocol.io/specification/2025-06-18/server/prompts#listing-prompts + "initialize", + "notifications/initialized", + "logging/setLevel", + "tools/list", + "resources/list", + "resources/templates/list", + "prompts/list", + } readOnly bool ) @@ -34,33 +51,31 @@ func main() { flag.BoolVar(&readOnly, "read-only", false, "Restrict the server to read-only operations") flag.Parse() - if resources.Info.BearerToken == "" { - mcpError(resources, errors.New("TW_MCP_BEARER_TOKEN environment variable is not set"), mcp.INVALID_PARAMS) - exit(exitCodeSetupFailure) - } - ctx := context.Background() - // detect the installation from the bearer token - info, err := auth.GetBearerInfo(ctx, resources, resources.Info.BearerToken) - if err != nil { - mcpError(resources, fmt.Errorf("failed to authenticate: %s", err), mcp.INVALID_PARAMS) - exit(exitCodeSetupFailure) - } + if resources.Info.BearerToken != "" { + // detect the installation from the bearer token + info, err := auth.GetBearerInfo(ctx, resources, resources.Info.BearerToken) + if err != nil { + mcpError(resources.Logger(), fmt.Errorf("failed to authenticate: %s", err), mcp.INVALID_PARAMS) + exit(exitCodeSetupFailure) + } - // inject customer URL in the context - ctx = config.WithCustomerURL(ctx, info.URL) - // inject bearer token in the context - ctx = session.WithBearerTokenContext(ctx, session.NewBearerToken(resources.Info.BearerToken, info.URL)) + // inject customer URL in the context + ctx = config.WithCustomerURL(ctx, info.URL) + // inject bearer token in the context + ctx = session.WithBearerTokenContext(ctx, session.NewBearerToken(resources.Info.BearerToken, info.URL)) + } mcpServer, err := newMCPServer(resources) if err != nil { - mcpError(resources, fmt.Errorf("failed to create MCP server: %s", err), mcp.INTERNAL_ERROR) + mcpError(resources.Logger(), fmt.Errorf("failed to create MCP server: %s", err), mcp.INTERNAL_ERROR) exit(exitCodeSetupFailure) } mcpSTDIOServer := server.NewStdioServer(mcpServer) - if err := mcpSTDIOServer.Listen(ctx, os.Stdin, os.Stdout); err != nil { - mcpError(resources, fmt.Errorf("failed to serve: %s", err), mcp.INTERNAL_ERROR) + stdinWrapper := newStdinWrapper(resources.Logger(), resources.Info.BearerToken != "", methodsWhitelist) + if err := mcpSTDIOServer.Listen(ctx, stdinWrapper, os.Stdout); err != nil { + mcpError(resources.Logger(), fmt.Errorf("failed to serve: %s", err), mcp.INTERNAL_ERROR) exit(exitCodeSetupFailure) } } @@ -73,14 +88,16 @@ func newMCPServer(resources config.Resources) (*server.MCPServer, error) { return config.NewMCPServer(resources, group), nil } -func mcpError(resources config.Resources, err error, code int) { +func mcpError(logger *slog.Logger, err error, code int) { mcpError := mcp.NewJSONRPCError(mcp.NewRequestId("startup"), code, err.Error(), nil) - encoder := json.NewEncoder(os.Stdout) - if err := encoder.Encode(mcpError); err != nil { - resources.Logger().Error("failed to encode error", + encoded, err := json.Marshal(mcpError) + if err != nil { + logger.Error("failed to encode error", slog.String("error", err.Error()), ) + return } + fmt.Printf("%s\n", string(encoded)) } type methodsInput []toolsets.Method @@ -110,6 +127,46 @@ func (t *methodsInput) Set(value string) error { return errs } +type stdinWrapper struct { + logger *slog.Logger + authenticated bool + methodsWhitelist []string +} + +func newStdinWrapper(logger *slog.Logger, authenticated bool, methods []string) stdinWrapper { + return stdinWrapper{ + logger: logger, + authenticated: authenticated, + methodsWhitelist: methods, + } +} + +func (s stdinWrapper) Read(p []byte) (n int, err error) { + if s.authenticated { + return os.Stdin.Read(p) + } + buffer := make([]byte, len(p)) + n, err = os.Stdin.Read(buffer) + if err != nil { + return n, err + } + content := buffer[:n] + if len(content) == 0 { + return n, err + } + var baseMessage struct { + Method string `json:"method"` + } + if err := json.Unmarshal(content, &baseMessage); err != nil { + return 0, errors.New("parse error") + } + if !slices.Contains(s.methodsWhitelist, baseMessage.Method) { + return 0, errors.New("not authenticated") + } + copy(p, buffer) + return n, err +} + type exitCode int const (