Skip to content

Commit aca2336

Browse files
authored
feat: add message_type to MCPMessage trait (#26)
* add message_type * update comments * refactor MCPMessage Trait
1 parent 80b26ce commit aca2336

File tree

6 files changed

+153
-30
lines changed

6 files changed

+153
-30
lines changed

Diff for: src/generated_schema/2024_11_05/mcp_schema.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
/// ----------------------------------------------------------------------------
2-
/// This file is auto-generated by mcp-schema-gen v0.1.3.
2+
/// This file is auto-generated by mcp-schema-gen v0.1.4.
33
/// WARNING:
44
/// It is not recommended to modify this file directly. You are free to
55
/// modify or extend the implementations as needed, but please do so at your own risk.
66
///
77
/// Generated from : <https://github.com/modelcontextprotocol/specification.git>
88
/// Hash : 63e1dbb75456b359b9ed8b27d21f4ac68cbb753e
9-
/// Generated at : 2025-02-17 18:04:41
9+
/// Generated at : 2025-02-18 08:30:28
1010
/// ----------------------------------------------------------------------------
1111
///
1212
/// MCP Protocol Version

Diff for: src/generated_schema/2024_11_05/schema_utils.rs

+71-23
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,31 @@ use serde_json::{json, Value};
44
use std::hash::{Hash, Hasher};
55
use std::{fmt::Display, str::FromStr};
66

7-
#[derive(Debug)]
7+
#[derive(Debug, PartialEq)]
88
pub enum MessageTypes {
99
Request,
1010
Response,
1111
Notification,
1212
Error,
1313
}
14+
/// Implements the `Display` trait for the `MessageTypes` enum,
15+
/// allowing it to be converted into a human-readable string.
16+
impl Display for MessageTypes {
17+
/// Formats the `MessageTypes` enum variant as a string.
18+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19+
write!(
20+
f,
21+
"{}",
22+
// Match the current enum variant and return a corresponding string
23+
match self {
24+
MessageTypes::Request => "Request",
25+
MessageTypes::Response => "Response",
26+
MessageTypes::Notification => "Notification",
27+
MessageTypes::Error => "Error",
28+
}
29+
)
30+
}
31+
}
1432

1533
/// A utility function used internally to detect the message type from the payload.
1634
/// This function is used when deserializing a `ClientMessage` into strongly-typed structs that represent the specific message received.
@@ -38,12 +56,18 @@ fn detect_message_type(value: &serde_json::Value) -> MessageTypes {
3856
MessageTypes::Request
3957
}
4058

