Skip to content

Commit 013c047

Browse files
committed
Use correct mutex
1 parent 5378d0f commit 013c047

File tree

4 files changed

+42
-43
lines changed

4 files changed

+42
-43
lines changed

client/transport/stdio_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ func compileTestServer(outputPath string) error {
3232
func TestStdio(t *testing.T) {
3333
// Compile mock server
3434
mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server")
35-
// Add .exe suffix on Windows
36-
if runtime.GOOS == "windows" {
37-
mockServerPath += ".exe"
38-
}
35+
// Add .exe suffix on Windows
36+
if runtime.GOOS == "windows" {
37+
mockServerPath += ".exe"
38+
}
3939
if err := compileTestServer(mockServerPath); err != nil {
4040
t.Fatalf("Failed to compile mock server: %v", err)
4141
}

server/server.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,9 +421,9 @@ func (s *MCPServer) AddResource(
421421

422422
// RemoveResource removes a resource from the server
423423
func (s *MCPServer) RemoveResource(uri string) {
424-
s.mu.Lock()
424+
s.resourcesMu.Lock()
425425
delete(s.resources, uri)
426-
s.mu.Unlock()
426+
s.resourcesMu.Unlock()
427427

428428
// Send notification to all initialized sessions if listChanged capability is enabled
429429
if s.capabilities.resources != nil && s.capabilities.resources.listChanged {

server/sse.go

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,20 @@ var _ ClientSession = (*sseSession)(nil)
5454
// SSEServer implements a Server-Sent Events (SSE) based MCP server.
5555
// It provides real-time communication capabilities over HTTP using the SSE protocol.
5656
type SSEServer struct {
57-
server *MCPServer
58-
baseURL string
59-
basePath string
60-
useFullURLForMessageEndpoint bool
61-
messageEndpoint string
62-
sseEndpoint string
63-
sessions sync.Map
64-
srv *http.Server
65-
contextFunc SSEContextFunc
57+
server *MCPServer
58+
baseURL string
59+
basePath string
60+
useFullURLForMessageEndpoint bool
61+
messageEndpoint string
62+
sseEndpoint string
63+
sessions sync.Map
64+
srv *http.Server
65+
contextFunc SSEContextFunc
6666

6767
keepAlive bool
6868
keepAliveInterval time.Duration
69-
70-
mu sync.RWMutex
69+
70+
mu sync.RWMutex
7171
}
7272

7373
// SSEOption defines a function type for configuring SSEServer
@@ -161,12 +161,12 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
161161
// NewSSEServer creates a new SSE server instance with the given MCP server and options.
162162
func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
163163
s := &SSEServer{
164-
server: server,
165-
sseEndpoint: "/sse",
166-
messageEndpoint: "/message",
167-
useFullURLForMessageEndpoint: true,
168-
keepAlive: false,
169-
keepAliveInterval: 10 * time.Second,
164+
server: server,
165+
sseEndpoint: "/sse",
166+
messageEndpoint: "/message",
167+
useFullURLForMessageEndpoint: true,
168+
keepAlive: false,
169+
keepAliveInterval: 10 * time.Second,
170170
}
171171

172172
// Apply all options
@@ -310,7 +310,6 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
310310
}()
311311
}
312312

313-
314313
// Send the initial endpoint event
315314
fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", s.GetMessageEndpointForClient(sessionID))
316315
flusher.Flush()

server/sse_test.go

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ import (
55
"bytes"
66
"context"
77
"encoding/json"
8-
"io"
98
"fmt"
9+
"io"
1010
"math/rand"
1111
"net/http"
1212
"net/http/httptest"
@@ -744,7 +744,7 @@ func TestSSEServer(t *testing.T) {
744744

745745
t.Run("Client receives and can respond to ping messages", func(t *testing.T) {
746746
mcpServer := NewMCPServer("test", "1.0.0")
747-
testServer := NewTestServer(mcpServer,
747+
testServer := NewTestServer(mcpServer,
748748
WithKeepAlive(true),
749749
WithKeepAliveInterval(50*time.Millisecond),
750750
)
@@ -757,73 +757,73 @@ func TestSSEServer(t *testing.T) {
757757
defer sseResp.Body.Close()
758758

759759
reader := bufio.NewReader(sseResp.Body)
760-
760+
761761
var messageURL string
762762
var pingID float64
763-
763+
764764
for {
765765
line, err := reader.ReadString('\n')
766766
if err != nil {
767767
t.Fatalf("Failed to read SSE event: %v", err)
768768
}
769-
769+
770770
if strings.HasPrefix(line, "event: endpoint") {
771771
dataLine, err := reader.ReadString('\n')
772772
if err != nil {
773773
t.Fatalf("Failed to read endpoint data: %v", err)
774774
}
775775
messageURL = strings.TrimSpace(strings.TrimPrefix(dataLine, "data: "))
776-
776+
777777
_, err = reader.ReadString('\n')
778778
if err != nil {
779779
t.Fatalf("Failed to read blank line: %v", err)
780780
}
781781
}
782-
782+
783783
if strings.HasPrefix(line, "event: message") {
784784
dataLine, err := reader.ReadString('\n')
785785
if err != nil {
786786
t.Fatalf("Failed to read message data: %v", err)
787787
}
788-
788+
789789
pingData := strings.TrimSpace(strings.TrimPrefix(dataLine, "data:"))
790790
var pingMsg mcp.JSONRPCRequest
791791
if err := json.Unmarshal([]byte(pingData), &pingMsg); err != nil {
792792
t.Fatalf("Failed to parse ping message: %v", err)
793793
}
794-
794+
795795
if pingMsg.Method == "ping" {
796796
pingID = pingMsg.ID.(float64)
797797
t.Logf("Received ping with ID: %f", pingID)
798798
break // We got the ping, exit the loop
799799
}
800-
800+
801801
_, err = reader.ReadString('\n')
802802
if err != nil {
803803
t.Fatalf("Failed to read blank line: %v", err)
804804
}
805805
}
806-
806+
807807
if messageURL != "" && pingID != 0 {
808808
break
809809
}
810810
}
811-
811+
812812
if messageURL == "" {
813813
t.Fatal("Did not receive message endpoint URL")
814814
}
815-
815+
816816
pingResponse := map[string]any{
817817
"jsonrpc": "2.0",
818818
"id": pingID,
819819
"result": map[string]any{},
820820
}
821-
821+
822822
requestBody, err := json.Marshal(pingResponse)
823823
if err != nil {
824824
t.Fatalf("Failed to marshal ping response: %v", err)
825825
}
826-
826+
827827
resp, err := http.Post(
828828
messageURL,
829829
"application/json",
@@ -833,22 +833,22 @@ func TestSSEServer(t *testing.T) {
833833
t.Fatalf("Failed to send ping response: %v", err)
834834
}
835835
defer resp.Body.Close()
836-
836+
837837
if resp.StatusCode != http.StatusAccepted {
838838
t.Errorf("Expected status 202 for ping response, got %d", resp.StatusCode)
839839
}
840-
840+
841841
body, err := io.ReadAll(resp.Body)
842842
if err != nil {
843843
t.Fatalf("Failed to read response body: %v", err)
844844
}
845-
845+
846846
if len(body) > 0 {
847847
var response map[string]any
848848
if err := json.Unmarshal(body, &response); err != nil {
849849
t.Fatalf("Failed to parse response body: %v", err)
850850
}
851-
851+
852852
if response["error"] != nil {
853853
t.Errorf("Expected no error in response, got %v", response["error"])
854854
}

0 commit comments

Comments
 (0)