Skip to content
Merged
1 change: 1 addition & 0 deletions packages/types/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export * from "./experiment.js"
export * from "./global-settings.js"
export * from "./history.js"
export * from "./ipc.js"
export * from "./mcp.js"
export * from "./message.js"
export * from "./mode.js"
export * from "./model.js"
Expand Down
31 changes: 31 additions & 0 deletions packages/types/src/mcp.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import { z } from "zod"

/**
* McpExecutionStatus
*/

export const mcpExecutionStatusSchema = z.discriminatedUnion("status", [
z.object({
executionId: z.string(),
status: z.literal("started"),
serverName: z.string(),
toolName: z.string(),
}),
z.object({
executionId: z.string(),
status: z.literal("output"),
response: z.string(),
}),
z.object({
executionId: z.string(),
status: z.literal("completed"),
response: z.string().optional(),
}),
z.object({
executionId: z.string(),
status: z.literal("error"),
error: z.string().optional(),
}),
])

export type McpExecutionStatus = z.infer<typeof mcpExecutionStatusSchema>
269 changes: 269 additions & 0 deletions src/core/tools/__tests__/useMcpToolTool.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
import { useMcpToolTool } from "../useMcpToolTool"
import { Task } from "../../task/Task"
import { ToolUse } from "../../../shared/tools"
import { formatResponse } from "../../prompts/responses"

// Mock dependencies
jest.mock("../../prompts/responses", () => ({
formatResponse: {
toolResult: jest.fn((result: string) => `Tool result: ${result}`),
toolError: jest.fn((error: string) => `Tool error: ${error}`),
invalidMcpToolArgumentError: jest.fn((server: string, tool: string) => `Invalid args for ${server}:${tool}`),
},
}))

jest.mock("../../../i18n", () => ({
t: jest.fn((key: string, params?: any) => {
if (key === "mcp:errors.invalidJsonArgument" && params?.toolName) {
return `Roo tried to use ${params.toolName} with an invalid JSON argument. Retrying...`
}
return key
}),
}))