41-
pub trait MCPMessage {
59+
/// Represents a generic MCP (Model Content Protocol) message.
60+
/// This trait defines methods to classify and extract information from messages.
61+
pub trait RPCMessage {
62+
fn request_id(&self) -> Option<&RequestId>;
63+
}
64+
65+
pub trait MCPMessage: RPCMessage {
66+
fn message_type(&self) -> MessageTypes;
4267
fn is_response(&self) -> bool;
4368
fn is_request(&self) -> bool;
4469
fn is_notification(&self) -> bool;
4570
fn is_error(&self) -> bool;
46-
fn request_id(&self) -> Option<&RequestId>;
4771
}
4872

4973
//*******************************//
@@ -94,6 +118,22 @@ pub enum ClientMessage {
94118
Error(JsonrpcError),
95119
}
96120

121+
impl RPCMessage for ClientMessage {
122+
// Retrieves the request ID associated with the message, if applicable
123+
fn request_id(&self) -> Option<&RequestId> {
124+
match self {
125+
// If the message is a request, return the associated request ID
126+
ClientMessage::Request(client_jsonrpc_request) => Some(&client_jsonrpc_request.id),
127+
// Notifications do not have request IDs
128+
ClientMessage::Notification(_) => None,
129+
// If the message is a response, return the associated request ID
130+
ClientMessage::Response(client_jsonrpc_response) => Some(&client_jsonrpc_response.id),
131+
// If the message is an error, return the associated request ID
132+
ClientMessage::Error(jsonrpc_error) => Some(&jsonrpc_error.id),
133+
}
134+
}
135+
}
136+
97137
// Implementing the `MCPMessage` trait for `ClientMessage`
98138
impl MCPMessage for ClientMessage {
99139
// Returns true if the message is a response type
@@ -116,17 +156,13 @@ impl MCPMessage for ClientMessage {
116156
matches!(self, ClientMessage::Error(_))
117157
}
118158

119-
// Retrieves the request ID associated with the message, if applicable
120-
fn request_id(&self) -> Option<&RequestId> {
159+
/// Determines the type of the message and returns the corresponding `MessageTypes` variant.
160+
fn message_type(&self) -> MessageTypes {
121161
match self {
122-
// If the message is a request, return the associated request ID
123-
ClientMessage::Request(client_jsonrpc_request) => Some(&client_jsonrpc_request.id),
124-
// Notifications do not have request IDs
125-
ClientMessage::Notification(_) => None,
126-
// If the message is a response, return the associated request ID
127-
ClientMessage::Response(client_jsonrpc_response) => Some(&client_jsonrpc_response.id),
128-
// If the message is an error, return the associated request ID
129-
ClientMessage::Error(jsonrpc_error) => Some(&jsonrpc_error.id),
162+
ClientMessage::Request(_) => MessageTypes::Request,
163+
ClientMessage::Notification(_) => MessageTypes::Notification,
164+
ClientMessage::Response(_) => MessageTypes::Response,
165+
ClientMessage::Error(_) => MessageTypes::Error,
130166
}
131167
}
132168
}
@@ -464,6 +500,22 @@ pub enum ServerMessage {
464500
Error(JsonrpcError),
465501
}
466502

503+
impl RPCMessage for ServerMessage {
504+
// Retrieves the request ID associated with the message, if applicable
505+
fn request_id(&self) -> Option<&RequestId> {
506+
match self {
507+
// If the message is a request, return the associated request ID
508+
ServerMessage::Request(client_jsonrpc_request) => Some(&client_jsonrpc_request.id),
509+
// Notifications do not have request IDs
510+
ServerMessage::Notification(_) => None,
511+
// If the message is a response, return the associated request ID
512+
ServerMessage::Response(client_jsonrpc_response) => Some(&client_jsonrpc_response.id),
513+
// If the message is an error, return the associated request ID
514+
ServerMessage::Error(jsonrpc_error) => Some(&jsonrpc_error.id),
515+
}
516+
}
517+
}
518+
467519
// Implementing the `MCPMessage` trait for `ServerMessage`
468520
impl MCPMessage for ServerMessage {
469521
// Returns true if the message is a response type
@@ -486,17 +538,13 @@ impl MCPMessage for ServerMessage {
486538
matches!(self, ServerMessage::Error(_))
487539
}
488540

489-
// Retrieves the request ID associated with the message, if applicable
490-
fn request_id(&self) -> Option<&RequestId> {
541+
/// Determines the type of the message and returns the corresponding `MessageTypes` variant.
542+
fn message_type(&self) -> MessageTypes {
491543
match self {
492-
// If the message is a request, return the associated request ID
493-
ServerMessage::Request(client_jsonrpc_request) => Some(&client_jsonrpc_request.id),
494-
// Notifications do not have request IDs
495-
ServerMessage::Notification(_) => None,
496-
// If the message is a response, return the associated request ID
497-
ServerMessage::Response(client_jsonrpc_response) => Some(&client_jsonrpc_response.id),
498-
// If the message is an error, return the associated request ID
499-
ServerMessage::Error(jsonrpc_error) => Some(&jsonrpc_error.id),
544+
ServerMessage::Request(_) => MessageTypes::Request,
545+
ServerMessage::Notification(_) => MessageTypes::Notification,
546+
ServerMessage::Response(_) => MessageTypes::Response,
547+
ServerMessage::Error(_) => MessageTypes::Error,
500548
}
501549
}
502550
}

Diff for: src/generated_schema/draft/mcp_schema.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
/// ----------------------------------------------------------------------------
2-
/// This file is auto-generated by mcp-schema-gen v0.1.3.
2+
/// This file is auto-generated by mcp-schema-gen v0.1.4.
33
/// WARNING:
44
/// It is not recommended to modify this file directly. You are free to
55
/// modify or extend the implementations as needed, but please do so at your own risk.
66
///
77
/// Generated from : <https://github.com/modelcontextprotocol/specification.git>
88
/// Hash : 63e1dbb75456b359b9ed8b27d21f4ac68cbb753e
9-
/// Generated at : 2025-02-17 18:04:41
9+
/// Generated at : 2025-02-18 08:30:28
1010
/// ----------------------------------------------------------------------------
1111
///
1212
/// MCP Protocol Version

Diff for: src/generated_schema/draft/schema_utils.rs

+40-1
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,31 @@ use serde_json::{json, Value};
44
use std::hash::{Hash, Hasher};
55
use std::{fmt::Display, str::FromStr};
66

7-
#[derive(Debug)]
7+
#[derive(Debug, PartialEq)]
88
pub enum MessageTypes {
99
Request,
1010
Response,
1111
Notification,
1212
Error,
1313
}
14+
/// Implements the `Display` trait for the `MessageTypes` enum,
15+
/// allowing it to be converted into a human-readable string.
16+
impl Display for MessageTypes {
17+
/// Formats the `MessageTypes` enum variant as a string.
18+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19+
write!(
20+
f,
21+
"{}",
22+
// Match the current enum variant and return a corresponding string
23+
match self {
24+
MessageTypes::Request => "Request",
25+
MessageTypes::Response => "Response",
26+
MessageTypes::Notification => "Notification",
27+
MessageTypes::Error => "Error",
28+
}
29+
)
30+
}
31+
}
1432

1533
/// A utility function used internally to detect the message type from the payload.
1634
/// This function is used when deserializing a `ClientMessage` into strongly-typed structs that represent the specific message received.
@@ -44,6 +62,7 @@ pub trait MCPMessage {
4462
fn is_notification(&self) -> bool;
4563
fn is_error(&self) -> bool;
4664
fn request_id(&self) -> Option<&RequestId>;
65+
fn message_type(&self) -> MessageTypes;
4766
}
4867

4968
//*******************************//
@@ -129,6 +148,16 @@ impl MCPMessage for ClientMessage {
129148
ClientMessage::Error(jsonrpc_error) => Some(&jsonrpc_error.id),
130149
}
131150
}
151+
152+
/// Determines the type of the message and returns the corresponding `MessageTypes` variant.
153+
fn message_type(&self) -> MessageTypes {
154+
match self {
155+
ClientMessage::Request(_) => MessageTypes::Request,
156+
ClientMessage::Notification(_) => MessageTypes::Notification,
157+
ClientMessage::Response(_) => MessageTypes::Response,
158+
ClientMessage::Error(_) => MessageTypes::Error,
159+
}
160+
}
132161
}
133162

