Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/core/MCPServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
ListToolsRequestSchema,
} from "@modelcontextprotocol/sdk/types.js";
import { ToolLoader } from "./toolLoader.js";
import { BaseTool } from "../tools/BaseTool.js";
import { ToolProtocol } from "../tools/BaseTool.js";
import { readFileSync } from "fs";
import { join, dirname } from "path";
import { logger } from "./Logger.js";
Expand All @@ -17,7 +17,7 @@ export interface MCPServerConfig {

export class MCPServer {
private server: Server;
private toolsMap: Map<string, BaseTool> = new Map();
private toolsMap: Map<string, ToolProtocol> = new Map();
private toolLoader: ToolLoader;
private serverName: string;
private serverVersion: string;
Expand Down Expand Up @@ -106,14 +106,22 @@ export class MCPServer {
).join(", ")}`
);
}
return tool.toolCall(request);

const toolRequest = {
params: request.params,
method: "tools/call" as const,
};

return tool.toolCall(toolRequest);
});
}

async start() {
try {
const tools = await this.toolLoader.loadTools();
this.toolsMap = new Map(tools.map((tool: BaseTool) => [tool.name, tool]));
this.toolsMap = new Map(
tools.map((tool: ToolProtocol) => [tool.name, tool])
);

const transport = new StdioServerTransport();
await this.server.connect(transport);
Expand Down
8 changes: 4 additions & 4 deletions src/core/toolLoader.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { BaseTool } from "../tools/BaseTool.js";
import { ToolProtocol } from "../tools/BaseTool.js";
import { join, dirname } from "path";
import { promises as fs } from "fs";
import { logger } from "./Logger.js";
Expand Down Expand Up @@ -29,7 +29,7 @@ export class ToolLoader {
return !isExcluded;
}

private validateTool(tool: any): tool is BaseTool {
private validateTool(tool: any): tool is ToolProtocol {
const isValid = Boolean(
tool &&
typeof tool.name === "string" &&
Expand All @@ -46,7 +46,7 @@ export class ToolLoader {
return isValid;
}

async loadTools(): Promise<BaseTool[]> {
async loadTools(): Promise<ToolProtocol[]> {
try {
logger.debug(`Attempting to load tools from: ${this.TOOLS_DIR}`);

Expand All @@ -66,7 +66,7 @@ export class ToolLoader {
const files = await fs.readdir(this.TOOLS_DIR);
logger.debug(`Found files in directory: ${files.join(", ")}`);

const tools: BaseTool[] = [];
const tools: ToolProtocol[] = [];

for (const file of files) {
if (!this.isToolFile(file)) {
Expand Down
7 changes: 6 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
export { MCPServer, type MCPServerConfig } from "./core/MCPServer.js";
export { BaseTool, BaseToolImplementation } from "./tools/BaseTool.js";
export {
MCPTool,
type ToolProtocol,
type ToolInputSchema,
type ToolInput,
} from "./tools/BaseTool.js";
export { ToolLoader } from "./core/toolLoader.js";
117 changes: 101 additions & 16 deletions src/tools/BaseTool.ts
Original file line number Diff line number Diff line change
@@ -1,32 +1,117 @@
import {
CallToolRequestSchema,
Tool,
} from "@modelcontextprotocol/sdk/types.js";
import { z } from "zod";
import { Tool as SDKTool } from "@modelcontextprotocol/sdk/types.js";

export interface BaseTool {
export type ToolInputSchema<T> = {
[K in keyof T]: {
type: z.ZodType<T[K]>;
description: string;
};
};

export type ToolInput<T extends ToolInputSchema<any>> = {
[K in keyof T]: z.infer<T[K]["type"]>;
};

export interface ToolProtocol extends SDKTool {
name: string;
toolDefinition: Tool;
toolCall(request: z.infer<typeof CallToolRequestSchema>): Promise<any>;
description: string;
toolDefinition: {
name: string;
description: string;
inputSchema: {
type: "object";
properties?: Record<string, unknown>;
};
};
toolCall(request: {
params: { name: string; arguments?: Record<string, unknown> };
}): Promise<{
content: Array<{ type: string; text: string }>;
}>;
}

export abstract class BaseToolImplementation implements BaseTool {
export abstract class MCPTool<TInput extends Record<string, any> = {}>
implements ToolProtocol
{
abstract name: string;
abstract toolDefinition: Tool;
abstract toolCall(
request: z.infer<typeof CallToolRequestSchema>
): Promise<any>;
abstract description: string;
protected abstract schema: ToolInputSchema<TInput>;
[key: string]: unknown;

protected createSuccessResponse(data: any) {
get inputSchema(): { type: "object"; properties?: Record<string, unknown> } {
return {
type: "object" as const,
properties: Object.fromEntries(
Object.entries(this.schema).map(([key, schema]) => [
key,
{
type: this.getJsonSchemaType(schema.type),
description: schema.description,
},
])
),
};
}

get toolDefinition() {
return {
name: this.name,
description: this.description,
inputSchema: this.inputSchema,
};
}

protected abstract execute(input: TInput): Promise<unknown>;

async toolCall(request: {
params: { name: string; arguments?: Record<string, unknown> };
}) {
try {
const args = request.params.arguments || {};
const validatedInput = await this.validateInput(args);
const result = await this.execute(validatedInput);
return this.createSuccessResponse(result);
} catch (error) {
return this.createErrorResponse(error as Error);
}
}

private async validateInput(args: Record<string, unknown>): Promise<TInput> {
const zodSchema = z.object(
Object.fromEntries(
Object.entries(this.schema).map(([key, schema]) => [key, schema.type])
)
);

return zodSchema.parse(args) as TInput;
}

private getJsonSchemaType(zodType: z.ZodType<any>): string {
if (zodType instanceof z.ZodString) return "string";
if (zodType instanceof z.ZodNumber) return "number";
if (zodType instanceof z.ZodBoolean) return "boolean";
if (zodType instanceof z.ZodArray) return "array";
if (zodType instanceof z.ZodObject) return "object";
return "string";
}

protected createSuccessResponse(data: unknown) {
return {
content: [{ type: "text", text: JSON.stringify(data) }],
};
}

protected createErrorResponse(error: Error | string) {
const message = error instanceof Error ? error.message : error;
protected createErrorResponse(error: Error) {
return {
content: [{ type: "error", text: message }],
content: [{ type: "error", text: error.message }],
};
}

protected async fetch<T>(url: string, init?: RequestInit): Promise<T> {
const response = await fetch(url, init);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
return response.json();
}
}