Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions cmd/yardstick-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"context"
"errors"
"flag"
"fmt"
"log"
Expand All @@ -28,11 +29,33 @@ type EchoResponse struct {
var alphanumericRegex = regexp.MustCompile(`^[a-zA-Z0-9]+$`)
var transport string
var port int
var authHeader string
var authValue string

func validateAlphanumeric(input string) bool {
return alphanumericRegex.MatchString(input)
}

func checkAuth(r *http.Request) error {
if authHeader == "" {
return nil
}
if r.Header.Get(authHeader) != authValue {
return errors.New("unauthorized")
}
return nil
}

func authWrapper(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := checkAuth(r); err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}

func echoHandler(_ context.Context, _ *mcp.CallToolRequest, params EchoRequest) (*mcp.CallToolResult, EchoResponse, error) {
if !validateAlphanumeric(params.Input) {
return &mcp.CallToolResult{
Expand Down Expand Up @@ -98,7 +121,7 @@ func main() {
}, nil)

// Mount the SSE handler at /sse - it will handle both GET (SSE stream) and POST (messages) requests
http.Handle("/sse", handler)
http.Handle("/sse", authWrapper(handler))

// Create server with timeouts to address G114 gosec issue
srv := &http.Server{
Expand All @@ -116,7 +139,7 @@ func main() {
return server
}, nil)

http.Handle("/mcp", handler)
http.Handle("/mcp", authWrapper(handler))

// Create server with timeouts to address G114 gosec issue
srv := &http.Server{
Expand Down Expand Up @@ -150,4 +173,7 @@ func parseConfig() {
port = intValue
}
}

authHeader = os.Getenv("AUTH_HEADER")
authValue = os.Getenv("AUTH_VALUE")
}
77 changes: 77 additions & 0 deletions cmd/yardstick-server/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"context"
"net/http"
"testing"

"github.com/modelcontextprotocol/go-sdk/mcp"
Expand Down Expand Up @@ -111,3 +112,79 @@ func TestEchoResponseCreation(t *testing.T) {
response := EchoResponse{Output: "test123"}
assert.Equal(t, "test123", response.Output)
}

func TestCheckAuth_HeaderAuth(t *testing.T) {
// Save original values
origHeader := authHeader
origValue := authValue
defer func() {
authHeader = origHeader
authValue = origValue
}()

// Set auth config
authHeader = "X-Auth-Token"
authValue = "secret123"

// Create request with correct header
req, err := http.NewRequest(http.MethodGet, "/test", nil)
assert.NoError(t, err)
req.Header.Set("X-Auth-Token", "secret123")

// Should pass authentication
err = checkAuth(req)
assert.NoError(t, err)
}

func TestCheckAuth_HeaderAuth_Fail(t *testing.T) {
// Save original values
origHeader := authHeader
origValue := authValue
defer func() {
authHeader = origHeader
authValue = origValue
}()

// Set auth config
authHeader = "X-Auth-Token"
authValue = "secret123"

// Test with wrong header value
req, err := http.NewRequest(http.MethodGet, "/test", nil)
assert.NoError(t, err)
req.Header.Set("X-Auth-Token", "wrongvalue")

err = checkAuth(req)
assert.Error(t, err)
assert.Equal(t, "unauthorized", err.Error())

// Test with missing header
req2, err := http.NewRequest(http.MethodGet, "/test", nil)
assert.NoError(t, err)

err = checkAuth(req2)
assert.Error(t, err)
assert.Equal(t, "unauthorized", err.Error())
}

func TestCheckAuth_Disabled(t *testing.T) {
// Save original values
origHeader := authHeader
origValue := authValue
defer func() {
authHeader = origHeader
authValue = origValue
}()

// Auth disabled when authHeader is empty
authHeader = ""
authValue = ""

// Create request without any auth header
req, err := http.NewRequest(http.MethodGet, "/test", nil)
assert.NoError(t, err)

// Should pass since auth is disabled
err = checkAuth(req)
assert.NoError(t, err)
}