-
Notifications
You must be signed in to change notification settings - Fork 432
/
Copy pathws_test_suite.rs
140 lines (123 loc) · 4.74 KB
/
ws_test_suite.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#![cfg(not(windows))]
use std::{net::SocketAddr, sync::Arc};
use anyhow::anyhow;
use axum::{Extension, Router, routing::get};
use futures::{SinkExt, StreamExt};
use juniper::{
EmptyMutation, LocalBoxFuture, RootNode,
http::tests::{WsIntegration, WsIntegrationMessage, graphql_transport_ws, graphql_ws},
tests::fixtures::starwars::schema::{Database, Query, Subscription},
};
use juniper_axum::subscriptions;
use juniper_graphql_ws::ConnectionConfig;
use tokio::{
net::{TcpListener, TcpStream},
time::timeout,
};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message};
type Schema = RootNode<'static, Query, EmptyMutation<Database>, Subscription>;
#[derive(Clone)]
struct TestApp(Router);
impl TestApp {
fn new(protocol: &'static str) -> Self {
let schema = Schema::new(Query, EmptyMutation::new(), Subscription);
let mut router = Router::new();
router = if protocol == "graphql-ws" {
router.route(
"/subscriptions",
get(subscriptions::graphql_ws::<Arc<Schema>>(
ConnectionConfig::new(Database::new()),
)),
)
} else {
router.route(
"/subscriptions",
get(subscriptions::graphql_transport_ws::<Arc<Schema>>(
ConnectionConfig::new(Database::new()),
)),
)
};
router = router.layer(Extension(Arc::new(schema)));
Self(router)
}
async fn run(self, messages: Vec<WsIntegrationMessage>) -> Result<(), anyhow::Error> {
let listener = TcpListener::bind("0.0.0.0:0".parse::<SocketAddr>().unwrap())
.await
.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, self.0).await.unwrap();
});
let (mut websocket, _) = connect_async(format!("ws://{}/subscriptions", addr))
.await
.unwrap();
for msg in messages {
Self::process_message(&mut websocket, msg).await?;
}
Ok(())
}
async fn process_message(
websocket: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
message: WsIntegrationMessage,
) -> Result<(), anyhow::Error> {
match message {
WsIntegrationMessage::Send(msg) => websocket
.send(Message::Text(msg.to_string().into()))
.await
.map_err(|e| anyhow!("Could not send message: {e}"))
.map(drop),
WsIntegrationMessage::Expect(expected, duration) => {
let message = timeout(duration, websocket.next())
.await
.map_err(|e| anyhow!("Timed out receiving message. Elapsed: {e}"))?;
match message {
None => Err(anyhow!("No message received")),
Some(Err(e)) => Err(anyhow!("WebSocket error: {e}")),
Some(Ok(Message::Text(json))) => {
let actual: serde_json::Value = serde_json::from_str(&json)
.map_err(|e| anyhow!("Cannot deserialize received message: {e}"))?;
if actual != expected {
return Err(anyhow!(
"Expected message: {expected}. \
Received message: {actual}",
));
}
Ok(())
}
Some(Ok(Message::Close(Some(frame)))) => {
let actual = serde_json::json!({
"code": u16::from(frame.code),
"description": *frame.reason,
});
if actual != expected {
return Err(anyhow!(
"Expected message: {expected}. \
Received message: {actual}",
));
}
Ok(())
}
Some(Ok(msg)) => Err(anyhow!("Received non-text message: {msg:?}")),
}
}
}
}
}
impl WsIntegration for TestApp {
fn run(
&self,
messages: Vec<WsIntegrationMessage>,
) -> LocalBoxFuture<Result<(), anyhow::Error>> {
Box::pin(self.clone().run(messages))
}
}
#[tokio::test]
async fn test_graphql_ws_integration() {
let app = TestApp::new("graphql-ws");
graphql_ws::run_test_suite(&app).await;
}
#[tokio::test]
async fn test_graphql_transport_integration() {
let app = TestApp::new("graphql-transport-ws");
graphql_transport_ws::run_test_suite(&app).await;
}