From 46801ae7fa333c855d16930f8b88411ca93545be Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Tue, 2 Dec 2025 20:15:25 +0000 Subject: [PATCH] feat: add header-based authentication for HTTP transports Adds optional token auth via AUTH_HEADER and AUTH_VALUE env vars. When configured, requests without the correct header receive 401. This is supposed to be used in tests. --- cmd/yardstick-server/main.go | 30 +++++++++++- cmd/yardstick-server/main_test.go | 77 +++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 2 deletions(-) diff --git a/cmd/yardstick-server/main.go b/cmd/yardstick-server/main.go index 5722db3..8bc3b33 100644 --- a/cmd/yardstick-server/main.go +++ b/cmd/yardstick-server/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "flag" "fmt" "log" @@ -28,11 +29,33 @@ type EchoResponse struct { var alphanumericRegex = regexp.MustCompile(`^[a-zA-Z0-9]+$`) var transport string var port int +var authHeader string +var authValue string func validateAlphanumeric(input string) bool { return alphanumericRegex.MatchString(input) } +func checkAuth(r *http.Request) error { + if authHeader == "" { + return nil + } + if r.Header.Get(authHeader) != authValue { + return errors.New("unauthorized") + } + return nil +} + +func authWrapper(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := checkAuth(r); err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) +} + func echoHandler(_ context.Context, _ *mcp.CallToolRequest, params EchoRequest) (*mcp.CallToolResult, EchoResponse, error) { if !validateAlphanumeric(params.Input) { return &mcp.CallToolResult{ @@ -98,7 +121,7 @@ func main() { }, nil) // Mount the SSE handler at /sse - it will handle both GET (SSE stream) and POST (messages) requests - http.Handle("/sse", handler) + http.Handle("/sse", authWrapper(handler)) // Create server with timeouts to address G114 gosec issue srv := &http.Server{ @@ -116,7 +139,7 @@ func main() { return server }, nil) - http.Handle("/mcp", handler) + http.Handle("/mcp", authWrapper(handler)) // Create server with timeouts to address G114 gosec issue srv := &http.Server{ @@ -150,4 +173,7 @@ func parseConfig() { port = intValue } } + + authHeader = os.Getenv("AUTH_HEADER") + authValue = os.Getenv("AUTH_VALUE") } diff --git a/cmd/yardstick-server/main_test.go b/cmd/yardstick-server/main_test.go index 26fb6d8..9d79f50 100644 --- a/cmd/yardstick-server/main_test.go +++ b/cmd/yardstick-server/main_test.go @@ -2,6 +2,7 @@ package main import ( "context" + "net/http" "testing" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -111,3 +112,79 @@ func TestEchoResponseCreation(t *testing.T) { response := EchoResponse{Output: "test123"} assert.Equal(t, "test123", response.Output) } + +func TestCheckAuth_HeaderAuth(t *testing.T) { + // Save original values + origHeader := authHeader + origValue := authValue + defer func() { + authHeader = origHeader + authValue = origValue + }() + + // Set auth config + authHeader = "X-Auth-Token" + authValue = "secret123" + + // Create request with correct header + req, err := http.NewRequest(http.MethodGet, "/test", nil) + assert.NoError(t, err) + req.Header.Set("X-Auth-Token", "secret123") + + // Should pass authentication + err = checkAuth(req) + assert.NoError(t, err) +} + +func TestCheckAuth_HeaderAuth_Fail(t *testing.T) { + // Save original values + origHeader := authHeader + origValue := authValue + defer func() { + authHeader = origHeader + authValue = origValue + }() + + // Set auth config + authHeader = "X-Auth-Token" + authValue = "secret123" + + // Test with wrong header value + req, err := http.NewRequest(http.MethodGet, "/test", nil) + assert.NoError(t, err) + req.Header.Set("X-Auth-Token", "wrongvalue") + + err = checkAuth(req) + assert.Error(t, err) + assert.Equal(t, "unauthorized", err.Error()) + + // Test with missing header + req2, err := http.NewRequest(http.MethodGet, "/test", nil) + assert.NoError(t, err) + + err = checkAuth(req2) + assert.Error(t, err) + assert.Equal(t, "unauthorized", err.Error()) +} + +func TestCheckAuth_Disabled(t *testing.T) { + // Save original values + origHeader := authHeader + origValue := authValue + defer func() { + authHeader = origHeader + authValue = origValue + }() + + // Auth disabled when authHeader is empty + authHeader = "" + authValue = "" + + // Create request without any auth header + req, err := http.NewRequest(http.MethodGet, "/test", nil) + assert.NoError(t, err) + + // Should pass since auth is disabled + err = checkAuth(req) + assert.NoError(t, err) +}