Skip to content

Commit b8dc82d

Browse files
feat: Tool Handler Middleware (#123)
* feat: add tool handler middleware capability * docs: add WithRecovery middleware to the calculator mcp server example * docs: add tool handler middleware section to the readme
1 parent 6b923f6 commit b8dc82d

File tree

3 files changed

+87
-15
lines changed

3 files changed

+87
-15
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ func main() {
122122
"1.0.0",
123123
server.WithResourceCapabilities(true, true),
124124
server.WithLogging(),
125+
server.WithRecovery(),
125126
)
126127

127128
// Add a calculator tool
@@ -522,6 +523,12 @@ initialization.
522523
Add the `Hooks` to the server at the time of creation using the
523524
`server.WithHooks` option.
524525

526+
### Tool Handler Middleware
527+
528+
Add middleware to tool call handlers using the `server.WithToolHandlerMiddleware` option. Middlewares can be registered on server creation and are applied on every tool call.
529+
530+
A recovery middleware option is available to recover from panics in a tool call and can be added to the server with the `server.WithRecovery` option.
531+
525532
## Contributing
526533

527534
<details>

server/server.go

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ type PromptHandlerFunc func(ctx context.Context, request mcp.GetPromptRequest) (
4141
// ToolHandlerFunc handles tool calls with given arguments.
4242
type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
4343

44+
// ToolHandlerMiddleware is a middleware function that wraps a ToolHandlerFunc.
45+
type ToolHandlerMiddleware func(ToolHandlerFunc) ToolHandlerFunc
46+
4447
// ServerTool combines a Tool with its ToolHandlerFunc.
4548
type ServerTool struct {
4649
Tool mcp.Tool
@@ -138,20 +141,21 @@ type NotificationHandlerFunc func(ctx context.Context, notification mcp.JSONRPCN
138141
// MCPServer implements a Model Control Protocol server that can handle various types of requests
139142
// including resources, prompts, and tools.
140143
type MCPServer struct {
141-
mu sync.RWMutex // Add mutex for protecting shared resources
142-
name string
143-
version string
144-
instructions string
145-
resources map[string]resourceEntry
146-
resourceTemplates map[string]resourceTemplateEntry
147-
prompts map[string]mcp.Prompt
148-
promptHandlers map[string]PromptHandlerFunc
149-
tools map[string]ServerTool
150-
notificationHandlers map[string]NotificationHandlerFunc
151-
capabilities serverCapabilities
152-
paginationLimit *int
153-
sessions sync.Map
154-
hooks *Hooks
144+
mu sync.RWMutex // Add mutex for protecting shared resources
145+
name string
146+
version string
147+
instructions string
148+
resources map[string]resourceEntry
149+
resourceTemplates map[string]resourceTemplateEntry
150+
prompts map[string]mcp.Prompt
151+
promptHandlers map[string]PromptHandlerFunc
152+
tools map[string]ServerTool
153+
toolHandlerMiddlewares []ToolHandlerMiddleware
154+
notificationHandlers map[string]NotificationHandlerFunc
155+
capabilities serverCapabilities
156+
paginationLimit *int
157+
sessions sync.Map
158+
hooks *Hooks
155159
}
156160

157161
// serverKey is the context key for storing the server instance
@@ -291,6 +295,30 @@ func WithResourceCapabilities(subscribe, listChanged bool) ServerOption {
291295
}
292296
}
293297

298+
// WithToolHandlerMiddleware allows adding a middleware for the
299+
// tool handler call chain.
300+
func WithToolHandlerMiddleware(
301+
toolHandlerMiddleware ToolHandlerMiddleware,
302+
) ServerOption {
303+
return func(s *MCPServer) {
304+
s.toolHandlerMiddlewares = append(s.toolHandlerMiddlewares, toolHandlerMiddleware)
305+
}
306+
}
307+
308+
// WithRecovery adds a middleware that recovers from panics in tool handlers.
309+
func WithRecovery() ServerOption {
310+
return WithToolHandlerMiddleware(func(next ToolHandlerFunc) ToolHandlerFunc {
311+
return func(ctx context.Context, request mcp.CallToolRequest) (result *mcp.CallToolResult, err error) {
312+
defer func() {
313+
if r := recover(); r != nil {
314+
err = fmt.Errorf("panic recovered in %s tool handler: %v", request.Params.Name, r)
315+
}
316+
}()
317+
return next(ctx, request)
318+
}
319+
})
320+
}
321+
294322
// WithHooks allows adding hooks that will be called before or after
295323
// either [all] requests or before / after specific request methods, or else
296324
// prior to returning an error to the client.
@@ -801,7 +829,11 @@ func (s *MCPServer) handleToolCall(
801829
}
802830
}
803831

804-
result, err := tool.Handler(ctx, request)
832+
finalHandler := tool.Handler
833+
for i := len(s.toolHandlerMiddlewares) - 1; i >= 0; i-- {
834+
finalHandler = s.toolHandlerMiddlewares[i](finalHandler)
835+
}
836+
result, err := finalHandler(ctx, request)
805837
if err != nil {
806838
return nil, &requestError{
807839
id: id,

server/server_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,3 +1345,36 @@ func TestMCPServer_WithHooks(t *testing.T) {
13451345
assert.IsType(t, afterPingData[0].msg, onSuccessData[0].msg, "OnSuccess message should be same type as AfterPing message")
13461346
assert.IsType(t, afterPingData[0].res, onSuccessData[0].res, "OnSuccess result should be same type as AfterPing result")
13471347
}
1348+
1349+
func TestMCPServer_WithRecover(t *testing.T) {
1350+
panicToolHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
1351+
panic("test panic")
1352+
}
1353+
1354+
server := NewMCPServer(
1355+
"test-server",
1356+
"1.0.0",
1357+
WithRecovery(),
1358+
)
1359+
1360+
server.AddTool(
1361+
mcp.NewTool("panic-tool"),
1362+
panicToolHandler,
1363+
)
1364+
1365+
response := server.HandleMessage(context.Background(), []byte(`{
1366+
"jsonrpc": "2.0",
1367+
"id": 4,
1368+
"method": "tools/call",
1369+
"params": {
1370+
"name": "panic-tool"
1371+
}
1372+
}`))
1373+
1374+
errorResponse, ok := response.(mcp.JSONRPCError)
1375+
1376+
require.True(t, ok)
1377+
assert.Equal(t, mcp.INTERNAL_ERROR, errorResponse.Error.Code)
1378+
assert.Equal(t, "panic recovered in panic-tool tool handler: test panic", errorResponse.Error.Message)
1379+
assert.Nil(t, errorResponse.Error.Data)
1380+
}

0 commit comments

Comments
 (0)