-
Notifications
You must be signed in to change notification settings - Fork 571
feat: client-side streamable-http transport supports continuously listening #317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 12 commits
472f442
292bcea
b7cf1a3
6449b15
6a05fc6
0a1a9e9
6435882
8d3f236
c0f4403
cc540bb
42ba0ff
928d9ea
21316a0
b6ca548
4313aa1
1f5efb5
a6ad665
50f9c47
32f36b9
f8b7dce
c706c93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,10 +17,24 @@ import ( | |
"time" | ||
|
||
"github.com/mark3labs/mcp-go/mcp" | ||
"github.com/mark3labs/mcp-go/util" | ||
) | ||
|
||
type StreamableHTTPCOption func(*StreamableHTTP) | ||
|
||
// WithContinuousListening enables receiving server-to-client notifications when no request is in flight. | ||
// In particular, if you want to receive global notifications from the server (like ToolListChangedNotification), | ||
// you should enable this option. | ||
// | ||
// It will establish a standalone long-live GET HTTP connection to the server. | ||
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server | ||
// NOTICE: Even enabled, the server may not support this feature. | ||
func WithContinuousListening() StreamableHTTPCOption { | ||
return func(sc *StreamableHTTP) { | ||
sc.getListeningEnabled = true | ||
} | ||
} | ||
|
||
func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption { | ||
return func(sc *StreamableHTTP) { | ||
sc.headers = headers | ||
|
@@ -40,6 +54,12 @@ func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption { | |
} | ||
} | ||
|
||
func WithLogger(logger util.Logger) StreamableHTTPCOption { | ||
return func(sc *StreamableHTTP) { | ||
sc.logger = logger | ||
} | ||
} | ||
|
||
// WithOAuth enables OAuth authentication for the client. | ||
func WithOAuth(config OAuthConfig) StreamableHTTPCOption { | ||
return func(sc *StreamableHTTP) { | ||
|
@@ -57,19 +77,22 @@ func WithOAuth(config OAuthConfig) StreamableHTTPCOption { | |
// | ||
// The current implementation does not support the following features: | ||
// - batching | ||
// - continuously listening for server notifications when no request is in flight | ||
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server) | ||
// - resuming stream | ||
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery) | ||
// - server -> client request | ||
type StreamableHTTP struct { | ||
serverURL *url.URL | ||
httpClient *http.Client | ||
headers map[string]string | ||
headerFunc HTTPHeaderFunc | ||
serverURL *url.URL | ||
httpClient *http.Client | ||
headers map[string]string | ||
headerFunc HTTPHeaderFunc | ||
logger util.Logger | ||
getListeningEnabled bool | ||
|
||
sessionID atomic.Value // string | ||
|
||
initialized chan struct{} | ||
initializedOnce sync.Once | ||
|
||
notificationHandler func(mcp.JSONRPCNotification) | ||
notifyMu sync.RWMutex | ||
|
||
|
@@ -88,10 +111,12 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str | |
} | ||
|
||
smc := &StreamableHTTP{ | ||
serverURL: parsedURL, | ||
httpClient: &http.Client{}, | ||
headers: make(map[string]string), | ||
closed: make(chan struct{}), | ||
serverURL: parsedURL, | ||
httpClient: &http.Client{}, | ||
headers: make(map[string]string), | ||
closed: make(chan struct{}), | ||
logger: util.DefaultLogger(), | ||
initialized: make(chan struct{}), | ||
} | ||
smc.sessionID.Store("") // set initial value to simplify later usage | ||
|
||
|
@@ -111,7 +136,18 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str | |
|
||
// Start initiates the HTTP connection to the server. | ||
func (c *StreamableHTTP) Start(ctx context.Context) error { | ||
// For Streamable HTTP, we don't need to establish a persistent connection | ||
// For Streamable HTTP, we don't need to establish a persistent connection by default | ||
if c.getListeningEnabled { | ||
go func() { | ||
select { | ||
case <-c.initialized: | ||
c.listenForever() | ||
case <-c.closed: | ||
return | ||
} | ||
}() | ||
} | ||
|
||
return nil | ||
} | ||
|
||
|
@@ -178,77 +214,26 @@ func (c *StreamableHTTP) SendRequest( | |
request JSONRPCRequest, | ||
) (*JSONRPCResponse, error) { | ||
|
||
// Create a combined context that could be canceled when the client is closed | ||
newCtx, cancel := context.WithCancel(ctx) | ||
defer cancel() | ||
go func() { | ||
select { | ||
case <-c.closed: | ||
cancel() | ||
case <-newCtx.Done(): | ||
// The original context was canceled, no need to do anything | ||
} | ||
}() | ||
ctx = newCtx | ||
|
||
// Marshal request | ||
requestBody, err := json.Marshal(request) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to marshal request: %w", err) | ||
} | ||
|
||
// Create HTTP request | ||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody)) | ||
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to create request: %w", err) | ||
} | ||
|
||
// Set headers | ||
req.Header.Set("Content-Type", "application/json") | ||
req.Header.Set("Accept", "application/json, text/event-stream") | ||
sessionID := c.sessionID.Load() | ||
if sessionID != "" { | ||
req.Header.Set(headerKeySessionID, sessionID.(string)) | ||
} | ||
for k, v := range c.headers { | ||
req.Header.Set(k, v) | ||
} | ||
|
||
// Add OAuth authorization if configured | ||
if c.oauthHandler != nil { | ||
authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) | ||
if err != nil { | ||
// If we get an authorization error, return a specific error that can be handled by the client | ||
if err.Error() == "no valid token available, authorization required" { | ||
return nil, &OAuthAuthorizationRequiredError{ | ||
Handler: c.oauthHandler, | ||
} | ||
} | ||
return nil, fmt.Errorf("failed to get authorization header: %w", err) | ||
} | ||
req.Header.Set("Authorization", authHeader) | ||
} | ||
|
||
if c.headerFunc != nil { | ||
for k, v := range c.headerFunc(ctx) { | ||
req.Header.Set(k, v) | ||
if errors.Is(err, errSessionTerminated) && request.Method == string(mcp.MethodInitialize) { | ||
// If the request is initialize, should not return a SessionTerminated error | ||
// It should be a genuine endpoint-routing issue. | ||
// ( Fall through to return StatusCode checking. ) | ||
} else { | ||
return nil, fmt.Errorf("failed to send request: %w", err) | ||
} | ||
} | ||
|
||
// Send request | ||
resp, err := c.httpClient.Do(req) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to send request: %w", err) | ||
} | ||
defer resp.Body.Close() | ||
|
||
// Check if we got an error response | ||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { | ||
// handle session closed | ||
if resp.StatusCode == http.StatusNotFound { | ||
c.sessionID.CompareAndSwap(sessionID, "") | ||
return nil, fmt.Errorf("session terminated (404). need to re-initialize") | ||
} | ||
|
||
// Handle OAuth unauthorized error | ||
if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil { | ||
|
@@ -272,6 +257,10 @@ func (c *StreamableHTTP) SendRequest( | |
if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" { | ||
c.sessionID.Store(sessionID) | ||
} | ||
|
||
c.initializedOnce.Do(func() { | ||
close(c.initialized) | ||
}) | ||
} | ||
|
||
// Handle different response types | ||
|
@@ -300,6 +289,78 @@ func (c *StreamableHTTP) SendRequest( | |
} | ||
} | ||
|
||
func (c *StreamableHTTP) sendHTTP( | ||
ctx context.Context, | ||
method string, | ||
body io.Reader, | ||
acceptType string, | ||
) (resp *http.Response, err error) { | ||
// Create a combined context that could be canceled when the client is closed | ||
newCtx, cancel := context.WithCancel(ctx) | ||
defer cancel() | ||
go func() { | ||
select { | ||
case <-c.closed: | ||
cancel() | ||
case <-newCtx.Done(): | ||
// The original context was canceled, no need to do anything | ||
} | ||
}() | ||
ctx = newCtx | ||
|
||
// Create HTTP request | ||
req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to create request: %w", err) | ||
} | ||
|
||
// Set headers | ||
req.Header.Set("Content-Type", "application/json") | ||
req.Header.Set("Accept", acceptType) | ||
sessionID := c.sessionID.Load().(string) | ||
if sessionID != "" { | ||
req.Header.Set(headerKeySessionID, sessionID) | ||
} | ||
for k, v := range c.headers { | ||
req.Header.Set(k, v) | ||
} | ||
|
||
// Add OAuth authorization if configured | ||
if c.oauthHandler != nil { | ||
authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) | ||
if err != nil { | ||
// If we get an authorization error, return a specific error that can be handled by the client | ||
if err.Error() == "no valid token available, authorization required" { | ||
return nil, &OAuthAuthorizationRequiredError{ | ||
Handler: c.oauthHandler, | ||
} | ||
} | ||
return nil, fmt.Errorf("failed to get authorization header: %w", err) | ||
} | ||
req.Header.Set("Authorization", authHeader) | ||
} | ||
|
||
if c.headerFunc != nil { | ||
for k, v := range c.headerFunc(ctx) { | ||
req.Header.Set(k, v) | ||
} | ||
} | ||
|
||
// Send request | ||
resp, err = c.httpClient.Do(req) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to send request: %w", err) | ||
} | ||
|
||
// universal handling for session terminated | ||
if resp.StatusCode == http.StatusNotFound { | ||
c.sessionID.CompareAndSwap(sessionID, "") | ||
return nil, errSessionTerminated | ||
} | ||
|
||
return resp, nil | ||
} | ||
|
||
// handleSSEResponse processes an SSE stream for a specific request. | ||
// It returns the final result for the request once received, or an error. | ||
func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
@@ -417,44 +478,7 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp. | |
} | ||
|
||
// Create HTTP request | ||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody)) | ||
if err != nil { | ||
return fmt.Errorf("failed to create request: %w", err) | ||
} | ||
|
||
// Set headers | ||
req.Header.Set("Content-Type", "application/json") | ||
req.Header.Set("Accept", "application/json, text/event-stream") | ||
if sessionID := c.sessionID.Load(); sessionID != "" { | ||
req.Header.Set(headerKeySessionID, sessionID.(string)) | ||
} | ||
for k, v := range c.headers { | ||
req.Header.Set(k, v) | ||
} | ||
|
||
// Add OAuth authorization if configured | ||
if c.oauthHandler != nil { | ||
authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) | ||
if err != nil { | ||
// If we get an authorization error, return a specific error that can be handled by the client | ||
if errors.Is(err, ErrOAuthAuthorizationRequired) { | ||
return &OAuthAuthorizationRequiredError{ | ||
Handler: c.oauthHandler, | ||
} | ||
} | ||
return fmt.Errorf("failed to get authorization header: %w", err) | ||
} | ||
req.Header.Set("Authorization", authHeader) | ||
} | ||
|
||
if c.headerFunc != nil { | ||
for k, v := range c.headerFunc(ctx) { | ||
req.Header.Set(k, v) | ||
} | ||
} | ||
|
||
// Send request | ||
resp, err := c.httpClient.Do(req) | ||
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") | ||
if err != nil { | ||
return fmt.Errorf("failed to send request: %w", err) | ||
} | ||
|
@@ -498,3 +522,66 @@ func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler { | |
func (c *StreamableHTTP) IsOAuthEnabled() bool { | ||
return c.oauthHandler != nil | ||
} | ||
|
||
func (c *StreamableHTTP) listenForever() { | ||
c.logger.Infof("listening to server forever") | ||
for { | ||
err := c.createGETConnectionToServer() | ||
if errors.Is(err, errGetMethodNotAllowed) { | ||
// server does not support listening | ||
c.logger.Errorf("server does not support listening") | ||
return | ||
} | ||
|
||
select { | ||
case <-c.closed: | ||
return | ||
default: | ||
} | ||
|
||
if err != nil { | ||
c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err) | ||
} | ||
time.Sleep(retryInterval) | ||
} | ||
} | ||
|
||
var ( | ||
errSessionTerminated = fmt.Errorf("session terminated (404). need to re-initialize") | ||
errGetMethodNotAllowed = fmt.Errorf("GET method not allowed") | ||
|
||
retryInterval = 1 * time.Second // a variable is convenient for testing | ||
) | ||
|
||
func (c *StreamableHTTP) createGETConnectionToServer() error { | ||
|
||
ctx := context.Background() // the sendHTTP will be automatically canceled when the client is closed | ||
resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream") | ||
if err != nil { | ||
return fmt.Errorf("failed to send request: %w", err) | ||
} | ||
defer resp.Body.Close() | ||
|
||
// Check if we got an error response | ||
if resp.StatusCode == http.StatusMethodNotAllowed { | ||
return errGetMethodNotAllowed | ||
} | ||
|
||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { | ||
body, _ := io.ReadAll(resp.Body) | ||
return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body) | ||
} | ||
|
||
// handle SSE response | ||
contentType := resp.Header.Get("Content-Type") | ||
if contentType != "text/event-stream" { | ||
return fmt.Errorf("unexpected content type: %s", contentType) | ||
} | ||
|
||
_, err = c.handleSSEResponse(ctx, resp.Body) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We currently ignore the response here. This read_stream is initialized during client setup, and it's the same stream shared by both the GET and POST handlers. That said, the intended behavior here is still somewhat unclear—the MCP spec doesn't explicitly define whether GET responses must be surfaced to the client, so it's possible the current handling is valid, but worth clarifying. |
||
if err != nil { | ||
return fmt.Errorf("failed to handle SSE response: %w", err) | ||
} | ||
|
||
return nil | ||
} |
Uh oh!
There was an error while loading. Please reload this page.