Skip to content

feat: client-side streamable-http transport supports continuously listening #317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 143 additions & 31 deletions client/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@ package client
import (
"context"
"fmt"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"sync"
"testing"
"time"

"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)

// SafeMap is a thread-safe map wrapper

func TestHTTPClient(t *testing.T) {
hooks := &server.Hooks{}
hooks.AddAfterCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) {
Expand Down Expand Up @@ -47,30 +52,46 @@ func TestHTTPClient(t *testing.T) {
return nil, fmt.Errorf("failed to send notification: %w", err)
}

return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{
Type: "text",
Text: "notification sent successfully",
},
},
}, nil
return mcp.NewToolResultText("notification sent successfully"), nil
},
)

addServerToolfunc := func(name string) {
mcpServer.AddTool(
mcp.NewTool(name),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
server := server.ServerFromContext(ctx)
server.SendNotificationToAllClients("helloToEveryone", map[string]any{
"message": "hello",
})
return mcp.NewToolResultText("done"), nil
},
)
}

testServer := server.NewTestStreamableHTTPServer(mcpServer)
defer testServer.Close()

initRequest := mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{
Name: "test-client2",
Version: "1.0.0",
},
},
}

t.Run("Can receive notification from server", func(t *testing.T) {
client, err := NewStreamableHttpClient(testServer.URL)
if err != nil {
t.Fatalf("create client failed %v", err)
return
}

notificationNum := 0
notificationNum := NewSafeMap()
client.OnNotification(func(notification mcp.JSONRPCNotification) {
notificationNum += 1
notificationNum.Increment(notification.Method)
})

ctx := context.Background()
Expand All @@ -81,31 +102,122 @@ func TestHTTPClient(t *testing.T) {
}

// Initialize
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
Name: "test-client",
Version: "1.0.0",
}

_, err = client.Initialize(ctx, initRequest)
if err != nil {
t.Fatalf("Failed to initialize: %v\n", err)
}

request := mcp.CallToolRequest{}
request.Params.Name = "notify"
result, err := client.CallTool(ctx, request)
if err != nil {
t.Fatalf("CallTool failed: %v", err)
}
t.Run("Can receive notifications related to the request", func(t *testing.T) {
request := mcp.CallToolRequest{}
request.Params.Name = "notify"
result, err := client.CallTool(ctx, request)
if err != nil {
t.Fatalf("CallTool failed: %v", err)
}

if len(result.Content) != 1 {
t.Errorf("Expected 1 content item, got %d", len(result.Content))
}
if len(result.Content) != 1 {
t.Errorf("Expected 1 content item, got %d", len(result.Content))
}

if n := notificationNum.Get("notifications/progress"); n != 1 {
t.Errorf("Expected 1 progross notification item, got %d", n)
}
if n := notificationNum.Len(); n != 1 {
t.Errorf("Expected 1 type of notification, got %d", n)
}
})

t.Run("Can not receive global notifications from server by default", func(t *testing.T) {
addServerToolfunc("hello1")
time.Sleep(time.Millisecond * 50)

helloNotifications := notificationNum.Get("hello1")
if helloNotifications != 0 {
t.Errorf("Expected 0 notification item, got %d", helloNotifications)
}
})

t.Run("Can receive global notifications from server when WithContinuousListening enabled", func(t *testing.T) {

client, err := NewStreamableHttpClient(testServer.URL,
transport.WithContinuousListening())
if err != nil {
t.Fatalf("create client failed %v", err)
return
}
defer client.Close()

notificationNum := NewSafeMap()
client.OnNotification(func(notification mcp.JSONRPCNotification) {
notificationNum.Increment(notification.Method)
})

ctx := context.Background()

if err := client.Start(ctx); err != nil {
t.Fatalf("Failed to start client: %v", err)
return
}

// Initialize
_, err = client.Initialize(ctx, initRequest)
if err != nil {
t.Fatalf("Failed to initialize: %v\n", err)
}

// can receive normal notification
request := mcp.CallToolRequest{}
request.Params.Name = "notify"
_, err = client.CallTool(ctx, request)
if err != nil {
t.Fatalf("CallTool failed: %v", err)
}

if n := notificationNum.Get("notifications/progress"); n != 1 {
t.Errorf("Expected 1 progross notification item, got %d", n)
}
if n := notificationNum.Len(); n != 1 {
t.Errorf("Expected 1 type of notification, got %d", n)
}

// can receive global notification
addServerToolfunc("hello2")
time.Sleep(time.Millisecond * 50) // wait for the notification to be sent as upper action is async

n := notificationNum.Get("notifications/tools/list_changed")
if n != 1 {
t.Errorf("Expected 1 notification item, got %d, %v", n, notificationNum)
}
})

if notificationNum != 1 {
t.Errorf("Expected 1 notification item, got %d", notificationNum)
}
})
}

type SafeMap struct {
mu sync.RWMutex
data map[string]int
}

func NewSafeMap() *SafeMap {
return &SafeMap{
data: make(map[string]int),
}
}

func (sm *SafeMap) Increment(key string) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.data[key]++
}

func (sm *SafeMap) Get(key string) int {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.data[key]
}

func (sm *SafeMap) Len() int {
sm.mu.RLock()
defer sm.mu.RUnlock()
return len(sm.data)
}
Loading
Loading