diff --git a/src/handlers/handlerUtils.ts b/src/handlers/handlerUtils.ts index b677188c2..cb28929b8 100644 --- a/src/handlers/handlerUtils.ts +++ b/src/handlers/handlerUtils.ts @@ -36,6 +36,8 @@ import { PreRequestValidatorService } from './services/preRequestValidatorServic import { ProviderContext } from './services/providerContext'; import { RequestContext } from './services/requestContext'; import { ResponseService } from './services/responseService'; +import { McpService } from './services/mcpService'; +import { log } from 'console'; function constructRequestBody( requestContext: RequestContext, @@ -352,6 +354,18 @@ export async function tryPost( requestContext.params = hookSpan.getContext().request.json; } + // Initialize MCP service if needed + const mcpService = requestContext.shouldHandleMcp() + ? new McpService(requestContext) + : null; + + if (mcpService) { + await mcpService.init(); + // Add MCP tools to the request + const mcpTools = mcpService.tools; + requestContext.addMcpTools(mcpTools); + } + // Attach the body of the request if (!providerContext.hasRequestHandler(requestContext)) { requestContext.transformToProviderRequestAndSave(); @@ -437,7 +451,10 @@ export async function tryPost( hookSpan.id, providerContext, hooksService, - logObject + logObject, + responseService, + cacheResponseObject, + mcpService || undefined ); const { response, originalResponseJson: mappedOriginalResponseJson } = @@ -456,10 +473,7 @@ export async function tryPost( originalResponseJson, }); - logObject - .updateRequestContext(requestContext, fetchOptions.headers) - .addResponse(response, mappedOriginalResponseJson) - .log(); + // The log is handled inside the recursiveAfterRequestHookHandler function return response; } @@ -1122,7 +1136,10 @@ export async function recursiveAfterRequestHookHandler( hookSpanId: string, providerContext: ProviderContext, hooksService: HooksService, - logObject: LogObjectBuilder + logObject: LogObjectBuilder, + responseService: ResponseService, + cacheResponseObject: CacheResponseObject, + mcpService?: McpService ): Promise<{ mappedResponse: Response; retryCount: number; @@ -1131,11 +1148,7 @@ export async function recursiveAfterRequestHookHandler( }> { const { honoContext: c, - providerOption, isStreaming: isStreamingMode, - params: gatewayParams, - endpoint: fn, - strictOpenAiCompliance, requestTimeout, retryConfig: retry, } = requestContext; @@ -1160,35 +1173,82 @@ export async function recursiveAfterRequestHookHandler( retry.useRetryAfterHeader )); - // Check if sync hooks are available - // This will be used to determine if we need to parse the response body or simply passthrough the response as is - const areSyncHooksAvailable = hooksService.areSyncHooksAvailable; - const { - response: mappedResponse, + response: currentResponse, responseJson: mappedResponseJson, originalResponseJson, - } = await responseHandler( - response, - isStreamingMode, - providerOption, - fn, - url, - false, - gatewayParams, - strictOpenAiCompliance, - c.req.url, - areSyncHooksAvailable - ); + } = await responseService.create({ + response: response, + responseTransformer: requestContext.endpoint, + isResponseAlreadyMapped: false, + cache: { + isCacheHit: false, + cacheStatus: cacheResponseObject.cacheStatus, + cacheKey: cacheResponseObject.cacheKey, + }, + retryAttempt: retryCount || 0, + createdAt, + }); + + logObject + .updateRequestContext(requestContext, options.headers) + .addResponse(currentResponse, originalResponseJson) + .log(); + + if ( + mcpService && + logObject && + !isStreamingMode && + mappedResponseJson?.choices?.[0]?.message?.tool_calls?.[0] + ) { + const mcpResult = await handleMcpToolCalls( + requestContext, + mappedResponseJson, + mcpService + ); + if (mcpResult.success) { + // Construct the base object for the request + const fetchOptions: RequestInit = await constructRequest( + providerContext, + requestContext + ); + + // Recurse with updated conversation + return recursiveAfterRequestHookHandler( + requestContext, + fetchOptions, + 0, // Reset retry attempts for new LLM request + hookSpanId, + providerContext, + hooksService, + logObject, + responseService, + cacheResponseObject, + mcpService + ); + } else { + // MCP failed, log and continue with current response + console.warn( + 'MCP processing failed, returning current response:', + mcpResult.error + ); + } + } const arhResponse = await afterRequestHookHandler( c, - mappedResponse, + currentResponse, mappedResponseJson, hookSpanId, retryAttemptsMade ); + logObject + .updateRequestContext(requestContext, options.headers) + .addResponse(arhResponse, originalResponseJson) + .addExecutionTime(createdAt) + .log(); + const remainingRetryCount = (retry?.attempts || 0) - (retryCount || 0) - retryAttemptsMade; @@ -1197,13 +1257,6 @@ export async function recursiveAfterRequestHookHandler( ); if (remainingRetryCount > 0 && !retrySkipped && isRetriableStatusCode) { - // Log the request here since we're about to retry - logObject - .updateRequestContext(requestContext, options.headers) - .addResponse(arhResponse, originalResponseJson) - .addExecutionTime(createdAt) - .log(); - return recursiveAfterRequestHookHandler( requestContext, options, @@ -1211,7 +1264,9 @@ export async function recursiveAfterRequestHookHandler( hookSpanId, providerContext, hooksService, - logObject + logObject, + responseService, + cacheResponseObject ); } @@ -1285,3 +1340,127 @@ export async function beforeRequestHookHandler( transformedBody: isTransformed ? span.getContext().request.json : null, }; } + +/** + * Handles MCP tool calls for a given request context and response JSON. + * This function processes tool calls from the response and executes them using the MCP service. + * It updates the request context with the tool responses and transforms the request to the provider's format. + * + * @param requestContext - The request context containing the conversation and parameters + * @param responseJson - The response JSON containing tool calls + * @param mcpService - The MCP service for executing tool calls + * @returns { success: boolean; error?: string } - The result of the MCP tool calls + */ +async function handleMcpToolCalls( + requestContext: RequestContext, + responseJson: any, + mcpService: McpService +): Promise<{ success: boolean; error?: string }> { + if (requestContext.endpoint !== 'chatComplete') { + return { + success: false, + error: 'MCP tool calls are only supported for /chat/completions endpoint', + }; + } + + const logsService = new LogsService(requestContext.honoContext); + + try { + const toolCalls = responseJson.choices[0].message.tool_calls; + const conversation: any[] = [...(requestContext.params.messages || [])]; + + const { mcpToolsMap, nonMcpToolsMap } = mcpService.findMCPTools(toolCalls); + + if (nonMcpToolsMap.size > 0) { + return { + success: false, + error: 'Exiting, since some tool calls are not MCP tools', + }; + } + + const mcpTools = Array.from(mcpToolsMap.values()); + + // Add assistant's response with tool calls to conversation + conversation.push(responseJson.choices[0].message); + + // Execute all tool calls in parallel for better performance + const toolCallPromises = mcpTools.map(async (toolCall: any) => { + const start = new Date().getTime(); + + try { + const toolResult = await mcpService.executeTool( + toolCall.function.name, + JSON.parse(toolCall.function.arguments) + ); + + const toolResponse = { + role: 'tool', + tool_call_id: toolCall.id, + content: JSON.stringify(toolResult), + }; + + const toolCallSpan = logsService.createExecuteToolSpan( + toolCall, + toolResult.content, + start, + new Date().getTime(), + requestContext.traceId + ); + + logsService.addRequestLog(toolCallSpan); + + return toolResponse; + } catch (toolError: any) { + if (toolError.message.includes('MCP_SERVER_TOOL_NOT_FOUND')) { + throw new Error('MCP_SERVER_TOOL_NOT_FOUND'); + } + + console.error( + `MCP tool call failed for ${toolCall.function.name}:`, + toolError + ); + + const errorResponse = { + role: 'tool', + tool_call_id: toolCall.id, + content: JSON.stringify({ + error: 'Tool execution failed', + details: toolError.message, + }), + }; + + const toolCallSpan = logsService.createExecuteToolSpan( + toolCall, + { error: toolError.message }, + start, + new Date().getTime(), + requestContext.traceId + ); + + logsService.addRequestLog(toolCallSpan); + + return errorResponse; + } + }); + + // Wait for all tool calls to complete + let toolResponses = await Promise.all(toolCallPromises); + toolResponses = toolResponses.filter((response: any) => response !== null); + + if (toolResponses.length === 0) { + return { success: false, error: 'No tool responses received' }; + } + + // Add all tool responses to conversation + conversation.push(...toolResponses); + + // Update the existing context + requestContext.updateMessages(conversation); + requestContext.transformToProviderRequestAndSave(); + + return { success: true }; + } catch (error: any) { + console.warn('Error in handleMcpToolCalls:', error); + return { success: false, error: error.message }; + } +} diff --git a/src/handlers/services/logsService.ts b/src/handlers/services/logsService.ts index d4587634f..5604ebefb 100644 --- a/src/handlers/services/logsService.ts +++ b/src/handlers/services/logsService.ts @@ -61,7 +61,7 @@ export interface LogObject { } export interface otlpSpanObject { - type: 'otlp_span'; + type: 'otel'; traceId: string; spanId: string; parentSpanId: string; @@ -90,6 +90,11 @@ export interface otlpSpanObject { }[]; } +function capitaliseSentence(str: string) { + // First letter of each word to uppercase + return str.replace(/\b\w/g, (char) => char.toUpperCase()); +} + export class LogsService { constructor(private honoContext: Context) {} @@ -103,16 +108,16 @@ export class LogsService { spanId?: string ) { return { - type: 'otlp_span', + type: 'otel', traceId: traceId, spanId: spanId ?? crypto.randomUUID(), parentSpanId: parentSpanId, - name: `execute_tool ${toolCall.function.name}`, + name: capitaliseSentence(toolCall.function.name.replaceAll('_', ' ')), kind: 'SPAN_KIND_INTERNAL', startTimeUnixNano: startTimeUnixNano, endTimeUnixNano: endTimeUnixNano, status: { - code: 'STATUS_CODE_OK', + code: 200, }, attributes: [ { diff --git a/src/handlers/services/mcpService.ts b/src/handlers/services/mcpService.ts new file mode 100644 index 000000000..48968a30a --- /dev/null +++ b/src/handlers/services/mcpService.ts @@ -0,0 +1,888 @@ +import { McpServer, McpServerConfig, ToolCall } from '../../types/requestBody'; +import { RequestContext } from './requestContext'; +import { GatewayError } from '../../errors/GatewayError'; + +// services/mcpService.ts +export class McpService { + private mcpConnections = new Map(); + private mcpTools = new Map(); + private mcpToolToServerMap = new Map(); + + constructor(private requestContext: RequestContext) {} + + async init(): Promise { + const mcpServers: McpServer[] = this.requestContext.mcpServers; + if (!mcpServers) { + return; + } + this.validateServerObjects(mcpServers); + for (const server of mcpServers) { + try { + const client = await this.connectToMcpServer(server); + if (client) { + this.mcpConnections.set(server.server_label, client); + let tools = await client.listTools(); + // console.log('MCP tools', tools); + if (server.allowed_tools && server.allowed_tools.length) { + tools = tools.filter((tool) => + server.allowed_tools!.includes(tool.name) + ); + } + const llmTools = this.transformToolsForLLM( + server.server_label, + tools + ); + this.mcpTools.set(server.server_label, llmTools); + } + } catch (error) { + console.error( + `Error connecting to MCP server ${server.server_url}:`, + error + ); + throw new GatewayError( + `Error connecting to MCP server \`${server.server_url}\`.` + ); + } + } + return; + } + + private validateServerObjects(servers: McpServer[]): void { + if (!servers || servers.length === 0) { + return; + } + + // Pre-compile regex patterns for better performance + const labelRegex = /^[a-zA-Z][a-zA-Z0-9-_]*$/; + const urlRegex = /^https?:\/\/[^\s/$.?#].[^\s]*$/i; + + // Private IP ranges and localhost patterns + const privatePatterns = [ + /localhost/i, + /127\.0\.0\.1/, + /::1/, + /0\.0\.0\.0/, + // Additional private IP ranges for comprehensive SSRF protection + /10\.\d{1,3}\.\d{1,3}\.\d{1,3}/, // 10.0.0.0/8 + /172\.(1[6-9]|2\d|3[01])\.\d{1,3}\.\d{1,3}/, // 172.16.0.0/12 + /192\.168\.\d{1,3}\.\d{1,3}/, // 192.168.0.0/16 + /169\.254\.\d{1,3}\.\d{1,3}/, // 169.254.0.0/16 (link-local) + ]; + + const seenLabels = new Set(); + + for (const server of servers) { + // Validate required fields exist + if (!server.server_label) { + throw new GatewayError( + 'MCP_SERVER_LABEL_NOT_FOUND: MCP server label not found' + ); + } + + if (!server.server_url) { + throw new GatewayError( + 'MCP_SERVER_URL_NOT_FOUND: MCP server URL not found' + ); + } + + // Validate label format + if (!labelRegex.test(server.server_label)) { + throw new GatewayError( + 'MCP_SERVER_LABEL_INVALID: MCP server label must start with a letter and can only contain letters, numbers, hyphens and underscores' + ); + } + + // Check label uniqueness (O(1) lookup instead of O(n)) + if (seenLabels.has(server.server_label)) { + throw new GatewayError( + 'MCP_SERVER_LABEL_NOT_UNIQUE: MCP server label must be unique' + ); + } + seenLabels.add(server.server_label); + + // Validate URL format first (fail fast) + if (!urlRegex.test(server.server_url)) { + throw new GatewayError( + 'MCP_SERVER_URL_INVALID: MCP server URL must be a valid URL' + ); + } + + // Check for SSRF vulnerabilities + if (privatePatterns.some((pattern) => pattern.test(server.server_url))) { + throw new GatewayError( + 'MCP_SERVER_URL_INVALID: MCP server URL must not hit private IPs or localhost' + ); + } + } + } + + get tools(): LLMFunction[] { + return Array.from(this.mcpTools.values()).flat(); + } + + /** + * Find MCP tools and non-MCP tools from a list of tool calls + * based on the MCP tools loaded in the MCP service + * @param toolCalls - Tool calls to find MCP tools for + * @returns - MCP tools and non-MCP tools + */ + findMCPTools(toolCalls: ToolCall[]): { + mcpToolsMap: Map; + nonMcpToolsMap: Map; + } { + let mcpToolsMap: Map = new Map(), + nonMcpToolsMap: Map = new Map(); + const mcpToolNames = this.tools.map((tool) => tool.function.name); + toolCalls.forEach((toolCall: ToolCall) => { + if (mcpToolNames.includes(toolCall.function.name)) { + mcpToolsMap.set(toolCall.function.name, toolCall); + } else { + nonMcpToolsMap.set(toolCall.function.name, toolCall); + } + }); + return { mcpToolsMap, nonMcpToolsMap }; + } + + async executeTool( + functionName: string, + toolArgs: any + ): Promise { + const serverName = this.mcpToolToServerMap.get(functionName); + if (!serverName) { + throw new Error( + `MCP_SERVER_TOOL_NOT_FOUND: MCP server not found for tool ${functionName}` + ); + } + const client = this.mcpConnections.get(serverName); + + if (!client || !this.mcpTools.has(serverName)) { + throw new Error( + `MCP_SERVER_TOOL_NOT_FOUND: MCP server ${serverName} not found or tool name not loaded in the mcp server` + ); + } + + const toolName = functionName.substring(serverName.length + 1); + return await client.executeTool(toolName, toolArgs); + } + + private transformToolsForLLM( + servername: string, + mcpTools: Tool[] + ): LLMFunction[] { + return mcpTools.map((tool) => { + const functionName = `${servername}_${tool.name}`; + this.mcpToolToServerMap.set(functionName, servername); + return { + type: 'function' as const, + function: { + name: functionName, + description: tool.description, + parameters: { + type: 'object' as const, + properties: tool.inputSchema.properties || {}, + required: tool.inputSchema.required || [], + // Preserve any additional schema properties like additionalProperties, etc. + ...Object.fromEntries( + Object.entries(tool.inputSchema).filter( + ([key]) => !['type', 'properties', 'required'].includes(key) + ) + ), + }, + }, + }; + }); + } + + private async connectToMcpServer( + server: McpServer + ): Promise { + const client = new MinimalMCPClient(server.server_url, server.headers); + const result = await client.initialize(); + if (!result.capabilities.tools) { + throw new Error(`MCP server ${server.server_url} does not support tools`); + return null; + } + return client; + } + + async [Symbol.asyncDispose](): Promise { + // console.log('Disposing MCP service...'); + + for (const [name, client] of this.mcpConnections) { + try { + // console.log('Closing MCP connection to', name); + await Promise.race([ + client.close(), + new Promise((_, reject) => + setTimeout(() => reject(new Error('Close timeout')), 10000) + ), + ]); + } catch (error) { + console.error(`Error closing MCP connection to ${name}:`, error); + // Continue closing other connections even if one fails + } + } + + this.mcpConnections.clear(); + this.mcpTools.clear(); + + // console.log('MCP service disposed'); + } +} + +// LLM Function Call format (OpenAI-style) +export interface LLMFunction { + type: 'function'; + function: { + name: string; + description?: string; + parameters: { + type: 'object'; + properties?: Record; + required?: string[]; + [key: string]: any; + }; + }; +} + +// Minimal MCP Client for fetching tools from remote servers +// Supports both StreamableHTTP and SSE transports + +interface JSONRPCRequest { + jsonrpc: '2.0'; + id: string | number; + method: string; + params?: any; +} + +interface JSONRPCResponse { + jsonrpc: '2.0'; + id: string | number; + result?: any; + error?: { + code: number; + message: string; + data?: any; + }; +} + +interface Tool { + name: string; + description?: string; + inputSchema: { + type: 'object'; + properties?: Record; + required?: string[]; + }; + outputSchema?: { + type: 'object'; + properties?: Record; + required?: string[]; + }; +} + +interface InitializeResult { + protocolVersion: string; + capabilities: { + tools?: any; + [key: string]: any; + }; + serverInfo: { + name: string; + version: string; + }; + instructions?: string; +} + +interface ListToolsResult { + tools: Tool[]; + nextCursor?: string; +} + +interface ToolExecutionResult { + content?: Array<{ + type: 'text' | 'image' | 'audio' | 'resource'; + text?: string; + data?: string; + mimeType?: string; + resource?: any; + }>; + structuredContent?: Record; + isError?: boolean; +} + +class MinimalMCPClient { + private url: URL; + private headers: Record; + private messageId = 0; + private sessionId?: string; + private isSSE = false; + private sseEndpoint?: URL; + private eventSource?: EventSource; + private abortController?: AbortController; // Add this + private streamReader?: ReadableStreamDefaultReader; // Add this + private pendingRequests = new Map< + string | number, + { + resolve: (value: any) => void; + reject: (error: Error) => void; + } + >(); + private sseConnectionResolve?: () => void; + private sseConnectionReject?: (error: Error) => void; + + constructor( + serverUrl: string, + headers?: Record, + options?: { messageEndpoint?: string } + ) { + this.url = new URL(serverUrl); + this.headers = headers || {}; + + // Check if this looks like an SSE endpoint + this.isSSE = serverUrl.includes('/sse') || serverUrl.includes('sse'); + + // If custom message endpoint provided, use it + if (options?.messageEndpoint) { + this.sseEndpoint = new URL(options.messageEndpoint); + } + } + + private getNextMessageId(): number { + return ++this.messageId; + } + + private getAuthHeaders(): HeadersInit { + const headers: HeadersInit = { + ...(this.headers || {}), + }; + + if (!this.isSSE) { + headers['Content-Type'] = 'application/json'; + headers['Accept'] = 'application/json, text/event-stream'; + } + + return headers; + } + + private async initializeSSE(): Promise { + if (!this.isSSE) return; + + return new Promise((resolve, reject) => { + // Set up a timeout for the SSE connection + const timeout = setTimeout(() => { + reject(new Error('Timeout waiting for SSE endpoint from server')); + }, 10000); // 10 second timeout + + // Store resolve/reject to call when endpoint is received + const originalResolve = resolve; + const originalReject = reject; + + this.sseConnectionResolve = () => { + clearTimeout(timeout); + originalResolve(); + }; + + this.sseConnectionReject = (error: Error) => { + clearTimeout(timeout); + originalReject(error); + }; + + // Start the SSE connection + this.establishSSEConnection().catch(this.sseConnectionReject); + }); + } + + private async establishSSEConnection(): Promise { + const headers = new Headers(this.getAuthHeaders()); + headers.set('Accept', 'text/event-stream'); + + const response = await fetch(this.url, { + method: 'GET', + headers, + }); + + if (!response.ok) { + throw new Error( + `Failed to establish SSE connection: HTTP ${response.status}: ${response.statusText}` + ); + } + + // The server should send an 'endpoint' event with the POST URL + // We'll wait for this in the stream parser + this.sseEndpoint = undefined; + + // Parse SSE stream for endpoint information and responses + this.parseSSEStream(response); + } + + private parseSSEStream(response: Response): void { + if (!response.body) return; + + // Create abort controller for this stream + this.abortController = new AbortController(); + + const reader = response.body + .pipeThrough(new TextDecoderStream()) + .getReader(); + + let buffer = ''; + let currentEvent = { + event: '', + data: '', + id: '', + }; + + const processStream = async () => { + try { + while (true) { + const { value, done } = await reader.read(); + if (done) break; + + // Check if aborted + if (this.abortController?.signal.aborted) { + // console.log('SSE stream processing aborted'); + break; + } + + buffer += value; + + // Process line by line + let lineEnd; + while ((lineEnd = buffer.indexOf('\n')) !== -1) { + const line = buffer.slice(0, lineEnd); + buffer = buffer.slice(lineEnd + 1); + + // Remove \r if present (for \r\n line endings) + const cleanLine = line.replace(/\r$/, ''); + + if (cleanLine === '') { + // Empty line = end of event, dispatch it + if (currentEvent.data || currentEvent.event) { + this.handleSSEEvent(currentEvent.event, currentEvent.data); + } + // Reset for next event + currentEvent = { event: '', data: '', id: '' }; + } else if (cleanLine.startsWith('event: ')) { + currentEvent.event = cleanLine.slice(7); + } else if (cleanLine.startsWith('data: ')) { + // Multiple data lines should be joined with \n + if (currentEvent.data) { + currentEvent.data += '\n' + cleanLine.slice(6); + } else { + currentEvent.data = cleanLine.slice(6); + } + } else if (cleanLine.startsWith('id: ')) { + currentEvent.id = cleanLine.slice(4); + } + // Ignore other fields like retry, etc. + } + } + } catch (error) { + if (!this.abortController?.signal.aborted) { + console.error('SSE stream error:', error); + if (this.sseConnectionReject) { + this.sseConnectionReject(error as Error); + this.sseConnectionResolve = undefined; + this.sseConnectionReject = undefined; + } + } + } finally { + try { + reader.releaseLock(); + } catch (e) { + // Reader might already be released + } + } + }; + + processStream(); + } + + private handleSSEEvent(eventType: string, data: string): void { + if (eventType === 'endpoint') { + // Server is telling us the POST endpoint URL (usually includes sessionId) + try { + this.sseEndpoint = new URL(data, this.url); + // console.log('SSE POST endpoint received:', this.sseEndpoint.href); + + // Extract session ID from the endpoint URL if present + const sessionId = this.sseEndpoint.searchParams.get('sessionId'); + if (sessionId) { + this.sessionId = sessionId; + // console.log('Session ID extracted:', sessionId); + } + + // Resolve the SSE connection promise now that we have the endpoint + if (this.sseConnectionResolve) { + this.sseConnectionResolve(); + this.sseConnectionResolve = undefined; + this.sseConnectionReject = undefined; + } + } catch (error) { + console.warn('Invalid endpoint URL from SSE:', data, error); + if (this.sseConnectionReject) { + this.sseConnectionReject(new Error(`Invalid endpoint URL: ${data}`)); + this.sseConnectionResolve = undefined; + this.sseConnectionReject = undefined; + } + } + return; + } + + // Handle JSON-RPC responses (default event type or 'message') + if (!eventType || eventType === 'message') { + try { + // Try to parse as JSON + const jsonResponse: JSONRPCResponse = JSON.parse(data); + // console.log( + // 'Parsed JSON-RPC response:', + // jsonResponse.id, + // jsonResponse.error ? 'ERROR' : 'SUCCESS' + // ); + + const pending = this.pendingRequests.get(jsonResponse.id); + + if (pending) { + this.pendingRequests.delete(jsonResponse.id); + + if (jsonResponse.error) { + pending.reject( + new Error( + `MCP Error ${jsonResponse.error.code}: ${jsonResponse.error.message}` + ) + ); + } else { + pending.resolve(jsonResponse.result); + } + } else { + console.warn( + 'Received response for unknown request ID:', + jsonResponse.id + ); + } + } catch (error) { + console.error('Failed to parse JSON from SSE data:', error); + console.error('Raw data (first 500 chars):', data.substring(0, 500)); + console.error( + 'Raw data (last 100 chars):', + data.substring(Math.max(0, data.length - 100)) + ); + } + } + } + + private async sendRequest(request: JSONRPCRequest): Promise { + if (this.isSSE) { + return this.sendSSERequest(request); + } else { + return this.sendDirectRequest(request); + } + } + + private async sendSSERequest(request: JSONRPCRequest): Promise { + if (!this.sseEndpoint) { + throw new Error('SSE POST endpoint not yet received from server'); + } + + return new Promise((resolve, reject) => { + // Store the pending request + this.pendingRequests.set(request.id, { resolve, reject }); + + // Prepare POST URL - session ID should already be in the endpoint URL + const postUrl = new URL(this.sseEndpoint!.href); + + // Send POST request to the endpoint (session ID is in the URL) + const headers = new Headers(this.getAuthHeaders()); + headers.set('Content-Type', 'application/json'); + + fetch(postUrl, { + method: 'POST', + headers, + body: JSON.stringify(request), + }).catch((error) => { + this.pendingRequests.delete(request.id); + reject(new Error(`Failed to send SSE request: ${error.message}`)); + }); + + // Set timeout for the request + setTimeout(() => { + if (this.pendingRequests.has(request.id)) { + this.pendingRequests.delete(request.id); + reject(new Error('SSE request timeout')); + } + }, 30000); // 30 second timeout + }); + } + + private async sendDirectRequest(request: JSONRPCRequest): Promise { + const headers = new Headers(this.getAuthHeaders()); + + // Include session ID if we have one + if (this.sessionId) { + headers.set('mcp-session-id', this.sessionId); + } + + // IMPORTANT: Add Accept header for both JSON and SSE + headers.set('Accept', 'application/json, text/event-stream'); + + const response = await fetch(this.url, { + method: 'POST', + headers, + body: JSON.stringify(request), + }); + + // Capture session ID from response if present + const newSessionId = response.headers.get('mcp-session-id'); + if (newSessionId) { + this.sessionId = newSessionId; + } + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`); + } + + const contentType = response.headers.get('content-type'); + + if (contentType?.includes('application/json')) { + // Direct JSON response + const jsonResponse: JSONRPCResponse = await response.json(); + + if (jsonResponse.error) { + throw new Error( + `MCP Error ${jsonResponse.error.code}: ${jsonResponse.error.message}` + ); + } + + return jsonResponse.result; + } else if (contentType?.includes('text/event-stream')) { + // SSE response - we need to parse the stream + return this.parseSSEResponse(response); + } else { + throw new Error(`Unexpected content type: ${contentType}`); + } + } + + private async parseSSEResponse(response: Response): Promise { + if (!response.body) { + throw new Error('No response body for SSE stream'); + } + + const requestId = this.messageId; // Store the current request ID for matching + + return new Promise((resolve, reject) => { + const reader = response + .body!.pipeThrough(new TextDecoderStream()) + .getReader(); + + const processStream = async () => { + try { + let buffer = ''; + + while (true) { + const { value, done } = await reader.read(); + if (done) break; + + buffer += value; + + // Process line by line + let lineEnd; + while ((lineEnd = buffer.indexOf('\n')) !== -1) { + const line = buffer.slice(0, lineEnd); + buffer = buffer.slice(lineEnd + 1); + + // Remove \r if present (for \r\n line endings) + const cleanLine = line.replace(/\r$/, ''); + + if (cleanLine.startsWith('data: ')) { + const data = cleanLine.slice(6); + try { + const jsonResponse: JSONRPCResponse = JSON.parse(data); + + // Check if this response matches our request + if (jsonResponse.id === requestId) { + if (jsonResponse.error) { + reject( + new Error( + `MCP Error ${jsonResponse.error.code}: ${jsonResponse.error.message}` + ) + ); + } else { + resolve(jsonResponse.result); + } + return; // Exit the stream processing + } + + // If it's not our response, it might be a notification + // Pass it to the message handler if available + // if (this.onmessage && (!jsonResponse.id || jsonResponse.id !== requestId)) { + // this.onmessage(jsonResponse); + // } + } catch (e) { + // Ignore parsing errors for non-JSON data lines + continue; + } + } + } + } + + // If we reach here without getting our response, it's an error + reject(new Error('No matching response received from SSE stream')); + } catch (error) { + reject(error); + } finally { + try { + reader.releaseLock(); + } catch (e) { + // Reader might already be released + } + } + }; + + processStream(); + }); + } + + async initialize(): Promise { + // Initialize SSE connection if needed + if (this.isSSE) { + await this.initializeSSE(); + } + + const request: JSONRPCRequest = { + jsonrpc: '2.0', + id: this.getNextMessageId(), + method: 'initialize', + params: { + protocolVersion: '2025-03-26', + capabilities: {}, + clientInfo: { + name: 'minimal-mcp-client', + version: '1.0.0', + }, + }, + }; + + const result = await this.sendRequest(request); + + // Send initialized notification + await this.sendNotification('notifications/initialized'); + + return result; + } + + private async sendNotification(method: string, params?: any): Promise { + const notification = { + jsonrpc: '2.0' as const, + method, + params, + }; + + if (this.isSSE && this.sseEndpoint) { + // Send via SSE POST endpoint (session ID already in URL) + const headers = new Headers(this.getAuthHeaders()); + headers.set('Content-Type', 'application/json'); + + const response = await fetch(this.sseEndpoint, { + method: 'POST', + headers, + body: JSON.stringify(notification), + }); + + if (!response.ok) { + throw new Error( + `Failed to send SSE notification: HTTP ${response.status}` + ); + } + } else { + // Send via direct HTTP + const headers = new Headers(this.getAuthHeaders()); + + if (this.sessionId) { + headers.set('mcp-session-id', this.sessionId); + } + + const response = await fetch(this.url, { + method: 'POST', + headers, + body: JSON.stringify(notification), + }); + + if (!response.ok) { + throw new Error(`Failed to send notification: HTTP ${response.status}`); + } + } + } + + async listTools(): Promise { + const request: JSONRPCRequest = { + jsonrpc: '2.0', + id: this.getNextMessageId(), + method: 'tools/list', + }; + + const result: ListToolsResult = await this.sendRequest(request); + return result.tools; + } + + async executeTool( + name: string, + args?: Record + ): Promise { + const request: JSONRPCRequest = { + jsonrpc: '2.0', + id: this.getNextMessageId(), + method: 'tools/call', + params: { + name, + arguments: args || {}, + }, + }; + + const result = await this.sendRequest(request); + return result; + } + + async close(): Promise { + // Clean up pending requests + for (const [id, pending] of this.pendingRequests) { + pending.reject(new Error('Connection closed')); + } + this.pendingRequests.clear(); + + // Close EventSource if we have one + if (this.eventSource) { + this.eventSource.close(); + this.eventSource = undefined; + } + + // For SSE connections, we need to abort any ongoing fetch operations + if (this.isSSE && this.abortController) { + // console.log('Aborting SSE connection...'); + this.abortController.abort(); + } + + // Attempt to terminate session if we have a session ID + if (this.sessionId && !this.isSSE) { + try { + const headers = new Headers(this.getAuthHeaders()); + headers.set('mcp-session-id', this.sessionId); + + // Add a timeout to prevent hanging + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), 5000); // 5 second timeout + + await fetch(this.url, { + method: 'DELETE', + headers, + signal: controller.signal, + }); + + clearTimeout(timeoutId); + } catch (error) { + // Ignore errors when terminating - server might not support it + console.warn('Failed to terminate session:', error); + } + } + } +} diff --git a/src/handlers/services/requestContext.ts b/src/handlers/services/requestContext.ts index 6378dd503..a4569a216 100644 --- a/src/handlers/services/requestContext.ts +++ b/src/handlers/services/requestContext.ts @@ -6,12 +6,16 @@ import { Options, Params, RetrySettings, + McpServer, + McpTool, + McpServerConfig, } from '../../types/requestBody'; import { endpointStrings } from '../../providers/types'; import { HEADER_KEYS, RETRY_STATUS_CODES } from '../../globals'; import { HookObject } from '../../middlewares/hooks/types'; import { HooksManager } from '../../middlewares/hooks'; import { transformToProviderRequest } from '../../services/transformToProviderRequest'; +import { LLMFunction } from './mcpService'; export class RequestContext { private _params: Params | null = null; @@ -227,4 +231,82 @@ export class RequestContext { requestOptions, ]); } + + shouldHandleMcp(): boolean { + // MCP applies only to chatComplete requests + if (this.endpoint !== 'chatComplete') return false; + + const { mcp_servers = [], tools = [] } = this.params ?? {}; + + if (mcp_servers.length > 0) return true; + + return tools.some((tool) => tool.type === 'mcp'); + } + + get mcpServers(): McpServer[] { + const { mcp_servers = [], tools = [] } = this.params ?? {}; + if (mcp_servers.length === 0 && tools.length === 0) return []; + + const mcpServers: McpServer[] = []; + + if (mcp_servers) { + for (const srv of mcp_servers) { + // Build the one object you actually need + const entry: McpServer = { + server_url: srv.url, + server_label: srv.name, + }; + + // Optional pieces, added only when present — no throw-away spreads + const tc = srv.tool_configuration; + if (tc?.allowed_tools) entry.allowed_tools = tc.allowed_tools; + + if (srv.authorization_token) { + entry.headers = { + Authorization: `Bearer ${srv.authorization_token}`, + }; + } + + mcpServers.push(entry); + } + } + + if (tools) { + for (const tool of tools) { + if (tool.type !== 'mcp') continue; + + //typecast tool to McpTool + const mcpTool = tool as McpTool; + + const entry: McpServer = { + server_url: mcpTool.server_url, + server_label: mcpTool.server_label, + }; + if (mcpTool.allowed_tools) entry.allowed_tools = mcpTool.allowed_tools; + if (mcpTool.require_approval) + entry.require_approval = mcpTool.require_approval; + if (mcpTool.headers) entry.headers = mcpTool.headers; + + mcpServers.push(entry); + } + } + + return mcpServers; + } + + addMcpTools(mcpTools: LLMFunction[]) { + if (mcpTools.length > 0) { + let newParams = { ...this.params }; + // Remove any existing tool with type `mcp` + newParams.tools = [...(this.params.tools || []), ...mcpTools]; + newParams.tools = newParams.tools?.filter((tool) => tool.type !== 'mcp'); + this.params = newParams; + } + } + + updateMessages(messages: any[]) { + let newParams = { ...this.params }; + newParams.messages = messages; + this.params = newParams; + } } diff --git a/src/middlewares/log/index.ts b/src/middlewares/log/index.ts index 57cbd9a2e..ebc6c3bb2 100644 --- a/src/middlewares/log/index.ts +++ b/src/middlewares/log/index.ts @@ -56,32 +56,41 @@ async function processLog(c: Context, start: number) { return; } - try { - const response = requestOptionsArray[0].requestParams.stream - ? { message: 'The response was a stream.' } - : await c.res.clone().json(); - - const responseString = JSON.stringify(response); - if (responseString.length > MAX_RESPONSE_LENGTH) { - requestOptionsArray[0].response = - responseString.substring(0, MAX_RESPONSE_LENGTH) + '...'; - } else { - requestOptionsArray[0].response = response; + for (const requestOption of requestOptionsArray) { + if (requestOption.type === 'otel') { + console.log('otel', JSON.stringify(requestOption)); + continue; } - } catch (error) { - console.error('Error processing log:', error); - } - await broadcastLog( - JSON.stringify({ - time: new Date().toLocaleString(), - method: c.req.method, - endpoint: c.req.url.split(':8787')[1], - status: c.res.status, - duration: ms, - requestOptions: requestOptionsArray, - }) - ); + console.log(requestOption.type || 'requestOption', requestOption); + + try { + const response = requestOption.requestParams.stream + ? { message: 'The response was a stream.' } + : await c.res.clone().json(); + + const responseString = JSON.stringify(response); + if (responseString.length > MAX_RESPONSE_LENGTH) { + requestOption.response = + responseString.substring(0, MAX_RESPONSE_LENGTH) + '...'; + } else { + requestOption.response = response; + } + } catch (error) { + console.error('Error processing log:', error); + } + + await broadcastLog( + JSON.stringify({ + time: new Date().toLocaleString(), + method: c.req.method, + endpoint: c.req.url.split(':8787')[1], + status: c.res.status, + duration: ms, + requestOptions: requestOption, + }) + ); + } } export const logger = () => { diff --git a/src/providers/anthropic/chatComplete.ts b/src/providers/anthropic/chatComplete.ts index 49aa47695..2ed82c4db 100644 --- a/src/providers/anthropic/chatComplete.ts +++ b/src/providers/anthropic/chatComplete.ts @@ -5,6 +5,8 @@ import { ContentType, SYSTEM_MESSAGE_ROLES, PromptCache, + McpTool, + Tool, } from '../../types/requestBody'; import { ChatCompletionResponse, @@ -353,7 +355,11 @@ export const AnthropicChatCompleteConfig: ProviderConfig = { transform: (params: Params) => { let tools: AnthropicTool[] = []; if (params.tools) { - params.tools.forEach((tool) => { + params.tools.forEach((tool: Tool | McpTool) => { + if (tool.type === 'mcp') { + return; + } + tool = tool as Tool; if (tool.function) { tools.push({ name: tool.function.name, diff --git a/src/providers/bedrock/chatComplete.ts b/src/providers/bedrock/chatComplete.ts index 5eb0b51ce..b45ea951a 100644 --- a/src/providers/bedrock/chatComplete.ts +++ b/src/providers/bedrock/chatComplete.ts @@ -10,6 +10,7 @@ import { ToolCall, SYSTEM_MESSAGE_ROLES, ContentType, + Tool, } from '../../types/requestBody'; import { ChatCompletionResponse, @@ -320,6 +321,10 @@ export const BedrockConverseChatCompleteConfig: ProviderConfig = { | { cachePoint: { type: string } } > = []; params.tools?.forEach((tool) => { + if (tool.type === 'mcp') { + return; + } + tool = tool as Tool; if (tool.function) { tools.push({ toolSpec: { diff --git a/src/providers/bedrock/uploadFileUtils.ts b/src/providers/bedrock/uploadFileUtils.ts index dfbc95717..ef02d7ba4 100644 --- a/src/providers/bedrock/uploadFileUtils.ts +++ b/src/providers/bedrock/uploadFileUtils.ts @@ -4,6 +4,7 @@ import { Message, MESSAGE_ROLES, Params, + Tool, } from '../../types/requestBody'; import { ChatCompletionResponse, @@ -226,6 +227,10 @@ const BedrockAnthropicChatCompleteConfig: ProviderConfig = { const tools: AnthropicTool[] = []; if (params.tools) { params.tools.forEach((tool) => { + if (tool.type === 'mcp') { + return; + } + tool = tool as Tool; if (tool.function) { tools.push({ name: tool.function.name, diff --git a/src/types/requestBody.ts b/src/types/requestBody.ts index df048b801..07bcfd6b5 100644 --- a/src/types/requestBody.ts +++ b/src/types/requestBody.ts @@ -360,7 +360,7 @@ export type ToolChoice = ToolChoiceObject | 'none' | 'auto' | 'required'; */ export interface Tool extends PromptCache { /** The name of the function. */ - type: string; + type: Exclude; /** A description of the function. */ function: Function; // this is used to support tools like computer, web_search, etc. @@ -395,7 +395,7 @@ export interface Params { context?: string; examples?: Examples[]; top_k?: number; - tools?: Tool[]; + tools?: (Tool | McpTool)[]; tool_choice?: ToolChoice; response_format?: { type: 'json_object' | 'text' | 'json_schema'; @@ -431,9 +431,40 @@ export interface Params { // Embeddings specific dimensions?: number; parameters?: any; + mcp_servers?: McpServerConfig[]; [key: string]: any; } +export interface McpServerConfig { + type: 'url' | 'local'; + url: string; + name: string; + authorization_token: string; + tool_configuration: { + enabled: boolean; + allowed_tools: string[]; + }; +} + +// A type of tool that is an MCP server +export interface McpTool { + type: 'mcp'; + server_url: string; + server_label: string; + allowed_tools?: string[]; + require_approval?: 'never' | 'always'; + headers?: Record; +} + +// Used to store MCP servers in the request context +export interface McpServer { + server_url: string; + server_label: string; + allowed_tools?: string[]; + require_approval?: 'never' | 'always'; + headers?: Record; +} + interface Examples { input?: Message; output?: Message; diff --git a/tests/unit/src/handlers/services/logsService.test.ts b/tests/unit/src/handlers/services/logsService.test.ts index ffdbc73c7..33fd81134 100644 --- a/tests/unit/src/handlers/services/logsService.test.ts +++ b/tests/unit/src/handlers/services/logsService.test.ts @@ -60,7 +60,7 @@ describe('LogsService', () => { ); expect(result).toEqual({ - type: 'otlp_span', + type: 'otel', traceId: 'trace-123', spanId: 'span-789', parentSpanId: 'parent-456',