diff --git a/rust/rpc.rs b/rust/rpc.rs index 2e6ff836..4d815bd4 100644 --- a/rust/rpc.rs +++ b/rust/rpc.rs @@ -158,7 +158,7 @@ where message = outgoing_rx.next() => { if let Some(message) = message { outgoing_line.clear(); - serde_json::to_writer(&mut outgoing_line, &message).map_err(Error::into_internal_error)?; + serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&message)).map_err(Error::into_internal_error)?; log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line)); outgoing_line.push(b'\n'); outgoing_bytes.write_all(&outgoing_line).await.ok(); @@ -190,7 +190,7 @@ where result: ResponseResult::Error(err), }; - serde_json::to_writer(&mut outgoing_line, &error_response)?; + serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&error_response))?; log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line)); outgoing_line.push(b'\n'); outgoing_bytes.write_all(&outgoing_line).await.ok(); @@ -325,6 +325,34 @@ pub enum OutgoingMessage { }, } +/// Either [`OutgoingMessage`] or [`IncomingMessage`] with `"jsonrpc": "2.0"` specified as +/// [required by JSON-RPC 2.0 Specification][1]. +/// +/// [1]: https://www.jsonrpc.org/specification#compatibility +#[derive(Debug, Serialize, Deserialize)] +pub struct JsonRpcMessage { + jsonrpc: &'static str, + #[serde(flatten)] + message: M, +} + +impl JsonRpcMessage { + /// Used version of [JSON-RPC protocol]. + /// + /// [JSON-RPC]: https://www.jsonrpc.org + pub const VERSION: &'static str = "2.0"; + + /// Wraps the provided [`OutgoingMessage`] or [`IncomingMessage`] into a versioned + /// [`JsonRpcMessage`]. + #[must_use] + pub fn wrap(message: M) -> Self { + Self { + jsonrpc: Self::VERSION, + message, + } + } +} + #[derive(Debug, Serialize, Deserialize, Clone)] #[serde(rename_all = "snake_case")] pub enum ResponseResult { diff --git a/rust/rpc_tests.rs b/rust/rpc_tests.rs index 65a79e66..f6773739 100644 --- a/rust/rpc_tests.rs +++ b/rust/rpc_tests.rs @@ -612,49 +612,54 @@ async fn test_full_conversation_flow() { async fn test_notification_wire_format() { use crate::{ AgentNotification, AgentSide, CancelNotification, ClientNotification, ClientSide, - ContentBlock, SessionNotification, SessionUpdate, TextContent, rpc::OutgoingMessage, + ContentBlock, SessionNotification, SessionUpdate, TextContent, + rpc::{JsonRpcMessage, OutgoingMessage}, }; use serde_json::{Value, json}; // Test client -> agent notification wire format - let outgoing_msg = OutgoingMessage::::Notification { - method: "cancel", - params: Some(ClientNotification::CancelNotification(CancelNotification { - session_id: SessionId("test-123".into()), - })), - }; + let outgoing_msg = + JsonRpcMessage::wrap(OutgoingMessage::::Notification { + method: "cancel", + params: Some(ClientNotification::CancelNotification(CancelNotification { + session_id: SessionId("test-123".into()), + })), + }); let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap(); assert_eq!( serialized, json!({ + "jsonrpc": "2.0", "method": "cancel", "params": { "sessionId": "test-123" - } + }, }) ); // Test agent -> client notification wire format - let outgoing_msg = OutgoingMessage::::Notification { - method: "sessionUpdate", - params: Some(AgentNotification::SessionNotification( - SessionNotification { - session_id: SessionId("test-456".into()), - update: SessionUpdate::AgentMessageChunk { - content: ContentBlock::Text(TextContent { - annotations: None, - text: "Hello".to_string(), - }), + let outgoing_msg = + JsonRpcMessage::wrap(OutgoingMessage::::Notification { + method: "sessionUpdate", + params: Some(AgentNotification::SessionNotification( + SessionNotification { + session_id: SessionId("test-456".into()), + update: SessionUpdate::AgentMessageChunk { + content: ContentBlock::Text(TextContent { + annotations: None, + text: "Hello".to_string(), + }), + }, }, - }, - )), - }; + )), + }); let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap(); assert_eq!( serialized, json!({ + "jsonrpc": "2.0", "method": "sessionUpdate", "params": { "sessionId": "test-456",