134163
//**************************//
@@ -499,6 +528,16 @@ impl MCPMessage for ServerMessage {
499528
ServerMessage::Error(jsonrpc_error) => Some(&jsonrpc_error.id),
500529
}
501530
}
531+
532+
/// Determines the type of the message and returns the corresponding `MessageTypes` variant.
533+
fn message_type(&self) -> MessageTypes {
534+
match self {
535+
ServerMessage::Request(_) => MessageTypes::Request,
536+
ServerMessage::Notification(_) => MessageTypes::Notification,
537+
ServerMessage::Response(_) => MessageTypes::Response,
538+
ServerMessage::Error(_) => MessageTypes::Error,
539+
}
540+
}
502541
}
503542

504543
impl FromStr for ServerMessage {

Diff for: tests/miscellaneous.rs

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#[path = "common/common.rs"]
2+
pub mod common;
3+
4+
mod miscellaneous_tests {
5+
use rust_mcp_schema::schema_utils::*;
6+
7+
#[test]
8+
fn test_display_request() {
9+
assert_eq!(MessageTypes::Request.to_string(), "Request");
10+
}
11+
12+
#[test]
13+
fn test_display_response() {
14+
assert_eq!(MessageTypes::Response.to_string(), "Response");
15+
}
16+
17+
#[test]
18+
fn test_display_notification() {
19+
assert_eq!(MessageTypes::Notification.to_string(), "Notification");
20+
}
21+
22+
#[test]
23+
fn test_display_error() {
24+
assert_eq!(MessageTypes::Error.to_string(), "Error");
25+
}
26+
}

Diff for: tests/test_serialize.rs

+12-2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ mod test_serialize {
4141
assert!(!message.is_response());
4242
assert!(!message.is_notification());
4343
assert!(!message.is_error());
44+
assert!(message.message_type() == MessageTypes::Request);
4445
assert!(
4546
matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15))
4647
);
@@ -228,6 +229,7 @@ mod test_serialize {
228229
assert!(message.is_response());
229230
assert!(!message.is_notification());
230231
assert!(!message.is_error());
232+
assert!(message.message_type() == MessageTypes::Response);
231233
assert!(
232234
matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15))
233235
);
@@ -279,6 +281,8 @@ mod test_serialize {
279281
assert!(message.is_response());
280282
assert!(!message.is_notification());
281283
assert!(!message.is_error());
284+
assert!(message.message_type() == MessageTypes::Response);
285+
282286
assert!(
283287
matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15))
284288
);
@@ -415,7 +419,7 @@ mod test_serialize {
415419
assert!(!message.is_response());
416420
assert!(message.is_notification());
417421
assert!(!message.is_error());
418-
422+
assert!(message.message_type() == MessageTypes::Notification);
419423
assert!(message.request_id().is_none());
420424

421425
assert!(matches!(message, ClientMessage::Notification(client_message)
@@ -509,7 +513,7 @@ mod test_serialize {
509513
assert!(!message.is_response());
510514
assert!(message.is_notification());
511515
assert!(!message.is_error());
512-
516+
assert!(message.message_type() == MessageTypes::Notification);
513517
assert!(message.request_id().is_none());
514518

515519
assert!(matches!(message, ServerMessage::Notification(client_message)
@@ -560,6 +564,8 @@ mod test_serialize {
560564
assert!(!message.is_response());
561565
assert!(!message.is_notification());
562566
assert!(!message.is_error());
567+
assert!(message.message_type() == MessageTypes::Request);
568+
563569
assert!(
564570
matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15))
565571
);
@@ -609,6 +615,8 @@ mod test_serialize {
609615
assert!(!message.is_response());
610616
assert!(!message.is_notification());
611617
assert!(message.is_error());
618+
assert!(message.message_type() == MessageTypes::Error);
619+
612620
assert!(
613621
matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15))
614622
);
@@ -628,6 +636,8 @@ mod test_serialize {
628636
assert!(!message.is_response());
629637
assert!(!message.is_notification());
630638
assert!(message.is_error());
639+
assert!(message.message_type() == MessageTypes::Error);
640+
631641
assert!(
632642
matches!(message.request_id(), Some(request_id) if matches!(request_id , RequestId::Integer(r) if *r == 15))
633643
);

0 commit comments

Comments
 (0)