Skip to content

feat: client-side streamable-http transport supports continuously listening #317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 194 additions & 107 deletions client/transport/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,24 @@ import (
"time"

"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/util"
)

type StreamableHTTPCOption func(*StreamableHTTP)

// WithContinuousListening enables receiving server-to-client notifications when no request is in flight.
// In particular, if you want to receive global notifications from the server (like ToolListChangedNotification),
// you should enable this option.
//
// It will establish a standalone long-live GET HTTP connection to the server.
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
// NOTICE: Even enabled, the server may not support this feature.
func WithContinuousListening() StreamableHTTPCOption {
return func(sc *StreamableHTTP) {
sc.getListeningEnabled = true
}
}

func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption {
return func(sc *StreamableHTTP) {
sc.headers = headers
Expand All @@ -40,6 +54,12 @@ func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption {
}
}

func WithLogger(logger util.Logger) StreamableHTTPCOption {
return func(sc *StreamableHTTP) {
sc.logger = logger
}
}

// WithOAuth enables OAuth authentication for the client.
func WithOAuth(config OAuthConfig) StreamableHTTPCOption {
return func(sc *StreamableHTTP) {
Expand All @@ -57,19 +77,22 @@ func WithOAuth(config OAuthConfig) StreamableHTTPCOption {
//
// The current implementation does not support the following features:
// - batching
// - continuously listening for server notifications when no request is in flight
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server)
// - resuming stream
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
// - server -> client request
type StreamableHTTP struct {
serverURL *url.URL
httpClient *http.Client
headers map[string]string
headerFunc HTTPHeaderFunc
serverURL *url.URL
httpClient *http.Client
headers map[string]string
headerFunc HTTPHeaderFunc
logger util.Logger
getListeningEnabled bool

sessionID atomic.Value // string

initialized chan struct{}
initializedOnce sync.Once

notificationHandler func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex

Expand All @@ -88,10 +111,12 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str
}

smc := &StreamableHTTP{
serverURL: parsedURL,
httpClient: &http.Client{},
headers: make(map[string]string),
closed: make(chan struct{}),
serverURL: parsedURL,
httpClient: &http.Client{},
headers: make(map[string]string),
closed: make(chan struct{}),
logger: util.DefaultLogger(),
initialized: make(chan struct{}),
}
smc.sessionID.Store("") // set initial value to simplify later usage

Expand All @@ -111,7 +136,18 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str

// Start initiates the HTTP connection to the server.
func (c *StreamableHTTP) Start(ctx context.Context) error {
// For Streamable HTTP, we don't need to establish a persistent connection
// For Streamable HTTP, we don't need to establish a persistent connection by default
if c.getListeningEnabled {
go func() {
select {
case <-c.initialized:
c.listenForever()
case <-c.closed:
return
}
}()
}

return nil
}

Expand Down Expand Up @@ -178,77 +214,26 @@ func (c *StreamableHTTP) SendRequest(
request JSONRPCRequest,
) (*JSONRPCResponse, error) {

// Create a combined context that could be canceled when the client is closed
newCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-c.closed:
cancel()
case <-newCtx.Done():
// The original context was canceled, no need to do anything
}
}()
ctx = newCtx

// Marshal request
requestBody, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}

// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
sessionID := c.sessionID.Load()
if sessionID != "" {
req.Header.Set(headerKeySessionID, sessionID.(string))
}
for k, v := range c.headers {
req.Header.Set(k, v)
}

// Add OAuth authorization if configured
if c.oauthHandler != nil {
authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
if err != nil {
// If we get an authorization error, return a specific error that can be handled by the client
if err.Error() == "no valid token available, authorization required" {
return nil, &OAuthAuthorizationRequiredError{
Handler: c.oauthHandler,
}
}
return nil, fmt.Errorf("failed to get authorization header: %w", err)
}
req.Header.Set("Authorization", authHeader)
}

if c.headerFunc != nil {
for k, v := range c.headerFunc(ctx) {
req.Header.Set(k, v)
if errors.Is(err, errSessionTerminated) && request.Method == string(mcp.MethodInitialize) {
// If the request is initialize, should not return a SessionTerminated error
// It should be a genuine endpoint-routing issue.
// ( Fall through to return StatusCode checking. )
} else {
return nil, fmt.Errorf("failed to send request: %w", err)
}
}

// Send request
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()

// Check if we got an error response
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
// handle session closed
if resp.StatusCode == http.StatusNotFound {
c.sessionID.CompareAndSwap(sessionID, "")
return nil, fmt.Errorf("session terminated (404). need to re-initialize")
}

// Handle OAuth unauthorized error
if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
Expand All @@ -272,6 +257,10 @@ func (c *StreamableHTTP) SendRequest(
if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
c.sessionID.Store(sessionID)
}

c.initializedOnce.Do(func() {
close(c.initialized)
})
}

// Handle different response types
Expand Down Expand Up @@ -300,6 +289,78 @@ func (c *StreamableHTTP) SendRequest(
}
}

