diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a756689b..63a29077 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -79,6 +79,11 @@ jobs: args: --features multitenant,analytics,geoblock,functional_tests cache: { sharedKey: "tests" } rustc: stable + - name: "Single-tenant functional tests" + cmd: test + args: --features functional_tests + cache: { sharedKey: "tests" } + rustc: stable include: - os: ubuntu-latest sccache-path: /home/runner/.cache/sccache @@ -138,7 +143,7 @@ jobs: run: | sccache --stop-server || true sccache --start-server - + - name: Install lld and llvm run: sudo apt-get install -y lld llvm diff --git a/Cargo.lock b/Cargo.lock index 14838ee5..5295f501 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,6 +112,27 @@ version = "1.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "async-channel" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81953c529336010edd6d8e358f886d9581267795c61b19475b71314bffa46d35" +dependencies = [ + "concurrent-queue", + "event-listener", + "futures-core", +] + [[package]] name = "async-recursion" version = "1.0.5" @@ -892,6 +913,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "concurrent-queue" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f057a694a54f12365049b0958a1685bb52d567f5593b355fbf685838e873d400" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "const-oid" version = "0.9.5" @@ -1098,6 +1128,25 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308" +[[package]] +name = "deadpool" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "421fe0f90f2ab22016f32a9881be5134fdd71c65298917084b0c7477cbc3856e" +dependencies = [ + "async-trait", + "deadpool-runtime", + "num_cpus", + "retain_mut", + "tokio", +] + +[[package]] +name = "deadpool-runtime" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63dfa964fe2a66f3fde91fc70b267fe193d822c7e603e2a675a49a7f46ad3f49" + [[package]] name = "der" version = "0.7.8" @@ -1241,6 +1290,7 @@ dependencies = [ "tracing-subscriber", "uuid", "wc", + "wiremock", ] [[package]] @@ -1511,6 +1561,21 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +[[package]] +name = "futures-lite" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49a9d51ce47660b1e808d3c990b4709f2f415d928835a17dfd16991515c46bce" +dependencies = [ + "fastrand 1.9.0", + "futures-core", + "futures-io", + "memchr", + "parking", + "pin-project-lite", + "waker-fn", +] + [[package]] name = "futures-macro" version = "0.3.28" @@ -1534,6 +1599,12 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +[[package]] +name = "futures-timer" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" + [[package]] name = "futures-util" version = "0.3.28" @@ -1757,6 +1828,27 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f" +[[package]] +name = "http-types" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e9b187a72d63adbfba487f48095306ac823049cb504ee195541e91c7775f5ad" +dependencies = [ + "anyhow", + "async-channel", + "base64 0.13.1", + "futures-lite", + "http", + "infer", + "pin-project-lite", + "rand 0.7.3", + "serde", + "serde_json", + "serde_qs", + "serde_urlencoded", + "url", +] + [[package]] name = "httparse" version = "1.8.0" @@ -1901,6 +1993,12 @@ dependencies = [ "hashbrown 0.14.1", ] +[[package]] +name = "infer" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64e9829a50b42bb782c1df523f78d332fe371b10c661e78b7a3c34b0198e9fac" + [[package]] name = "instant" version = "0.1.12" @@ -2495,6 +2593,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "parking" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" + [[package]] name = "parking_lot" version = "0.11.2" @@ -3049,6 +3153,12 @@ dependencies = [ "winreg", ] +[[package]] +name = "retain_mut" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4389f1d5789befaf6029ebd9f7dac4af7f7e3d61b69d4f30e2ac02b57e7712b0" + [[package]] name = "ring" version = "0.16.20" @@ -3273,6 +3383,17 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_qs" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7715380eec75f029a4ef7de39a9200e0a63823176b759d055b613f5a87df6a6" +dependencies = [ + "percent-encoding", + "serde", + "thiserror", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -4103,6 +4224,7 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] @@ -4144,6 +4266,12 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" +[[package]] +name = "waker-fn" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3c4517f54858c779bbcbf228f4fca63d121bf85fbecb2dc578cdf4a39395690" + [[package]] name = "want" version = "0.3.1" @@ -4398,6 +4526,28 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "wiremock" +version = "0.5.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "079aee011e8a8e625d16df9e785de30a6b77f80a6126092d76a57375f96448da" +dependencies = [ + "assert-json-diff", + "async-trait", + "base64 0.21.4", + "deadpool", + "futures", + "futures-timer", + "http-types", + "hyper", + "log", + "once_cell", + "regex", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "xmlparser" version = "0.13.5" diff --git a/Cargo.toml b/Cargo.toml index 4ae02443..7a557a5d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -91,6 +91,7 @@ ipnet = "2.5" cerberus = { git = "https://github.com/WalletConnect/cerberus.git", tag = "v0.5.0" } async-recursion = "1.0.4" tap = "1.0.1" +wiremock = "0.5.21" [dev-dependencies] serial_test = "1.0" diff --git a/src/handlers/push_message.rs b/src/handlers/push_message.rs index 4be26da7..b2ac425d 100644 --- a/src/handlers/push_message.rs +++ b/src/handlers/push_message.rs @@ -49,13 +49,13 @@ pub struct PushMessageBody { pub async fn handler( #[cfg(feature = "analytics")] SecureClientIp(client_ip): SecureClientIp, - Path((tenant_id, id)): Path<(String, String)>, + Path((tenant_id, client_id)): Path<(String, String)>, StateExtractor(state): StateExtractor>, headers: HeaderMap, RequireValidSignature(Json(body)): RequireValidSignature>, ) -> Result { let res = handler_internal( - Path((tenant_id.clone(), id.clone())), + Path((tenant_id.clone(), client_id.clone())), StateExtractor(state.clone()), headers.clone(), RequireValidSignature(Json(body.clone())), @@ -116,7 +116,7 @@ pub async fn handler( debug!( %request_id, %tenant_id, - client_id = %id, + client_id = %client_id, ip = %client_ip, "loaded geo data" ); @@ -133,9 +133,9 @@ pub async fn handler( Ok(response) } -#[instrument(name = "push_message_internal", skip_all, fields(tenant_id = tenant_id, client_id = id, id = body.id))] +#[instrument(name = "push_message_internal", skip_all, fields(tenant_id = tenant_id, client_id = client_id, id = body.id))] pub async fn handler_internal( - Path((tenant_id, id)): Path<(String, String)>, + Path((tenant_id, client_id)): Path<(String, String)>, StateExtractor(state): StateExtractor>, headers: HeaderMap, RequireValidSignature(Json(body)): RequireValidSignature>, @@ -146,7 +146,7 @@ pub async fn handler_internal( #[cfg(feature = "analytics")] let (flags, encrypted) = (body.payload.clone().flags, body.payload.is_encrypted()); - let client = match state.client_store.get_client(&tenant_id, &id).await { + let client = match state.client_store.get_client(&tenant_id, &client_id).await { Ok(c) => Ok(c), Err(StoreError::NotFound(_, _)) => Err(ClientNotFound), Err(e) => Err(Store(e)), @@ -161,7 +161,7 @@ pub async fn handler_internal( country: None, continent: None, project_id: tenant_id.clone().into(), - client_id: id.clone().into(), + client_id: client_id.clone().into(), topic: topic.clone(), push_provider: "unknown".into(), encrypted, @@ -182,7 +182,7 @@ pub async fn handler_internal( country: None, continent: None, project_id: tenant_id.clone().into(), - client_id: id.clone().into(), + client_id: client_id.clone().into(), topic, push_provider: client.push_type.as_str().into(), encrypted, @@ -199,14 +199,14 @@ pub async fn handler_internal( increment_counter!(state.metrics, received_notifications); - let id = id + let client_id = client_id .trim_start_matches(DECENTRALIZED_IDENTIFIER_PREFIX) .to_string(); debug!( %request_id, %tenant_id, - client_id = %id, + client_id = %client_id, "fetched client to send notification" ); @@ -214,7 +214,7 @@ pub async fn handler_internal( warn!( %request_id, %tenant_id, - client_id = %id, + client_id = %client_id, "client tenant id does not match request tenant id" ); @@ -224,7 +224,7 @@ pub async fn handler_internal( warn!( %request_id, %tenant_id, - client_id = %id, + client_id = %client_id, "client tenant id has not been set, allowing request to continue" ); } else { @@ -265,13 +265,13 @@ pub async fn handler_internal( if let Ok(notification) = state .notification_store - .get_notification(&body.id, &tenant_id) + .get_notification(&body.id, &client_id, &tenant_id) .await { warn!( %request_id, %tenant_id, - client_id = %id, + client_id = %client_id, notification_id = %notification.id, last_recieved_at = %notification.last_received_at, "notification has already been received" @@ -293,7 +293,7 @@ pub async fn handler_internal( let notification = state .notification_store - .create_or_update_notification(&body.id, &tenant_id, &id, &body.payload) + .create_or_update_notification(&body.id, &tenant_id, &client_id, &body.payload) .await .tap_err(|e| warn!("error create_or_update_notification: {e:?}")) .map_err(|e| (Error::Store(e), analytics.clone()))?; @@ -301,7 +301,7 @@ pub async fn handler_internal( info!( %request_id, %tenant_id, - client_id = %id, + client_id = %client_id, notification_id = %notification.id, "stored notification", ); @@ -312,7 +312,7 @@ pub async fn handler_internal( warn!( %request_id, %tenant_id, - client_id = %id, + client_id = %client_id, notification_id = %notification.id, last_recieved_at = %notification.last_received_at, "notification has already been processed" @@ -341,7 +341,7 @@ pub async fn handler_internal( debug!( %request_id, %tenant_id, - client_id = %id, + client_id = %client_id, notification_id = %notification.id, "fetched tenant" ); @@ -358,7 +358,7 @@ pub async fn handler_internal( debug!( %request_id, %tenant_id, - client_id = %id, + client_id = %client_id, notification_id = %notification.id, push_type = client.push_type.as_str(), "fetched provider" @@ -372,14 +372,14 @@ pub async fn handler_internal( Error::BadDeviceToken(_) => { state .client_store - .delete_client(&tenant_id, &id) + .delete_client(&tenant_id, &client_id) .await .map_err(|e| (Error::Store(e), analytics.clone()))?; increment_counter!(state.metrics, client_suspensions); warn!( %request_id, %tenant_id, - client_id = %id, + client_id = %client_id, notification_id = %notification.id, push_type = client.push_type.as_str(), "client has been deleted due to a bad device token" @@ -396,7 +396,7 @@ pub async fn handler_internal( warn!( %request_id, %tenant_id, - client_id = %id, + client_id = %client_id, notification_id = %notification.id, push_type = client.push_type.as_str(), "tenant has been suspended due to invalid provider credentials" @@ -413,7 +413,7 @@ pub async fn handler_internal( warn!( %request_id, %tenant_id, - client_id = %id, + client_id = %client_id, notification_id = %notification.id, push_type = client.push_type.as_str(), "tenant has been suspended due to invalid provider credentials" @@ -429,7 +429,7 @@ pub async fn handler_internal( info!( %request_id, %tenant_id, - client_id = %id, + client_id = %client_id, notification_id = %notification.id, push_type = client.push_type.as_str(), "sent notification" diff --git a/src/providers/noop.rs b/src/providers/noop.rs index 0ee0cc05..3e424b7a 100644 --- a/src/providers/noop.rs +++ b/src/providers/noop.rs @@ -1,6 +1,7 @@ use { crate::{handlers::push_message::MessagePayload, providers::PushProvider}, async_trait::async_trait, + reqwest::Url, std::collections::HashMap, tracing::span, }; @@ -32,6 +33,10 @@ impl PushProvider for NoopProvider { let notifications = self.notifications.get_mut(&token).unwrap(); notifications.append(&mut vec![payload]); + if let Ok(url) = token.parse::() { + assert!(reqwest::get(url).await?.status().is_success()); + } + Ok(()) } } diff --git a/src/stores/notification.rs b/src/stores/notification.rs index 9dbdea8e..e876f4ec 100644 --- a/src/stores/notification.rs +++ b/src/stores/notification.rs @@ -29,7 +29,12 @@ pub trait NotificationStore { client_id: &str, payload: &MessagePayload, ) -> stores::Result; - async fn get_notification(&self, id: &str, tenant_id: &str) -> stores::Result; + async fn get_notification( + &self, + id: &str, + client_id: &str, + tenant_id: &str, + ) -> stores::Result; async fn delete_notification(&self, id: &str, tenant_id: &str) -> stores::Result<()>; } @@ -62,11 +67,20 @@ RETURNING *;", } } - async fn get_notification(&self, id: &str, tenant_id: &str) -> stores::Result { + async fn get_notification( + &self, + id: &str, + client_id: &str, + tenant_id: &str, + ) -> stores::Result { let res = sqlx::query_as::( - "SELECT * FROM public.notifications WHERE id = $1 and tenant_id = $2", + " + SELECT * + FROM public.notifications + WHERE id = $1 AND client_id = $2 AND tenant_id = $3", ) .bind(id) + .bind(client_id) .bind(tenant_id) .fetch_one(self) .await; diff --git a/tests/functional/singletenant/push.rs b/tests/functional/singletenant/push.rs index eb490273..fc2714b5 100644 --- a/tests/functional/singletenant/push.rs +++ b/tests/functional/singletenant/push.rs @@ -4,6 +4,7 @@ use { push_message::{MessagePayload, PushMessageBody}, register_client::RegisterBody, }, + hyper::StatusCode, relay_rpc::{ auth::{ ed25519_dalek::Keypair, @@ -13,11 +14,10 @@ use { }, test_context::test_context, uuid::Uuid, + wiremock::{http::Method, matchers::method, Mock, MockServer, ResponseTemplate}, }; -#[test_context(EchoServerContext)] -#[tokio::test] -async fn test_push(ctx: &mut EchoServerContext) { +async fn create_client(ctx: &mut EchoServerContext) -> (ClientId, MockServer) { let mut rng = StdRng::from_entropy(); let keypair = Keypair::generate(&mut rng); @@ -33,10 +33,21 @@ async fn test_push(ctx: &mut EchoServerContext) { .unwrap() .to_string(); + let mock_server = { + let mock_server = MockServer::start().await; + Mock::given(method(Method::Get)) + .respond_with(ResponseTemplate::new(StatusCode::OK)) + .expect(1) + .mount(&mock_server) + .await; + mock_server + }; + let token = mock_server.uri(); + let payload = RegisterBody { client_id: client_id.clone(), push_type: "noop".to_string(), - token: "test".to_string(), + token: token.clone(), }; // Register client @@ -54,6 +65,14 @@ async fn test_push(ctx: &mut EchoServerContext) { "Response was not successful" ); + (client_id, mock_server) +} + +#[test_context(EchoServerContext)] +#[tokio::test] +async fn test_push(ctx: &mut EchoServerContext) { + let (client_id, _mock_server) = create_client(ctx).await; + // Push let push_message_id = Uuid::new_v4().to_string(); let topic = Uuid::new_v4().to_string(); @@ -106,3 +125,58 @@ async fn test_push(ctx: &mut EchoServerContext) { "Response was not successful" ); } + +#[test_context(EchoServerContext)] +#[tokio::test] +async fn test_push_multiple_clients(ctx: &mut EchoServerContext) { + let (client_id1, _mock_server1) = create_client(ctx).await; + let (client_id2, _mock_server2) = create_client(ctx).await; + + // Push + let push_message_id = Uuid::new_v4().to_string(); + let topic = Uuid::new_v4().to_string(); + let blob = Uuid::new_v4().to_string(); + let push_message_payload = MessagePayload { + topic: topic.into(), + blob: blob.to_string(), + flags: 0, + }; + let payload = PushMessageBody { + id: push_message_id.clone(), + payload: push_message_payload, + }; + + // Push client 1 + let client = reqwest::Client::new(); + let response = client + .post(format!( + "http://{}/clients/{}", + ctx.server.public_addr, + client_id1.clone() + )) + .json(&payload) + .send() + .await + .expect("Call failed"); + assert!( + response.status().is_success(), + "Response was not successful" + ); + + // Push client 2 + let client = reqwest::Client::new(); + let response = client + .post(format!( + "http://{}/clients/{}", + ctx.server.public_addr, + client_id2.clone() + )) + .json(&payload) + .send() + .await + .expect("Call failed"); + assert!( + response.status().is_success(), + "Response was not successful" + ); +} diff --git a/tests/functional/stores/client.rs b/tests/functional/stores/client.rs index f17c3011..a420916a 100644 --- a/tests/functional/stores/client.rs +++ b/tests/functional/stores/client.rs @@ -96,7 +96,7 @@ async fn client_upsert_token(ctx: &mut StoreContext) { .unwrap(); let get_notification_result = ctx .notifications - .get_notification(¬ification_id, TENANT_ID) + .get_notification(¬ification_id, &client_id, TENANT_ID) .await .unwrap(); assert_eq!(get_notification_result.client_id, client_id); @@ -159,7 +159,7 @@ async fn client_upsert_id(ctx: &mut StoreContext) { .unwrap(); let get_notification_result = ctx .notifications - .get_notification(¬ification_id, TENANT_ID) + .get_notification(¬ification_id, &client_id, TENANT_ID) .await .unwrap(); assert_eq!(get_notification_result.client_id, client_id); @@ -226,7 +226,7 @@ async fn client_create_same_id_and_token(ctx: &mut StoreContext) { .unwrap(); let get_notification_result = ctx .notifications - .get_notification(¬ification_id, TENANT_ID) + .get_notification(¬ification_id, &client_id, TENANT_ID) .await .unwrap(); assert_eq!(get_notification_result.client_id, client_id); diff --git a/tests/functional/stores/notification.rs b/tests/functional/stores/notification.rs index d30a9492..45731954 100644 --- a/tests/functional/stores/notification.rs +++ b/tests/functional/stores/notification.rs @@ -12,7 +12,7 @@ use { test_context::test_context, }; -pub async fn get_client(client_store: &ClientStoreArc) -> String { +pub async fn create_client(client_store: &ClientStoreArc) -> String { let id = format!("id-{}", gen_id()); let token = format!("token-{}", gen_id()); @@ -30,8 +30,8 @@ pub async fn get_client(client_store: &ClientStoreArc) -> String { #[test_context(StoreContext)] #[tokio::test] -async fn notification_creation(ctx: &mut StoreContext) { - let client_id = get_client(&ctx.clients).await; +async fn notification(ctx: &mut StoreContext) { + let client_id = create_client(&ctx.clients).await; let res = ctx .notifications @@ -42,5 +42,30 @@ async fn notification_creation(ctx: &mut StoreContext) { }) .await; - assert!(res.is_ok()) + assert!(res.is_ok()); +} + +#[test_context(StoreContext)] +#[tokio::test] +async fn notification_multiple_clients_same_payload(ctx: &mut StoreContext) { + let message_id = gen_id(); + let payload = MessagePayload { + topic: String::new(), + flags: 0, + blob: "example-payload".to_string(), + }; + + let client_id = create_client(&ctx.clients).await; + let res = ctx + .notifications + .create_or_update_notification(&message_id, TENANT_ID, &client_id, &payload) + .await; + assert!(res.is_ok()); + + let client_id = create_client(&ctx.clients).await; + let res = ctx + .notifications + .create_or_update_notification(&message_id, TENANT_ID, &client_id, &payload) + .await; + assert!(res.is_ok()); }