diff --git a/src/crates/services-integrations/src/mcp/protocol/transport.rs b/src/crates/services-integrations/src/mcp/protocol/transport.rs index ab07a0a23..763c78efc 100644 --- a/src/crates/services-integrations/src/mcp/protocol/transport.rs +++ b/src/crates/services-integrations/src/mcp/protocol/transport.rs @@ -23,7 +23,7 @@ impl MCPTransport { } } - async fn next_request_id(&self) -> u64 { + pub async fn next_request_id(&self) -> u64 { let mut id = self.request_id.lock().await; *id += 1; *id @@ -40,6 +40,16 @@ impl MCPTransport { Ok(id) } + pub async fn send_request_with_id( + &self, + id: u64, + method: String, + params: Option, + ) -> MCPRuntimeResult<()> { + let request = MCPRequest::new(Value::Number(id.into()), method, params); + self.send_message(MCPMessage::Request(request)).await + } + pub async fn send_notification( &self, method: String, diff --git a/src/crates/services-integrations/src/mcp/server/connection.rs b/src/crates/services-integrations/src/mcp/server/connection.rs index 4d442b409..0cb8f0da5 100644 --- a/src/crates/services-integrations/src/mcp/server/connection.rs +++ b/src/crates/services-integrations/src/mcp/server/connection.rs @@ -186,14 +186,22 @@ impl MCPConnection { ) -> MCPRuntimeResult { match &self.transport { TransportType::Local(transport) => { - let request_id = transport.send_request(method.clone(), params).await?; - + let request_id = transport.next_request_id().await; let (tx, rx) = oneshot::channel(); { let mut pending = self.pending_requests.write().await; pending.insert(request_id, tx); } + if let Err(error) = transport + .send_request_with_id(request_id, method.clone(), params) + .await + { + let mut pending = self.pending_requests.write().await; + pending.remove(&request_id); + return Err(error); + } + let response = if let Some(request_timeout) = self.request_timeout { tokio::time::timeout(request_timeout, rx) .await @@ -234,7 +242,15 @@ impl MCPConnection { let response = self .send_request_and_wait(request.method.clone(), request.params) .await?; - parse_response_result(&response) + let result = parse_response_result(&response)?; + + if let TransportType::Local(transport) = &self.transport { + transport + .send_notification("notifications/initialized".to_string(), None) + .await?; + } + + Ok(result) } TransportType::Remote(transport) => { transport.initialize(client_name, client_version).await