From 3458e78233437c9afd71b9195424659e62aa08f3 Mon Sep 17 00:00:00 2001 From: Alex Andru Date: Mon, 9 Dec 2024 22:44:00 +0100 Subject: [PATCH] feat: add MCPTool abstraction --- src/core/MCPServer.ts | 16 ++++-- src/core/toolLoader.ts | 8 +-- src/index.ts | 7 ++- src/tools/BaseTool.ts | 117 +++++++++++++++++++++++++++++++++++------ 4 files changed, 123 insertions(+), 25 deletions(-) diff --git a/src/core/MCPServer.ts b/src/core/MCPServer.ts index 70423e0..33ba7f6 100644 --- a/src/core/MCPServer.ts +++ b/src/core/MCPServer.ts @@ -5,7 +5,7 @@ import { ListToolsRequestSchema, } from "@modelcontextprotocol/sdk/types.js"; import { ToolLoader } from "./toolLoader.js"; -import { BaseTool } from "../tools/BaseTool.js"; +import { ToolProtocol } from "../tools/BaseTool.js"; import { readFileSync } from "fs"; import { join, dirname } from "path"; import { logger } from "./Logger.js"; @@ -17,7 +17,7 @@ export interface MCPServerConfig { export class MCPServer { private server: Server; - private toolsMap: Map = new Map(); + private toolsMap: Map = new Map(); private toolLoader: ToolLoader; private serverName: string; private serverVersion: string; @@ -106,14 +106,22 @@ export class MCPServer { ).join(", ")}` ); } - return tool.toolCall(request); + + const toolRequest = { + params: request.params, + method: "tools/call" as const, + }; + + return tool.toolCall(toolRequest); }); } async start() { try { const tools = await this.toolLoader.loadTools(); - this.toolsMap = new Map(tools.map((tool: BaseTool) => [tool.name, tool])); + this.toolsMap = new Map( + tools.map((tool: ToolProtocol) => [tool.name, tool]) + ); const transport = new StdioServerTransport(); await this.server.connect(transport); diff --git a/src/core/toolLoader.ts b/src/core/toolLoader.ts index 7949d6b..e645d8e 100644 --- a/src/core/toolLoader.ts +++ b/src/core/toolLoader.ts @@ -1,4 +1,4 @@ -import { BaseTool } from "../tools/BaseTool.js"; +import { ToolProtocol } from "../tools/BaseTool.js"; import { join, dirname } from "path"; import { promises as fs } from "fs"; import { logger } from "./Logger.js"; @@ -29,7 +29,7 @@ export class ToolLoader { return !isExcluded; } - private validateTool(tool: any): tool is BaseTool { + private validateTool(tool: any): tool is ToolProtocol { const isValid = Boolean( tool && typeof tool.name === "string" && @@ -46,7 +46,7 @@ export class ToolLoader { return isValid; } - async loadTools(): Promise { + async loadTools(): Promise { try { logger.debug(`Attempting to load tools from: ${this.TOOLS_DIR}`); @@ -66,7 +66,7 @@ export class ToolLoader { const files = await fs.readdir(this.TOOLS_DIR); logger.debug(`Found files in directory: ${files.join(", ")}`); - const tools: BaseTool[] = []; + const tools: ToolProtocol[] = []; for (const file of files) { if (!this.isToolFile(file)) { diff --git a/src/index.ts b/src/index.ts index 6d404c6..7f99980 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,3 +1,8 @@ export { MCPServer, type MCPServerConfig } from "./core/MCPServer.js"; -export { BaseTool, BaseToolImplementation } from "./tools/BaseTool.js"; +export { + MCPTool, + type ToolProtocol, + type ToolInputSchema, + type ToolInput, +} from "./tools/BaseTool.js"; export { ToolLoader } from "./core/toolLoader.js"; diff --git a/src/tools/BaseTool.ts b/src/tools/BaseTool.ts index bbf08b1..5663c64 100644 --- a/src/tools/BaseTool.ts +++ b/src/tools/BaseTool.ts @@ -1,32 +1,117 @@ -import { - CallToolRequestSchema, - Tool, -} from "@modelcontextprotocol/sdk/types.js"; import { z } from "zod"; +import { Tool as SDKTool } from "@modelcontextprotocol/sdk/types.js"; -export interface BaseTool { +export type ToolInputSchema = { + [K in keyof T]: { + type: z.ZodType; + description: string; + }; +}; + +export type ToolInput> = { + [K in keyof T]: z.infer; +}; + +export interface ToolProtocol extends SDKTool { name: string; - toolDefinition: Tool; - toolCall(request: z.infer): Promise; + description: string; + toolDefinition: { + name: string; + description: string; + inputSchema: { + type: "object"; + properties?: Record; + }; + }; + toolCall(request: { + params: { name: string; arguments?: Record }; + }): Promise<{ + content: Array<{ type: string; text: string }>; + }>; } -export abstract class BaseToolImplementation implements BaseTool { +export abstract class MCPTool = {}> + implements ToolProtocol +{ abstract name: string; - abstract toolDefinition: Tool; - abstract toolCall( - request: z.infer - ): Promise; + abstract description: string; + protected abstract schema: ToolInputSchema; + [key: string]: unknown; - protected createSuccessResponse(data: any) { + get inputSchema(): { type: "object"; properties?: Record } { + return { + type: "object" as const, + properties: Object.fromEntries( + Object.entries(this.schema).map(([key, schema]) => [ + key, + { + type: this.getJsonSchemaType(schema.type), + description: schema.description, + }, + ]) + ), + }; + } + + get toolDefinition() { + return { + name: this.name, + description: this.description, + inputSchema: this.inputSchema, + }; + } + + protected abstract execute(input: TInput): Promise; + + async toolCall(request: { + params: { name: string; arguments?: Record }; + }) { + try { + const args = request.params.arguments || {}; + const validatedInput = await this.validateInput(args); + const result = await this.execute(validatedInput); + return this.createSuccessResponse(result); + } catch (error) { + return this.createErrorResponse(error as Error); + } + } + + private async validateInput(args: Record): Promise { + const zodSchema = z.object( + Object.fromEntries( + Object.entries(this.schema).map(([key, schema]) => [key, schema.type]) + ) + ); + + return zodSchema.parse(args) as TInput; + } + + private getJsonSchemaType(zodType: z.ZodType): string { + if (zodType instanceof z.ZodString) return "string"; + if (zodType instanceof z.ZodNumber) return "number"; + if (zodType instanceof z.ZodBoolean) return "boolean"; + if (zodType instanceof z.ZodArray) return "array"; + if (zodType instanceof z.ZodObject) return "object"; + return "string"; + } + + protected createSuccessResponse(data: unknown) { return { content: [{ type: "text", text: JSON.stringify(data) }], }; } - protected createErrorResponse(error: Error | string) { - const message = error instanceof Error ? error.message : error; + protected createErrorResponse(error: Error) { return { - content: [{ type: "error", text: message }], + content: [{ type: "error", text: error.message }], }; } + + protected async fetch(url: string, init?: RequestInit): Promise { + const response = await fetch(url, init); + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + return response.json(); + } }