func (c *StreamableHTTP) sendHTTP(
ctx context.Context,
method string,
body io.Reader,
acceptType string,
) (resp *http.Response, err error) {
// Create a combined context that could be canceled when the client is closed
newCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-c.closed:
cancel()
case <-newCtx.Done():
// The original context was canceled, no need to do anything
}
}()
ctx = newCtx

// Create HTTP request
req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", acceptType)
sessionID := c.sessionID.Load().(string)
if sessionID != "" {
req.Header.Set(headerKeySessionID, sessionID)
}
for k, v := range c.headers {
req.Header.Set(k, v)
}

// Add OAuth authorization if configured
if c.oauthHandler != nil {
authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
if err != nil {
// If we get an authorization error, return a specific error that can be handled by the client
if err.Error() == "no valid token available, authorization required" {
return nil, &OAuthAuthorizationRequiredError{
Handler: c.oauthHandler,
}
}
return nil, fmt.Errorf("failed to get authorization header: %w", err)
}
req.Header.Set("Authorization", authHeader)
}

if c.headerFunc != nil {
for k, v := range c.headerFunc(ctx) {
req.Header.Set(k, v)
}
}

// Send request
resp, err = c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}

// universal handling for session terminated
if resp.StatusCode == http.StatusNotFound {
c.sessionID.CompareAndSwap(sessionID, "")
return nil, errSessionTerminated
}

return resp, nil
}

// handleSSEResponse processes an SSE stream for a specific request.
// It returns the final result for the request once received, or an error.
func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When readSSE#ReadString occur error, reader will be closed and won't receive notification anymore, but handleSSEResponse will wait ctx.Done(). Now, client will ignore all the notification, am I right?
image

image

Expand Down Expand Up @@ -417,44 +478,7 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.
}

// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}

// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
if sessionID := c.sessionID.Load(); sessionID != "" {
req.Header.Set(headerKeySessionID, sessionID.(string))
}
for k, v := range c.headers {
req.Header.Set(k, v)
}

// Add OAuth authorization if configured
if c.oauthHandler != nil {
authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
if err != nil {
// If we get an authorization error, return a specific error that can be handled by the client
if errors.Is(err, ErrOAuthAuthorizationRequired) {
return &OAuthAuthorizationRequiredError{
Handler: c.oauthHandler,
}
}
return fmt.Errorf("failed to get authorization header: %w", err)
}
req.Header.Set("Authorization", authHeader)
}

if c.headerFunc != nil {
for k, v := range c.headerFunc(ctx) {
req.Header.Set(k, v)
}
}

// Send request
resp, err := c.httpClient.Do(req)
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
Expand Down Expand Up @@ -498,3 +522,66 @@ func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler {
func (c *StreamableHTTP) IsOAuthEnabled() bool {
return c.oauthHandler != nil
}

func (c *StreamableHTTP) listenForever() {
c.logger.Infof("listening to server forever")
for {
err := c.createGETConnectionToServer()
if errors.Is(err, errGetMethodNotAllowed) {
// server does not support listening
c.logger.Errorf("server does not support listening")
return
}

select {
case <-c.closed:
return
default:
}

if err != nil {
c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err)
}
time.Sleep(retryInterval)
}
}

var (
errSessionTerminated = fmt.Errorf("session terminated (404). need to re-initialize")
errGetMethodNotAllowed = fmt.Errorf("GET method not allowed")

retryInterval = 1 * time.Second // a variable is convenient for testing
)

func (c *StreamableHTTP) createGETConnectionToServer() error {

ctx := context.Background() // the sendHTTP will be automatically canceled when the client is closed
resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream")
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()

// Check if we got an error response
if resp.StatusCode == http.StatusMethodNotAllowed {
return errGetMethodNotAllowed
}

if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
}

// handle SSE response
contentType := resp.Header.Get("Content-Type")
if contentType != "text/event-stream" {
return fmt.Errorf("unexpected content type: %s", contentType)
}

_, err = c.handleSSEResponse(ctx, resp.Body)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We currently ignore the response here.
However, if you refer to the Python SDK implementation, you'll see that server responses are actively written to a read_stream—a memory stream used for receiving messages.

This read_stream is initialized during client setup, and it's the same stream shared by both the GET and POST handlers.

That said, the intended behavior here is still somewhat unclear—the MCP spec doesn't explicitly define whether GET responses must be surfaced to the client, so it's possible the current handling is valid, but worth clarifying.

if err != nil {
return fmt.Errorf("failed to handle SSE response: %w", err)
}

return nil
}
Loading