Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions rust/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ where
message = outgoing_rx.next() => {
if let Some(message) = message {
outgoing_line.clear();
serde_json::to_writer(&mut outgoing_line, &message).map_err(Error::into_internal_error)?;
serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&message)).map_err(Error::into_internal_error)?;
log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
outgoing_line.push(b'\n');
outgoing_bytes.write_all(&outgoing_line).await.ok();
Expand Down Expand Up @@ -190,7 +190,7 @@ where
result: ResponseResult::Error(err),
};

serde_json::to_writer(&mut outgoing_line, &error_response)?;
serde_json::to_writer(&mut outgoing_line, &JsonRpcMessage::wrap(&error_response))?;
log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
outgoing_line.push(b'\n');
outgoing_bytes.write_all(&outgoing_line).await.ok();
Expand Down Expand Up @@ -325,6 +325,34 @@ pub enum OutgoingMessage<Local: Side, Remote: Side> {
},
}

/// Either [`OutgoingMessage`] or [`IncomingMessage`] with `"jsonrpc": "2.0"` specified as
/// [required by JSON-RPC 2.0 Specification][1].
///
/// [1]: https://www.jsonrpc.org/specification#compatibility
#[derive(Debug, Serialize, Deserialize)]
pub struct JsonRpcMessage<M> {
jsonrpc: &'static str,
#[serde(flatten)]
message: M,
}

impl<M> JsonRpcMessage<M> {
/// Used version of [JSON-RPC protocol].
///
/// [JSON-RPC]: https://www.jsonrpc.org
pub const VERSION: &'static str = "2.0";

/// Wraps the provided [`OutgoingMessage`] or [`IncomingMessage`] into a versioned
/// [`JsonRpcMessage`].
#[must_use]
pub fn wrap(message: M) -> Self {
Self {
jsonrpc: Self::VERSION,
message,
}
}
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "snake_case")]
pub enum ResponseResult<Res> {
Expand Down
47 changes: 26 additions & 21 deletions rust/rpc_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -612,49 +612,54 @@ async fn test_full_conversation_flow() {
async fn test_notification_wire_format() {
use crate::{
AgentNotification, AgentSide, CancelNotification, ClientNotification, ClientSide,
ContentBlock, SessionNotification, SessionUpdate, TextContent, rpc::OutgoingMessage,
ContentBlock, SessionNotification, SessionUpdate, TextContent,
rpc::{JsonRpcMessage, OutgoingMessage},
};
use serde_json::{Value, json};

// Test client -> agent notification wire format
let outgoing_msg = OutgoingMessage::<ClientSide, AgentSide>::Notification {
method: "cancel",
params: Some(ClientNotification::CancelNotification(CancelNotification {
session_id: SessionId("test-123".into()),
})),
};
let outgoing_msg =
JsonRpcMessage::wrap(OutgoingMessage::<ClientSide, AgentSide>::Notification {
method: "cancel",
params: Some(ClientNotification::CancelNotification(CancelNotification {
session_id: SessionId("test-123".into()),
})),
});

let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
assert_eq!(
serialized,
json!({
"jsonrpc": "2.0",
"method": "cancel",
"params": {
"sessionId": "test-123"
}
},
})
);

// Test agent -> client notification wire format
let outgoing_msg = OutgoingMessage::<AgentSide, ClientSide>::Notification {
method: "sessionUpdate",
params: Some(AgentNotification::SessionNotification(
SessionNotification {
session_id: SessionId("test-456".into()),
update: SessionUpdate::AgentMessageChunk {
content: ContentBlock::Text(TextContent {
annotations: None,
text: "Hello".to_string(),
}),
let outgoing_msg =
JsonRpcMessage::wrap(OutgoingMessage::<AgentSide, ClientSide>::Notification {
method: "sessionUpdate",
params: Some(AgentNotification::SessionNotification(
SessionNotification {
session_id: SessionId("test-456".into()),
update: SessionUpdate::AgentMessageChunk {
content: ContentBlock::Text(TextContent {
annotations: None,
text: "Hello".to_string(),
}),
},
},
},
)),
};
)),
});

let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
assert_eq!(
serialized,
json!({
"jsonrpc": "2.0",
"method": "sessionUpdate",
"params": {
"sessionId": "test-456",
Expand Down