describe("useMcpToolTool", () => {
let mockTask: Partial<Task>
let mockAskApproval: jest.Mock
let mockHandleError: jest.Mock
let mockPushToolResult: jest.Mock
let mockRemoveClosingTag: jest.Mock
let mockProviderRef: any

beforeEach(() => {
mockAskApproval = jest.fn()
mockHandleError = jest.fn()
mockPushToolResult = jest.fn()
mockRemoveClosingTag = jest.fn((tag: string, value?: string) => value || "")

mockProviderRef = {
deref: jest.fn().mockReturnValue({
getMcpHub: jest.fn().mockReturnValue({
callTool: jest.fn(),
}),
postMessageToWebview: jest.fn(),
}),
}

mockTask = {
consecutiveMistakeCount: 0,
recordToolError: jest.fn(),
sayAndCreateMissingParamError: jest.fn(),
say: jest.fn(),
ask: jest.fn(),
lastMessageTs: 123456789,
providerRef: mockProviderRef,
}
})

describe("parameter validation", () => {
it("should handle missing server_name", async () => {
const block: ToolUse = {
type: "tool_use",
name: "use_mcp_tool",
params: {
tool_name: "test_tool",
arguments: "{}",
},
partial: false,
}

mockTask.sayAndCreateMissingParamError = jest.fn().mockResolvedValue("Missing server_name error")

await useMcpToolTool(
mockTask as Task,
block,
mockAskApproval,
mockHandleError,
mockPushToolResult,
mockRemoveClosingTag,
)

expect(mockTask.consecutiveMistakeCount).toBe(1)
expect(mockTask.recordToolError).toHaveBeenCalledWith("use_mcp_tool")
expect(mockTask.sayAndCreateMissingParamError).toHaveBeenCalledWith("use_mcp_tool", "server_name")
expect(mockPushToolResult).toHaveBeenCalledWith("Missing server_name error")
})

it("should handle missing tool_name", async () => {
const block: ToolUse = {
type: "tool_use",
name: "use_mcp_tool",
params: {
server_name: "test_server",
arguments: "{}",
},
partial: false,
}

mockTask.sayAndCreateMissingParamError = jest.fn().mockResolvedValue("Missing tool_name error")

await useMcpToolTool(
mockTask as Task,
block,
mockAskApproval,
mockHandleError,
mockPushToolResult,
mockRemoveClosingTag,
)

expect(mockTask.consecutiveMistakeCount).toBe(1)
expect(mockTask.recordToolError).toHaveBeenCalledWith("use_mcp_tool")
expect(mockTask.sayAndCreateMissingParamError).toHaveBeenCalledWith("use_mcp_tool", "tool_name")
expect(mockPushToolResult).toHaveBeenCalledWith("Missing tool_name error")
})

it("should handle invalid JSON arguments", async () => {
const block: ToolUse = {
type: "tool_use",
name: "use_mcp_tool",
params: {
server_name: "test_server",
tool_name: "test_tool",
arguments: "invalid json",
},
partial: false,
}

await useMcpToolTool(
mockTask as Task,
block,
mockAskApproval,
mockHandleError,
mockPushToolResult,
mockRemoveClosingTag,
)

expect(mockTask.consecutiveMistakeCount).toBe(1)
expect(mockTask.recordToolError).toHaveBeenCalledWith("use_mcp_tool")
expect(mockTask.say).toHaveBeenCalledWith("error", expect.stringContaining("invalid JSON argument"))
expect(mockPushToolResult).toHaveBeenCalledWith("Tool error: Invalid args for test_server:test_tool")
})
})

describe("partial requests", () => {
it("should handle partial requests", async () => {
const block: ToolUse = {
type: "tool_use",
name: "use_mcp_tool",
params: {
server_name: "test_server",
tool_name: "test_tool",
arguments: "{}",
},
partial: true,
}

mockTask.ask = jest.fn().mockResolvedValue(true)

await useMcpToolTool(
mockTask as Task,
block,
mockAskApproval,
mockHandleError,
mockPushToolResult,
mockRemoveClosingTag,
)

expect(mockTask.ask).toHaveBeenCalledWith("use_mcp_server", expect.stringContaining("use_mcp_tool"), true)
})
})

describe("successful execution", () => {
it("should execute tool successfully with valid parameters", async () => {
const block: ToolUse = {
type: "tool_use",
name: "use_mcp_tool",
params: {
server_name: "test_server",
tool_name: "test_tool",
arguments: '{"param": "value"}',
},
partial: false,
}

mockAskApproval.mockResolvedValue(true)

const mockToolResult = {
content: [{ type: "text", text: "Tool executed successfully" }],
isError: false,
}

mockProviderRef.deref.mockReturnValue({
getMcpHub: () => ({
callTool: jest.fn().mockResolvedValue(mockToolResult),
}),
postMessageToWebview: jest.fn(),
})

await useMcpToolTool(
mockTask as Task,
block,
mockAskApproval,
mockHandleError,
mockPushToolResult,
mockRemoveClosingTag,
)

expect(mockTask.consecutiveMistakeCount).toBe(0)
expect(mockAskApproval).toHaveBeenCalled()
expect(mockTask.say).toHaveBeenCalledWith("mcp_server_request_started")
expect(mockTask.say).toHaveBeenCalledWith("mcp_server_response", "Tool executed successfully")
expect(mockPushToolResult).toHaveBeenCalledWith("Tool result: Tool executed successfully")
})

it("should handle user rejection", async () => {
const block: ToolUse = {
type: "tool_use",
name: "use_mcp_tool",
params: {
server_name: "test_server",
tool_name: "test_tool",
arguments: "{}",
},
partial: false,
}

mockAskApproval.mockResolvedValue(false)

await useMcpToolTool(
mockTask as Task,
block,
mockAskApproval,
mockHandleError,
mockPushToolResult,
mockRemoveClosingTag,
)

expect(mockTask.say).not.toHaveBeenCalledWith("mcp_server_request_started")
expect(mockPushToolResult).not.toHaveBeenCalled()
})
})

describe("error handling", () => {
it("should handle unexpected errors", async () => {
const block: ToolUse = {
type: "tool_use",
name: "use_mcp_tool",
params: {
server_name: "test_server",
tool_name: "test_tool",
},
partial: false,
}

const error = new Error("Unexpected error")
mockAskApproval.mockRejectedValue(error)

await useMcpToolTool(
mockTask as Task,
block,
mockAskApproval,
mockHandleError,
mockPushToolResult,
mockRemoveClosingTag,
)

expect(mockHandleError).toHaveBeenCalledWith("executing MCP tool", error)
})
})
})
Loading
Loading