diff --git a/server/sse.go b/server/sse.go index 3f171a55..41699573 100644 --- a/server/sse.go +++ b/server/sse.go @@ -227,7 +227,9 @@ func WithSSEEndpoint(endpoint string) SSEOption { } } -// WithHTTPServer sets the HTTP server instance +// WithHTTPServer sets the HTTP server instance. +// NOTE: When providing a custom HTTP server, you must handle routing yourself +// If routing is not set up, the server will start but won't handle any MCP requests. func WithHTTPServer(srv *http.Server) SSEOption { return func(s *SSEServer) { s.srv = srv diff --git a/server/streamable_http.go b/server/streamable_http.go index b13577a8..8a989cb0 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -73,6 +73,15 @@ func WithHTTPContextFunc(fn HTTPContextFunc) StreamableHTTPOption { } } +// WithStreamableHTTPServer sets the HTTP server instance for StreamableHTTPServer. +// NOTE: When providing a custom HTTP server, you must handle routing yourself +// If routing is not set up, the server will start but won't handle any MCP requests. +func WithStreamableHTTPServer(srv *http.Server) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.httpServer = srv + } +} + // WithLogger sets the logger for the server func WithLogger(logger util.Logger) StreamableHTTPOption { return func(s *StreamableHTTPServer) { @@ -155,15 +164,24 @@ func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) // s.Start(":8080") func (s *StreamableHTTPServer) Start(addr string) error { s.mu.Lock() - mux := http.NewServeMux() - mux.Handle(s.endpointPath, s) - s.httpServer = &http.Server{ - Addr: addr, - Handler: mux, + if s.httpServer == nil { + mux := http.NewServeMux() + mux.Handle(s.endpointPath, s) + s.httpServer = &http.Server{ + Addr: addr, + Handler: mux, + } + } else { + if s.httpServer.Addr == "" { + s.httpServer.Addr = addr + } else if s.httpServer.Addr != addr { + return fmt.Errorf("conflicting listen address: WithStreamableHTTPServer(%q) vs Start(%q)", s.httpServer.Addr, addr) + } } + srv := s.httpServer s.mu.Unlock() - return s.httpServer.ListenAndServe() + return srv.ListenAndServe() } // Shutdown gracefully stops the server, closing all active sessions diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 9f48eade..7474464f 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -670,6 +670,56 @@ func TestStreamableHTTP_SessionWithTools(t *testing.T) { }) } +func TestStreamableHTTPServer_WithOptions(t *testing.T) { + t.Run("WithStreamableHTTPServer sets httpServer field", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0") + customServer := &http.Server{Addr: ":9999"} + httpServer := NewStreamableHTTPServer(mcpServer, WithStreamableHTTPServer(customServer)) + + if httpServer.httpServer != customServer { + t.Errorf("Expected httpServer to be set to custom server instance, got %v", httpServer.httpServer) + } + }) + + t.Run("Start with conflicting address returns error", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0") + customServer := &http.Server{Addr: ":9999"} + httpServer := NewStreamableHTTPServer(mcpServer, WithStreamableHTTPServer(customServer)) + + err := httpServer.Start(":8888") + if err == nil { + t.Error("Expected error for conflicting address, got nil") + } else if !strings.Contains(err.Error(), "conflicting listen address") { + t.Errorf("Expected error message to contain 'conflicting listen address', got '%s'", err.Error()) + } + }) + + t.Run("Options consistency test", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0") + endpointPath := "/test-mcp" + customServer := &http.Server{} + + // Options to test + options := []StreamableHTTPOption{ + WithEndpointPath(endpointPath), + WithStreamableHTTPServer(customServer), + } + + // Apply options multiple times and verify consistency + for i := 0; i < 10; i++ { + server := NewStreamableHTTPServer(mcpServer, options...) + + if server.endpointPath != endpointPath { + t.Errorf("Expected endpointPath %s, got %s", endpointPath, server.endpointPath) + } + + if server.httpServer != customServer { + t.Errorf("Expected httpServer to match, got %v", server.httpServer) + } + } + }) +} + func postJSON(url string, bodyObject any) (*http.Response, error) { jsonBody, _ := json.Marshal(bodyObject) req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))