From caeea65f181c2d2108517a8e9268be174fa3fd91 Mon Sep 17 00:00:00 2001 From: Rob Date: Tue, 4 Jun 2024 17:04:50 -0400 Subject: [PATCH 01/31] delegate sending --- cdn-broker/benches/broadcast.rs | 2 +- cdn-broker/benches/direct.rs | 2 +- cdn-broker/src/connections/mod.rs | 67 +----- cdn-broker/src/lib.rs | 2 +- cdn-broker/src/tasks/broker/handler.rs | 36 ++-- cdn-broker/src/tasks/broker/sender.rs | 58 ++--- cdn-broker/src/tasks/broker/sync.rs | 27 ++- cdn-broker/src/tasks/user/handler.rs | 16 +- cdn-broker/src/tasks/user/sender.rs | 53 +---- cdn-broker/src/tests/broadcast.rs | 2 +- cdn-broker/src/tests/direct.rs | 2 +- cdn-broker/src/tests/mod.rs | 26 +-- cdn-client/src/lib.rs | 10 + cdn-client/src/retry.rs | 29 ++- cdn-marshal/src/handlers.rs | 10 +- cdn-proto/benches/protocols.rs | 6 +- cdn-proto/src/connection/auth/broker.rs | 10 +- cdn-proto/src/connection/auth/marshal.rs | 6 +- cdn-proto/src/connection/auth/user.rs | 8 +- cdn-proto/src/connection/protocols/memory.rs | 192 ++++------------- cdn-proto/src/connection/protocols/mod.rs | 211 +++++++++++++++---- cdn-proto/src/connection/protocols/quic.rs | 104 +-------- cdn-proto/src/connection/protocols/tcp.rs | 113 ++-------- cdn-proto/src/def.rs | 1 - tests/src/tests/double_connect.rs | 14 +- tests/src/tests/subscribe.rs | 3 +- 26 files changed, 393 insertions(+), 617 deletions(-) diff --git a/cdn-broker/benches/broadcast.rs b/cdn-broker/benches/broadcast.rs index 4688303..51f1336 100644 --- a/cdn-broker/benches/broadcast.rs +++ b/cdn-broker/benches/broadcast.rs @@ -5,7 +5,7 @@ use std::time::Duration; use cdn_broker::reexports::tests::{TestDefinition, TestRun}; use cdn_broker::{assert_received, send_message_as}; -use cdn_proto::connection::{protocols::Connection as _, Bytes}; +use cdn_proto::connection::Bytes; use cdn_proto::def::TestTopic; use cdn_proto::message::{Broadcast, Message}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; diff --git a/cdn-broker/benches/direct.rs b/cdn-broker/benches/direct.rs index d94edbe..a5b9e8e 100644 --- a/cdn-broker/benches/direct.rs +++ b/cdn-broker/benches/direct.rs @@ -5,7 +5,7 @@ use std::time::Duration; use cdn_broker::reexports::tests::{TestDefinition, TestRun}; use cdn_broker::{assert_received, send_message_as}; -use cdn_proto::connection::{protocols::Connection as _, Bytes}; +use cdn_proto::connection::Bytes; use cdn_proto::def::TestTopic; use cdn_proto::message::{Direct, Message}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; diff --git a/cdn-broker/src/connections/mod.rs b/cdn-broker/src/connections/mod.rs index b6049e3..f7e1ff2 100644 --- a/cdn-broker/src/connections/mod.rs +++ b/cdn-broker/src/connections/mod.rs @@ -4,8 +4,7 @@ use std::collections::{HashMap, HashSet}; use cdn_proto::{ - connection::UserPublicKey, - def::{Connection, RunDef}, + connection::{protocols::Connection, UserPublicKey}, discovery::BrokerIdentifier, message::Topic, mnemonic, @@ -23,14 +22,14 @@ mod direct; type TaskMap = HashMap; -pub struct Connections { +pub struct Connections { // Our identity. Used for versioned vector conflict resolution. identity: BrokerIdentifier, // The current users connected to us, along with their running tasks - users: HashMap, TaskMap)>, + users: HashMap, // The current brokers connected to us, along with their running tasks - brokers: HashMap, TaskMap)>, + brokers: HashMap, // The versioned vector for looking up where direct messages should go direct_map: DirectMap, @@ -38,7 +37,7 @@ pub struct Connections { broadcast_map: BroadcastMap, } -impl Connections { +impl Connections { /// Create a new `Connections`. Requires an identity for /// version vector conflict resolution. pub fn new(identity: BrokerIdentifier) -> Self { @@ -60,12 +59,12 @@ impl Connections { pub fn get_broker_connection( &self, broker_identifier: &BrokerIdentifier, - ) -> Option> { + ) -> Option { self.brokers.get(broker_identifier).map(|(c, _)| c.clone()) } /// Get the connection for a given user public key (cloned) - pub fn get_user_connection(&self, user: &UserPublicKey) -> Option> { + pub fn get_user_connection(&self, user: &UserPublicKey) -> Option { self.users.get(user).map(|(c, _)| c.clone()) } @@ -74,54 +73,6 @@ impl Connections { self.brokers.keys().cloned().collect() } - /// Add a task to the list of tasks for a broker along with a unique ID - /// This is used to cancel the task if the broker disconnects. - pub fn add_broker_task( - &mut self, - broker_identifier: &BrokerIdentifier, - id: u128, - handle: AbortHandle, - ) { - if let Some((_, handles)) = self.brokers.get_mut(broker_identifier) { - // If the broker exists, add the handle to the map of tasks - handles.insert(id, handle); - } else { - // Otherwise, cancel the task - handle.abort(); - } - } - - /// Add a task to the list of tasks for a user along with a unique ID - /// This is used to cancel the task if the user disconnects. - /// TODO: macro this? - pub fn add_user_task(&mut self, user: &UserPublicKey, id: u128, handle: AbortHandle) { - if let Some((_, handles)) = self.users.get_mut(user) { - // If the user exists, add the handle to the map of tasks - handles.insert(id, handle); - } else { - // Otherwise, cancel the task - handle.abort(); - } - } - - /// Remove a task from the list of tasks for a broker. - /// Does not abort the task. - pub fn remove_broker_task(&mut self, broker_identifier: &BrokerIdentifier, id: u128) { - if let Some((_, handles)) = self.brokers.get_mut(broker_identifier) { - // If the broker exists, remove the handle from the map of tasks - handles.remove(&id); - } - } - - /// Remove a task from the list of tasks for a user. - /// Does not abort the task. - pub fn remove_user_task(&mut self, user: &UserPublicKey, id: u128) { - if let Some((_, handles)) = self.users.get_mut(user) { - // If the broker exists, remove the handle from the map of tasks - handles.remove(&id); - } - } - /// Get all users and brokers interested in a list of topics. pub fn get_interested_by_topic( &self, @@ -222,7 +173,7 @@ impl Connections { pub fn add_broker( &mut self, broker_identifier: BrokerIdentifier, - connection: Connection, + connection: Connection, handle: AbortHandle, ) { // Increment the metric for the number of brokers connected @@ -243,7 +194,7 @@ impl Connections { pub fn add_user( &mut self, user_public_key: &UserPublicKey, - connection: Connection, + connection: Connection, topics: &[Topic], handle: AbortHandle, ) { diff --git a/cdn-broker/src/lib.rs b/cdn-broker/src/lib.rs index 4f7edcc..1da443d 100644 --- a/cdn-broker/src/lib.rs +++ b/cdn-broker/src/lib.rs @@ -75,7 +75,7 @@ struct Inner { /// The connections that currently exist. We use this everywhere we need to update connection /// state or send messages. - connections: Arc>>, + connections: Arc>, } /// The main `Broker` struct. We instantiate this when we want to run a broker. diff --git a/cdn-broker/src/tasks/broker/handler.rs b/cdn-broker/src/tasks/broker/handler.rs index 8f5e41f..b5b3905 100644 --- a/cdn-broker/src/tasks/broker/handler.rs +++ b/cdn-broker/src/tasks/broker/handler.rs @@ -4,8 +4,8 @@ use std::{sync::Arc, time::Duration}; use cdn_proto::{ authenticate_with_broker, bail, - connection::{auth::broker::BrokerAuth, protocols::Connection as _, Bytes, UserPublicKey}, - def::{Connection, RunDef}, + connection::{auth::broker::BrokerAuth, protocols::Connection, Bytes, UserPublicKey}, + def::RunDef, discovery::BrokerIdentifier, error::{Error, Result}, message::{Message, Topic}, @@ -20,7 +20,7 @@ impl Inner { /// This function is the callback for handling a broker (private) connection. pub async fn handle_broker_connection( self: Arc, - mut connection: Connection, + mut connection: Connection, is_outbound: bool, ) { // Depending on which way the direction came in, we will want to authenticate with a different @@ -80,13 +80,13 @@ impl Inner { .abort_handle(); // Send a full topic sync - if let Err(err) = self.full_topic_sync(&broker_identifier) { + if let Err(err) = self.full_topic_sync(&broker_identifier).await { error!("failed to perform full topic sync: {err}"); return; }; // Send a full user sync - if let Err(err) = self.full_user_sync(&broker_identifier) { + if let Err(err) = self.full_user_sync(&broker_identifier).await { error!("failed to perform full user sync: {err}"); return; }; @@ -94,10 +94,10 @@ impl Inner { // If we have `strong-consistency` enabled, send partials #[cfg(feature = "strong-consistency")] { - if let Err(err) = self.partial_topic_sync() { + if let Err(err) = self.partial_topic_sync().await { error!("failed to perform partial topic sync: {err}"); } - if let Err(err) = self.partial_user_sync() { + if let Err(err) = self.partial_user_sync().await { error!("failed to perform partial user sync: {err}"); } } @@ -112,7 +112,7 @@ impl Inner { pub async fn broker_receive_loop( self: &Arc, broker_identifier: &BrokerIdentifier, - connection: Connection, + connection: Connection, ) -> Result<()> { loop { // Receive a message from the broker @@ -126,14 +126,16 @@ impl Inner { Message::Direct(ref direct) => { let user_public_key = UserPublicKey::from(direct.recipient.clone()); - self.handle_direct_message(&user_public_key, raw_message, true); + self.handle_direct_message(&user_public_key, raw_message, true) + .await; } // If we receive a broadcast message from a broker, we want to send it to all interested users Message::Broadcast(ref broadcast) => { let topics = broadcast.topics.clone(); - self.handle_broadcast_message(&topics, &raw_message, true); + self.handle_broadcast_message(&topics, &raw_message, true) + .await; } // If we receive a subscribe message from a broker, we add them as "interested" locally. @@ -169,7 +171,7 @@ impl Inner { } /// This function handles direct messages from users and brokers. - pub fn handle_direct_message( + pub async fn handle_direct_message( self: &Arc, user_public_key: &UserPublicKey, message: Bytes, @@ -193,7 +195,7 @@ impl Inner { ); // Send the message to the user - self.try_send_to_user(user_public_key, message); + self.try_send_to_user(user_public_key, message).await; } else { // Otherwise, send the message to the broker (but only if we are not told to send to the user only) if !to_user_only { @@ -205,14 +207,14 @@ impl Inner { ); // Send the message to the broker - self.try_send_to_broker(&broker_identifier, message); + self.try_send_to_broker(&broker_identifier, message).await; } } } } /// This function handles broadcast messages from users and brokers. - pub fn handle_broadcast_message( + pub async fn handle_broadcast_message( self: &Arc, topics: &[Topic], message: &Bytes, @@ -235,12 +237,14 @@ impl Inner { // Send the message to all interested brokers for broker_identifier in interested_brokers { - self.try_send_to_broker(&broker_identifier, message.clone()); + self.try_send_to_broker(&broker_identifier, message.clone()) + .await; } // Send the message to all interested users for user_public_key in interested_users { - self.try_send_to_user(&user_public_key, message.clone()); + self.try_send_to_user(&user_public_key, message.clone()) + .await; } } } diff --git a/cdn-broker/src/tasks/broker/sender.rs b/cdn-broker/src/tasks/broker/sender.rs index 1ddb23a..9f11faa 100644 --- a/cdn-broker/src/tasks/broker/sender.rs +++ b/cdn-broker/src/tasks/broker/sender.rs @@ -1,9 +1,6 @@ use std::sync::Arc; -use cdn_proto::connection::protocols::Connection; use cdn_proto::{connection::Bytes, def::RunDef, discovery::BrokerIdentifier}; -use tokio::spawn; -use tokio::sync::Notify; use tracing::error; use crate::Inner; @@ -11,7 +8,7 @@ use crate::Inner; impl Inner { /// Attempts to asynchronously send a message to a broker. /// If it fails, the broker is removed from the list of connections. - pub fn try_send_to_broker( + pub async fn try_send_to_broker( self: &Arc, broker_identifier: &BrokerIdentifier, message: Bytes, @@ -28,56 +25,29 @@ impl Inner { let self_ = self.clone(); let broker_identifier_ = broker_identifier.clone(); - // Create a random handle identifier - let handle_identifier = rand::random::(); - - // To notify the sender when the task has been added - let notify = Arc::new(Notify::const_new()); - let notified = notify.clone(); - // Send the message - let send_handle = spawn(async move { - if let Err(e) = connection.send_message_raw(message).await { - error!("failed to send message to broker: {:?}", e); - - // Remove the broker if we failed to send the message - self_ - .connections - .write() - .remove_broker(&broker_identifier_, "failed to send message"); - } else { - // Wait for the sender to add the task to the list - notified.notified().await; - - // If we successfully sent the message, remove the task from the list - self_ - .connections - .write() - .remove_broker_task(&broker_identifier_, handle_identifier); - }; - }) - .abort_handle(); - - // Add the send handle to the list of tasks for the broker - self.connections.write().add_broker_task( - broker_identifier, - handle_identifier, - send_handle, - ); - - // Notify the sender that the task has been added - notify.notify_one(); + if let Err(e) = connection.send_message_raw(message).await { + error!("failed to send message to broker: {:?}", e); + + // Remove the broker if we failed to send the message + self_ + .connections + .write() + .remove_broker(&broker_identifier_, "failed to send message"); + } } } /// Attempts to asynchronously send a message to all brokers. /// If it fails, the failing broker is removed from the list of connections. - pub fn try_send_to_brokers(self: &Arc, message: &Bytes) { + pub async fn try_send_to_brokers(self: &Arc, message: &Bytes) { // Get the optional connection let brokers = self.connections.read().get_broker_identifiers(); for broker in brokers { - self.clone().try_send_to_broker(&broker, message.clone()); + self.clone() + .try_send_to_broker(&broker, message.clone()) + .await; } } } diff --git a/cdn-broker/src/tasks/broker/sync.rs b/cdn-broker/src/tasks/broker/sync.rs index 3764592..ee6c628 100644 --- a/cdn-broker/src/tasks/broker/sync.rs +++ b/cdn-broker/src/tasks/broker/sync.rs @@ -40,12 +40,13 @@ impl Inner { /// # Errors /// - If we fail to serialize the message /// - If we fail to send the message - pub fn full_user_sync(self: &Arc, broker: &BrokerIdentifier) -> Result<()> { + pub async fn full_user_sync(self: &Arc, broker: &BrokerIdentifier) -> Result<()> { // Get full user sync map let full_sync_map = self.connections.read().get_full_user_sync(); // Serialize and send the message to the broker - self.try_send_to_broker(broker, prepare_sync_message!(full_sync_map)); + self.try_send_to_broker(broker, prepare_sync_message!(full_sync_map)) + .await; Ok(()) } @@ -55,7 +56,7 @@ impl Inner { /// /// # Errors /// - If we fail to serialize the message - pub fn partial_user_sync(self: &Arc) -> Result<()> { + pub async fn partial_user_sync(self: &Arc) -> Result<()> { // Get partial user sync map let partial_sync_map = self.connections.write().get_partial_user_sync(); @@ -68,7 +69,7 @@ impl Inner { let raw_message = prepare_sync_message!(partial_sync_map); // Send to all brokers - self.try_send_to_brokers(&raw_message); + self.try_send_to_brokers(&raw_message).await; Ok(()) } @@ -78,7 +79,10 @@ impl Inner { /// /// # Errors /// - if we fail to serialize the message - pub fn full_topic_sync(self: &Arc, broker_identifier: &BrokerIdentifier) -> Result<()> { + pub async fn full_topic_sync( + self: &Arc, + broker_identifier: &BrokerIdentifier, + ) -> Result<()> { // Get full list of topics let topics = self.connections.read().get_full_topic_sync(); @@ -90,7 +94,8 @@ impl Inner { Serialize, "failed to serialize topics" )), - ); + ) + .await; Ok(()) } @@ -100,7 +105,7 @@ impl Inner { /// /// # Errors /// - If we fail to serialize the message - pub fn partial_topic_sync(self: &Arc) -> Result<()> { + pub async fn partial_topic_sync(self: &Arc) -> Result<()> { // Get partial list of topics let (additions, removals) = self.connections.write().get_partial_topic_sync(); @@ -114,7 +119,7 @@ impl Inner { )); // Send to all brokers - self.try_send_to_brokers(&raw_subscribe_message); + self.try_send_to_brokers(&raw_subscribe_message).await; } // If we have some removals, @@ -127,7 +132,7 @@ impl Inner { )); // Send to all brokers - self.try_send_to_brokers(&raw_unsubscribe_message); + self.try_send_to_brokers(&raw_unsubscribe_message).await; } Ok(()) @@ -138,12 +143,12 @@ impl Inner { pub async fn run_sync_task(self: Arc) { loop { // Perform user sync - if let Err(err) = self.partial_user_sync() { + if let Err(err) = self.partial_user_sync().await { error!("failed to perform partial user sync: {err}"); }; // Perform topic sync - if let Err(err) = self.partial_topic_sync() { + if let Err(err) = self.partial_topic_sync().await { error!("failed to perform partial topic sync: {err}"); }; diff --git a/cdn-broker/src/tasks/user/handler.rs b/cdn-broker/src/tasks/user/handler.rs index 2315bb6..13775e4 100644 --- a/cdn-broker/src/tasks/user/handler.rs +++ b/cdn-broker/src/tasks/user/handler.rs @@ -4,8 +4,8 @@ use std::sync::Arc; use std::time::Duration; -use cdn_proto::connection::{protocols::Connection as _, UserPublicKey}; -use cdn_proto::def::{Connection, RunDef, Topic as _}; +use cdn_proto::connection::{protocols::Connection, UserPublicKey}; +use cdn_proto::def::{RunDef, Topic as _}; use cdn_proto::error::{Error, Result}; use cdn_proto::{connection::auth::broker::BrokerAuth, message::Message, mnemonic}; use tokio::spawn; @@ -16,7 +16,7 @@ use crate::Inner; impl Inner { /// This function handles a user (public) connection. - pub async fn handle_user_connection(self: Arc, connection: Connection) { + pub async fn handle_user_connection(self: Arc, connection: Connection) { // Verify (authenticate) the connection. Needs to happen within 5 seconds // TODO: make this stateless (e.g. separate subscribe message on connect) let Ok(Ok((public_key, mut topics))) = timeout( @@ -72,12 +72,12 @@ impl Inner { #[cfg(feature = "strong-consistency")] { // Send partial topic data - if let Err(err) = self.partial_topic_sync() { + if let Err(err) = self.partial_topic_sync().await { error!("failed to perform partial topic sync: {err}"); } // Send partial user data - if let Err(err) = self.partial_user_sync() { + if let Err(err) = self.partial_user_sync().await { error!("failed to perform partial user sync: {err}"); } } @@ -88,7 +88,7 @@ impl Inner { pub async fn user_receive_loop( self: &Arc, public_key: &UserPublicKey, - connection: Connection, + connection: Connection, ) -> Result<()> { loop { // Receive a message from the user @@ -102,7 +102,7 @@ impl Inner { Message::Direct(ref direct) => { let user_public_key = UserPublicKey::from(direct.recipient.clone()); - self.handle_direct_message(&user_public_key, raw_message, false); + self.handle_direct_message(&user_public_key, raw_message, false).await; } // If we get a broadcast message from a user, send it to both brokers and users. @@ -111,7 +111,7 @@ impl Inner { let mut topics = broadcast.topics.clone(); Def::Topic::prune(&mut topics)?; - self.handle_broadcast_message(&topics, &raw_message, false); + self.handle_broadcast_message(&topics, &raw_message, false).await; } // Subscribe messages from users will just update the state locally diff --git a/cdn-broker/src/tasks/user/sender.rs b/cdn-broker/src/tasks/user/sender.rs index 0353217..7fd2058 100644 --- a/cdn-broker/src/tasks/user/sender.rs +++ b/cdn-broker/src/tasks/user/sender.rs @@ -1,62 +1,27 @@ use std::sync::Arc; -use cdn_proto::connection::protocols::Connection; use cdn_proto::connection::UserPublicKey; use cdn_proto::{connection::Bytes, def::RunDef}; -use tokio::spawn; -use tokio::sync::Notify; use tracing::error; use crate::Inner; impl Inner { - pub fn try_send_to_user(self: &Arc, user: &UserPublicKey, message: Bytes) { + pub async fn try_send_to_user(self: &Arc, user: &UserPublicKey, message: Bytes) { // Get the optional connection let connection = self.connections.read().get_user_connection(user); // If the connection exists, if let Some(connection) = connection { - // Clone what we need - let self_ = self.clone(); - let user_ = user.clone(); - - // Create a random handle identifier - let handle_identifier = rand::random::(); - - // To notify the sender when the task has been added - let notify = Arc::new(Notify::const_new()); - let notified = notify.clone(); - // Send the message - let send_handle = spawn(async move { - if let Err(e) = connection.send_message_raw(message).await { - error!("failed to send message to user: {:?}", e); - - // Remove the broker if we failed to send the message - self_ - .connections - .write() - .remove_user(user_, "failed to send message"); - } else { - // Wait for the sender to add the task to the list - notified.notified().await; - - // If we successfully sent the message, remove the task from the list - self_ - .connections - .write() - .remove_user_task(&user_, handle_identifier); - }; - }) - .abort_handle(); - - // Add the send handle to the list of tasks for the broker - self.connections - .write() - .add_user_task(user, handle_identifier, send_handle); - - // Notify the sender that the task has been added - notify.notify_one(); + if let Err(e) = connection.send_message_raw(message).await { + error!("failed to send message to user: {:?}", e); + + // Remove the broker if we failed to send the message + self.connections + .write() + .remove_user(user.clone(), "failed to send message"); + } } } } diff --git a/cdn-broker/src/tests/broadcast.rs b/cdn-broker/src/tests/broadcast.rs index c2d867b..92eca01 100644 --- a/cdn-broker/src/tests/broadcast.rs +++ b/cdn-broker/src/tests/broadcast.rs @@ -4,7 +4,7 @@ use std::time::Duration; use cdn_proto::{ - connection::{protocols::Connection, Bytes}, + connection::Bytes, def::TestTopic, message::{Broadcast, Message}, }; diff --git a/cdn-broker/src/tests/direct.rs b/cdn-broker/src/tests/direct.rs index de2f166..15b503f 100644 --- a/cdn-broker/src/tests/direct.rs +++ b/cdn-broker/src/tests/direct.rs @@ -4,7 +4,7 @@ use std::time::Duration; use cdn_proto::{ - connection::{protocols::Connection, Bytes}, + connection::Bytes, def::TestTopic, message::{Direct, Message}, }; diff --git a/cdn-broker/src/tests/mod.rs b/cdn-broker/src/tests/mod.rs index c58c28d..b0e059f 100644 --- a/cdn-broker/src/tests/mod.rs +++ b/cdn-broker/src/tests/mod.rs @@ -5,10 +5,7 @@ use std::sync::Arc; use cdn_proto::{ - connection::protocols::{ - memory::{Memory, MemoryConnection}, - Connection, - }, + connection::protocols::{memory::Memory, Connection}, crypto::{rng::DeterministicRng, signature::KeyPair}, def::TestingRunDef, discovery::BrokerIdentifier, @@ -29,9 +26,9 @@ use crate::{connections::DirectMap, Broker, Config}; /// An actor is a [user/broker] that we inject to test message send functionality. pub struct InjectedActor { /// The in-memory sender that sends to the broker under test - pub sender: MemoryConnection, + pub sender: Connection, /// The in-memory receiver that receives from the broker under test - pub receiver: MemoryConnection, + pub receiver: Connection, } /// This lets us send a message as a particular network actor. It just helps @@ -59,7 +56,9 @@ macro_rules! assert_received { // Make sure we haven't received this message (no, $actor: expr) => { assert!( - $actor.receiver.receiver.0.is_empty(), + timeout(Duration::from_millis(100), $actor.receiver.recv_message()) + .await + .is_err(), "wasn't supposed to receive a message but did" ) }; @@ -67,20 +66,23 @@ macro_rules! assert_received { // Make sure we have received the message in a timeframe of 50ms (yes, $actor: expr, $message:expr) => { // Receive the message with a timeout - let Ok(message) = - timeout(Duration::from_millis(50), $actor.receiver.receiver.0.recv()).await + let Ok(message) = timeout( + Duration::from_millis(50), + $actor.receiver.recv_message_raw(), + ) + .await else { panic!("timed out trying to receive message"); }; // Assert the message is the correct one assert!( - message - == Ok(Bytes::from_unchecked( + message.unwrap() + == Bytes::from_unchecked( $message .serialize() .expect("failed to re-serialize message") - )), + ), "was supposed to receive a message but did not" ) }; diff --git a/cdn-client/src/lib.rs b/cdn-client/src/lib.rs index d901ddd..144af47 100644 --- a/cdn-client/src/lib.rs +++ b/cdn-client/src/lib.rs @@ -160,4 +160,14 @@ impl Client { pub async fn send_message(&self, message: Message) -> Result<()> { self.0.send_message(message).await } + + /// Flushes the connection, ensuring that all messages are sent. + /// This is useful for ensuring that messages are sent before a + /// connection is closed. + /// + /// # Errors + /// - if the connection is already closed + pub async fn flush(&self) -> Result<()> { + self.0.flush().await + } } diff --git a/cdn-client/src/retry.rs b/cdn-client/src/retry.rs index 59843e7..d4ecb15 100644 --- a/cdn-client/src/retry.rs +++ b/cdn-client/src/retry.rs @@ -10,10 +10,10 @@ use std::{collections::HashSet, sync::Arc, time::Duration}; use cdn_proto::{ connection::{ auth::user::UserAuth, - protocols::{Connection as _, Protocol as _}, + protocols::{Connection, Protocol as _}, }, crypto::signature::KeyPair, - def::{Connection, ConnectionDef, Protocol, Scheme}, + def::{ConnectionDef, Protocol, Scheme}, error::{Error, Result}, message::{Message, Topic}, }; @@ -45,7 +45,7 @@ pub struct Inner { use_local_authority: bool, /// The underlying connection - connection: Arc>>>, + connection: Arc>>, /// The keypair to use when authenticating pub keypair: KeyPair>, @@ -63,7 +63,7 @@ impl Inner { /// # Errors /// - If the connection failed /// - If authentication failed - async fn connect(self: &Arc) -> Result> { + async fn connect(self: &Arc) -> Result { // Make the connection to the marshal let connection = bail!( Protocol::::connect(&self.endpoint, self.use_local_authority).await, @@ -263,4 +263,25 @@ impl Retry { // If we failed to receive a message, kick off reconnection logic try_with_reconnect!(self, out) } + + /// Flushes the connection, ensuring that all messages are sent. + /// + /// # Errors + /// - If we are in the middle of reconnecting + /// - If the connection is closed + pub async fn flush(&self) -> Result<()> { + // Check if we're (probably) reconnecting or not + if let Ok(connection_guard) = self.inner.connection.try_read() { + // We're not reconnecting, try to send the message + // Initialize the connection if it does not yet exist + connection_guard + .get_or_try_init(|| self.inner.connect()) + .await? + .flush() + .await + } else { + // We are reconnecting, return an error + Err(Error::Connection("reconnection in progress".to_string())) + } + } } diff --git a/cdn-marshal/src/handlers.rs b/cdn-marshal/src/handlers.rs index 4872e19..c6b4a3a 100644 --- a/cdn-marshal/src/handlers.rs +++ b/cdn-marshal/src/handlers.rs @@ -1,8 +1,8 @@ use std::time::Duration; use cdn_proto::{ - connection::{auth::marshal::MarshalAuth, protocols::Connection as _}, - def::{Connection, RunDef}, + connection::{auth::marshal::MarshalAuth, protocols::Connection}, + def::RunDef, mnemonic, }; use tokio::time::timeout; @@ -13,7 +13,7 @@ use crate::Marshal; impl Marshal { /// Handles a user's connection, including authentication. pub async fn handle_connection( - connection: Connection, + connection: Connection, mut discovery_client: R::DiscoveryClientType, ) { // Verify (authenticate) the connection @@ -26,7 +26,7 @@ impl Marshal { info!(id = mnemonic(&user_public_key), "user authenticated"); } - // Finish the connection, ensuring all data was sent - connection.finish().await; + // Flush the connection, ensuring all data was sent + let _ = connection.flush().await; } } diff --git a/cdn-proto/benches/protocols.rs b/cdn-proto/benches/protocols.rs index c14556b..a1e86d8 100644 --- a/cdn-proto/benches/protocols.rs +++ b/cdn-proto/benches/protocols.rs @@ -16,8 +16,8 @@ use tokio::{join, runtime::Runtime, spawn}; /// Transfer a message `raw_message` from `conn1` to `conn2.` This is the primary /// function used for testing network protocol speed. async fn transfer>( - conn1: Proto::Connection, - conn2: Proto::Connection, + conn1: Connection, + conn2: Connection, raw_message: Bytes, ) { // Send from the first connection @@ -44,7 +44,7 @@ async fn transfer>( /// to test. fn set_up_bench>( message_size: usize, -) -> (Runtime, Proto::Connection, Proto::Connection, Bytes) { +) -> (Runtime, Connection, Connection, Bytes) { // Create new tokio runtime let benchmark_runtime = tokio::runtime::Runtime::new().expect("failed to create Tokio runtime"); diff --git a/cdn-proto/src/connection/auth/broker.rs b/cdn-proto/src/connection/auth/broker.rs index 4fd1703..5776558 100644 --- a/cdn-proto/src/connection/auth/broker.rs +++ b/cdn-proto/src/connection/auth/broker.rs @@ -9,9 +9,9 @@ use tracing::error; use crate::{ bail, - connection::protocols::Connection as _, + connection::protocols::Connection, crypto::signature::SignatureScheme, - def::{Connection, PublicKey, RunDef, Scheme}, + def::{PublicKey, RunDef, Scheme}, discovery::{BrokerIdentifier, DiscoveryClient}, error::{Error, Result}, fail_verification_with_message, @@ -69,7 +69,7 @@ impl BrokerAuth { /// - If authentication fails /// - If our connection fails pub async fn verify_user( - connection: &Connection, + connection: &Connection, #[cfg(not(feature = "global-permits"))] broker_identifier: &BrokerIdentifier, discovery_client: &mut R::DiscoveryClientType, ) -> Result<(UserPublicKey, Vec)> { @@ -152,7 +152,7 @@ impl BrokerAuth { /// - If we fail to authenticate /// - If we have a connection failure pub async fn authenticate_with_broker( - connection: &Connection, + connection: &Connection, keypair: &KeyPair>, ) -> Result { // Get the current timestamp, which we sign to avoid replay attacks @@ -231,7 +231,7 @@ impl BrokerAuth { /// # Errors /// - If verification has failed pub async fn verify_broker( - connection: &Connection, + connection: &Connection, our_identifier: &BrokerIdentifier, our_public_key: &PublicKey, ) -> Result<()> { diff --git a/cdn-proto/src/connection/auth/marshal.rs b/cdn-proto/src/connection/auth/marshal.rs index e09454e..fb77bb0 100644 --- a/cdn-proto/src/connection/auth/marshal.rs +++ b/cdn-proto/src/connection/auth/marshal.rs @@ -9,8 +9,8 @@ use tracing::error; use crate::{ bail, - connection::protocols::Connection as _, - def::{Connection, PublicKey, RunDef, Scheme}, + connection::protocols::Connection, + def::{PublicKey, RunDef, Scheme}, discovery::DiscoveryClient, error::{Error, Result}, fail_verification_with_message, @@ -35,7 +35,7 @@ impl MarshalAuth { /// - If authentication fails /// - If our connection fails pub async fn verify_user( - connection: &Connection, + connection: &Connection, discovery_client: &mut R::DiscoveryClientType, ) -> Result { // Receive the signed message from the user diff --git a/cdn-proto/src/connection/auth/user.rs b/cdn-proto/src/connection/auth/user.rs index 216e29a..a823440 100644 --- a/cdn-proto/src/connection/auth/user.rs +++ b/cdn-proto/src/connection/auth/user.rs @@ -6,12 +6,12 @@ use std::{ time::{SystemTime, UNIX_EPOCH}, }; -use crate::connection::protocols::Connection as _; +use crate::connection::protocols::Connection; use crate::crypto::signature::Serializable; use crate::{ bail, crypto::signature::{KeyPair, SignatureScheme}, - def::{Connection, ConnectionDef, Scheme}, + def::{ConnectionDef, Scheme}, error::{Error, Result}, message::{AuthenticateWithKey, AuthenticateWithPermit, Message, Topic}, }; @@ -29,7 +29,7 @@ impl UserAuth { /// - If we fail authentication /// - If our connection fails pub async fn authenticate_with_marshal( - connection: &Connection, + connection: &Connection, keypair: &KeyPair>, ) -> Result<(String, u64)> { // Get the current timestamp, which we sign to avoid replay attacks @@ -103,7 +103,7 @@ impl UserAuth { /// - If authentication fails /// - If our connection fails pub async fn authenticate_with_broker( - connection: &Connection, + connection: &Connection, permit: u64, subscribed_topics: HashSet, ) -> Result<()> { diff --git a/cdn-proto/src/connection/protocols/memory.rs b/cdn-proto/src/connection/protocols/memory.rs index bbbae46..2e451a7 100644 --- a/cdn-proto/src/connection/protocols/memory.rs +++ b/cdn-proto/src/connection/protocols/memory.rs @@ -1,30 +1,27 @@ //! The memory protocol is a completely in-memory channel-based protocol. //! It can only be used intra-process. -use std::{ - collections::HashMap, - sync::{Arc, OnceLock}, -}; +use std::{collections::HashMap, sync::OnceLock}; use async_trait::async_trait; use kanal::{unbounded_async, AsyncReceiver, AsyncSender}; use rustls::{Certificate, PrivateKey}; -use tokio::{sync::RwLock, task::spawn_blocking}; +use tokio::{ + io::{duplex, DuplexStream}, + sync::RwLock, + task::spawn_blocking, +}; use super::{Connection, Listener, Protocol, UnfinalizedConnection}; +use crate::connection::middleware::NoMiddleware; #[cfg(feature = "metrics")] -use crate::connection::metrics::{BYTES_RECV, BYTES_SENT}; use crate::{ bail, - connection::{middleware::Middleware, Bytes}, + connection::middleware::Middleware, error::{Error, Result}, - message::Message, }; -type SenderChannel = AsyncSender; -type ReceiverChannel = AsyncReceiver; - -type ChannelExchange = (AsyncSender, AsyncReceiver); +type ChannelExchange = (AsyncSender, AsyncReceiver); /// A global list of listeners that are initialized later. This is to help /// connections find listeners. @@ -36,8 +33,6 @@ pub struct Memory; #[async_trait] impl Protocol for Memory { - type Connection = MemoryConnection; - type UnfinalizedConnection = UnfinalizedMemoryConnection; type Listener = MemoryListener; @@ -45,10 +40,7 @@ impl Protocol for Memory { /// /// # Errors /// - If the listener is not listening - async fn connect( - remote_endpoint: &str, - _use_local_authority: bool, - ) -> Result { + async fn connect(remote_endpoint: &str, _use_local_authority: bool) -> Result { // If the peer is not listening, return an error // Get or initialize the channels as a static value let listeners = LISTENERS.get_or_init(RwLock::default).read().await; @@ -60,8 +52,8 @@ impl Protocol for Memory { )); }; - // Create a channel for sending messages and receiving them - let (send_to_us, receive_from_them) = unbounded_async(); + // Create a duplex stream to send and receive bytes + let (send_to_us, receive_from_them) = duplex(8192); // Send our channel to them bail!( @@ -77,11 +69,11 @@ impl Protocol for Memory { "failed to connect to remote endpoint" ); - // Return the conmunication channels - Ok(MemoryConnection { - sender: Arc::from(MemorySenderRef(send_to_them)), - receiver: Arc::from(MemoryReceiverRef(receive_from_them)), - }) + // Convert the streams into a `Connection` + let connection = Connection::from_streams::<_, _, M>(send_to_them, receive_from_them); + + // Return our connection + Ok(connection) } /// Binds to a local endpoint. The bind endpoint should be numeric. @@ -111,114 +103,22 @@ impl Protocol for Memory { } } -#[derive(Clone)] -pub struct MemoryConnection { - pub sender: Arc, - pub receiver: Arc, -} - -#[derive(Clone)] -pub struct MemorySenderRef(AsyncSender); - -#[async_trait] -impl Connection for MemoryConnection { - /// Send an (unserialized) message over the stream. - /// - /// # Errors - /// If we fail to send or serialize the message - async fn send_message(&self, message: Message) -> Result<()> { - // Serialize the message - let raw_message = Bytes::from_unchecked(bail!( - message.serialize(), - Serialize, - "failed to serialize message" - )); - - // Add to our metrics, if desired - #[cfg(feature = "metrics")] - BYTES_SENT.add(raw_message.len() as f64); - - // Send the now-raw message - self.send_message_raw(raw_message).await - } - - /// Send a pre-serialized message over the connection. - /// - /// # Errors - /// - If we fail to deliver the message. This usually means a connection problem. - async fn send_message_raw(&self, raw_message: Bytes) -> Result<()> { - // Send the message over the channel - bail!( - self.sender.0.send(raw_message).await, - Connection, - "failed to send message over connection" - ); - - Ok(()) - } - - /// Receives a single message from our channel and deserializes - /// it. - /// - /// # Errors - /// - If the other side of the channel is closed - /// - If we fail deserialization - async fn recv_message(&self) -> Result { - // Receive raw message - let raw_message = self.recv_message_raw().await?; - - // Deserialize and return the message - Ok(bail!( - Message::deserialize(&raw_message), - Deserialize, - "failed to deserialize message" - )) - } - - /// Receives a single message from our channel without - /// deserializing≥ - /// - /// # Errors - /// - If the other side of the channel is closed - /// - If we fail deserialization - async fn recv_message_raw(&self) -> Result { - // Receive a message from the channel - let raw_message = bail!( - self.receiver.0.recv().await, - Connection, - "failed to receive message from connection" - ); - - // Add to our metrics, if desired - #[cfg(feature = "metrics")] - BYTES_RECV.add(raw_message.len() as f64); - - Ok(raw_message) - } - - /// Finish the connection, sending any remaining data. - /// Is a no-op for memory connections. - async fn finish(&self) {} -} - -#[derive(Clone)] -pub struct MemoryReceiverRef(pub AsyncReceiver); - /// A connection that has yet to be finalized. Allows us to keep accepting /// connections while we process this one. pub struct UnfinalizedMemoryConnection { - bytes_sender: SenderChannel, - bytes_receiver: ReceiverChannel, + send_stream: DuplexStream, + receive_stream: DuplexStream, } #[async_trait] -impl UnfinalizedConnection for UnfinalizedMemoryConnection { +impl UnfinalizedConnection for UnfinalizedMemoryConnection { /// Prepares the `MemoryConnection` for usage by `Arc()ing` things. - async fn finalize(self) -> Result { - Ok(MemoryConnection { - sender: Arc::from(MemorySenderRef(self.bytes_sender)), - receiver: Arc::from(MemoryReceiverRef(self.bytes_receiver)), - }) + async fn finalize(self) -> Result { + // Convert the streams into a `Connection` + let connection = Connection::from_streams::<_, _, M>(self.send_stream, self.receive_stream); + + // Return our connection + Ok(connection) } } @@ -226,8 +126,8 @@ impl UnfinalizedConnection for UnfinalizedMemoryConnection { /// so we can remove on drop. pub struct MemoryListener { bind_endpoint: String, - receive_new_connection: AsyncReceiver, - send_to_new_connection: AsyncSender, + receive_new_connection: AsyncReceiver, + send_to_new_connection: AsyncSender, } #[async_trait] @@ -247,7 +147,7 @@ impl Listener for MemoryListener { ); // Create our bytes sender - let (send_bytes_to_us, bytes_receiver) = unbounded_async(); + let (send_bytes_to_us, bytes_receiver) = duplex(8192); // Send the remote connection our channel bail!( @@ -256,10 +156,10 @@ impl Listener for MemoryListener { "failed to finalize connection" ); - // Return this as unfinalized + // Return our unfinalized connection Ok(UnfinalizedMemoryConnection { - bytes_sender, - bytes_receiver, + send_stream: bytes_sender, + receive_stream: bytes_receiver, }) } } @@ -278,32 +178,16 @@ impl Drop for MemoryListener { } } -/// If we drop the sender, we want to close the channel -impl Drop for MemorySenderRef { - fn drop(&mut self) { - self.0.close(); - } -} - -/// If we drop the receiver, we want to close the channel -impl Drop for MemoryReceiverRef { - fn drop(&mut self) { - self.0.close(); - } -} - impl Memory { /// Generate a testing pair of channels for sending and receiving in memory. /// This is particularly useful for tests. #[must_use] - pub fn gen_testing_connection() -> MemoryConnection { - // Create channels - let (sender, receiver) = unbounded_async(); - - MemoryConnection { - sender: Arc::from(MemorySenderRef(sender)), - receiver: Arc::from(MemoryReceiverRef(receiver)), - } + pub fn gen_testing_connection() -> Connection { + // Create our channels + let (sender, receiver) = duplex(8192); + + // Convert the streams into a `Connection` + Connection::from_streams::<_, _, NoMiddleware>(sender, receiver) } } diff --git a/cdn-proto/src/connection/protocols/mod.rs b/cdn-proto/src/connection/protocols/mod.rs index b39c37d..5394656 100644 --- a/cdn-proto/src/connection/protocols/mod.rs +++ b/cdn-proto/src/connection/protocols/mod.rs @@ -1,11 +1,14 @@ //! This module defines connections, listeners, and their implementations. -use std::time::Duration; +use std::{sync::Arc, time::Duration}; use async_trait::async_trait; +use kanal::{AsyncReceiver, AsyncSender}; use rustls::{Certificate, PrivateKey}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, + sync::oneshot, + task::AbortHandle, time::timeout, }; @@ -27,16 +30,14 @@ pub mod tcp; /// The `Protocol` trait lets us be generic over a connection type (Tcp, Quic, etc). #[async_trait] pub trait Protocol: Send + Sync + 'static { - type Connection: Connection + Send + Sync + Clone; - - type UnfinalizedConnection: UnfinalizedConnection + Send + Sync; + type UnfinalizedConnection: UnfinalizedConnection + Send + Sync; type Listener: Listener + Send + Sync; /// Connect to a remote endpoint, returning an instance of `Self`. /// /// # Errors /// Errors if we fail to connect or if we fail to bind to the interface we want. - async fn connect(remote_endpoint: &str, use_local_authority: bool) -> Result; + async fn connect(remote_endpoint: &str, use_local_authority: bool) -> Result; /// Bind to the local endpoint, returning an instance of `Listener`. /// @@ -50,58 +51,193 @@ pub trait Protocol: Send + Sync + 'static { } #[async_trait] -pub trait Connection { +pub trait Listener { + /// Accept an unfinalized connection from the local, bound socket. + /// Returns a connection or an error if we encountered one. + /// + /// # Errors + /// If we fail to accept a connection + async fn accept(&self) -> Result; +} + +#[async_trait] +pub trait UnfinalizedConnection { + /// Finalize an incoming connection. This is separated so we can prevent + /// actors who are slow from clogging up the incoming connection by offloading + /// it to a separate task. + async fn finalize(self) -> Result; +} + +/// A connection to a remote endpoint. +#[derive(Clone)] +pub struct Connection(Arc); + +/// A message to send over the channel, either a raw message or a flush message. +/// The flush message is used to ensure that all messages are sent before we close the connection. +enum BytesOrFlush { + Bytes(Bytes), + Flush(oneshot::Sender<()>), +} + +/// A reference to a delegated connection, containing the sender and +/// receiver channels. +#[derive(Clone)] +pub struct ConnectionRef { + sender: AsyncSender, + receiver: AsyncReceiver, + + tasks: Arc>, +} + +impl Drop for ConnectionRef { + fn drop(&mut self) { + // Cancel all tasks + for task in self.tasks.iter() { + task.abort(); + } + } +} + +impl Connection { + /// Converts a set of writer and reader streams into a connection. + /// Under the hood, this spawns sending and receiving tasks. + fn from_streams< + W: AsyncWriteExt + Unpin + Send + 'static, + R: AsyncReadExt + Unpin + Send + 'static, + M: Middleware, + >( + mut writer: W, + mut reader: R, + ) -> Self { + // Create the channels that will be used to send and receive messages + let (send_to_caller, receive_from_task) = kanal::unbounded_async(); + let (send_to_task, receive_from_caller) = kanal::unbounded_async(); + + // Spawn the task that receives from the caller and sends to the stream + let sender_task = tokio::spawn(async move { + // While we can successfully receive messages from the caller, + while let Ok(message) = receive_from_caller.recv().await { + match message { + BytesOrFlush::Bytes(message) => { + // Write the message to the stream + if let Err(_) = write_length_delimited(&mut writer, message).await { + receive_from_caller.close(); + return; + }; + } + BytesOrFlush::Flush(result_sender) => { + // Acknowledge that we've finished successfully + let _ = result_sender.send(()); + } + } + } + }) + .abort_handle(); + + // Spawn the task that receives from the stream and sends to the caller + let receiver_task = tokio::spawn(async move { + // While we can successfully read messages from the stream, + while let Ok(message) = read_length_delimited::(&mut reader).await { + if let Err(_) = send_to_caller.send(message).await { + send_to_caller.close(); + return; + }; + } + }) + .abort_handle(); + + // Return the connection + Self(Arc::new(ConnectionRef { + sender: send_to_task, + receiver: receive_from_task, + tasks: Arc::from(vec![sender_task, receiver_task]), + })) + } + /// Send an (unserialized) message over the stream. /// /// # Errors /// If we fail to send or serialize the message - async fn send_message(&self, message: Message) -> Result<()>; + pub async fn send_message(&self, message: Message) -> Result<()> { + // Serialize our message + let raw_message = Bytes::from_unchecked(bail!( + message.serialize(), + Serialize, + "failed to serialize message" + )); + + // Send the message in its raw form + self.send_message_raw(raw_message).await + } /// Send a pre-serialized message over the connection. /// /// # Errors /// - If we fail to deliver the message. This usually means a connection problem. - async fn send_message_raw(&self, raw_message: Bytes) -> Result<()>; + pub async fn send_message_raw(&self, raw_message: Bytes) -> Result<()> { + // Send the message + bail!( + self.0.sender.send(BytesOrFlush::Bytes(raw_message)).await, + Connection, + "failed to send message" + ); + + Ok(()) + } - /// Receives message over the stream and deserializes it. + /// Receives a message from the stream and deserializes it. /// /// # Errors /// - if we fail to receive the message /// - if we fail deserialization - async fn recv_message(&self) -> Result; + pub async fn recv_message(&self) -> Result { + // Receive the raw message + let raw_message = self.recv_message_raw().await?; + + // Deserialize and return the message + Ok(bail!( + Message::deserialize(&raw_message), + Deserialize, + "failed to deserialize message" + )) + } /// Receives a message over the stream without deserializing. /// /// # Errors /// - if we fail to receive the message - async fn recv_message_raw(&self) -> Result; + pub async fn recv_message_raw(&self) -> Result { + // Receive and return the message + Ok(bail!( + self.0.receiver.recv().await, + Connection, + "failed to send message" + )) + } - /// Gracefully finish the connection, sending any remaining data. - async fn finish(&self); -} + pub async fn flush(&self) -> Result<()> { + // Create notifier to wait for the flush message to be acknowledged + let (flush_sender, flush_receiver) = oneshot::channel(); -#[async_trait] -pub trait Listener { - /// Accept an unfinalized connection from the local, bound socket. - /// Returns a connection or an error if we encountered one. - /// - /// # Errors - /// If we fail to accept a connection - async fn accept(&self) -> Result; -} - -#[async_trait] -pub trait UnfinalizedConnection { - /// Finalize an incoming connection. This is separated so we can prevent - /// actors who are slow from clogging up the incoming connection by offloading - /// it to a separate task. - async fn finalize(self) -> Result; + // Send the flush message + bail!( + self.0.sender.send(BytesOrFlush::Flush(flush_sender)).await, + Connection, + "failed to flush connection" + ); + + // Wait to receive the result + match flush_receiver.await { + Ok(()) => Ok(()), + _ => Err(Error::Connection("failed to flush connection".to_string())), + } + } } /// Read a length-delimited (serialized) message from a stream. /// Has a bounds check for if the message is too big async fn read_length_delimited( - mut stream: R, + stream: &mut R, ) -> Result { // Read the message size from the stream let message_size = bail!( @@ -132,9 +268,6 @@ async fn read_length_delimited( "failed to read message" ); - // Drop the stream since we're done with it - drop(stream); - // Add to our metrics, if desired #[cfg(feature = "metrics")] metrics::BYTES_RECV.add(message_size as f64); @@ -144,7 +277,7 @@ async fn read_length_delimited( /// Write a length-delimited (serialized) message to a stream. async fn write_length_delimited( - mut stream: W, + stream: &mut W, message: Bytes, ) -> Result<()> { // Get the length of the message @@ -176,9 +309,6 @@ async fn write_length_delimited( "failed to send message" ); - // Drop the stream since we're done with it - drop(stream); - // Increment the number of bytes we've sent by this amount #[cfg(feature = "metrics")] metrics::BYTES_SENT.add(message_len as f64); @@ -191,7 +321,7 @@ pub mod tests { use anyhow::Result; use tokio::{join, spawn, task::JoinHandle}; - use super::{Connection, Listener, Protocol, UnfinalizedConnection}; + use super::{Listener, Protocol, UnfinalizedConnection}; use crate::{ connection::middleware::NoMiddleware, crypto::tls::{generate_cert_from_ca, LOCAL_CA_CERT, LOCAL_CA_KEY}, @@ -255,7 +385,8 @@ pub mod tests { // Send our message connection.send_message(new_connection_to_listener).await?; - connection.finish().await; + // Flush the connection, ensuring the message is sent + connection.flush().await?; Ok(()) }); diff --git a/cdn-proto/src/connection/protocols/quic.rs b/cdn-proto/src/connection/protocols/quic.rs index fa4bf45..84c3537 100644 --- a/cdn-proto/src/connection/protocols/quic.rs +++ b/cdn-proto/src/connection/protocols/quic.rs @@ -2,7 +2,6 @@ //! connection that implements our message framing and connection //! logic. -use std::marker::PhantomData; use std::time::Duration; use std::{ net::{SocketAddr, ToSocketAddrs}, @@ -13,21 +12,15 @@ use async_trait::async_trait; use quinn::{ClientConfig, Connecting, Endpoint, ServerConfig, TransportConfig, VarInt}; use rustls::{Certificate, PrivateKey, RootCertStore}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::sync::Mutex; use tokio::time::timeout; -use super::{ - read_length_delimited, write_length_delimited, Connection, Listener, Protocol, - UnfinalizedConnection, -}; +use super::{Connection, Listener, Protocol, UnfinalizedConnection}; use crate::connection::middleware::Middleware; -use crate::connection::Bytes; use crate::crypto::tls::{LOCAL_CA_CERT, PROD_CA_CERT}; use crate::parse_endpoint; use crate::{ bail, bail_option, error::{Error, Result}, - message::Message, }; /// The `Quic` protocol. We use this to define commonalities between QUIC @@ -37,15 +30,10 @@ pub struct Quic; #[async_trait] impl Protocol for Quic { - type Connection = QuicConnection; - type UnfinalizedConnection = UnfinalizedQuicConnection; type Listener = QuicListener; - async fn connect( - remote_endpoint: &str, - use_local_authority: bool, - ) -> Result> { + async fn connect(remote_endpoint: &str, use_local_authority: bool) -> Result { // Parse the endpoint let remote_endpoint = bail_option!( bail!( @@ -137,12 +125,8 @@ impl Protocol for Quic { "failed to bootstrap connection" ); - // Create connection - let connection = QuicConnection { - sender: Arc::from(Mutex::from(sender)), - receiver: Arc::from(Mutex::from(receiver)), - pd: PhantomData, - }; + // Convert the streams into a `Connection` + let connection = Connection::from_streams::<_, _, M>(sender, receiver); Ok(connection) } @@ -185,85 +169,17 @@ impl Protocol for Quic { } } -#[derive(Clone)] -pub struct QuicConnection { - sender: Arc>, - receiver: Arc>, - pd: PhantomData, -} - -#[async_trait] -impl Connection for QuicConnection { - /// Send an unserialized message over the stream. - /// - /// # Errors - /// If we fail to send or serialize the message - async fn send_message(&self, message: Message) -> Result<()> { - // Serialize our message - let raw_message = Bytes::from_unchecked(bail!( - message.serialize(), - Serialize, - "failed to serialize message" - )); - - // Send the message in its raw form - self.send_message_raw(raw_message).await - } - - /// Send a pre-serialized message over the connection. - /// - /// # Errors - /// - If we fail to deliver the message. This usually means a connection problem. - async fn send_message_raw(&self, raw_message: Bytes) -> Result<()> { - // Write the message length-delimited - write_length_delimited(&mut *self.sender.lock().await, raw_message).await - } - - /// Receives a single message over the stream and deserializes - /// it. - /// - /// # Errors - /// - if we fail to receive the message - /// - if we fail to deserialize the message - async fn recv_message(&self) -> Result { - // Receive the raw message - let raw_message = self.recv_message_raw().await?; - - // Deserialize and return the message - Ok(bail!( - Message::deserialize(&raw_message), - Deserialize, - "failed to deserialize message" - )) - } - - /// Receives a single message over the stream and deserializes - /// it. - /// - /// # Errors - /// - if we fail to receive the message - async fn recv_message_raw(&self) -> Result { - // Receive the length-delimited message - read_length_delimited::<_, M>(&mut *self.receiver.lock().await).await - } - - /// Gracefully finish the connection, sending any remaining data. - async fn finish(&self) { - let _ = self.sender.lock().await.finish().await; - } -} - /// A connection that has yet to be finalized. Allows us to keep accepting /// connections while we process this one. pub struct UnfinalizedQuicConnection(Connecting); #[async_trait] -impl UnfinalizedConnection> for UnfinalizedQuicConnection { +impl UnfinalizedConnection for UnfinalizedQuicConnection { /// Finalize the connection by awaiting on `Connecting` and cloning the connection. /// /// # Errors /// If we to finalize our connection. - async fn finalize(self) -> Result> { + async fn finalize(self) -> Result { // Await on the `Connecting` to obtain `Connection` let connection = bail!(self.0.await, Connection, "failed to finalize connection"); @@ -281,12 +197,8 @@ impl UnfinalizedConnection> for UnfinalizedQuic "failed to bootstrap connection" ); - // Create connection - let connection = QuicConnection { - sender: Arc::from(Mutex::from(sender)), - receiver: Arc::from(Mutex::from(receiver)), - pd: PhantomData, - }; + // Create a sender and receiver + let connection = Connection::from_streams::<_, _, M>(sender, receiver); // Clone and return the connection Ok(connection) diff --git a/cdn-proto/src/connection/protocols/tcp.rs b/cdn-proto/src/connection/protocols/tcp.rs index 03d24b0..85579fa 100644 --- a/cdn-proto/src/connection/protocols/tcp.rs +++ b/cdn-proto/src/connection/protocols/tcp.rs @@ -2,31 +2,20 @@ //! connection that implements our message framing and connection //! logic. -use std::marker::PhantomData; use std::net::SocketAddr; +use std::net::ToSocketAddrs; use std::time::Duration; -use std::{net::ToSocketAddrs, sync::Arc}; use async_trait::async_trait; use rustls::{Certificate, PrivateKey}; -use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; -use tokio::sync::Mutex; +use tokio::net::{TcpSocket, TcpStream}; use tokio::time::timeout; -use tokio::{ - io::AsyncWriteExt, - net::{TcpSocket, TcpStream}, -}; -use super::{ - read_length_delimited, write_length_delimited, Connection, Listener, Protocol, - UnfinalizedConnection, -}; +use super::{Connection, Listener, Protocol, UnfinalizedConnection}; use crate::connection::middleware::Middleware; use crate::{ bail, bail_option, - connection::Bytes, error::{Error, Result}, - message::Message, parse_endpoint, }; @@ -37,8 +26,6 @@ pub struct Tcp; #[async_trait] impl Protocol for Tcp { - type Connection = TcpConnection; - type Listener = TcpListener; type UnfinalizedConnection = UnfinalizedTcpConnection; @@ -47,7 +34,7 @@ impl Protocol for Tcp { /// /// # Errors /// Errors if we fail to connect or if we fail to bind to the interface we want. - async fn connect(remote_endpoint: &str, _use_local_authority: bool) -> Result + async fn connect(remote_endpoint: &str, _use_local_authority: bool) -> Result where Self: Sized, { @@ -83,11 +70,11 @@ impl Protocol for Tcp { // Split the connection and create our wrapper let (receiver, sender) = stream.into_split(); - Ok(TcpConnection { - sender: Arc::from(Mutex::from(sender)), - receiver: Arc::from(Mutex::from(receiver)), - pd: PhantomData, - }) + + // Convert the streams into a `Connection` + let connection = Connection::from_streams::<_, _, M>(sender, receiver); + + Ok(connection) } /// Binds to a local endpoint. Does not use a TLS configuration. @@ -112,95 +99,25 @@ impl Protocol for Tcp { } } -#[derive(Clone)] -pub struct TcpConnection { - sender: Arc>, - receiver: Arc>, - pd: PhantomData, -} - -#[async_trait] -impl Connection for TcpConnection { - /// Send an unserialized message over the stream. - /// - /// # Errors - /// If we fail to send or serialize the message - async fn send_message(&self, message: Message) -> Result<()> { - // Serialize our message - let raw_message = Bytes::from_unchecked(bail!( - message.serialize(), - Serialize, - "failed to serialize message" - )); - - // Send the message in its raw form - self.send_message_raw(raw_message).await - } - - /// Send a pre-serialized message over the connection. - /// - /// # Errors - /// - If we fail to deliver the message. This usually means a connection problem. - async fn send_message_raw(&self, raw_message: Bytes) -> Result<()> { - // Write the message length-delimited - write_length_delimited(&mut *self.sender.lock().await, raw_message).await - } - - /// Gracefully finish the connection, sending any remaining data. - /// This is done by sending two empty messages. - async fn finish(&self) { - let _ = self.sender.lock().await.flush().await; - } - - /// Receives a single message over the stream and deserializes - /// it. - /// - /// # Errors - /// - if we fail to receive the message - /// - if we fail to deserialize the message - async fn recv_message(&self) -> Result { - // Receive the raw message - let raw_message = self.recv_message_raw().await?; - - // Deserialize and return the message - Ok(bail!( - Message::deserialize(&raw_message), - Deserialize, - "failed to deserialize message" - )) - } - - /// Receives a single message over the stream and deserializes - /// it. - /// - /// # Errors - /// - if we fail to receive the message - async fn recv_message_raw(&self) -> Result { - // Receive the length-delimited message - read_length_delimited::<_, M>(&mut *self.receiver.lock().await).await - } -} - /// A connection that has yet to be finalized. Allows us to keep accepting /// connections while we process this one. pub struct UnfinalizedTcpConnection(TcpStream); #[async_trait] -impl UnfinalizedConnection> for UnfinalizedTcpConnection { +impl UnfinalizedConnection for UnfinalizedTcpConnection { /// Finalize the connection by splitting it into a sender and receiver side. /// Conssumes `Self`. /// /// # Errors /// Does not actually error, but satisfies trait bounds. - async fn finalize(self) -> Result> { + async fn finalize(self) -> Result { // Split the connection and create our wrapper let (receiver, sender) = self.0.into_split(); - Ok(TcpConnection { - sender: Arc::from(Mutex::from(sender)), - receiver: Arc::from(Mutex::from(receiver)), - pd: PhantomData, - }) + // Convert the streams into a `Connection` + let connection = Connection::from_streams::<_, _, M>(sender, receiver); + + Ok(connection) } } diff --git a/cdn-proto/src/def.rs b/cdn-proto/src/def.rs index 012cc0c..ec16ac9 100644 --- a/cdn-proto/src/def.rs +++ b/cdn-proto/src/def.rs @@ -125,4 +125,3 @@ pub type PublicKey = as SignatureScheme>::PublicKey; pub type Protocol = ::Protocol; pub type Middleware = ::Middleware; pub type Listener = as ProtocolType>>::Listener; -pub type Connection = as ProtocolType>>::Connection; diff --git a/tests/src/tests/double_connect.rs b/tests/src/tests/double_connect.rs index 7221f3a..7210d56 100644 --- a/tests/src/tests/double_connect.rs +++ b/tests/src/tests/double_connect.rs @@ -34,10 +34,13 @@ async fn test_double_connect_same_broker() { sleep(Duration::from_millis(50)).await; // Attempt to send a message, should fail - assert!(client1 - .send_direct_message(&keypair_from_seed(1).1, b"hello direct".to_vec()) - .await - .is_err()); + assert!( + client1 + .send_direct_message(&keypair_from_seed(1).1, b"hello direct".to_vec()) + .await + .is_err() + || client1.flush().await.is_err() + ); // The second client to connect should have succeeded client2 @@ -121,7 +124,8 @@ async fn test_double_connect_different_broker() { client1 .send_direct_message(&keypair_from_seed(1).1, b"hello direct".to_vec()) .await - .is_err(), + .is_err() + || client1.flush().await.is_err(), "second client connected when it shouldn't have" ); } diff --git a/tests/src/tests/subscribe.rs b/tests/src/tests/subscribe.rs index 1ca6566..5a91a80 100644 --- a/tests/src/tests/subscribe.rs +++ b/tests/src/tests/subscribe.rs @@ -139,7 +139,8 @@ async fn test_invalid_subscribe() { client .send_broadcast_message(vec![1], b"hello invalid".to_vec()) .await - .is_err(), + .is_err() + || client.flush().await.is_err(), "sent message but should've been disconnected" ); From 7358e4bd1b2883851d3a1fe8f572e9ec3b31c849 Mon Sep 17 00:00:00 2001 From: Rob Date: Tue, 4 Jun 2024 17:08:23 -0400 Subject: [PATCH 02/31] clippy and fmt --- cdn-broker/src/tasks/user/handler.rs | 6 ++++-- cdn-client/src/lib.rs | 2 +- cdn-client/src/retry.rs | 2 +- cdn-proto/benches/protocols.rs | 10 +++------- cdn-proto/src/connection/protocols/mod.rs | 4 ++-- 5 files changed, 11 insertions(+), 13 deletions(-) diff --git a/cdn-broker/src/tasks/user/handler.rs b/cdn-broker/src/tasks/user/handler.rs index 13775e4..78ab026 100644 --- a/cdn-broker/src/tasks/user/handler.rs +++ b/cdn-broker/src/tasks/user/handler.rs @@ -102,7 +102,8 @@ impl Inner { Message::Direct(ref direct) => { let user_public_key = UserPublicKey::from(direct.recipient.clone()); - self.handle_direct_message(&user_public_key, raw_message, false).await; + self.handle_direct_message(&user_public_key, raw_message, false) + .await; } // If we get a broadcast message from a user, send it to both brokers and users. @@ -111,7 +112,8 @@ impl Inner { let mut topics = broadcast.topics.clone(); Def::Topic::prune(&mut topics)?; - self.handle_broadcast_message(&topics, &raw_message, false).await; + self.handle_broadcast_message(&topics, &raw_message, false) + .await; } // Subscribe messages from users will just update the state locally diff --git a/cdn-client/src/lib.rs b/cdn-client/src/lib.rs index 144af47..4561376 100644 --- a/cdn-client/src/lib.rs +++ b/cdn-client/src/lib.rs @@ -164,7 +164,7 @@ impl Client { /// Flushes the connection, ensuring that all messages are sent. /// This is useful for ensuring that messages are sent before a /// connection is closed. - /// + /// /// # Errors /// - if the connection is already closed pub async fn flush(&self) -> Result<()> { diff --git a/cdn-client/src/retry.rs b/cdn-client/src/retry.rs index d4ecb15..8f36667 100644 --- a/cdn-client/src/retry.rs +++ b/cdn-client/src/retry.rs @@ -265,7 +265,7 @@ impl Retry { } /// Flushes the connection, ensuring that all messages are sent. - /// + /// /// # Errors /// - If we are in the middle of reconnecting /// - If the connection is closed diff --git a/cdn-proto/benches/protocols.rs b/cdn-proto/benches/protocols.rs index a1e86d8..2b98830 100644 --- a/cdn-proto/benches/protocols.rs +++ b/cdn-proto/benches/protocols.rs @@ -15,11 +15,7 @@ use tokio::{join, runtime::Runtime, spawn}; /// Transfer a message `raw_message` from `conn1` to `conn2.` This is the primary /// function used for testing network protocol speed. -async fn transfer>( - conn1: Connection, - conn2: Connection, - raw_message: Bytes, -) { +async fn transfer(conn1: Connection, conn2: Connection, raw_message: Bytes) { // Send from the first connection let conn1_jh = spawn(async move { conn1 @@ -113,7 +109,7 @@ fn bench_quic(c: &mut Criterion) { group.throughput(Throughput::Bytes(*size as u64)); group.bench_function(BenchmarkId::from_parameter(size), |b| { b.to_async(&runtime).iter(|| { - transfer::( + transfer( black_box(conn1.clone()), black_box(conn2.clone()), black_box(message.clone()), @@ -139,7 +135,7 @@ fn bench_tcp(c: &mut Criterion) { group.throughput(Throughput::Bytes(*size as u64)); group.bench_function(BenchmarkId::from_parameter(size), |b| { b.to_async(&runtime).iter(|| { - transfer::( + transfer( black_box(conn1.clone()), black_box(conn2.clone()), black_box(message.clone()), diff --git a/cdn-proto/src/connection/protocols/mod.rs b/cdn-proto/src/connection/protocols/mod.rs index 5394656..745613c 100644 --- a/cdn-proto/src/connection/protocols/mod.rs +++ b/cdn-proto/src/connection/protocols/mod.rs @@ -120,7 +120,7 @@ impl Connection { match message { BytesOrFlush::Bytes(message) => { // Write the message to the stream - if let Err(_) = write_length_delimited(&mut writer, message).await { + if write_length_delimited(&mut writer, message).await.is_err() { receive_from_caller.close(); return; }; @@ -138,7 +138,7 @@ impl Connection { let receiver_task = tokio::spawn(async move { // While we can successfully read messages from the stream, while let Ok(message) = read_length_delimited::(&mut reader).await { - if let Err(_) = send_to_caller.send(message).await { + if send_to_caller.send(message).await.is_err() { send_to_caller.close(); return; }; From 73b9bfbf3928eaccb4faa9a70dc2446b81afbce5 Mon Sep 17 00:00:00 2001 From: Rob Date: Wed, 5 Jun 2024 10:05:39 -0400 Subject: [PATCH 03/31] close channel on drop --- cdn-proto/src/connection/protocols/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cdn-proto/src/connection/protocols/mod.rs b/cdn-proto/src/connection/protocols/mod.rs index 745613c..72219e0 100644 --- a/cdn-proto/src/connection/protocols/mod.rs +++ b/cdn-proto/src/connection/protocols/mod.rs @@ -94,6 +94,8 @@ impl Drop for ConnectionRef { // Cancel all tasks for task in self.tasks.iter() { task.abort(); + self.sender.close(); + self.receiver.close(); } } } From 46d5dd3088db7ba5da2ee3b270841ca2bc07e16d Mon Sep 17 00:00:00 2001 From: Rob Date: Wed, 5 Jun 2024 13:19:41 -0400 Subject: [PATCH 04/31] move channel closes --- cdn-proto/src/connection/protocols/mod.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/cdn-proto/src/connection/protocols/mod.rs b/cdn-proto/src/connection/protocols/mod.rs index 72219e0..4e1440e 100644 --- a/cdn-proto/src/connection/protocols/mod.rs +++ b/cdn-proto/src/connection/protocols/mod.rs @@ -91,11 +91,13 @@ pub struct ConnectionRef { impl Drop for ConnectionRef { fn drop(&mut self) { - // Cancel all tasks + // Close the channels + self.sender.close(); + self.receiver.close(); + + // Abort all tasks for task in self.tasks.iter() { task.abort(); - self.sender.close(); - self.receiver.close(); } } } @@ -128,7 +130,7 @@ impl Connection { }; } BytesOrFlush::Flush(result_sender) => { - // Acknowledge that we've finished successfully + // Acknowledge that we've processed up to this point let _ = result_sender.send(()); } } @@ -213,7 +215,7 @@ impl Connection { Ok(bail!( self.0.receiver.recv().await, Connection, - "failed to send message" + "failed to receive message" )) } From db997e692d58b6eebfe8cc936e3812d8d4466ade Mon Sep 17 00:00:00 2001 From: Rob Date: Wed, 5 Jun 2024 15:10:30 -0400 Subject: [PATCH 05/31] merge --- cdn-proto/src/connection/protocols/memory.rs | 2 +- cdn-proto/src/connection/protocols/mod.rs | 2 +- cdn-proto/src/connection/protocols/tcp.rs | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/cdn-proto/src/connection/protocols/memory.rs b/cdn-proto/src/connection/protocols/memory.rs index f0cae14..e6a3294 100644 --- a/cdn-proto/src/connection/protocols/memory.rs +++ b/cdn-proto/src/connection/protocols/memory.rs @@ -5,7 +5,7 @@ use std::{collections::HashMap, sync::OnceLock}; use async_trait::async_trait; use kanal::{unbounded_async, AsyncReceiver, AsyncSender}; -use rustls::{Certificate, PrivateKey}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use tokio::{ io::{duplex, DuplexStream}, sync::RwLock, diff --git a/cdn-proto/src/connection/protocols/mod.rs b/cdn-proto/src/connection/protocols/mod.rs index afee591..a3c8a05 100644 --- a/cdn-proto/src/connection/protocols/mod.rs +++ b/cdn-proto/src/connection/protocols/mod.rs @@ -4,7 +4,7 @@ use std::{sync::Arc, time::Duration}; use async_trait::async_trait; use kanal::{AsyncReceiver, AsyncSender}; -use rustls::{Certificate, PrivateKey}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, sync::oneshot, diff --git a/cdn-proto/src/connection/protocols/tcp.rs b/cdn-proto/src/connection/protocols/tcp.rs index b5e2995..49404bf 100644 --- a/cdn-proto/src/connection/protocols/tcp.rs +++ b/cdn-proto/src/connection/protocols/tcp.rs @@ -7,7 +7,8 @@ use std::net::ToSocketAddrs; use std::time::Duration; use async_trait::async_trait; -use rustls::{Certificate, PrivateKey}; +use rustls::pki_types::CertificateDer; +use rustls::pki_types::PrivateKeyDer; use tokio::net::{TcpSocket, TcpStream}; use tokio::time::timeout; From 4ba4714e4816a629ae931c9b5f06ee0e1b004429 Mon Sep 17 00:00:00 2001 From: Rob Date: Wed, 5 Jun 2024 17:22:55 -0400 Subject: [PATCH 06/31] move out bidirectional open --- cdn-proto/src/connection/protocols/quic.rs | 81 ++++++++++++++++------ 1 file changed, 60 insertions(+), 21 deletions(-) diff --git a/cdn-proto/src/connection/protocols/quic.rs b/cdn-proto/src/connection/protocols/quic.rs index e547c7f..b7e3d53 100644 --- a/cdn-proto/src/connection/protocols/quic.rs +++ b/cdn-proto/src/connection/protocols/quic.rs @@ -112,22 +112,14 @@ impl Protocol for Quic { ); // Open an outgoing bidirectional stream - let (mut sender, receiver) = bail!( + let (sender, receiver) = bail!( bail!( - timeout(Duration::from_secs(5), connection.open_bi()).await, + timeout(Duration::from_secs(5), open_bi(&connection)).await, Connection, - "timed out accepting stream" + "timed out opening bidirectional stream" ), Connection, - "failed to accept bidirectional stream" - ); - - // Write a `u8` to bootstrap the connection (make the sender aware of our - // outbound stream request) - bail!( - sender.write_u8(0).await, - Connection, - "failed to bootstrap connection" + "failed to open bidirectional stream" ); // Convert the streams into a `Connection` @@ -189,19 +181,16 @@ impl UnfinalizedConnection for UnfinalizedQuicConnection { let connection = bail!(self.0.await, Connection, "failed to finalize connection"); // Accept an incoming bidirectional stream - let (sender, mut receiver) = bail!( - connection.accept_bi().await, + let (sender, receiver) = bail!( + bail!( + timeout(Duration::from_secs(5), accept_bi(&connection)).await, + Connection, + "timed out accepting bidirectional stream" + ), Connection, "failed to accept bidirectional stream" ); - // Read the `u8` required to bootstrap the connection - bail!( - receiver.read_u8().await, - Connection, - "failed to bootstrap connection" - ); - // Create a sender and receiver let connection = Connection::from_streams::<_, _, M>(sender, receiver); @@ -233,6 +222,56 @@ impl Listener for QuicListener { } } +/// A helper function for opening a new connection and atomically +/// writing to it to bootstrap it. +/// +/// # Errors +/// - If we fail to open a bidirectional stream +/// - If we fail to write to the stream +async fn open_bi(connection: &quinn::Connection) -> Result<(quinn::SendStream, quinn::RecvStream)> { + // Open a bidirectional stream + let (mut sender, receiver) = bail!( + connection.open_bi().await, + Connection, + "failed to open unidirectional stream" + ); + + // Write a `u8` to bootstrap the connection + bail!( + sender.write_u8(0).await, + Connection, + "failed to write `u8` to unidirectional stream" + ); + + Ok((sender, receiver)) +} + +/// A helper function for accepting a new connection and atomically +/// reading from it to bootstrap it. +/// +/// # Errors +/// - If we fail to accept a bidirectional stream +/// - If we fail to read from the stream +async fn accept_bi( + connection: &quinn::Connection, +) -> Result<(quinn::SendStream, quinn::RecvStream)> { + // Accept an incoming bidirectional stream + let (sender, mut receiver) = bail!( + connection.accept_bi().await, + Connection, + "failed to accept bidirectional stream" + ); + + // Read the `u8` required to bootstrap the connection + bail!( + receiver.read_u8().await, + Connection, + "failed to read `u8` from bidirectional stream" + ); + + Ok((sender, receiver)) +} + #[cfg(test)] mod tests { use anyhow::{anyhow, Result}; From c07b3cd409e0e75850bb615d73d224646993e186 Mon Sep 17 00:00:00 2001 From: Rob Date: Wed, 5 Jun 2024 17:32:12 -0400 Subject: [PATCH 07/31] single cancel for handlers --- cdn-broker/src/connections/mod.rs | 33 +++++++++++-------------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/cdn-broker/src/connections/mod.rs b/cdn-broker/src/connections/mod.rs index f7e1ff2..006ae70 100644 --- a/cdn-broker/src/connections/mod.rs +++ b/cdn-broker/src/connections/mod.rs @@ -20,16 +20,14 @@ use self::broadcast::BroadcastMap; mod broadcast; mod direct; -type TaskMap = HashMap; - pub struct Connections { // Our identity. Used for versioned vector conflict resolution. identity: BrokerIdentifier, // The current users connected to us, along with their running tasks - users: HashMap, + users: HashMap, // The current brokers connected to us, along with their running tasks - brokers: HashMap, + brokers: HashMap, // The versioned vector for looking up where direct messages should go direct_map: DirectMap, @@ -183,10 +181,7 @@ impl Connections { // Remove the old broker if it exists self.remove_broker(&broker_identifier, "already existed"); - self.brokers.insert( - broker_identifier, - (connection, HashMap::from([(0, handle)])), - ); + self.brokers.insert(broker_identifier, (connection, handle)); } /// Insert a user into our map. Updates the versioned vector that @@ -206,10 +201,8 @@ impl Connections { self.remove_user(user_public_key.clone(), "already existed"); // Add to our map. Remove the old one if it exists - self.users.insert( - user_public_key.clone(), - (connection, HashMap::from([(0, handle)])), - ); + self.users + .insert(user_public_key.clone(), (connection, handle)); // Insert into our direct map self.direct_map @@ -225,15 +218,13 @@ impl Connections { /// from our broadcast map, in case they were subscribed to any topics. pub fn remove_broker(&mut self, broker_identifier: &BrokerIdentifier, reason: &str) { // Remove from broker list, cancelling the previous task if it exists - if let Some(task_handles) = self.brokers.remove(broker_identifier).map(|(_, h)| h) { + if let Some((_, task)) = self.brokers.remove(broker_identifier) { // Decrement the metric for the number of brokers connected metrics::NUM_BROKERS_CONNECTED.dec(); error!(id = %broker_identifier, reason = reason, "broker disconnected"); - // Cancel all tasks - for (_, handle) in task_handles { - handle.abort(); - } + // Cancel the task + task.abort(); }; // Remove from all topics @@ -249,7 +240,7 @@ impl Connections { /// to send us messages for a disconnected user. pub fn remove_user(&mut self, user_public_key: UserPublicKey, reason: &str) { // Remove from user list, returning the previous handle if it exists - if let Some(task_handles) = self.users.remove(&user_public_key).map(|(_, h)| h) { + if let Some((_, task)) = self.users.remove(&user_public_key) { // Decrement the metric for the number of users connected metrics::NUM_USERS_CONNECTED.dec(); warn!( @@ -258,10 +249,8 @@ impl Connections { "user disconnected" ); - // Cancel all tasks - for (_, handle) in task_handles { - handle.abort(); - } + // Cancel the task + task.abort(); }; // Remove from user topics From 93a6bb31a593297aab577816ac03ed3fc3d5aca9 Mon Sep 17 00:00:00 2001 From: Rob Date: Thu, 6 Jun 2024 10:13:19 -0400 Subject: [PATCH 08/31] soft close --- cdn-client/src/lib.rs | 6 +-- cdn-client/src/retry.rs | 6 +-- cdn-marshal/src/handlers.rs | 4 +- cdn-proto/src/connection/protocols/memory.rs | 6 ++- cdn-proto/src/connection/protocols/mod.rs | 54 ++++++++++++++------ cdn-proto/src/connection/protocols/quic.rs | 15 +++++- cdn-proto/src/connection/protocols/tcp.rs | 6 +++ tests/src/tests/double_connect.rs | 4 +- tests/src/tests/subscribe.rs | 2 +- 9 files changed, 73 insertions(+), 30 deletions(-) diff --git a/cdn-client/src/lib.rs b/cdn-client/src/lib.rs index 4561376..1c7374d 100644 --- a/cdn-client/src/lib.rs +++ b/cdn-client/src/lib.rs @@ -161,13 +161,13 @@ impl Client { self.0.send_message(message).await } - /// Flushes the connection, ensuring that all messages are sent. + /// Soft close the connection, ensuring that all messages are sent. /// This is useful for ensuring that messages are sent before a /// connection is closed. /// /// # Errors /// - if the connection is already closed - pub async fn flush(&self) -> Result<()> { - self.0.flush().await + pub async fn soft_close(&self) -> Result<()> { + self.0.soft_close().await } } diff --git a/cdn-client/src/retry.rs b/cdn-client/src/retry.rs index 8f36667..32a7b15 100644 --- a/cdn-client/src/retry.rs +++ b/cdn-client/src/retry.rs @@ -264,12 +264,12 @@ impl Retry { try_with_reconnect!(self, out) } - /// Flushes the connection, ensuring that all messages are sent. + /// Soft close the connection, ensuring that all messages are sent. /// /// # Errors /// - If we are in the middle of reconnecting /// - If the connection is closed - pub async fn flush(&self) -> Result<()> { + pub async fn soft_close(&self) -> Result<()> { // Check if we're (probably) reconnecting or not if let Ok(connection_guard) = self.inner.connection.try_read() { // We're not reconnecting, try to send the message @@ -277,7 +277,7 @@ impl Retry { connection_guard .get_or_try_init(|| self.inner.connect()) .await? - .flush() + .soft_close() .await } else { // We are reconnecting, return an error diff --git a/cdn-marshal/src/handlers.rs b/cdn-marshal/src/handlers.rs index c6b4a3a..6408cea 100644 --- a/cdn-marshal/src/handlers.rs +++ b/cdn-marshal/src/handlers.rs @@ -26,7 +26,7 @@ impl Marshal { info!(id = mnemonic(&user_public_key), "user authenticated"); } - // Flush the connection, ensuring all data was sent - let _ = connection.flush().await; + // Soft close the connection, ensuring all data was sent + let _ = connection.soft_close().await; } } diff --git a/cdn-proto/src/connection/protocols/memory.rs b/cdn-proto/src/connection/protocols/memory.rs index e6a3294..bb5125a 100644 --- a/cdn-proto/src/connection/protocols/memory.rs +++ b/cdn-proto/src/connection/protocols/memory.rs @@ -12,7 +12,7 @@ use tokio::{ task::spawn_blocking, }; -use super::{Connection, Listener, Protocol, UnfinalizedConnection}; +use super::{Connection, Listener, Protocol, SoftClose, UnfinalizedConnection}; use crate::connection::middleware::NoMiddleware; #[cfg(feature = "metrics")] use crate::{ @@ -191,6 +191,10 @@ impl Memory { } } +/// Soft closing is a no-op for memory connections. +#[async_trait] +impl SoftClose for DuplexStream {} + #[cfg(test)] mod tests { use anyhow::Result; diff --git a/cdn-proto/src/connection/protocols/mod.rs b/cdn-proto/src/connection/protocols/mod.rs index a3c8a05..4a0d74e 100644 --- a/cdn-proto/src/connection/protocols/mod.rs +++ b/cdn-proto/src/connection/protocols/mod.rs @@ -72,18 +72,19 @@ pub trait UnfinalizedConnection { #[derive(Clone)] pub struct Connection(Arc); -/// A message to send over the channel, either a raw message or a flush message. -/// The flush message is used to ensure that all messages are sent before we close the connection. -enum BytesOrFlush { +/// A message to send over the channel, either a raw message or a soft close. +/// Soft close is used to indicate that the connection should be closed after +/// all messages have been sent. +enum BytesOrSoftClose { Bytes(Bytes), - Flush(oneshot::Sender<()>), + SoftClose(oneshot::Sender<()>), } /// A reference to a delegated connection, containing the sender and /// receiver channels. #[derive(Clone)] pub struct ConnectionRef { - sender: AsyncSender, + sender: AsyncSender, receiver: AsyncReceiver, tasks: Arc>, @@ -102,11 +103,19 @@ impl Drop for ConnectionRef { } } +/// Implement a soft close for all types that implement `SoftClose`. +/// This allows us to soft close a connection, allowing all messages to be sent +/// before closing. +#[async_trait] +trait SoftClose { + async fn soft_close(&mut self) {} +} + impl Connection { /// Converts a set of writer and reader streams into a connection. /// Under the hood, this spawns sending and receiving tasks. fn from_streams< - W: AsyncWriteExt + Unpin + Send + 'static, + W: AsyncWriteExt + Unpin + Send + SoftClose + 'static, R: AsyncReadExt + Unpin + Send + 'static, M: Middleware, >( @@ -122,14 +131,17 @@ impl Connection { // While we can successfully receive messages from the caller, while let Ok(message) = receive_from_caller.recv().await { match message { - BytesOrFlush::Bytes(message) => { + BytesOrSoftClose::Bytes(message) => { // Write the message to the stream if write_length_delimited(&mut writer, message).await.is_err() { receive_from_caller.close(); return; }; } - BytesOrFlush::Flush(result_sender) => { + BytesOrSoftClose::SoftClose(result_sender) => { + // Soft close the writer, allowing it to finish sending + writer.soft_close().await; + // Acknowledge that we've processed up to this point let _ = result_sender.send(()); } @@ -181,7 +193,10 @@ impl Connection { pub async fn send_message_raw(&self, raw_message: Bytes) -> Result<()> { // Send the message bail!( - self.0.sender.send(BytesOrFlush::Bytes(raw_message)).await, + self.0 + .sender + .send(BytesOrSoftClose::Bytes(raw_message)) + .await, Connection, "failed to send message" ); @@ -219,19 +234,26 @@ impl Connection { )) } - pub async fn flush(&self) -> Result<()> { + /// Soft close the connection, allowing all messages to be sent before closing. + /// + /// # Errors + /// - If we fail to soft close the connection + pub async fn soft_close(&self) -> Result<()> { // Create notifier to wait for the flush message to be acknowledged - let (flush_sender, flush_receiver) = oneshot::channel(); + let (soft_close_sender, soft_close_receiver) = oneshot::channel(); - // Send the flush message + // Send the soft close message bail!( - self.0.sender.send(BytesOrFlush::Flush(flush_sender)).await, + self.0 + .sender + .send(BytesOrSoftClose::SoftClose(soft_close_sender)) + .await, Connection, "failed to flush connection" ); // Wait to receive the result - match flush_receiver.await { + match soft_close_receiver.await { Ok(()) => Ok(()), _ => Err(Error::Connection("failed to flush connection".to_string())), } @@ -389,8 +411,8 @@ pub mod tests { // Send our message connection.send_message(new_connection_to_listener).await?; - // Flush the connection, ensuring the message is sent - connection.flush().await?; + // Soft close the connection, allowing all messages to be sent + connection.soft_close().await?; Ok(()) }); diff --git a/cdn-proto/src/connection/protocols/quic.rs b/cdn-proto/src/connection/protocols/quic.rs index b7e3d53..28d19c9 100644 --- a/cdn-proto/src/connection/protocols/quic.rs +++ b/cdn-proto/src/connection/protocols/quic.rs @@ -9,13 +9,13 @@ use std::{ }; use async_trait::async_trait; -use quinn::{ClientConfig, Endpoint, Incoming, ServerConfig, TransportConfig, VarInt}; +use quinn::{ClientConfig, Endpoint, Incoming, SendStream, ServerConfig, TransportConfig, VarInt}; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use rustls::RootCertStore; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::time::timeout; -use super::{Connection, Listener, Protocol, UnfinalizedConnection}; +use super::{Connection, Listener, Protocol, SoftClose, UnfinalizedConnection}; use crate::connection::middleware::Middleware; use crate::crypto::tls::{LOCAL_CA_CERT, PROD_CA_CERT}; use crate::parse_endpoint; @@ -272,6 +272,17 @@ async fn accept_bi( Ok((sender, receiver)) } +#[async_trait] +impl SoftClose for SendStream { + /// Soft close the stream by shutting down the write side and waiting for the + /// read side to close (with a timeout of 3 seconds). + async fn soft_close(&mut self) { + if self.finish().is_ok() { + let _ = timeout(Duration::from_secs(3), self.stopped()).await; + } + } +} + #[cfg(test)] mod tests { use anyhow::{anyhow, Result}; diff --git a/cdn-proto/src/connection/protocols/tcp.rs b/cdn-proto/src/connection/protocols/tcp.rs index 49404bf..3370720 100644 --- a/cdn-proto/src/connection/protocols/tcp.rs +++ b/cdn-proto/src/connection/protocols/tcp.rs @@ -9,9 +9,11 @@ use std::time::Duration; use async_trait::async_trait; use rustls::pki_types::CertificateDer; use rustls::pki_types::PrivateKeyDer; +use tokio::net::tcp::OwnedWriteHalf; use tokio::net::{TcpSocket, TcpStream}; use tokio::time::timeout; +use super::SoftClose; use super::{Connection, Listener, Protocol, UnfinalizedConnection}; use crate::connection::middleware::Middleware; use crate::{ @@ -146,6 +148,10 @@ impl Listener for TcpListener { } } +/// Soft closing is a no-op for TCP connections. +#[async_trait] +impl SoftClose for OwnedWriteHalf {} + #[cfg(test)] mod tests { use anyhow::{anyhow, Result}; diff --git a/tests/src/tests/double_connect.rs b/tests/src/tests/double_connect.rs index 7210d56..4324d42 100644 --- a/tests/src/tests/double_connect.rs +++ b/tests/src/tests/double_connect.rs @@ -39,7 +39,7 @@ async fn test_double_connect_same_broker() { .send_direct_message(&keypair_from_seed(1).1, b"hello direct".to_vec()) .await .is_err() - || client1.flush().await.is_err() + || client1.soft_close().await.is_err() ); // The second client to connect should have succeeded @@ -125,7 +125,7 @@ async fn test_double_connect_different_broker() { .send_direct_message(&keypair_from_seed(1).1, b"hello direct".to_vec()) .await .is_err() - || client1.flush().await.is_err(), + || client1.soft_close().await.is_err(), "second client connected when it shouldn't have" ); } diff --git a/tests/src/tests/subscribe.rs b/tests/src/tests/subscribe.rs index 5a91a80..6b1a612 100644 --- a/tests/src/tests/subscribe.rs +++ b/tests/src/tests/subscribe.rs @@ -140,7 +140,7 @@ async fn test_invalid_subscribe() { .send_broadcast_message(vec![1], b"hello invalid".to_vec()) .await .is_err() - || client.flush().await.is_err(), + || client.soft_close().await.is_err(), "sent message but should've been disconnected" ); From 2ca8f568efc0552ea3d0e5032ad86de18568b7fc Mon Sep 17 00:00:00 2001 From: Rob Date: Thu, 6 Jun 2024 10:15:13 -0400 Subject: [PATCH 09/31] fmt --- cdn-proto/src/connection/protocols/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cdn-proto/src/connection/protocols/mod.rs b/cdn-proto/src/connection/protocols/mod.rs index 4a0d74e..3486315 100644 --- a/cdn-proto/src/connection/protocols/mod.rs +++ b/cdn-proto/src/connection/protocols/mod.rs @@ -235,7 +235,7 @@ impl Connection { } /// Soft close the connection, allowing all messages to be sent before closing. - /// + /// /// # Errors /// - If we fail to soft close the connection pub async fn soft_close(&self) -> Result<()> { From 6a53bd2a0df938421d4ef8432e6eec12f60f0ffa Mon Sep 17 00:00:00 2001 From: Rob Date: Thu, 6 Jun 2024 10:25:45 -0400 Subject: [PATCH 10/31] trusted middleware for client connections --- cdn-proto/benches/protocols.rs | 4 ++-- cdn-proto/src/connection/middleware/mod.rs | 2 +- cdn-proto/src/connection/protocols/tcp.rs | 1 - cdn-proto/src/def.rs | 4 ++-- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/cdn-proto/benches/protocols.rs b/cdn-proto/benches/protocols.rs index 2b98830..7c9947c 100644 --- a/cdn-proto/benches/protocols.rs +++ b/cdn-proto/benches/protocols.rs @@ -2,7 +2,7 @@ use cdn_proto::{ connection::{ - middleware::NoMiddleware, + middleware::TrustedMiddleware, protocols::{quic::Quic, tcp::Tcp, Connection, Listener, Protocol, UnfinalizedConnection}, Bytes, }, @@ -38,7 +38,7 @@ async fn transfer(conn1: Connection, conn2: Connection, raw_message: Bytes) { /// Set up our protocol benchmarks, including async runtime, given the message size /// to test. -fn set_up_bench>( +fn set_up_bench>( message_size: usize, ) -> (Runtime, Connection, Connection, Bytes) { // Create new tokio runtime diff --git a/cdn-proto/src/connection/middleware/mod.rs b/cdn-proto/src/connection/middleware/mod.rs index 3b967cf..96995ea 100644 --- a/cdn-proto/src/connection/middleware/mod.rs +++ b/cdn-proto/src/connection/middleware/mod.rs @@ -7,7 +7,7 @@ pub mod pool; lazy_static! { /// A global semaphore that prevents the server from allocating too much memory at once. - static ref MEMORY_POOL: MemoryPool = MemoryPool::new(u32::MAX as usize); + static ref MEMORY_POOL: MemoryPool = MemoryPool::new((u32::MAX / 4) as usize); } /// A trait that defines middleware for a connection. diff --git a/cdn-proto/src/connection/protocols/tcp.rs b/cdn-proto/src/connection/protocols/tcp.rs index 639e6be..3370720 100644 --- a/cdn-proto/src/connection/protocols/tcp.rs +++ b/cdn-proto/src/connection/protocols/tcp.rs @@ -172,4 +172,3 @@ mod tests { super_test_connection::(format!("127.0.0.1:{port}")).await } } - diff --git a/cdn-proto/src/def.rs b/cdn-proto/src/def.rs index ec16ac9..4d09d47 100644 --- a/cdn-proto/src/def.rs +++ b/cdn-proto/src/def.rs @@ -88,14 +88,14 @@ impl ConnectionDef for ProductionUserConnection { } /// The production client connection configuration. -/// Uses BLS signatures, QUIC, and no middleware. +/// Uses BLS signatures, QUIC, and trusted middleware. /// Differs from `ProductionUserConnection` in that this is used by /// the client, not the broker. pub struct ProductionClientConnection; impl ConnectionDef for ProductionClientConnection { type Scheme = Scheme<::User>; type Protocol = Protocol<::User>; - type Middleware = NoMiddleware; + type Middleware = TrustedMiddleware; } /// The testing run configuration. From e72666c3d6f30077902405fdb28de60f8ed6d7aa Mon Sep 17 00:00:00 2001 From: Rob Date: Thu, 6 Jun 2024 14:58:44 -0400 Subject: [PATCH 11/31] move middleware --- cdn-broker/src/binaries/bad-broker.rs | 1 + cdn-broker/src/binaries/broker.rs | 8 ++ cdn-broker/src/lib.rs | 17 +++- cdn-broker/src/reexports.rs | 3 - cdn-broker/src/tasks/broker/heartbeat.rs | 2 +- cdn-broker/src/tasks/broker/listener.rs | 5 +- cdn-broker/src/tasks/user/listener.rs | 5 +- cdn-broker/src/tests/mod.rs | 2 +- cdn-client/src/reexports.rs | 3 - cdn-client/src/retry.rs | 9 +- cdn-marshal/src/binaries/marshal.rs | 8 ++ cdn-marshal/src/lib.rs | 21 ++++- cdn-proto/benches/protocols.rs | 10 +-- cdn-proto/src/connection/middleware/mod.rs | 91 ++++++++++++-------- cdn-proto/src/connection/protocols/memory.rs | 20 +++-- cdn-proto/src/connection/protocols/mod.rs | 42 +++++---- cdn-proto/src/connection/protocols/quic.rs | 16 ++-- cdn-proto/src/connection/protocols/tcp.rs | 16 ++-- cdn-proto/src/def.rs | 13 +-- tests/src/tests/mod.rs | 2 + 20 files changed, 190 insertions(+), 104 deletions(-) diff --git a/cdn-broker/src/binaries/bad-broker.rs b/cdn-broker/src/binaries/bad-broker.rs index 6583d9f..0f5c939 100644 --- a/cdn-broker/src/binaries/bad-broker.rs +++ b/cdn-broker/src/binaries/bad-broker.rs @@ -68,6 +68,7 @@ async fn main() -> Result<()> { public_advertise_endpoint: format!("local_ip:{public_port}"), private_bind_endpoint: format!("0.0.0.0:{private_port}"), private_advertise_endpoint: format!("local_ip:{private_port}"), + global_memory_pool_size: None, }; // Create new `Broker` diff --git a/cdn-broker/src/binaries/broker.rs b/cdn-broker/src/binaries/broker.rs index cba18b6..32ec5c9 100644 --- a/cdn-broker/src/binaries/broker.rs +++ b/cdn-broker/src/binaries/broker.rs @@ -52,6 +52,13 @@ struct Args { /// The seed for broker key generation #[arg(short, long, default_value_t = 0)] key_seed: u64, + + /// The size of the global memory pool (in bytes). This is the maximum number of bytes that + /// can be allocated at once for all connections. A connection will block if it + /// tries to allocate more than this amount until some memory is freed. + /// Default is 1GB. + #[arg(long, default_value_t = 1_073_741_824)] + global_memory_pool_size: usize, } #[tokio::main] @@ -96,6 +103,7 @@ async fn main() -> Result<()> { public_advertise_endpoint: args.public_advertise_endpoint, private_bind_endpoint: args.private_bind_endpoint, private_advertise_endpoint: args.private_advertise_endpoint, + global_memory_pool_size: Some(args.global_memory_pool_size), }; // Create new `Broker` diff --git a/cdn-broker/src/lib.rs b/cdn-broker/src/lib.rs index 5538584..51c9722 100644 --- a/cdn-broker/src/lib.rs +++ b/cdn-broker/src/lib.rs @@ -19,7 +19,7 @@ use std::{ mod metrics; use cdn_proto::{ bail, - connection::protocols::Protocol as _, + connection::{middleware::Middleware, protocols::Protocol as _}, crypto::tls::{generate_cert_from_ca, load_ca}, def::{Listener, Protocol, RunDef, Scheme}, discovery::{BrokerIdentifier, DiscoveryClient}, @@ -60,6 +60,12 @@ pub struct Config { /// An optional TLS CA key path. If not specified, will use the local one. pub ca_key_path: Option, + + /// The size of the global memory pool (in bytes). This is the maximum number of bytes that + /// can be allocated at once for all connections. A connection will block if it + /// tries to allocate more than this amount until some memory is freed. + /// Default is 1GB. + pub global_memory_pool_size: Option, } /// The broker `Inner` that we use to share common data between broker tasks. @@ -76,6 +82,9 @@ struct Inner { /// The connections that currently exist. We use this everywhere we need to update connection /// state or send messages. connections: Arc>, + + /// The shared middleware that we use for all connections. + middleware: Middleware, } /// The main `Broker` struct. We instantiate this when we want to run a broker. @@ -117,6 +126,8 @@ impl Broker { discovery_endpoint, ca_cert_path, ca_key_path, + + global_memory_pool_size, } = config; // Get the local IP address so we can replace in @@ -203,6 +214,9 @@ impl Broker { }) .transpose()?; + // Create the globally shared middleware + let middleware = Middleware::new(global_memory_pool_size, None); + // Create and return `Self` as wrapping an `Inner` (with things that we need to share) Ok(Self { inner: Arc::from(Inner { @@ -210,6 +224,7 @@ impl Broker { identity: identity.clone(), keypair, connections: Arc::from(RwLock::from(Connections::new(identity))), + middleware, }), metrics_bind_endpoint, user_listener, diff --git a/cdn-broker/src/reexports.rs b/cdn-broker/src/reexports.rs index e00cd24..948b59b 100644 --- a/cdn-broker/src/reexports.rs +++ b/cdn-broker/src/reexports.rs @@ -6,9 +6,6 @@ pub mod connection { pub use cdn_proto::connection::protocols::quic::Quic; pub use cdn_proto::connection::protocols::tcp::Tcp; } - pub use cdn_proto::connection::middleware::{ - Middleware, NoMiddleware, TrustedMiddleware, UntrustedMiddleware, - }; } pub mod discovery { diff --git a/cdn-broker/src/tasks/broker/heartbeat.rs b/cdn-broker/src/tasks/broker/heartbeat.rs index 96a517b..2e80c17 100644 --- a/cdn-broker/src/tasks/broker/heartbeat.rs +++ b/cdn-broker/src/tasks/broker/heartbeat.rs @@ -84,7 +84,7 @@ impl Inner { let connection = // Our TCP protocol is unsecured, so the cert we use does not matter. // Time out is at protocol level - match Protocol::::connect(&to_connect_endpoint, true).await + match Protocol::::connect(&to_connect_endpoint, true, inner.middleware.clone()).await { Ok(connection) => connection, Err(err) => { diff --git a/cdn-broker/src/tasks/broker/listener.rs b/cdn-broker/src/tasks/broker/listener.rs index b023269..367bbd7 100644 --- a/cdn-broker/src/tasks/broker/listener.rs +++ b/cdn-broker/src/tasks/broker/listener.rs @@ -28,7 +28,10 @@ impl Inner { let inner = self.clone(); spawn(async move { // Finalize the connection - let Ok(connection) = unfinalized_connection.finalize().await else { + let Ok(connection) = unfinalized_connection + .finalize(inner.middleware.clone()) + .await + else { return; }; diff --git a/cdn-broker/src/tasks/user/listener.rs b/cdn-broker/src/tasks/user/listener.rs index 7dbe777..3ea7028 100644 --- a/cdn-broker/src/tasks/user/listener.rs +++ b/cdn-broker/src/tasks/user/listener.rs @@ -28,7 +28,10 @@ impl Inner { let inner = self.clone(); spawn(async move { // Finalize the connection - let Ok(connection) = unfinalized_connection.finalize().await else { + let Ok(connection) = unfinalized_connection + .finalize(inner.middleware.clone()) + .await + else { return; }; diff --git a/cdn-broker/src/tests/mod.rs b/cdn-broker/src/tests/mod.rs index b0e059f..7a03ed7 100644 --- a/cdn-broker/src/tests/mod.rs +++ b/cdn-broker/src/tests/mod.rs @@ -133,7 +133,7 @@ impl TestDefinition { public_key, private_key, }, - + global_memory_pool_size: None, ca_cert_path: None, ca_key_path: None, }; diff --git a/cdn-client/src/reexports.rs b/cdn-client/src/reexports.rs index 34936f8..bbb4f01 100644 --- a/cdn-client/src/reexports.rs +++ b/cdn-client/src/reexports.rs @@ -6,9 +6,6 @@ pub mod connection { pub use cdn_proto::connection::protocols::quic::Quic; pub use cdn_proto::connection::protocols::tcp::Tcp; } - pub use cdn_proto::connection::middleware::{ - Middleware, NoMiddleware, TrustedMiddleware, UntrustedMiddleware, - }; } pub mod discovery { diff --git a/cdn-client/src/retry.rs b/cdn-client/src/retry.rs index 32a7b15..3323020 100644 --- a/cdn-client/src/retry.rs +++ b/cdn-client/src/retry.rs @@ -10,6 +10,7 @@ use std::{collections::HashSet, sync::Arc, time::Duration}; use cdn_proto::{ connection::{ auth::user::UserAuth, + middleware::Middleware, protocols::{Connection, Protocol as _}, }, crypto::signature::KeyPair, @@ -64,9 +65,13 @@ impl Inner { /// - If the connection failed /// - If authentication failed async fn connect(self: &Arc) -> Result { + // Create the middleware we will use for all connections + let middleware = Middleware::new(None, Some(1)); + // Make the connection to the marshal let connection = bail!( - Protocol::::connect(&self.endpoint, self.use_local_authority).await, + Protocol::::connect(&self.endpoint, self.use_local_authority, middleware.clone()) + .await, Connection, "failed to connect to endpoint" ); @@ -80,7 +85,7 @@ impl Inner { // Make the connection to the broker let connection = bail!( - Protocol::::connect(&broker_endpoint, self.use_local_authority).await, + Protocol::::connect(&broker_endpoint, self.use_local_authority, middleware).await, Connection, "failed to connect to broker" ); diff --git a/cdn-marshal/src/binaries/marshal.rs b/cdn-marshal/src/binaries/marshal.rs index 40805d3..34c8c0a 100644 --- a/cdn-marshal/src/binaries/marshal.rs +++ b/cdn-marshal/src/binaries/marshal.rs @@ -34,6 +34,13 @@ struct Args { /// If not provided, a local, pinned CA is used #[arg(long)] ca_key_path: Option, + + /// The size of the global memory pool (in bytes). This is the maximum number of bytes that + /// can be allocated at once for all connections. A connection will block if it + /// tries to allocate more than this amount until some memory is freed. + /// Default is 1GB. + #[arg(long, default_value_t = 1_073_741_824)] + global_memory_pool_size: usize, } #[tokio::main] @@ -60,6 +67,7 @@ async fn main() -> Result<()> { metrics_bind_endpoint: args.metrics_bind_endpoint, ca_cert_path: args.ca_cert_path, ca_key_path: args.ca_key_path, + global_memory_pool_size: Some(args.global_memory_pool_size), }; // Create new `Marshal` from the config diff --git a/cdn-marshal/src/lib.rs b/cdn-marshal/src/lib.rs index 84f1bbf..febdcc7 100644 --- a/cdn-marshal/src/lib.rs +++ b/cdn-marshal/src/lib.rs @@ -14,7 +14,10 @@ mod handlers; use cdn_proto::{ bail, - connection::protocols::{Listener as _, Protocol as _, UnfinalizedConnection}, + connection::{ + middleware::Middleware, + protocols::{Listener as _, Protocol as _, UnfinalizedConnection}, + }, crypto::tls::{generate_cert_from_ca, load_ca}, def::{Listener, Protocol, RunDef}, discovery::DiscoveryClient, @@ -42,6 +45,11 @@ pub struct Config { /// The endpoint to bind to for externalizing metrics (in `IP:port` form). If not provided, /// metrics are not exposed. pub metrics_bind_endpoint: Option, + + /// The size of the global memory pool (in bytes). This is the maximum number of bytes that + /// can be allocated at once for all connections. A connection will block if it + /// tries to allocate more than this amount until some memory is freed. + pub global_memory_pool_size: Option, } /// A connection `Marshal`. The user authenticates with it, receiving a permit @@ -57,6 +65,9 @@ pub struct Marshal { /// The endpoint to bind to for externalizing metrics (in `IP:port` form). If not provided, /// metrics are not exposed. metrics_bind_endpoint: Option, + + // The middleware to use for the connection + middleware: Middleware, } impl Marshal { @@ -73,6 +84,7 @@ impl Marshal { metrics_bind_endpoint, ca_cert_path, ca_key_path, + global_memory_pool_size, } = config; // Conditionally load CA cert and key in @@ -112,11 +124,15 @@ impl Marshal { }) .transpose()?; + // Create the middleware + let middleware = Middleware::new(global_memory_pool_size, None); + // Create `Self` from the `Listener` Ok(Self { listener: Arc::from(listener), metrics_bind_endpoint, discovery_client, + middleware, }) } @@ -143,9 +159,10 @@ impl Marshal { // Create a task to handle the connection let discovery_client = self.discovery_client.clone(); + let middleware = self.middleware.clone(); spawn(async move { // Finalize the connection - let Ok(connection) = unfinalized_connection.finalize().await else { + let Ok(connection) = unfinalized_connection.finalize(middleware).await else { return; }; diff --git a/cdn-proto/benches/protocols.rs b/cdn-proto/benches/protocols.rs index 7c9947c..d8bcb2a 100644 --- a/cdn-proto/benches/protocols.rs +++ b/cdn-proto/benches/protocols.rs @@ -2,7 +2,7 @@ use cdn_proto::{ connection::{ - middleware::TrustedMiddleware, + middleware::Middleware, protocols::{quic::Quic, tcp::Tcp, Connection, Listener, Protocol, UnfinalizedConnection}, Bytes, }, @@ -38,9 +38,7 @@ async fn transfer(conn1: Connection, conn2: Connection, raw_message: Bytes) { /// Set up our protocol benchmarks, including async runtime, given the message size /// to test. -fn set_up_bench>( - message_size: usize, -) -> (Runtime, Connection, Connection, Bytes) { +fn set_up_bench(message_size: usize) -> (Runtime, Connection, Connection, Bytes) { // Create new tokio runtime let benchmark_runtime = tokio::runtime::Runtime::new().expect("failed to create Tokio runtime"); @@ -66,13 +64,13 @@ fn set_up_bench>( // Finalize the connection unfinalized_connection - .finalize() + .finalize(Middleware::none()) .await .expect("failed to finalize connection") }); // Attempt to connect - let conn1 = Proto::connect(&format!("127.0.0.1:{port}"), true) + let conn1 = Proto::connect(&format!("127.0.0.1:{port}"), true, Middleware::none()) .await .expect("failed to connect to listener"); diff --git a/cdn-proto/src/connection/middleware/mod.rs b/cdn-proto/src/connection/middleware/mod.rs index 96995ea..618e9be 100644 --- a/cdn-proto/src/connection/middleware/mod.rs +++ b/cdn-proto/src/connection/middleware/mod.rs @@ -1,47 +1,66 @@ -use async_trait::async_trait; -use lazy_static::lazy_static; +use pool::MemoryPool; -use self::pool::{AllocationPermit, MemoryPool}; +use self::pool::AllocationPermit; pub mod pool; -lazy_static! { - /// A global semaphore that prevents the server from allocating too much memory at once. - static ref MEMORY_POOL: MemoryPool = MemoryPool::new((u32::MAX / 4) as usize); +/// Shared middleware for all connections. +#[derive(Clone)] +pub struct Middleware { + /// The global memory pool to check with before allocating. + global_memory_pool: Option, + + /// Per connection, the size of the channel buffer. + connection_message_pool_size: Option, } -/// A trait that defines middleware for a connection. -#[async_trait] -pub trait Middleware: 'static + Send + Sync + Clone { - async fn allocate_message_bytes(num_bytes: u32) -> Option { - // Acquire and return a permit for the number of bytes requested - let permit = MEMORY_POOL - .alloc(num_bytes) - .await - .expect("required semaphore has been dropped"); - - Some(permit) +impl Middleware { + /// Create a new middleware with a global memory pool of `global_memory_pool_size` bytes + /// and a connection message pool size of `connection_message_pool_size` bytes. + /// + /// If the global memory pool is not set, it will not be used. + /// If the connection message pool size is not set, an unbounded channel will be used. + pub fn new( + global_memory_pool_size: Option, + connection_message_pool_size: Option, + ) -> Self { + // Create a new global memory pool if the size is set, otherwise set it to `None`. + Self { + global_memory_pool: global_memory_pool_size.map(|size| MemoryPool::new(size)), + connection_message_pool_size, + } } -} -/// Middleware that does not do anything -#[derive(Clone)] -pub struct NoMiddleware; -#[async_trait] -impl Middleware for NoMiddleware { - async fn allocate_message_bytes(_num_bytes: u32) -> Option { - None + /// Create a new middleware with no global memory pool and no connection message pool size. + /// This means an unbounded channel will be used for connections and no global memory pool + /// will be checked. + pub fn none() -> Self { + // Create a new middleware with no global memory pool and no connection message pool size. + Self { + global_memory_pool: None, + connection_message_pool_size: None, + } } -} -/// Middleware for untrusted connections -#[derive(Clone)] -pub struct UntrustedMiddleware; -#[async_trait] -impl Middleware for UntrustedMiddleware {} + /// Allocate a permit for a message of `num_bytes` bytes. + /// If the global memory pool is not set, this will return `None`. + pub async fn allocate_message_bytes(&self, num_bytes: u32) -> Option { + if let Some(pool) = &self.global_memory_pool { + // If the global memory pool is set, allocate a permit + Some( + pool.alloc(num_bytes) + .await + .expect("required semaphore has been dropped"), + ) + } else { + // If the global memory pool is not set, return `None` + None + } + } -/// Middleware for trusted connections -#[derive(Clone)] -pub struct TrustedMiddleware; -#[async_trait] -impl Middleware for TrustedMiddleware {} + /// Get the connection message pool size, if set. + pub fn connection_message_pool_size(&self) -> Option { + // Return the connection message pool size + self.connection_message_pool_size + } +} diff --git a/cdn-proto/src/connection/protocols/memory.rs b/cdn-proto/src/connection/protocols/memory.rs index bb5125a..7c617f3 100644 --- a/cdn-proto/src/connection/protocols/memory.rs +++ b/cdn-proto/src/connection/protocols/memory.rs @@ -13,7 +13,6 @@ use tokio::{ }; use super::{Connection, Listener, Protocol, SoftClose, UnfinalizedConnection}; -use crate::connection::middleware::NoMiddleware; #[cfg(feature = "metrics")] use crate::{ bail, @@ -32,7 +31,7 @@ static LISTENERS: OnceLock>> = OnceLock: pub struct Memory; #[async_trait] -impl Protocol for Memory { +impl Protocol for Memory { type UnfinalizedConnection = UnfinalizedMemoryConnection; type Listener = MemoryListener; @@ -40,7 +39,11 @@ impl Protocol for Memory { /// /// # Errors /// - If the listener is not listening - async fn connect(remote_endpoint: &str, _use_local_authority: bool) -> Result { + async fn connect( + remote_endpoint: &str, + _use_local_authority: bool, + middleware: Middleware, + ) -> Result { // If the peer is not listening, return an error // Get or initialize the channels as a static value let listeners = LISTENERS.get_or_init(RwLock::default).read().await; @@ -70,7 +73,7 @@ impl Protocol for Memory { ); // Convert the streams into a `Connection` - let connection = Connection::from_streams::<_, _, M>(send_to_them, receive_from_them); + let connection = Connection::from_streams(send_to_them, receive_from_them, middleware); // Return our connection Ok(connection) @@ -111,11 +114,12 @@ pub struct UnfinalizedMemoryConnection { } #[async_trait] -impl UnfinalizedConnection for UnfinalizedMemoryConnection { +impl UnfinalizedConnection for UnfinalizedMemoryConnection { /// Prepares the `MemoryConnection` for usage by `Arc()ing` things. - async fn finalize(self) -> Result { + async fn finalize(self, middleware: Middleware) -> Result { // Convert the streams into a `Connection` - let connection = Connection::from_streams::<_, _, M>(self.send_stream, self.receive_stream); + let connection = + Connection::from_streams(self.send_stream, self.receive_stream, middleware); // Return our connection Ok(connection) @@ -187,7 +191,7 @@ impl Memory { let (sender, receiver) = duplex(8192); // Convert the streams into a `Connection` - Connection::from_streams::<_, _, NoMiddleware>(sender, receiver) + Connection::from_streams(sender, receiver, Middleware::none()) } } diff --git a/cdn-proto/src/connection/protocols/mod.rs b/cdn-proto/src/connection/protocols/mod.rs index 3486315..12f1c1c 100644 --- a/cdn-proto/src/connection/protocols/mod.rs +++ b/cdn-proto/src/connection/protocols/mod.rs @@ -29,15 +29,19 @@ pub mod tcp; /// The `Protocol` trait lets us be generic over a connection type (Tcp, Quic, etc). #[async_trait] -pub trait Protocol: Send + Sync + 'static { - type UnfinalizedConnection: UnfinalizedConnection + Send + Sync; +pub trait Protocol: Send + Sync + 'static { + type UnfinalizedConnection: UnfinalizedConnection + Send + Sync; type Listener: Listener + Send + Sync; /// Connect to a remote endpoint, returning an instance of `Self`. /// /// # Errors /// Errors if we fail to connect or if we fail to bind to the interface we want. - async fn connect(remote_endpoint: &str, use_local_authority: bool) -> Result; + async fn connect( + remote_endpoint: &str, + use_local_authority: bool, + middleware: Middleware, + ) -> Result; /// Bind to the local endpoint, returning an instance of `Listener`. /// @@ -61,11 +65,11 @@ pub trait Listener { } #[async_trait] -pub trait UnfinalizedConnection { +pub trait UnfinalizedConnection { /// Finalize an incoming connection. This is separated so we can prevent /// actors who are slow from clogging up the incoming connection by offloading /// it to a separate task. - async fn finalize(self) -> Result; + async fn finalize(self, middleware: Middleware) -> Result; } /// A connection to a remote endpoint. @@ -117,14 +121,19 @@ impl Connection { fn from_streams< W: AsyncWriteExt + Unpin + Send + SoftClose + 'static, R: AsyncReadExt + Unpin + Send + 'static, - M: Middleware, >( mut writer: W, mut reader: R, + middleware: Middleware, ) -> Self { - // Create the channels that will be used to send and receive messages - let (send_to_caller, receive_from_task) = kanal::unbounded_async(); - let (send_to_task, receive_from_caller) = kanal::unbounded_async(); + // Create the channels that will be used to send and receive messages. + // Conditionally create bounded channels if the user specifies a size + let ((send_to_caller, receive_from_task), (send_to_task, receive_from_caller)) = + if let Some(size) = middleware.connection_message_pool_size() { + (kanal::bounded_async(size), kanal::bounded_async(size)) + } else { + (kanal::unbounded_async(), kanal::unbounded_async()) + }; // Spawn the task that receives from the caller and sends to the stream let sender_task = tokio::spawn(async move { @@ -153,7 +162,7 @@ impl Connection { // Spawn the task that receives from the stream and sends to the caller let receiver_task = tokio::spawn(async move { // While we can successfully read messages from the stream, - while let Ok(message) = read_length_delimited::(&mut reader).await { + while let Ok(message) = read_length_delimited::(&mut reader, &middleware).await { if send_to_caller.send(message).await.is_err() { send_to_caller.close(); return; @@ -262,8 +271,9 @@ impl Connection { /// Read a length-delimited (serialized) message from a stream. /// Has a bounds check for if the message is too big -async fn read_length_delimited( +async fn read_length_delimited( stream: &mut R, + middleware: &Middleware, ) -> Result { // Read the message size from the stream let message_size = bail!( @@ -278,7 +288,7 @@ async fn read_length_delimited( } // Acquire the allocation if necessary - let permit = M::allocate_message_bytes(message_size).await; + let permit = middleware.allocate_message_bytes(message_size).await; // Create buffer of the proper size let mut buffer = vec![0; usize::try_from(message_size).expect(">= 32 bit system")]; @@ -349,7 +359,7 @@ pub mod tests { use super::{Listener, Protocol, UnfinalizedConnection}; use crate::{ - connection::middleware::NoMiddleware, + connection::middleware::Middleware, crypto::tls::{generate_cert_from_ca, LOCAL_CA_CERT, LOCAL_CA_KEY}, message::{Direct, Message}, }; @@ -362,7 +372,7 @@ pub mod tests { /// /// # Errors /// If the connection failed - pub async fn test_connection>(bind_endpoint: String) -> Result<()> { + pub async fn test_connection(bind_endpoint: String) -> Result<()> { // Generate cert signed by local CA let (cert, key) = generate_cert_from_ca(LOCAL_CA_CERT, LOCAL_CA_KEY)?; @@ -387,7 +397,7 @@ pub mod tests { let unfinalized_connection = listener.accept().await?; // Finalize the connection - let connection = unfinalized_connection.finalize().await?; + let connection = unfinalized_connection.finalize(Middleware::none()).await?; // Send our message connection.send_message(listener_to_new_connection_).await?; @@ -402,7 +412,7 @@ pub mod tests { // Spawn a task to connect and send and receive the message let new_connection_jh: JoinHandle> = spawn(async move { // Connect to the remote - let connection = P::connect(bind_endpoint.as_str(), true).await?; + let connection = P::connect(bind_endpoint.as_str(), true, Middleware::none()).await?; // Receive a message, assert it's the correct one let message = connection.recv_message().await?; diff --git a/cdn-proto/src/connection/protocols/quic.rs b/cdn-proto/src/connection/protocols/quic.rs index 28d19c9..97bc577 100644 --- a/cdn-proto/src/connection/protocols/quic.rs +++ b/cdn-proto/src/connection/protocols/quic.rs @@ -30,11 +30,15 @@ use crate::{ pub struct Quic; #[async_trait] -impl Protocol for Quic { +impl Protocol for Quic { type UnfinalizedConnection = UnfinalizedQuicConnection; type Listener = QuicListener; - async fn connect(remote_endpoint: &str, use_local_authority: bool) -> Result { + async fn connect( + remote_endpoint: &str, + use_local_authority: bool, + middleware: Middleware, + ) -> Result { // Parse the endpoint let remote_endpoint = bail_option!( bail!( @@ -123,7 +127,7 @@ impl Protocol for Quic { ); // Convert the streams into a `Connection` - let connection = Connection::from_streams::<_, _, M>(sender, receiver); + let connection = Connection::from_streams::<_, _>(sender, receiver, middleware); Ok(connection) } @@ -171,12 +175,12 @@ impl Protocol for Quic { pub struct UnfinalizedQuicConnection(Incoming); #[async_trait] -impl UnfinalizedConnection for UnfinalizedQuicConnection { +impl UnfinalizedConnection for UnfinalizedQuicConnection { /// Finalize the connection by awaiting on `Connecting` and cloning the connection. /// /// # Errors /// If we to finalize our connection. - async fn finalize(self) -> Result { + async fn finalize(self, middleware: Middleware) -> Result { // Await on the `Connecting` to obtain `Connection` let connection = bail!(self.0.await, Connection, "failed to finalize connection"); @@ -192,7 +196,7 @@ impl UnfinalizedConnection for UnfinalizedQuicConnection { ); // Create a sender and receiver - let connection = Connection::from_streams::<_, _, M>(sender, receiver); + let connection = Connection::from_streams(sender, receiver, middleware); // Clone and return the connection Ok(connection) diff --git a/cdn-proto/src/connection/protocols/tcp.rs b/cdn-proto/src/connection/protocols/tcp.rs index 3370720..51a4363 100644 --- a/cdn-proto/src/connection/protocols/tcp.rs +++ b/cdn-proto/src/connection/protocols/tcp.rs @@ -28,7 +28,7 @@ use crate::{ pub struct Tcp; #[async_trait] -impl Protocol for Tcp { +impl Protocol for Tcp { type Listener = TcpListener; type UnfinalizedConnection = UnfinalizedTcpConnection; @@ -37,7 +37,11 @@ impl Protocol for Tcp { /// /// # Errors /// Errors if we fail to connect or if we fail to bind to the interface we want. - async fn connect(remote_endpoint: &str, _use_local_authority: bool) -> Result + async fn connect( + remote_endpoint: &str, + _use_local_authority: bool, + middleware: Middleware, + ) -> Result where Self: Sized, { @@ -75,7 +79,7 @@ impl Protocol for Tcp { let (receiver, sender) = stream.into_split(); // Convert the streams into a `Connection` - let connection = Connection::from_streams::<_, _, M>(sender, receiver); + let connection = Connection::from_streams(sender, receiver, middleware); Ok(connection) } @@ -107,18 +111,18 @@ impl Protocol for Tcp { pub struct UnfinalizedTcpConnection(TcpStream); #[async_trait] -impl UnfinalizedConnection for UnfinalizedTcpConnection { +impl UnfinalizedConnection for UnfinalizedTcpConnection { /// Finalize the connection by splitting it into a sender and receiver side. /// Conssumes `Self`. /// /// # Errors /// Does not actually error, but satisfies trait bounds. - async fn finalize(self) -> Result { + async fn finalize(self, middleware: Middleware) -> Result { // Split the connection and create our wrapper let (receiver, sender) = self.0.into_split(); // Convert the streams into a `Connection` - let connection = Connection::from_streams::<_, _, M>(sender, receiver); + let connection = Connection::from_streams(sender, receiver, middleware); Ok(connection) } diff --git a/cdn-proto/src/def.rs b/cdn-proto/src/def.rs index 4d09d47..6d6f316 100644 --- a/cdn-proto/src/def.rs +++ b/cdn-proto/src/def.rs @@ -3,9 +3,6 @@ use jf_signature::bls_over_bn254::BLSOverBN254CurveSignatureScheme as BLS; use num_enum::{IntoPrimitive, TryFromPrimitive}; -use crate::connection::middleware::{ - Middleware as MiddlewareType, NoMiddleware, TrustedMiddleware, UntrustedMiddleware, -}; use crate::connection::protocols::memory::Memory; use crate::connection::protocols::{quic::Quic, tcp::Tcp, Protocol as ProtocolType}; use crate::crypto::signature::SignatureScheme; @@ -55,8 +52,7 @@ pub trait RunDef: 'static { /// This trait defines the connection configuration for a single CDN component. pub trait ConnectionDef: 'static { type Scheme: SignatureScheme; - type Protocol: ProtocolType; - type Middleware: MiddlewareType; + type Protocol: ProtocolType; } /// The production run configuration. @@ -75,7 +71,6 @@ pub struct ProductionBrokerConnection; impl ConnectionDef for ProductionBrokerConnection { type Scheme = BLS; type Protocol = Tcp; - type Middleware = TrustedMiddleware; } /// The production user connection configuration. @@ -84,7 +79,6 @@ pub struct ProductionUserConnection; impl ConnectionDef for ProductionUserConnection { type Scheme = BLS; type Protocol = Quic; - type Middleware = UntrustedMiddleware; } /// The production client connection configuration. @@ -95,7 +89,6 @@ pub struct ProductionClientConnection; impl ConnectionDef for ProductionClientConnection { type Scheme = Scheme<::User>; type Protocol = Protocol<::User>; - type Middleware = TrustedMiddleware; } /// The testing run configuration. @@ -114,7 +107,6 @@ pub struct TestingConnection; impl ConnectionDef for TestingConnection { type Scheme = BLS; type Protocol = Memory; - type Middleware = NoMiddleware; } // Type aliases to automatically disambiguate usage @@ -123,5 +115,4 @@ pub type PublicKey = as SignatureScheme>::PublicKey; // Type aliases to automatically disambiguate usage pub type Protocol = ::Protocol; -pub type Middleware = ::Middleware; -pub type Listener = as ProtocolType>>::Listener; +pub type Listener = as ProtocolType>::Listener; diff --git a/tests/src/tests/mod.rs b/tests/src/tests/mod.rs index f5906b4..8ed0297 100644 --- a/tests/src/tests/mod.rs +++ b/tests/src/tests/mod.rs @@ -70,6 +70,7 @@ async fn new_broker(key: u64, public_ep: &str, private_ep: &str, discovery_ep: & private_bind_endpoint: private_ep.to_string(), public_advertise_endpoint: public_ep.to_string(), public_bind_endpoint: public_ep.to_string(), + global_memory_pool_size: None, }; // Create broker @@ -91,6 +92,7 @@ async fn new_marshal(ep: &str, discovery_ep: &str) { metrics_bind_endpoint: None, ca_cert_path: None, ca_key_path: None, + global_memory_pool_size: None, }; // Create a new marshal From 107a56b023fbc0d88e8911447d6d24ddece94216 Mon Sep 17 00:00:00 2001 From: Rob Date: Thu, 6 Jun 2024 14:59:07 -0400 Subject: [PATCH 12/31] clippy --- cdn-proto/src/connection/middleware/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cdn-proto/src/connection/middleware/mod.rs b/cdn-proto/src/connection/middleware/mod.rs index 618e9be..f541751 100644 --- a/cdn-proto/src/connection/middleware/mod.rs +++ b/cdn-proto/src/connection/middleware/mod.rs @@ -26,7 +26,7 @@ impl Middleware { ) -> Self { // Create a new global memory pool if the size is set, otherwise set it to `None`. Self { - global_memory_pool: global_memory_pool_size.map(|size| MemoryPool::new(size)), + global_memory_pool: global_memory_pool_size.map(MemoryPool::new), connection_message_pool_size, } } From 10758992c0ed85ed21bb88a98f7ebfa4510288d9 Mon Sep 17 00:00:00 2001 From: Rob Date: Thu, 6 Jun 2024 15:16:51 -0400 Subject: [PATCH 13/31] conditional dep for console --- cdn-broker/src/binaries/bad-broker.rs | 1 + cdn-broker/src/binaries/broker.rs | 1 + 2 files changed, 2 insertions(+) diff --git a/cdn-broker/src/binaries/bad-broker.rs b/cdn-broker/src/binaries/bad-broker.rs index 0f5c939..2be2568 100644 --- a/cdn-broker/src/binaries/bad-broker.rs +++ b/cdn-broker/src/binaries/bad-broker.rs @@ -9,6 +9,7 @@ use clap::Parser; use jf_signature::{bls_over_bn254::BLSOverBN254CurveSignatureScheme as BLS, SignatureScheme}; use rand::{rngs::StdRng, SeedableRng}; use tokio::{spawn, time::sleep}; +#[cfg(not(tokio_unstable))] use tracing_subscriber::EnvFilter; #[derive(Parser, Debug)] diff --git a/cdn-broker/src/binaries/broker.rs b/cdn-broker/src/binaries/broker.rs index 32ec5c9..9b5e0da 100644 --- a/cdn-broker/src/binaries/broker.rs +++ b/cdn-broker/src/binaries/broker.rs @@ -5,6 +5,7 @@ use cdn_proto::{crypto::signature::KeyPair, def::ProductionRunDef, error::Result use clap::Parser; use jf_signature::{bls_over_bn254::BLSOverBN254CurveSignatureScheme as BLS, SignatureScheme}; use rand::{rngs::StdRng, SeedableRng}; +#[cfg(not(tokio_unstable))] use tracing_subscriber::EnvFilter; #[derive(Parser, Debug)] From 9aa5370de18f86a7ed6e7de8bdfef8df552f3df4 Mon Sep 17 00:00:00 2001 From: Rob Date: Thu, 6 Jun 2024 15:17:05 -0400 Subject: [PATCH 14/31] change log to info --- process-compose.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/process-compose.yaml b/process-compose.yaml index bfdae42..6b49e70 100644 --- a/process-compose.yaml +++ b/process-compose.yaml @@ -1,7 +1,7 @@ version: "0.5" environment: - - RUST_LOG=debug + - RUST_LOG=info processes: redis: From 07be57404c10e31135ac92d9b687cd4a4e89af81 Mon Sep 17 00:00:00 2001 From: Rob Date: Fri, 7 Jun 2024 09:35:06 -0400 Subject: [PATCH 15/31] package updates --- Cargo.lock | 87 +++++++++++++++++--------------------------- cdn-proto/Cargo.toml | 4 +- 2 files changed, 35 insertions(+), 56 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c20b1a0..3859617 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -94,9 +94,9 @@ dependencies = [ [[package]] name = "anstyle-query" -version = "1.0.3" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a64c907d4e79225ac72e2a354c9ce84d50ebb4586dee56c82b3ee73004f537f5" +checksum = "ad186efb764318d35165f1758e7dcef3b10628e26d41a44bc5550652e6804391" dependencies = [ "windows-sys 0.52.0", ] @@ -414,7 +414,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "89b47800b0be77592da0afd425cc03468052844aff33b84e33cc696f64e77b6a" dependencies = [ "concurrent-queue", - "event-listener-strategy 0.5.2", + "event-listener-strategy", "futures-core", "pin-project-lite", ] @@ -450,9 +450,9 @@ dependencies = [ [[package]] name = "async-io" -version = "2.3.2" +version = "2.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcccb0f599cfa2f8ace422d3555572f47424da5648a4382a9dd0310ff8210884" +checksum = "0d6baa8f0178795da0e71bc42c9e5d13261aac7ee549853162e66a241ba17964" dependencies = [ "async-lock", "cfg-if", @@ -469,12 +469,12 @@ dependencies = [ [[package]] name = "async-lock" -version = "3.3.0" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d034b430882f8381900d3fe6f0aaa3ad94f2cb4ac519b429692a1bc2dda4ae7b" +checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" dependencies = [ - "event-listener 4.0.3", - "event-listener-strategy 0.4.0", + "event-listener 5.3.1", + "event-listener-strategy", "pin-project-lite", ] @@ -768,9 +768,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.0.98" +version = "1.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41c270e7540d725e65ac7f1b212ac8ce349719624d7bcff99f8e2e488e8cf03f" +checksum = "96c51067fd44124faa7f870b4b1c969379ad32b2ba805aa959430ceaa384f695" [[package]] name = "cdn-broker" @@ -894,9 +894,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.4" +version = "4.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" +checksum = "a9689a29b593160de5bc4aacab7b5d54fb52231de70122626c178e6a368994c7" dependencies = [ "clap_builder", "clap_derive", @@ -904,9 +904,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.2" +version = "4.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" +checksum = "2e5387378c84f6faa26890ebf9f0a92989f8873d4d380467bcd0d8d8620424df" dependencies = [ "anstream", "anstyle", @@ -916,9 +916,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.4" +version = "4.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "528131438037fd55894f62d6e9f068b8f45ac57ffa77517819645d10aed04f64" +checksum = "c780290ccf4fb26629baa7a1081e68ced113f1d3ec302fa5948f1c381ebf06c6" dependencies = [ "heck 0.5.0", "proc-macro2", @@ -928,9 +928,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" +checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" [[package]] name = "colorchoice" @@ -1311,17 +1311,6 @@ version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" -[[package]] -name = "event-listener" -version = "4.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67b215c49b2b248c855fb73579eb1f4f26c38ffdc12973e20e07b91d78d5646e" -dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", -] - [[package]] name = "event-listener" version = "5.3.1" @@ -1333,16 +1322,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "event-listener-strategy" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3" -dependencies = [ - "event-listener 4.0.3", - "pin-project-lite", -] - [[package]] name = "event-listener-strategy" version = "0.5.2" @@ -1758,9 +1737,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.28" +version = "0.14.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80" +checksum = "f361cde2f109281a220d4307746cdfd5ee3f410da58a70377762396775634b33" dependencies = [ "bytes", "futures-channel", @@ -2476,9 +2455,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "piper" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "464db0c665917b13ebb5d453ccdec4add5658ee1adc7affc7677615356a8afaf" +checksum = "ae1d5c74c9876f070d3e8fd503d748c7d974c3e48da8f41350fa5222ef9b4391" dependencies = [ "atomic-waker", "fastrand", @@ -2542,9 +2521,9 @@ dependencies = [ [[package]] name = "polling" -version = "3.7.0" +version = "3.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645493cf344456ef24219d02a768cf1fb92ddf8c92161679ae3d91b91a637be3" +checksum = "5e6a007746f34ed64099e88783b0ae369eaa3da6392868ba262e2af9b8fbaea1" dependencies = [ "cfg-if", "concurrent-queue", @@ -2609,9 +2588,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.84" +version = "1.0.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6" +checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23" dependencies = [ "unicode-ident", ] @@ -2819,9 +2798,9 @@ dependencies = [ [[package]] name = "redis" -version = "0.24.0" +version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c580d9cbbe1d1b479e8d67cf9daf6a62c957e6846048408b80b43ac3f6af84cd" +checksum = "e0d7a6955c7511f60f3ba9e86c6d02b3c3f144f8c24b288d1f4e18074ab8bbec" dependencies = [ "arc-swap", "async-trait", @@ -3519,9 +3498,9 @@ checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" [[package]] name = "symbolic-common" -version = "12.8.0" +version = "12.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cccfffbc6bb3bb2d3a26cd2077f4d055f6808d266f9d4d158797a4c60510dfe" +checksum = "71297dc3e250f7dbdf8adb99e235da783d690f5819fdeb4cce39d9cfb0aca9f1" dependencies = [ "debugid", "memmap2", @@ -3531,9 +3510,9 @@ dependencies = [ [[package]] name = "symbolic-demangle" -version = "12.8.0" +version = "12.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76a99812da4020a67e76c4eb41f08c87364c14170495ff780f30dd519c221a68" +checksum = "424fa2c9bf2c862891b9cfd354a752751a6730fd838a4691e7f6c2c7957b9daf" dependencies = [ "cpp_demangle", "rustc-demangle", diff --git a/cdn-proto/Cargo.toml b/cdn-proto/Cargo.toml index 99ee396..ede673f 100644 --- a/cdn-proto/Cargo.toml +++ b/cdn-proto/Cargo.toml @@ -26,12 +26,12 @@ harness = false [dependencies] -redis = { version = "0.24.0", default-features = false, features = [ +redis = { version = "0.25", default-features = false, features = [ "connection-manager", "tokio-comp", ] } -sqlx = { version = "0.7.3", default-features = false, features = [ +sqlx = { version = "0.7", default-features = false, features = [ "sqlite", "macros", "migrate", From 5b13eec3c234bfc93688efc0f9eea0e8fa5aeb35 Mon Sep 17 00:00:00 2001 From: Rob Date: Fri, 7 Jun 2024 11:47:38 -0400 Subject: [PATCH 16/31] more benchmark macros --- cdn-broker/src/tests/mod.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/cdn-broker/src/tests/mod.rs b/cdn-broker/src/tests/mod.rs index 7a03ed7..aac6048 100644 --- a/cdn-broker/src/tests/mod.rs +++ b/cdn-broker/src/tests/mod.rs @@ -41,6 +41,13 @@ macro_rules! send_message_as { .await .expect("failed to send message"); }; + + // Send a message to all actors in a vector + (all, $all: expr, $message: expr) => { + for actor in &$all { + send_message_as!(actor, $message); + } + }; } #[macro_export] @@ -53,6 +60,13 @@ macro_rules! assert_received { } }; + // Make sure everyone in the vector has received this message + (yes, all, $all: expr, $message:expr) => { + for actor in &$all { + assert_received!(yes, actor, $message); + } + }; + // Make sure we haven't received this message (no, $actor: expr) => { assert!( From 49971754af01e0e9748076130a195031fd6ac4c4 Mon Sep 17 00:00:00 2001 From: Rob Date: Fri, 7 Jun 2024 13:10:10 -0400 Subject: [PATCH 17/31] minor updates to testing harness --- cdn-broker/src/tests/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cdn-broker/src/tests/mod.rs b/cdn-broker/src/tests/mod.rs index aac6048..f58dfd6 100644 --- a/cdn-broker/src/tests/mod.rs +++ b/cdn-broker/src/tests/mod.rs @@ -24,6 +24,7 @@ mod direct; use crate::{connections::DirectMap, Broker, Config}; /// An actor is a [user/broker] that we inject to test message send functionality. +#[derive(Clone)] pub struct InjectedActor { /// The in-memory sender that sends to the broker under test pub sender: Connection, @@ -173,8 +174,7 @@ impl TestDefinition { // For each user, for (i, topics) in users.iter().enumerate() { // Extrapolate identifier - #[allow(clippy::cast_possible_truncation)] - let identifier: Arc> = Arc::from(vec![i as u8]); + let identifier: Arc> = Arc::from(i.to_be_bytes().to_vec()); // Generate a testing pair of memory network channels let connection1 = Memory::gen_testing_connection(); From fe2fc82335c555c92984dac3bd6028eb859ab012 Mon Sep 17 00:00:00 2001 From: Rob Date: Fri, 7 Jun 2024 13:21:15 -0400 Subject: [PATCH 18/31] Revert "minor updates to testing harness" This reverts commit 49971754af01e0e9748076130a195031fd6ac4c4. --- cdn-broker/src/tests/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cdn-broker/src/tests/mod.rs b/cdn-broker/src/tests/mod.rs index f58dfd6..aac6048 100644 --- a/cdn-broker/src/tests/mod.rs +++ b/cdn-broker/src/tests/mod.rs @@ -24,7 +24,6 @@ mod direct; use crate::{connections::DirectMap, Broker, Config}; /// An actor is a [user/broker] that we inject to test message send functionality. -#[derive(Clone)] pub struct InjectedActor { /// The in-memory sender that sends to the broker under test pub sender: Connection, @@ -174,7 +173,8 @@ impl TestDefinition { // For each user, for (i, topics) in users.iter().enumerate() { // Extrapolate identifier - let identifier: Arc> = Arc::from(i.to_be_bytes().to_vec()); + #[allow(clippy::cast_possible_truncation)] + let identifier: Arc> = Arc::from(vec![i as u8]); // Generate a testing pair of memory network channels let connection1 = Memory::gen_testing_connection(); From 2d7902dc123fbcdd32e4378765e076a180bbac3d Mon Sep 17 00:00:00 2001 From: Rob Date: Fri, 7 Jun 2024 18:19:05 -0400 Subject: [PATCH 19/31] vastly improve testing infra --- cdn-broker/benches/broadcast.rs | 22 +- cdn-broker/benches/direct.rs | 46 ++- cdn-broker/src/reexports.rs | 2 +- cdn-broker/src/tests/broadcast.rs | 50 +++- cdn-broker/src/tests/direct.rs | 72 +++-- cdn-broker/src/tests/mod.rs | 465 ++++++++++++++++++------------ cdn-proto/src/def.rs | 24 +- tests/src/tests/mod.rs | 11 +- 8 files changed, 419 insertions(+), 273 deletions(-) diff --git a/cdn-broker/benches/broadcast.rs b/cdn-broker/benches/broadcast.rs index 51f1336..bc9e191 100644 --- a/cdn-broker/benches/broadcast.rs +++ b/cdn-broker/benches/broadcast.rs @@ -3,8 +3,9 @@ use std::time::Duration; -use cdn_broker::reexports::tests::{TestDefinition, TestRun}; +use cdn_broker::reexports::tests::{TestBroker, TestDefinition, TestRun, TestUser}; use cdn_broker::{assert_received, send_message_as}; +use cdn_proto::connection::protocols::memory::Memory; use cdn_proto::connection::Bytes; use cdn_proto::def::TestTopic; use cdn_proto::message::{Broadcast, Message}; @@ -49,11 +50,14 @@ fn bench_broadcast_user(c: &mut Criterion) { // Set up our broker under test let run = benchmark_runtime.block_on(async move { let run_definition = TestDefinition { - connected_users: vec![vec![TestTopic::Global as u8], vec![TestTopic::Global as u8]], + connected_users: vec![ + TestUser::with_index(0, vec![TestTopic::Global.into()]), + TestUser::with_index(1, vec![TestTopic::Global.into()]), + ], connected_brokers: vec![], }; - run_definition.into_run().await + run_definition.into_run::().await }); // Benchmark @@ -71,14 +75,18 @@ fn bench_broadcast_broker(c: &mut Criterion) { // Set up our broker under test let run = benchmark_runtime.block_on(async move { let run_definition = TestDefinition { - connected_users: vec![vec![]], + connected_users: vec![TestUser::with_index(0, vec![])], connected_brokers: vec![ - (vec![], vec![TestTopic::Global as u8]), - (vec![], vec![TestTopic::Global as u8]), + TestBroker { + connected_users: vec![TestUser::with_index(1, vec![TestTopic::Global.into()])], + }, + TestBroker { + connected_users: vec![TestUser::with_index(2, vec![TestTopic::Global.into()])], + }, ], }; - run_definition.into_run().await + run_definition.into_run::().await }); // Benchmark diff --git a/cdn-broker/benches/direct.rs b/cdn-broker/benches/direct.rs index a5b9e8e..184e1fd 100644 --- a/cdn-broker/benches/direct.rs +++ b/cdn-broker/benches/direct.rs @@ -3,8 +3,9 @@ use std::time::Duration; -use cdn_broker::reexports::tests::{TestDefinition, TestRun}; -use cdn_broker::{assert_received, send_message_as}; +use cdn_broker::reexports::tests::{TestBroker, TestDefinition, TestRun, TestUser}; +use cdn_broker::{assert_received, at_index, send_message_as}; +use cdn_proto::connection::protocols::memory::Memory; use cdn_proto::connection::Bytes; use cdn_proto::def::TestTopic; use cdn_proto::message::{Direct, Message}; @@ -15,7 +16,7 @@ use pprof::criterion::{Output, PProfProfiler}; async fn direct_user_to_self(run: &TestRun) { // Allocate a rather large message let message = Message::Direct(Direct { - recipient: vec![0], + recipient: at_index![0], message: vec![0; 10000], }); @@ -29,7 +30,7 @@ async fn direct_user_to_self(run: &TestRun) { async fn direct_user_to_user(run: &TestRun) { // Allocate a rather large message let message = Message::Direct(Direct { - recipient: vec![1], + recipient: at_index![1], message: vec![0; 10000], }); @@ -43,7 +44,7 @@ async fn direct_user_to_user(run: &TestRun) { async fn direct_user_to_broker(run: &TestRun) { // Allocate a rather large message let message = Message::Direct(Direct { - recipient: vec![2], + recipient: at_index![2], message: vec![0; 10000], }); @@ -57,7 +58,7 @@ async fn direct_user_to_broker(run: &TestRun) { async fn direct_broker_to_user(run: &TestRun) { // Allocate a rather large message let message = Message::Direct(Direct { - recipient: vec![0], + recipient: at_index![0], message: vec![0; 10000], }); @@ -76,11 +77,11 @@ fn bench_direct_user_to_self(c: &mut Criterion) { // Set up our broker under test let run = benchmark_runtime.block_on(async move { let run_definition = TestDefinition { - connected_users: vec![vec![TestTopic::Global as u8]], + connected_users: vec![TestUser::with_index(0, vec![TestTopic::Global as u8])], connected_brokers: vec![], }; - run_definition.into_run().await + run_definition.into_run::().await }); // Run the benchmark @@ -99,11 +100,14 @@ fn bench_direct_user_to_user(c: &mut Criterion) { // Set up our broker under test let run = benchmark_runtime.block_on(async move { let run_definition = TestDefinition { - connected_users: vec![vec![TestTopic::Global as u8], vec![TestTopic::Global as u8]], + connected_users: vec![ + TestUser::with_index(0, vec![TestTopic::Global as u8]), + TestUser::with_index(1, vec![TestTopic::Global as u8]), + ], connected_brokers: vec![], }; - run_definition.into_run().await + run_definition.into_run::().await }); // Run the benchmark @@ -122,11 +126,16 @@ fn bench_direct_user_to_broker(c: &mut Criterion) { // Set up our broker under test let run = benchmark_runtime.block_on(async move { let run_definition = TestDefinition { - connected_users: vec![vec![TestTopic::Global as u8], vec![TestTopic::Global as u8]], - connected_brokers: vec![(vec![2], vec![TestTopic::Global as u8])], + connected_users: vec![ + TestUser::with_index(0, vec![TestTopic::Global as u8]), + TestUser::with_index(1, vec![TestTopic::Global as u8]), + ], + connected_brokers: vec![TestBroker { + connected_users: vec![TestUser::with_index(2, vec![TestTopic::Global as u8])], + }], }; - run_definition.into_run().await + run_definition.into_run::().await }); // Run the benchmark @@ -145,11 +154,16 @@ fn bench_direct_broker_to_user(c: &mut Criterion) { // Set up our broker under test let run = benchmark_runtime.block_on(async move { let run_definition = TestDefinition { - connected_users: vec![vec![TestTopic::Global as u8], vec![TestTopic::Global as u8]], - connected_brokers: vec![(vec![2], vec![TestTopic::Global as u8])], + connected_users: vec![ + TestUser::with_index(0, vec![TestTopic::Global as u8]), + TestUser::with_index(1, vec![TestTopic::Global as u8]), + ], + connected_brokers: vec![TestBroker { + connected_users: vec![TestUser::with_index(0, vec![TestTopic::Global as u8])], + }], }; - run_definition.into_run().await + run_definition.into_run::().await }); // Run the benchmark diff --git a/cdn-broker/src/reexports.rs b/cdn-broker/src/reexports.rs index 948b59b..120bd97 100644 --- a/cdn-broker/src/reexports.rs +++ b/cdn-broker/src/reexports.rs @@ -29,5 +29,5 @@ pub mod error { /// This is not guarded by `![cfg(test)]` because we use the same functions /// when doing benchmarks. pub mod tests { - pub use crate::tests::{TestDefinition, TestRun}; + pub use crate::tests::{TestBroker, TestDefinition, TestRun, TestUser}; } diff --git a/cdn-broker/src/tests/broadcast.rs b/cdn-broker/src/tests/broadcast.rs index 92eca01..c044528 100644 --- a/cdn-broker/src/tests/broadcast.rs +++ b/cdn-broker/src/tests/broadcast.rs @@ -4,13 +4,13 @@ use std::time::Duration; use cdn_proto::{ - connection::Bytes, + connection::{protocols::memory::Memory, Bytes}, def::TestTopic, message::{Broadcast, Message}, }; use tokio::time::{sleep, timeout}; -use super::TestDefinition; +use super::{TestBroker, TestDefinition, TestUser}; use crate::{assert_received, send_message_as}; /// Test sending a broadcast message from a user. @@ -22,19 +22,28 @@ async fn test_broadcast_user() { // This run definition: 3 brokers, 6 users let run_definition = TestDefinition { connected_users: vec![ - vec![TestTopic::Global as u8, TestTopic::DA as u8], - vec![TestTopic::DA as u8], - vec![TestTopic::Global as u8], + TestUser::with_index(0, vec![TestTopic::Global.into(), TestTopic::DA.into()]), + TestUser::with_index(1, vec![TestTopic::DA.into()]), + TestUser::with_index(2, vec![TestTopic::Global.into()]), ], connected_brokers: vec![ - (vec![3], vec![TestTopic::DA as u8]), - (vec![4], vec![TestTopic::Global as u8, TestTopic::DA as u8]), - (vec![5], vec![]), + TestBroker { + connected_users: vec![TestUser::with_index(3, vec![TestTopic::DA.into()])], + }, + TestBroker { + connected_users: vec![TestUser::with_index( + 4, + vec![TestTopic::Global.into(), TestTopic::DA.into()], + )], + }, + TestBroker { + connected_users: vec![TestUser::with_index(5, vec![])], + }, ], }; // Start the run - let run = run_definition.into_run().await; + let run = run_definition.into_run::().await; // We need a little time for our subscribe messages to propagate sleep(Duration::from_millis(25)).await; @@ -88,19 +97,28 @@ async fn test_broadcast_broker() { // This run definition: 3 brokers, 6 users let run_definition = TestDefinition { connected_users: vec![ - vec![TestTopic::Global as u8, TestTopic::DA as u8], - vec![TestTopic::DA as u8], - vec![TestTopic::Global as u8], + TestUser::with_index(0, vec![TestTopic::Global.into(), TestTopic::DA.into()]), + TestUser::with_index(1, vec![TestTopic::DA.into()]), + TestUser::with_index(2, vec![TestTopic::Global.into()]), ], connected_brokers: vec![ - (vec![3], vec![TestTopic::DA as u8]), - (vec![4], vec![TestTopic::Global as u8, TestTopic::DA as u8]), - (vec![5], vec![]), + TestBroker { + connected_users: vec![TestUser::with_index(3, vec![TestTopic::DA.into()])], + }, + TestBroker { + connected_users: vec![TestUser::with_index( + 4, + vec![TestTopic::Global.into(), TestTopic::DA.into()], + )], + }, + TestBroker { + connected_users: vec![TestUser::with_index(5, vec![])], + }, ], }; // Start the run - let run = run_definition.into_run().await; + let run = run_definition.into_run::().await; // We need a little time for our subscribe messages to propagate sleep(Duration::from_millis(25)).await; diff --git a/cdn-broker/src/tests/direct.rs b/cdn-broker/src/tests/direct.rs index 15b503f..9c74653 100644 --- a/cdn-broker/src/tests/direct.rs +++ b/cdn-broker/src/tests/direct.rs @@ -4,14 +4,14 @@ use std::time::Duration; use cdn_proto::{ - connection::Bytes, + connection::{protocols::memory::Memory, Bytes}, def::TestTopic, message::{Direct, Message}, }; use tokio::time::{sleep, timeout}; -use super::TestDefinition; -use crate::{assert_received, send_message_as}; +use super::{TestBroker, TestDefinition, TestUser}; +use crate::{assert_received, at_index, send_message_as}; /// This test tests that: /// 1. A user sending a message to itself on a broker has it delivered @@ -23,22 +23,28 @@ async fn test_direct_user_to_user() { // This run definition: 3 brokers, 6 users let run_definition = TestDefinition { connected_users: vec![ - vec![TestTopic::Global as u8], - vec![TestTopic::Global as u8, TestTopic::DA as u8], + TestUser::with_index(0, vec![TestTopic::Global.into()]), + TestUser::with_index(1, vec![TestTopic::DA.into()]), ], connected_brokers: vec![ - (vec![2], vec![TestTopic::DA as u8]), - (vec![3], vec![]), - (vec![4], vec![]), + TestBroker { + connected_users: vec![TestUser::with_index(2, vec![TestTopic::DA.into()])], + }, + TestBroker { + connected_users: vec![TestUser::with_index(3, vec![])], + }, + TestBroker { + connected_users: vec![TestUser::with_index(4, vec![])], + }, ], }; // Start the run - let run = run_definition.into_run().await; + let run = run_definition.into_run::().await; // Send a message from user_0 to itself let message = Message::Direct(Direct { - recipient: vec![0], + recipient: at_index![0], message: b"test direct 0".to_vec(), }); @@ -54,7 +60,7 @@ async fn test_direct_user_to_user() { // Create a message that user_1 will use to send to user_0 let message = Message::Direct(Direct { - recipient: vec![1], + recipient: at_index![1], message: b"test direct 1".to_vec(), }); @@ -78,30 +84,36 @@ async fn test_direct_user_to_broker() { // This run definition: 3 brokers, 6 users let run_definition = TestDefinition { connected_users: vec![ - vec![TestTopic::Global as u8], - vec![TestTopic::Global as u8, TestTopic::DA as u8], + TestUser::with_index(0, vec![TestTopic::Global.into()]), + TestUser::with_index(1, vec![TestTopic::Global.into(), TestTopic::DA.into()]), ], connected_brokers: vec![ - (vec![3], vec![TestTopic::DA as u8]), - (vec![2], vec![]), - (vec![4], vec![]), + TestBroker { + connected_users: vec![TestUser::with_index(2, vec![])], + }, + TestBroker { + connected_users: vec![TestUser::with_index(3, vec![TestTopic::DA.into()])], + }, + TestBroker { + connected_users: vec![TestUser::with_index(4, vec![])], + }, ], }; // Start the run - let run = run_definition.into_run().await; + let run = run_definition.into_run::().await; // Send a message as a user to another user that another broker owns (user_0 to user_2) let message = Message::Direct(Direct { - recipient: vec![2], + recipient: at_index![2], message: b"test direct 2".to_vec(), }); // Send the message as user_0 send_message_as!(run.connected_users[0], message); - // Assert broker_1 received it - assert_received!(yes, run.connected_brokers[1], message); + // Assert broker_0 received it + assert_received!(yes, run.connected_brokers[0], message); // Assert no one else got it, and we didn't get it again assert_received!(no, all, run.connected_users); @@ -117,23 +129,29 @@ async fn test_direct_broker_to_user() { // This run definition: 3 brokers, 6 users let run_definition = TestDefinition { connected_users: vec![ - vec![TestTopic::Global as u8], - vec![TestTopic::Global as u8, TestTopic::DA as u8], + TestUser::with_index(0, vec![TestTopic::Global.into()]), + TestUser::with_index(1, vec![TestTopic::Global.into(), TestTopic::DA.into()]), ], connected_brokers: vec![ - (vec![3], vec![TestTopic::DA as u8]), - (vec![2], vec![]), - (vec![4], vec![]), + TestBroker { + connected_users: vec![TestUser::with_index(2, vec![])], + }, + TestBroker { + connected_users: vec![TestUser::with_index(3, vec![TestTopic::DA.into()])], + }, + TestBroker { + connected_users: vec![TestUser::with_index(4, vec![])], + }, ], }; // Start the run - let run = run_definition.into_run().await; + let run = run_definition.into_run::().await; // Send a message as a broker through the test broker to a user that we own // Tests that broker_1 -> test_broker should not come back to us. let message = Message::Direct(Direct { - recipient: vec![2], + recipient: at_index![2], message: b"test direct 2".to_vec(), }); diff --git a/cdn-broker/src/tests/mod.rs b/cdn-broker/src/tests/mod.rs index aac6048..e5eb053 100644 --- a/cdn-broker/src/tests/mod.rs +++ b/cdn-broker/src/tests/mod.rs @@ -5,8 +5,16 @@ use std::sync::Arc; use cdn_proto::{ - connection::protocols::{memory::Memory, Connection}, - crypto::{rng::DeterministicRng, signature::KeyPair}, + connection::{ + middleware::Middleware, + protocols::{Connection, Listener, Protocol, UnfinalizedConnection}, + UserPublicKey, + }, + crypto::{ + rng::DeterministicRng, + signature::KeyPair, + tls::{generate_cert_from_ca, LOCAL_CA_CERT, LOCAL_CA_KEY}, + }, def::TestingRunDef, discovery::BrokerIdentifier, message::{Message, Topic}, @@ -15,29 +23,20 @@ use jf_signature::{bls_over_bn254::BLSOverBN254CurveSignatureScheme as BLS, Sign use rand::{rngs::StdRng, RngCore, SeedableRng}; use tokio::spawn; +use crate::{connections::DirectMap, Broker, Config}; + #[cfg(test)] mod broadcast; #[cfg(test)] mod direct; -use crate::{connections::DirectMap, Broker, Config}; - -/// An actor is a [user/broker] that we inject to test message send functionality. -pub struct InjectedActor { - /// The in-memory sender that sends to the broker under test - pub sender: Connection, - /// The in-memory receiver that receives from the broker under test - pub receiver: Connection, -} - /// This lets us send a message as a particular network actor. It just helps /// readability. #[macro_export] macro_rules! send_message_as { ($obj:expr, $message: expr) => { - $obj.sender - .send_message($message.clone()) + $obj.send_message($message.clone()) .await .expect("failed to send message"); }; @@ -70,7 +69,7 @@ macro_rules! assert_received { // Make sure we haven't received this message (no, $actor: expr) => { assert!( - timeout(Duration::from_millis(100), $actor.receiver.recv_message()) + timeout(Duration::from_millis(100), $actor.recv_message()) .await .is_err(), "wasn't supposed to receive a message but did" @@ -80,11 +79,7 @@ macro_rules! assert_received { // Make sure we have received the message in a timeframe of 50ms (yes, $actor: expr, $message:expr) => { // Receive the message with a timeout - let Ok(message) = timeout( - Duration::from_millis(50), - $actor.receiver.recv_message_raw(), - ) - .await + let Ok(message) = timeout(Duration::from_millis(50), $actor.recv_message_raw()).await else { panic!("timed out trying to receive message"); }; @@ -102,206 +97,294 @@ macro_rules! assert_received { }; } +/// Get the public key of a user at a particular index +#[macro_export] +macro_rules! at_index { + ($index: expr) => { + ($index as usize).to_le_bytes().to_vec() + }; +} + +/// A test user is a user that will be connected to the broker under test. +pub struct TestUser { + /// The public key of the user + pub public_key: UserPublicKey, + + /// The topics the user is subscribed to + pub subscribed_topics: Vec, +} + +impl TestUser { + /// Create a new test user with a particular index and subscribed topics + pub fn with_index(index: usize, subscribed_topics: Vec) -> Self { + let public_key = Arc::new(at_index!(index)); + Self { + public_key, + subscribed_topics, + } + } +} + +/// A test broker is a broker that will be connected to the broker under test. +pub struct TestBroker { + /// The users connected to this broker + pub connected_users: Vec, +} + +impl TestBroker { + /// Create a new test broker with a set of connected users + pub fn new(connected_users: Vec) -> Self { + Self { connected_users } + } +} + /// This is what we use to describe tests. These are the [brokers/users] connected /// _DIRECTLY_ to the broker under test, along with the topics they're subscribed to, /// and the user index they are responsible for. A connected user has the same "identity" /// as its index in the `connected_users` vector. pub struct TestDefinition { - pub connected_brokers: Vec<(Vec, Vec)>, - pub connected_users: Vec>, + pub connected_brokers: Vec, + pub connected_users: Vec, } /// A `TestRun` is converted from a `TestDefinition`. It contains actors with their -/// sending and receiving channels so we can pretend to be talking to the broker. +/// connections so we can pretend to be talking to the broker. pub struct TestRun { - /// The connected brokers and their handles - pub connected_brokers: Vec, + /// The connected brokers and their connections + pub connected_brokers: Vec, - /// The connected users and their handles - pub connected_users: Vec, + /// The connected users and their connections + pub connected_users: Vec, } -impl TestDefinition { - /// Creates a new broker under test. This configures and starts a local broker - /// who will be deterministically tested. - async fn new_broker_under_test() -> Broker { - // Create a key for our broker [under test] - let (private_key, public_key) = BLS::key_gen(&(), &mut DeterministicRng(0)).unwrap(); - - // Create a temporary SQLite file for the broker's discovery endpoint - let temp_dir = std::env::temp_dir(); - let discovery_endpoint = temp_dir - .join(format!("test-{}.sqlite", StdRng::from_entropy().next_u64())) - .to_string_lossy() - .into(); - - // Build the broker's config - let broker_config: Config = Config { - metrics_bind_endpoint: None, - public_advertise_endpoint: String::new(), - public_bind_endpoint: String::new(), - private_advertise_endpoint: String::new(), - private_bind_endpoint: String::new(), - discovery_endpoint, - keypair: KeyPair { - public_key, - private_key, - }, - global_memory_pool_size: None, - ca_cert_path: None, - ca_key_path: None, - }; +/// Generate `n` connection pairs for a given protocol +async fn gen_connection_pairs(num: usize) -> Vec<(Connection, Connection)> { + // Generate cert signed by local CA + let (cert, key) = + generate_cert_from_ca(LOCAL_CA_CERT, LOCAL_CA_KEY).expect("failed to generate cert"); - // Create the broker - Broker::new(broker_config) + // Get random port to bind to + let bind_endpoint = format!( + "127.0.0.1:{}", + portpicker::pick_unused_port().expect("failed to get unused port") + ); + + // Create the listener + let listener = P::bind(bind_endpoint.as_str(), cert, key) + .await + .expect("failed to bind"); + + // Create the list of connection pairs we will return + let mut connection_pairs = Vec::new(); + + for _ in 0..num { + // Spawn a task to connect the user to the broker + let bind_endpoint_ = bind_endpoint.clone(); + let unfinalized_outgoing_connection = + spawn(async move { P::connect(&bind_endpoint_, true, Middleware::none()).await }); + + // Accept the connection from the user + let incoming_connection = listener + .accept() + .await + .expect("failed to accept connection") + .finalize(Middleware::none()) .await - .expect("failed to create broker") + .expect("failed to finalize connection"); + + // Finalize the outgoing connection + let outgoing_connection = unfinalized_outgoing_connection + .await + .expect("failed to connect to broker") + .expect("failed to connect to broker"); + + // Add the connection pair to the list + connection_pairs.push((incoming_connection, outgoing_connection)); } - /// This is a helper function to inject users from our `TestDefinition` into the broker under test. - /// It creates sending and receiving channels, spawns a receive loop on the broker, - /// and adds the user to the internal state. - /// - /// Then, it sends subscription messages to the broker for the topics described in `TestDefinition` - fn inject_users( - broker_under_test: &Broker, - users: &[Vec], - ) -> Vec { - // Return this at the end, our running list of users - let mut injected_users: Vec = Vec::new(); - - // For each user, - for (i, topics) in users.iter().enumerate() { - // Extrapolate identifier - #[allow(clippy::cast_possible_truncation)] - let identifier: Arc> = Arc::from(vec![i as u8]); - - // Generate a testing pair of memory network channels - let connection1 = Memory::gen_testing_connection(); - let connection2 = Memory::gen_testing_connection(); - - // Create our user object - let injected_user = InjectedActor { - sender: connection1.clone(), - receiver: connection2.clone(), - }; - - // Spawn our user receiver in the broker under test - let inner = broker_under_test.inner.clone(); - let identifier_ = identifier.clone(); - let receive_handle = - spawn(async move { inner.user_receive_loop(&identifier_, connection1).await }) - .abort_handle(); - - // Inject our user into the connections - broker_under_test.inner.connections.write().add_user( - &identifier, - connection2, - topics, - receive_handle, - ); - - // Add to our running total - injected_users.push(injected_user); - } + connection_pairs +} +/// Create a new broker under test. All test users and brokers will be connected to this broker. +async fn new_broker_under_test() -> Broker> { + // Create a key for our broker [under test] + let (private_key, public_key) = BLS::key_gen(&(), &mut DeterministicRng(0)).unwrap(); + + // Create a temporary SQLite file for the broker's discovery endpoint + let temp_dir = std::env::temp_dir(); + let discovery_endpoint = temp_dir + .join(format!("test-{}.sqlite", StdRng::from_entropy().next_u64())) + .to_string_lossy() + .into(); + + // Build the broker's config + let broker_config = Config { + metrics_bind_endpoint: None, + public_advertise_endpoint: String::new(), + public_bind_endpoint: String::new(), + private_advertise_endpoint: String::new(), + private_bind_endpoint: String::new(), + discovery_endpoint, + keypair: KeyPair { + public_key, + private_key, + }, + global_memory_pool_size: None, + ca_cert_path: None, + ca_key_path: None, + }; + + // Create and return the broker + Broker::new(broker_config) + .await + .expect("failed to create broker") +} - injected_users +/// This is a helper function to inject users from our `TestDefinition` into the broker under test. +/// It creates the relevant connections, spawns a receive loop on the broker, and adds the user to +/// the internal state. +/// +/// After that, it sends subscription messages to the broker for the topics described in `TestDefinition` +async fn inject_users( + broker_under_test: &Broker>, + users: Vec, +) -> Vec { + // Generate a set of connected pairs, one for each user + // incoming (listener), outgoing (connect) + let mut connection_pairs = gen_connection_pairs::(users.len()).await; + + // Create the list of users we will return + let mut connected_users = Vec::new(); + + // For each user, + for user in users { + // Pop the next connection + let (incoming_connection, outgoing_connection) = connection_pairs + .pop() + .expect("not enough connections spawned"); + + // Spawn a task to handle the user inside of the broker + let inner = broker_under_test.inner.clone(); + let user_public_key = user.public_key.clone(); + let incoming_connection_ = incoming_connection.clone(); + let receive_handle = spawn(async move { + inner + .user_receive_loop(&user_public_key, incoming_connection_) + .await + }) + .abort_handle(); + + // Inject our user into the connections + broker_under_test.inner.connections.write().add_user( + &user.public_key.clone(), + incoming_connection, + &user.subscribed_topics, + receive_handle, + ); + + // Add our connection with our user so we can return it + connected_users.push(outgoing_connection); } - /// This is a helper function to inject brokers from our `TestDefinition` into the broker under test. - /// It creates sending and receiving channels, spawns a receive loop on the broker, - /// and adds the broker to the internal state. - /// - /// Then, it sends subscription messages to the broker for the topics described in `TestDefinition`, - /// and syncs the users up so the broker knows where to send messages. - async fn inject_brokers( - broker_under_test: &Broker, - brokers: Vec<(Vec, Vec)>, - ) -> Vec { - // Return this at the end, our running list of brokers - let mut injected_brokers: Vec = Vec::new(); - - // For each broker, - for (i, broker) in brokers.iter().enumerate() { - // Create our identifier - let identifier: BrokerIdentifier = format!("{i}/{i}") - .try_into() - .expect("failed to create broker identifier"); - - // Generate a testing pair of memory network channels - let connection1 = Memory::gen_testing_connection(); - let connection2 = Memory::gen_testing_connection(); - - // Create our broker object - let injected_broker = InjectedActor { - sender: connection1.clone(), - receiver: connection2.clone(), - }; - - // Spawn our receiver in the broker under test - let inner = broker_under_test.inner.clone(); - let identifier_ = identifier.clone(); - let receive_handle = spawn(async move { - inner - .broker_receive_loop(&identifier_, connection1) - .await - .unwrap(); - }) - .abort_handle(); - - // Inject our broker into the connections - broker_under_test.inner.connections.write().add_broker( - identifier.clone(), - connection2, - receive_handle, - ); - - // Send our subscriptions to it - let subscribe_message = Message::Subscribe(broker.1.clone()); - send_message_as!(injected_broker, subscribe_message); - - // Create a map of our users - let mut user_map = DirectMap::new(identifier.clone()); - - for user in broker.0.clone() { - user_map.insert(Arc::from(vec![user]), identifier.clone()); - } - - // Sync the map to the broker under test - let user_sync_message = Message::UserSync( - rkyv::to_bytes::<_, 256>(&user_map.diff()) - .expect("failed to serialize map") - .to_vec(), - ); - send_message_as!(injected_broker, user_sync_message); - - // Add to our running total - injected_brokers.push(injected_broker); + connected_users +} + +/// This is a helper function to inject brokers from our `TestDefinition` into the broker under test. +/// It creates the relevant connections, spawns a receive loop on the broker, and adds the broker to +/// the internal state. +/// +/// After that, it sends subscription messages to the broker for the topics described in `TestDefinition`, +/// and syncs the users up so the broker knows where to send messages. +async fn inject_brokers( + broker_under_test: &Broker>, + brokers: Vec, +) -> Vec { + // Generate a set of connected pairs, one for each broker + // incoming (listener), outgoing (connect) + let mut connection_pairs = gen_connection_pairs::(brokers.len()).await; + + // Create the list of brokers we will return + let mut connected_brokers = Vec::new(); + + // For each broker + for (i, broker) in brokers.into_iter().enumerate() { + // Create an identifier for the broker + let identifier: BrokerIdentifier = format!("{i}/{i}") + .try_into() + .expect("failed to create broker identifier"); + + // Pop the next connection + let (incoming_connection, outgoing_connection) = connection_pairs + .pop() + .expect("not enough connections spawned"); + + // Spawn a task to handle the broker inside of the broker under test + let inner = broker_under_test.inner.clone(); + let identifier_ = identifier.clone(); + let incoming_connection_ = incoming_connection.clone(); + let receive_handle = spawn(async move { + inner + .broker_receive_loop(&identifier_, incoming_connection_) + .await + }) + .abort_handle(); + + // Inject the broker into our connections + broker_under_test.inner.connections.write().add_broker( + identifier.clone(), + incoming_connection, + receive_handle, + ); + + // Aggregate the topics we should be subscribed to + let mut topics = Vec::new(); + for user in broker.connected_users.iter() { + topics.extend(user.subscribed_topics.clone()); } - injected_brokers + // Send our subscriptions to it + let subscribe_message = Message::Subscribe(topics); + send_message_as!(outgoing_connection, subscribe_message); + + // Create a map of our users + let mut user_map = DirectMap::new(identifier.clone()); + for user in broker.connected_users { + user_map.insert(Arc::from(user.public_key), identifier.clone()); + } + + // Sync the map to the broker under test + let user_sync_message = Message::UserSync( + rkyv::to_bytes::<_, 256>(&user_map.diff()) + .expect("failed to serialize map") + .to_vec(), + ); + send_message_as!(outgoing_connection, user_sync_message); + + // Add our connection with our broker so we can return it + connected_brokers.push(outgoing_connection); } - /// This is the conversion from a `TestDefinition` into a `Run`. Implicitly, the broker is started - /// and all sending and receiving operations on that broker start. - pub async fn into_run(self) -> TestRun { - // Create a new `Run`, which we will be returning + connected_brokers +} + +impl TestDefinition { + /// Start the test run, connecting all users and brokers to the broker under test. + pub async fn into_run(self) -> TestRun { + // Create the `Run` we will return let mut run = TestRun { - connected_users: vec![], - connected_brokers: vec![], + connected_users: Vec::new(), + connected_brokers: Vec::new(), }; - // Create our broker under test - let broker_under_test = Self::new_broker_under_test().await; + // Create a new broker under test with the provided protocols + let broker_under_test = new_broker_under_test::().await; - // Inject our brokers - run.connected_brokers = - Self::inject_brokers(&broker_under_test, self.connected_brokers).await; + // Inject the users into the broker under test + run.connected_users = inject_users(&broker_under_test, self.connected_users).await; - // Inject our users - run.connected_users = Self::inject_users(&broker_under_test, &self.connected_users); + // Inject the brokers into the broker under test + run.connected_brokers = inject_brokers(&broker_under_test, self.connected_brokers).await; - // Return our injected brokers and users + // Return the run run } } diff --git a/cdn-proto/src/def.rs b/cdn-proto/src/def.rs index 6d6f316..316742c 100644 --- a/cdn-proto/src/def.rs +++ b/cdn-proto/src/def.rs @@ -1,9 +1,9 @@ //! Compile-time run configuration for all CDN components. +use std::marker::PhantomData; use jf_signature::bls_over_bn254::BLSOverBN254CurveSignatureScheme as BLS; use num_enum::{IntoPrimitive, TryFromPrimitive}; -use crate::connection::protocols::memory::Memory; use crate::connection::protocols::{quic::Quic, tcp::Tcp, Protocol as ProtocolType}; use crate::crypto::signature::SignatureScheme; use crate::discovery::embedded::Embedded; @@ -92,21 +92,25 @@ impl ConnectionDef for ProductionClientConnection { } /// The testing run configuration. -/// Uses in-memory protocols and an embedded discovery client. -pub struct TestingRunDef; -impl RunDef for TestingRunDef { - type Broker = TestingConnection; - type User = TestingConnection; +/// Uses generic protocols and an embedded discovery client. +pub struct TestingRunDef { + pd: PhantomData<(B, U)>, +} +impl RunDef for TestingRunDef { + type Broker = TestingConnection; + type User = TestingConnection; type DiscoveryClientType = Embedded; type Topic = TestTopic; } /// The testing connection configuration. -/// Uses BLS signatures, in-memory protocols, and no middleware. -pub struct TestingConnection; -impl ConnectionDef for TestingConnection { +/// Uses BLS signatures, generic protocols, and no middleware. +pub struct TestingConnection { + pd: PhantomData

, +} +impl ConnectionDef for TestingConnection

{ type Scheme = BLS; - type Protocol = Memory; + type Protocol = P; } // Type aliases to automatically disambiguate usage diff --git a/tests/src/tests/mod.rs b/tests/src/tests/mod.rs index 8ed0297..77401a2 100644 --- a/tests/src/tests/mod.rs +++ b/tests/src/tests/mod.rs @@ -2,6 +2,7 @@ use cdn_broker::{Broker, Config as BrokerConfig}; use cdn_client::{Client, Config as ClientConfig}; use cdn_marshal::{Config as MarshalConfig, Marshal}; use cdn_proto::{ + connection::protocols::memory::Memory, crypto::signature::{KeyPair, Serializable, SignatureScheme}, def::{TestingConnection, TestingRunDef}, discovery::{embedded::Embedded, BrokerIdentifier, DiscoveryClient}, @@ -57,7 +58,7 @@ async fn new_broker(key: u64, public_ep: &str, private_ep: &str, discovery_ep: & let (private_key, public_key) = keypair_from_seed(key); // Create config - let config: BrokerConfig = BrokerConfig { + let config: BrokerConfig> = BrokerConfig { ca_cert_path: None, ca_key_path: None, discovery_endpoint: discovery_ep.to_string(), @@ -74,7 +75,7 @@ async fn new_broker(key: u64, public_ep: &str, private_ep: &str, discovery_ep: & }; // Create broker - let broker = Broker::::new(config) + let broker = Broker::>::new(config) .await .expect("failed to create broker"); @@ -96,7 +97,7 @@ async fn new_marshal(ep: &str, discovery_ep: &str) { }; // Create a new marshal - let marshal = Marshal::::new(config) + let marshal = Marshal::>::new(config) .await .expect("failed to create marshal"); @@ -106,7 +107,7 @@ async fn new_marshal(ep: &str, discovery_ep: &str) { /// Create a new client, supplying it with the given topics and marshal /// endpoint. `Key` is a deterministic, seeded keypair. -fn new_client(key: u64, topics: Vec, marshal_ep: &str) -> Client { +fn new_client(key: u64, topics: Vec, marshal_ep: &str) -> Client> { // Generate keypair let (private_key, public_key) = keypair_from_seed(key); @@ -122,7 +123,7 @@ fn new_client(key: u64, topics: Vec, marshal_ep: &str) -> Client::new(config) + Client::>::new(config) } /// Create a new database client with the given endpoint and identity. From ff68ff29d6f8fa88e07678fbe95b851912fbe1ab Mon Sep 17 00:00:00 2001 From: Rob Date: Mon, 10 Jun 2024 09:37:43 -0400 Subject: [PATCH 20/31] add latency calculation --- cdn-proto/src/connection/metrics.rs | 6 +++++- cdn-proto/src/connection/middleware/pool.rs | 16 +++++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/cdn-proto/src/connection/metrics.rs b/cdn-proto/src/connection/metrics.rs index 7ae4ddb..8165b82 100644 --- a/cdn-proto/src/connection/metrics.rs +++ b/cdn-proto/src/connection/metrics.rs @@ -1,7 +1,7 @@ //! Feature-gated connection specific metrics use lazy_static::lazy_static; -use prometheus::{register_gauge, Gauge}; +use prometheus::{register_gauge, register_histogram, Gauge, Histogram}; lazy_static! { // The total number of bytes sent @@ -11,4 +11,8 @@ lazy_static! { // The total number of bytes received pub static ref BYTES_RECV: Gauge = register_gauge!("total_bytes_recv", "the total number of bytes received").unwrap(); + + // Per-message latency + pub static ref LATENCY: Histogram = + register_histogram!("message_latency", "message delivery latency").unwrap(); } diff --git a/cdn-proto/src/connection/middleware/pool.rs b/cdn-proto/src/connection/middleware/pool.rs index bd63dc3..612348f 100644 --- a/cdn-proto/src/connection/middleware/pool.rs +++ b/cdn-proto/src/connection/middleware/pool.rs @@ -7,12 +7,14 @@ //! receive a message, we await on allocating it. When we are done sending it out to everyone, //! we drop the `Parc`, allowing for re-allocation. -use std::{ops::Deref, sync::Arc}; +use std::{ops::Deref, sync::Arc, time::Instant}; use anyhow::Result; use derivative::Derivative; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use crate::connection::metrics; + /// A global memory arena that tracks but does not allocate memory. /// Allows for asynchronous capping of memory usage. #[derive(Clone)] @@ -28,7 +30,15 @@ impl MemoryPool { /// An acquired permit that allows for allocation of a memory region /// of a particular size. #[allow(dead_code)] -pub struct AllocationPermit(OwnedSemaphorePermit); +pub struct AllocationPermit(OwnedSemaphorePermit, Instant); + +/// When dropped, log the time of allocation to deallocation +/// as latency. +impl Drop for AllocationPermit { + fn drop(&mut self) { + metrics::LATENCY.observe(self.1.elapsed().as_secs_f64()); + } +} impl MemoryPool { /// Asynchronously allocate `n` bytes from the global pool, waiting @@ -39,7 +49,7 @@ impl MemoryPool { pub async fn alloc(&self, n: u32) -> Result { // Acquire many permits to the underlying semaphore let permit = self.0.clone().acquire_many_owned(n).await?; - Ok(AllocationPermit(permit)) + Ok(AllocationPermit(permit, Instant::now())) } } From ec31d162f2b44fbe594be2de8954b62bc8395f98 Mon Sep 17 00:00:00 2001 From: Rob Date: Mon, 10 Jun 2024 10:25:16 -0400 Subject: [PATCH 21/31] update process compose --- process-compose.yaml | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/process-compose.yaml b/process-compose.yaml index 6b49e70..7c92802 100644 --- a/process-compose.yaml +++ b/process-compose.yaml @@ -11,7 +11,7 @@ processes: command: cargo run --bin marshal -- -d "redis://:changeme!@localhost:6379" broker_0: - command: cargo run --bin broker -- -d "redis://:changeme!@localhost:6379" + command: cargo run --bin broker -- -d "redis://:changeme!@localhost:6379" --metrics-bind-endpoint localhost:9090 broker_1: command: cargo run --bin broker --release -- @@ -21,9 +21,6 @@ processes: --private-advertise-endpoint local_ip:1741 -d "redis://:changeme!@localhost:6379" - client_0: - command: cargo run --bin client --release -- -m "127.0.0.1:1737" - # Uncomment the following lines to run misbehaving processes and the Tokio console # broker_tokio_console: @@ -45,8 +42,8 @@ processes: # marshal_0: # condition: process_started - # bad_sender: - # command: cargo run --bin bad-sender -- -m "127.0.0.1:1737" - # depends_on: - # marshal_0: - # condition: process_started \ No newline at end of file + bad_sender: + command: cargo run --bin bad-sender -- -m "127.0.0.1:1737" + depends_on: + marshal_0: + condition: process_started \ No newline at end of file From c5069af38dd7bb2b1124e1c61b6b4239950af832 Mon Sep 17 00:00:00 2001 From: Rob Date: Mon, 10 Jun 2024 10:28:30 -0400 Subject: [PATCH 22/31] add metrics to second broker --- process-compose.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/process-compose.yaml b/process-compose.yaml index 7c92802..9303999 100644 --- a/process-compose.yaml +++ b/process-compose.yaml @@ -16,9 +16,10 @@ processes: broker_1: command: cargo run --bin broker --release -- --public-bind-endpoint 0.0.0.0:1740 - --public-advertise-endpoint local_ip:1740 + --public-advertise-endpoint local_ip:1740 # local_ip is a special value that will be replaced with the host's local IP address --private-bind-endpoint 0.0.0.0:1741 --private-advertise-endpoint local_ip:1741 + --metrics-bind-endpoint localhost:9091 -d "redis://:changeme!@localhost:6379" # Uncomment the following lines to run misbehaving processes and the Tokio console From 46321a54b2a8b3fe3d34c39997bee53eb7328b54 Mon Sep 17 00:00:00 2001 From: Rob Date: Mon, 10 Jun 2024 10:55:29 -0400 Subject: [PATCH 23/31] add per-message latency calculation --- cdn-client/src/binaries/bad-sender.rs | 6 +++++- cdn-proto/src/connection/metrics.rs | 6 +++++- cdn-proto/src/metrics.rs | 29 ++++++++++++++++++++++++++- process-compose.yaml | 17 ++++++++-------- 4 files changed, 47 insertions(+), 11 deletions(-) diff --git a/cdn-client/src/binaries/bad-sender.rs b/cdn-client/src/binaries/bad-sender.rs index e764f5d..bd9bd55 100644 --- a/cdn-client/src/binaries/bad-sender.rs +++ b/cdn-client/src/binaries/bad-sender.rs @@ -19,6 +19,10 @@ struct Args { /// The remote marshal endpoint to connect to, including the port. #[arg(short, long)] marshal_endpoint: String, + + /// The size of the messages to be sent to the broker + #[arg(long, default_value = "9000000")] + message_size: u32, } #[tokio::main] @@ -56,7 +60,7 @@ async fn main() { // Create a client, specifying the BLS signature algorithm // and the `QUIC` protocol. let client = Client::::new(config); - let message = vec![0u8; 10000]; + let message = vec![0u8; args.message_size as usize]; // In a loop, loop { diff --git a/cdn-proto/src/connection/metrics.rs b/cdn-proto/src/connection/metrics.rs index 8165b82..af55332 100644 --- a/cdn-proto/src/connection/metrics.rs +++ b/cdn-proto/src/connection/metrics.rs @@ -14,5 +14,9 @@ lazy_static! { // Per-message latency pub static ref LATENCY: Histogram = - register_histogram!("message_latency", "message delivery latency").unwrap(); + register_histogram!("latency", "message delivery latency").unwrap(); + + // The per-message latency over the last 30 seconds + pub static ref RUNNING_LATENCY: Gauge = + register_gauge!("running_latency", "average message delivery latency over the last 30s").unwrap(); } diff --git a/cdn-proto/src/metrics.rs b/cdn-proto/src/metrics.rs index 5390273..667ff47 100644 --- a/cdn-proto/src/metrics.rs +++ b/cdn-proto/src/metrics.rs @@ -1,12 +1,18 @@ //! A simple metrics server that allows us to serve process metrics as-needed. -use std::net::SocketAddr; +use std::{net::SocketAddr, time::Duration}; +use tokio::time::sleep; use tracing::error; use warp::Filter; +use crate::connection::metrics; + /// Start the metrics server that should run forever on a particular port pub async fn serve_metrics(bind_endpoint: SocketAddr) { + // Spawn an additional task to calculate the running latency + tokio::spawn(running_latency_calculator()); + // The `/metrics` route is standard for Prometheus deployments let route = warp::path("metrics").map(|| { // Gather all metrics, encode them, and return them. @@ -25,3 +31,24 @@ pub async fn serve_metrics(bind_endpoint: SocketAddr) { // Serve the route on the specified port warp::serve(route).run(bind_endpoint).await; } + +/// A simple latency calculator that calculates the running latency every 30s +/// and sets the corresponding `RUNNING_LATENCY` gauge. +pub async fn running_latency_calculator() { + // Initialize the values to 0 + let mut latency_sum = 0.0; + let mut latency_count = 0; + + // Start calculating the latency + loop { + // Sleep for 30s + sleep(Duration::from_secs(30)).await; + + // Calculate the running latency by subtracting the previous sum and count + latency_sum = metrics::LATENCY.get_sample_sum() - latency_sum; + latency_count = metrics::LATENCY.get_sample_count() - latency_count; + + // Set the running latency + metrics::RUNNING_LATENCY.set(latency_sum / latency_count as f64); + } +} diff --git a/process-compose.yaml b/process-compose.yaml index 9303999..09953c2 100644 --- a/process-compose.yaml +++ b/process-compose.yaml @@ -13,15 +13,22 @@ processes: broker_0: command: cargo run --bin broker -- -d "redis://:changeme!@localhost:6379" --metrics-bind-endpoint localhost:9090 + # Note: `local_ip` is a special value that will be replaced with the host's local IP address broker_1: command: cargo run --bin broker --release -- --public-bind-endpoint 0.0.0.0:1740 - --public-advertise-endpoint local_ip:1740 # local_ip is a special value that will be replaced with the host's local IP address + --public-advertise-endpoint local_ip:1740 --private-bind-endpoint 0.0.0.0:1741 --private-advertise-endpoint local_ip:1741 --metrics-bind-endpoint localhost:9091 -d "redis://:changeme!@localhost:6379" + heavy_load: + command: cargo run --bin bad-sender -- -m "127.0.0.1:1737" + depends_on: + marshal_0: + condition: process_started + # Uncomment the following lines to run misbehaving processes and the Tokio console # broker_tokio_console: @@ -41,10 +48,4 @@ processes: # command: cargo run --bin bad-connector -- -m "127.0.0.1:1737" # depends_on: # marshal_0: - # condition: process_started - - bad_sender: - command: cargo run --bin bad-sender -- -m "127.0.0.1:1737" - depends_on: - marshal_0: - condition: process_started \ No newline at end of file + # condition: process_started \ No newline at end of file From 4ee374ade30df8e6714147840b5558c7134fd82e Mon Sep 17 00:00:00 2001 From: Rob Date: Mon, 10 Jun 2024 10:58:32 -0400 Subject: [PATCH 24/31] clippy --- cdn-broker/src/connections/mod.rs | 2 +- cdn-broker/src/tests/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cdn-broker/src/connections/mod.rs b/cdn-broker/src/connections/mod.rs index 006ae70..41f99dc 100644 --- a/cdn-broker/src/connections/mod.rs +++ b/cdn-broker/src/connections/mod.rs @@ -155,7 +155,7 @@ impl Connections { let differences = (added.copied().collect(), removed.copied().collect()); // Set the previous to the new one - *previous = now.clone(); + previous.clone_from(&now); // Return the differences differences diff --git a/cdn-broker/src/tests/mod.rs b/cdn-broker/src/tests/mod.rs index e5eb053..a782d7d 100644 --- a/cdn-broker/src/tests/mod.rs +++ b/cdn-broker/src/tests/mod.rs @@ -348,7 +348,7 @@ async fn inject_brokers( // Create a map of our users let mut user_map = DirectMap::new(identifier.clone()); for user in broker.connected_users { - user_map.insert(Arc::from(user.public_key), identifier.clone()); + user_map.insert(user.public_key, identifier.clone()); } // Sync the map to the broker under test From d8206f86a924acf1cff09290419db83c1a270c70 Mon Sep 17 00:00:00 2001 From: Rob Date: Mon, 10 Jun 2024 13:35:25 -0400 Subject: [PATCH 25/31] implement tcp+tls --- Cargo.lock | 12 + cdn-proto/Cargo.toml | 5 +- cdn-proto/src/connection/protocols/mod.rs | 8 + cdn-proto/src/connection/protocols/quic.rs | 23 +- cdn-proto/src/connection/protocols/tcp_tls.rs | 260 ++++++++++++++++++ cdn-proto/src/crypto/tls.rs | 34 ++- 6 files changed, 319 insertions(+), 23 deletions(-) create mode 100644 cdn-proto/src/connection/protocols/tcp_tls.rs diff --git a/Cargo.lock b/Cargo.lock index 3859617..339aff8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -854,6 +854,7 @@ dependencies = [ "sqlx", "thiserror", "tokio", + "tokio-rustls", "tracing", "url", "warp", @@ -3761,6 +3762,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +dependencies = [ + "rustls", + "rustls-pki-types", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.15" diff --git a/cdn-proto/Cargo.toml b/cdn-proto/Cargo.toml index ede673f..bc95460 100644 --- a/cdn-proto/Cargo.toml +++ b/cdn-proto/Cargo.toml @@ -49,7 +49,7 @@ thiserror = "1" quinn = { version = "0.11", default-features = false, features = [ "rustls", "runtime-tokio", - "ring" + "ring", ] } jf-signature.workspace = true ark-serialize = "0.4" @@ -58,6 +58,7 @@ url = "2" tracing.workspace = true pem = "3" rustls = { version = "0.23", default-features = false } +tokio-rustls = { version = "0.26", default-features = false } async-trait = "0.1" warp = { version = "0.3", default-features = false } anyhow = "1" @@ -66,4 +67,4 @@ rkyv.workspace = true mnemonic = "1" rcgen.workspace = true derivative.workspace = true -num_enum = "0.7" \ No newline at end of file +num_enum = "0.7" diff --git a/cdn-proto/src/connection/protocols/mod.rs b/cdn-proto/src/connection/protocols/mod.rs index 12f1c1c..c45fadb 100644 --- a/cdn-proto/src/connection/protocols/mod.rs +++ b/cdn-proto/src/connection/protocols/mod.rs @@ -26,6 +26,7 @@ use crate::connection::metrics; pub mod memory; pub mod quic; pub mod tcp; +pub mod tcp_tls; /// The `Protocol` trait lets us be generic over a connection type (Tcp, Quic, etc). #[async_trait] @@ -146,6 +147,13 @@ impl Connection { receive_from_caller.close(); return; }; + + // Flush the writer + // Is a no-op for everything but TCP+TLS + if writer.flush().await.is_err() { + receive_from_caller.close(); + return; + }; } BytesOrSoftClose::SoftClose(result_sender) => { // Soft close the writer, allowing it to finish sending diff --git a/cdn-proto/src/connection/protocols/quic.rs b/cdn-proto/src/connection/protocols/quic.rs index 97bc577..e551425 100644 --- a/cdn-proto/src/connection/protocols/quic.rs +++ b/cdn-proto/src/connection/protocols/quic.rs @@ -11,13 +11,12 @@ use std::{ use async_trait::async_trait; use quinn::{ClientConfig, Endpoint, Incoming, SendStream, ServerConfig, TransportConfig, VarInt}; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; -use rustls::RootCertStore; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::time::timeout; use super::{Connection, Listener, Protocol, SoftClose, UnfinalizedConnection}; use crate::connection::middleware::Middleware; -use crate::crypto::tls::{LOCAL_CA_CERT, PROD_CA_CERT}; +use crate::crypto::tls::generate_root_certificate_store; use crate::parse_endpoint; use crate::{ bail, bail_option, @@ -62,24 +61,8 @@ impl Protocol for Quic { "failed to bind to local endpoint" ); - // Pick which authority to trust based on whether or not we have requested - // to use the local one - let root_ca = if use_local_authority { - LOCAL_CA_CERT - } else { - PROD_CA_CERT - }; - - // Parse the provided CA in `.PEM` format - let root_ca = bail!(pem::parse(root_ca), Parse, "failed to parse PEM file").into_contents(); - - // Create root certificate store and add our CA - let mut root_cert_store = RootCertStore::empty(); - bail!( - root_cert_store.add(CertificateDer::from(root_ca)), - File, - "failed to add certificate to root store" - ); + // Generate root certificate store based on the local authority + let root_cert_store = generate_root_certificate_store(use_local_authority)?; // Create config from the root store let mut config: ClientConfig = bail!( diff --git a/cdn-proto/src/connection/protocols/tcp_tls.rs b/cdn-proto/src/connection/protocols/tcp_tls.rs new file mode 100644 index 0000000..980071e --- /dev/null +++ b/cdn-proto/src/connection/protocols/tcp_tls.rs @@ -0,0 +1,260 @@ +//! This file defines and implements a thin wrapper around a TCP +//! + TLS connection that implements our message framing and connection +//! logic. + +use std::net::SocketAddr; +use std::net::ToSocketAddrs; +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use rustls::pki_types::CertificateDer; +use rustls::pki_types::PrivateKeyDer; +use rustls::pki_types::ServerName; +use rustls::ClientConfig; +use rustls::ServerConfig; +use tokio::io::WriteHalf; +use tokio::net::TcpListener; +use tokio::net::{TcpSocket, TcpStream}; +use tokio::time::timeout; +use tokio_rustls::TlsAcceptor; +use tokio_rustls::TlsConnector; + +use super::SoftClose; +use super::{Connection, Listener, Protocol, UnfinalizedConnection}; +use crate::connection::middleware::Middleware; +use crate::crypto::tls::generate_root_certificate_store; +use crate::{ + bail, bail_option, + error::{Error, Result}, + parse_endpoint, +}; + +/// The `Tcp` protocol. We use this to define commonalities between TCP +/// listeners, connections, etc. +#[derive(Clone, PartialEq, Eq)] +pub struct TcpTls; + +#[async_trait] +impl Protocol for TcpTls { + type Listener = TcpTlsListener; + type UnfinalizedConnection = UnfinalizedTcpTlsConnection; + + /// Connect to a remote endpoint, returning an instance of `Self`. + /// With TCP, this requires just connecting to the remote endpoint. + /// + /// # Errors + /// Errors if we fail to connect or if we fail to bind to the interface we want. + async fn connect( + remote_endpoint: &str, + use_local_authority: bool, + middleware: Middleware, + ) -> Result + where + Self: Sized, + { + // Parse the socket endpoint + let remote_endpoint = bail_option!( + bail!( + remote_endpoint.to_socket_addrs(), + Parse, + "failed to parse remote endpoint" + ) + .next(), + Connection, + "did not find suitable address for endpoint" + ); + + // Create a new TCP socket + let socket = bail!( + TcpSocket::new_v4(), + Connection, + "failed to bind to local socket" + ); + + // Generate root certificate store based on the local authority + let root_cert_store = generate_root_certificate_store(use_local_authority)?; + + // Create `rustls` config from the root store + let config: ClientConfig = ClientConfig::builder() + .with_root_certificates(root_cert_store) + // this just means no mTLS + .with_no_client_auth(); + + // Create a new TLS connector from the config + let tls_connector = TlsConnector::from(Arc::new(config)); + let espresso_san = bail!( + ServerName::try_from("espresso"), + Connection, + "failed to parse server name \"espresso\"" + ); + + // Connect the stream to the local socket + let stream = bail!( + bail!( + timeout(Duration::from_secs(5), socket.connect(remote_endpoint)).await, + Connection, + "timed out connecting to tcp endpoint" + ), + Connection, + "failed to connect to tcp endpoint" + ); + + // Wrap the stream in the TLS connection + let stream = bail!( + bail!( + timeout( + Duration::from_secs(5), + tls_connector.connect(espresso_san, stream) + ) + .await, + Connection, + "timed out attempting tls handshake" + ), + Connection, + "failed to perform tls handshake" + ); + + // Split the connection and create our wrapper + let (receiver, sender) = tokio::io::split(stream); + + // Convert the streams into a `Connection` + let connection = Connection::from_streams(sender, receiver, middleware); + + Ok(connection) + } + + /// Binds to a local endpoint. Does not use a TLS configuration. + /// + /// # Errors + /// - If we cannot bind to the local interface + /// - If we cannot parse the bind endpoint + async fn bind( + bind_endpoint: &str, + certificate: CertificateDer<'static>, + key: PrivateKeyDer<'static>, + ) -> Result { + // Create server configuration from the loaded certificate + let server_config = bail!( + ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(vec![certificate], key), + Connection, + "failed to create tls server configuration" + ); + + // Create a new TLS acceptor from the server configuration + let tls_acceptor = TlsAcceptor::from(Arc::new(server_config)); + + // Parse the bind endpoint + let bind_endpoint: SocketAddr = parse_endpoint!(bind_endpoint); + + // Try to bind to the local endpoint + let tcp_listener = bail!( + TcpListener::bind(bind_endpoint).await, + Connection, + "failed to bind to local endpoint" + ); + + // Return the listener and TLS acceptor + Ok(TcpTlsListener { + tcp_listener, + tls_acceptor, + }) + } +} + +/// A connection that has yet to be finalized. Allows us to keep accepting +/// connections while we process this one. +pub struct UnfinalizedTcpTlsConnection { + tcp_stream: TcpStream, + tls_acceptor: TlsAcceptor, +} + +#[async_trait] +impl UnfinalizedConnection for UnfinalizedTcpTlsConnection { + /// Finalize the connection by splitting it into a sender and receiver side. + /// Conssumes `Self`. + /// + /// # Errors + /// Does not actually error, but satisfies trait bounds. + async fn finalize(self, middleware: Middleware) -> Result { + // Wrap the stream in the TLS connection + let stream = bail!( + bail!( + timeout( + Duration::from_secs(5), + self.tls_acceptor.accept(self.tcp_stream) + ) + .await, + Connection, + "timed out attempting tls handshake" + ), + Connection, + "failed to perform tls handshake" + ); + + // Split the connection and create our wrapper + let (receiver, sender) = tokio::io::split(stream); + + // Convert the streams into a `Connection` + let connection = Connection::from_streams(sender, receiver, middleware); + + Ok(connection) + } +} + +/// The listener struct. Needed to receive messages over TCP. Is a light +/// wrapper around `tokio::net::TcpListener` and `tokio_rustls::TlsAcceptor`. +pub struct TcpTlsListener { + tcp_listener: TcpListener, + tls_acceptor: TlsAcceptor, +} + +#[async_trait] +impl Listener for TcpTlsListener { + /// Accept an unfinalized connection from the listener. + /// + /// # Errors + /// - If we fail to accept a connection from the listener. + async fn accept(&self) -> Result { + // Try to accept a connection from the underlying endpoint + // Split into reader and writer half + let connection = bail!( + self.tcp_listener.accept().await, + Connection, + "failed to accept connection" + ); + + // Return the unfinalized connection + Ok(UnfinalizedTcpTlsConnection { + tcp_stream: connection.0, + tls_acceptor: self.tls_acceptor.clone(), + }) + } +} + +/// Soft closing is a no-op for TCP connections. +#[async_trait] +impl SoftClose for WriteHalf {} + +#[cfg(test)] +mod tests { + use anyhow::{anyhow, Result}; + + use super::super::tests::test_connection as super_test_connection; + use super::TcpTls; + + #[tokio::test] + /// Test connection establishment, listening for connections, and message + /// sending and receiving. Just proxies to the super traits' function + pub async fn test_connection() -> Result<()> { + // Get random, available port + let Some(port) = portpicker::pick_unused_port() else { + return Err(anyhow!("no unused ports")); + }; + + // Test using the super's function + super_test_connection::(format!("127.0.0.1:{port}")).await + } +} diff --git a/cdn-proto/src/crypto/tls.rs b/cdn-proto/src/crypto/tls.rs index f26276f..f9b4999 100644 --- a/cdn-proto/src/crypto/tls.rs +++ b/cdn-proto/src/crypto/tls.rs @@ -2,7 +2,10 @@ //! way to skip server verification. use rcgen::{CertificateParams, Ia5String, IsCa, KeyPair, SanType}; -use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; +use rustls::{ + pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}, + RootCertStore, +}; use crate::{ bail, @@ -115,3 +118,32 @@ pub fn load_ca( Ok((LOCAL_CA_CERT.to_string(), LOCAL_CA_KEY.to_string())) } } + +/// Generate a root certificate store based on whether or not we want to use the +/// local authority. +/// +/// # Errors +/// - If we fail to parse the provided CA certificate +/// - If we fail to add the certificate to the root store +pub fn generate_root_certificate_store(use_local_authority: bool) -> Result { + // Pick which authority to trust based on whether or not we have requested + // to use the local one + let root_ca = if use_local_authority { + LOCAL_CA_CERT + } else { + PROD_CA_CERT + }; + + // Parse the provided CA in `.PEM` format + let root_ca = bail!(pem::parse(root_ca), Parse, "failed to parse PEM file").into_contents(); + + // Create root certificate store and add our CA + let mut root_cert_store = RootCertStore::empty(); + bail!( + root_cert_store.add(CertificateDer::from(root_ca)), + File, + "failed to add certificate to root store" + ); + + Ok(root_cert_store) +} From 18f79958ebb3874449d9f0420caa581b45887563 Mon Sep 17 00:00:00 2001 From: Rob Date: Mon, 10 Jun 2024 13:59:21 -0400 Subject: [PATCH 26/31] clippy lints --- cdn-broker/src/tests/mod.rs | 2 +- cdn-proto/src/connection/middleware/mod.rs | 7 +++++-- cdn-proto/src/connection/protocols/mod.rs | 13 ++++++------- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/cdn-broker/src/tests/mod.rs b/cdn-broker/src/tests/mod.rs index a782d7d..a2dc25d 100644 --- a/cdn-broker/src/tests/mod.rs +++ b/cdn-broker/src/tests/mod.rs @@ -337,7 +337,7 @@ async fn inject_brokers( // Aggregate the topics we should be subscribed to let mut topics = Vec::new(); - for user in broker.connected_users.iter() { + for user in &broker.connected_users { topics.extend(user.subscribed_topics.clone()); } diff --git a/cdn-proto/src/connection/middleware/mod.rs b/cdn-proto/src/connection/middleware/mod.rs index f541751..0e9632f 100644 --- a/cdn-proto/src/connection/middleware/mod.rs +++ b/cdn-proto/src/connection/middleware/mod.rs @@ -34,7 +34,7 @@ impl Middleware { /// Create a new middleware with no global memory pool and no connection message pool size. /// This means an unbounded channel will be used for connections and no global memory pool /// will be checked. - pub fn none() -> Self { + pub const fn none() -> Self { // Create a new middleware with no global memory pool and no connection message pool size. Self { global_memory_pool: None, @@ -44,6 +44,9 @@ impl Middleware { /// Allocate a permit for a message of `num_bytes` bytes. /// If the global memory pool is not set, this will return `None`. + /// + /// # Panics + /// - If the required semaphore has been dropped. This should never happen pub async fn allocate_message_bytes(&self, num_bytes: u32) -> Option { if let Some(pool) = &self.global_memory_pool { // If the global memory pool is set, allocate a permit @@ -59,7 +62,7 @@ impl Middleware { } /// Get the connection message pool size, if set. - pub fn connection_message_pool_size(&self) -> Option { + pub const fn connection_message_pool_size(&self) -> Option { // Return the connection message pool size self.connection_message_pool_size } diff --git a/cdn-proto/src/connection/protocols/mod.rs b/cdn-proto/src/connection/protocols/mod.rs index c45fadb..4d46aca 100644 --- a/cdn-proto/src/connection/protocols/mod.rs +++ b/cdn-proto/src/connection/protocols/mod.rs @@ -130,11 +130,10 @@ impl Connection { // Create the channels that will be used to send and receive messages. // Conditionally create bounded channels if the user specifies a size let ((send_to_caller, receive_from_task), (send_to_task, receive_from_caller)) = - if let Some(size) = middleware.connection_message_pool_size() { - (kanal::bounded_async(size), kanal::bounded_async(size)) - } else { - (kanal::unbounded_async(), kanal::unbounded_async()) - }; + middleware.connection_message_pool_size().map_or_else( + || (kanal::unbounded_async(), kanal::unbounded_async()), + |size| (kanal::bounded_async(size), kanal::bounded_async(size)), + ); // Spawn the task that receives from the caller and sends to the stream let sender_task = tokio::spawn(async move { @@ -314,7 +313,7 @@ async fn read_length_delimited( // Add to our metrics, if desired #[cfg(feature = "metrics")] - metrics::BYTES_RECV.add(message_size as f64); + metrics::BYTES_RECV.add(f64::from(message_size)); Ok(Bytes::from(buffer, permit)) } @@ -355,7 +354,7 @@ async fn write_length_delimited( // Increment the number of bytes we've sent by this amount #[cfg(feature = "metrics")] - metrics::BYTES_SENT.add(message_len as f64); + metrics::BYTES_SENT.add(f64::from(message_len)); Ok(()) } From 21f2ee31fb3756791525037172ea4ada4a3c7ec5 Mon Sep 17 00:00:00 2001 From: Rob Date: Mon, 10 Jun 2024 14:03:53 -0400 Subject: [PATCH 27/31] allow lint for test --- cdn-broker/src/connections/broadcast/relational_map.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cdn-broker/src/connections/broadcast/relational_map.rs b/cdn-broker/src/connections/broadcast/relational_map.rs index 980dfa1..05e7d60 100644 --- a/cdn-broker/src/connections/broadcast/relational_map.rs +++ b/cdn-broker/src/connections/broadcast/relational_map.rs @@ -3,6 +3,7 @@ use std::{ hash::Hash, }; + /// A relational, bidirectional multimap that relates keys to a set of values, /// and values to a set of keys. pub struct RelationalMap { @@ -111,6 +112,8 @@ impl Relatio } #[cfg(test)] +// Makes tests more readable +#[allow(clippy::unnecessary_get_then_check)] pub mod tests { use super::RelationalMap; From 61900cb427387a45de779080ec1b5a17e575c73f Mon Sep 17 00:00:00 2001 From: Rob Date: Mon, 10 Jun 2024 14:05:43 -0400 Subject: [PATCH 28/31] fmt --- cdn-broker/src/connections/broadcast/relational_map.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/cdn-broker/src/connections/broadcast/relational_map.rs b/cdn-broker/src/connections/broadcast/relational_map.rs index 05e7d60..8dea57a 100644 --- a/cdn-broker/src/connections/broadcast/relational_map.rs +++ b/cdn-broker/src/connections/broadcast/relational_map.rs @@ -3,7 +3,6 @@ use std::{ hash::Hash, }; - /// A relational, bidirectional multimap that relates keys to a set of values, /// and values to a set of keys. pub struct RelationalMap { From 7b48e28aa1a766135ae435a438872a7499dff587 Mon Sep 17 00:00:00 2001 From: Rob Date: Mon, 10 Jun 2024 14:08:47 -0400 Subject: [PATCH 29/31] change prometheus label --- cdn-proto/src/connection/metrics.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cdn-proto/src/connection/metrics.rs b/cdn-proto/src/connection/metrics.rs index af55332..cb4862e 100644 --- a/cdn-proto/src/connection/metrics.rs +++ b/cdn-proto/src/connection/metrics.rs @@ -18,5 +18,5 @@ lazy_static! { // The per-message latency over the last 30 seconds pub static ref RUNNING_LATENCY: Gauge = - register_gauge!("running_latency", "average message delivery latency over the last 30s").unwrap(); + register_gauge!("running_latency", "average tail latency over the last 30s").unwrap(); } From 8e408f6265d1bd68d70e82239dcea784ec998edc Mon Sep 17 00:00:00 2001 From: Rob Date: Mon, 10 Jun 2024 14:19:09 -0400 Subject: [PATCH 30/31] update dependencies --- Cargo.lock | 298 +++++++++++++++++++++++++++++++++++++----- cdn-broker/Cargo.toml | 2 +- 2 files changed, 264 insertions(+), 36 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 339aff8..5968213 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -895,9 +895,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.6" +version = "4.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9689a29b593160de5bc4aacab7b5d54fb52231de70122626c178e6a368994c7" +checksum = "5db83dced34638ad474f39f250d7fea9598bdd239eaced1bdf45d597da0f433f" dependencies = [ "clap_builder", "clap_derive", @@ -905,9 +905,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.6" +version = "4.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e5387378c84f6faa26890ebf9f0a92989f8873d4d380467bcd0d8d8620424df" +checksum = "f7e204572485eb3fbf28f871612191521df159bc3e15a9f5064c66dba3a8c05f" dependencies = [ "anstream", "anstyle", @@ -964,9 +964,9 @@ dependencies = [ [[package]] name = "console-api" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd326812b3fd01da5bb1af7d340d0d555fd3d4b641e7f1dfcf5962a902952787" +checksum = "a257c22cd7e487dd4a13d413beabc512c5052f0bc048db0da6a84c3d8a6142fd" dependencies = [ "futures-core", "prost", @@ -977,9 +977,9 @@ dependencies = [ [[package]] name = "console-subscriber" -version = "0.2.0" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7481d4c57092cd1c19dd541b92bdce883de840df30aa5d03fd48a3935c01842e" +checksum = "31c4cc54bae66f7d9188996404abdf7fdfa23034ef8e43478c8810828abad758" dependencies = [ "console-api", "crossbeam-channel", @@ -987,6 +987,7 @@ dependencies = [ "futures-task", "hdrhistogram", "humantime", + "prost", "prost-types", "serde", "serde_json", @@ -1720,9 +1721,9 @@ dependencies = [ [[package]] name = "httparse" -version = "1.8.0" +version = "1.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" +checksum = "9f3935c160d00ac752e09787e6e6bfc26494c2183cc922f1bc678a60d4733bc2" [[package]] name = "httpdate" @@ -1772,14 +1773,134 @@ dependencies = [ "tokio-io-timeout", ] +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" + +[[package]] +name = "icu_properties" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f8ac670d7422d7f76b32e17a5db556510825b29ec9154f235977c9caba61036" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + [[package]] name = "idna" -version = "0.5.0" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +checksum = "4716a3a0933a1d01c2f72450e89596eb51dd34ef3c211ccd875acdf1f8fe47ed" dependencies = [ - "unicode-bidi", - "unicode-normalization", + "icu_normalizer", + "icu_properties", + "smallvec", + "utf8_iter", ] [[package]] @@ -2050,6 +2171,12 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "litemap" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704" + [[package]] name = "local-ip-address" version = "0.6.1" @@ -2673,9 +2800,9 @@ dependencies = [ [[package]] name = "quinn" -version = "0.11.1" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "904e3d3ba178131798c6d9375db2b13b34337d489b089fc5ba0825a2ff1bee73" +checksum = "e4ceeeeabace7857413798eb1ffa1e9c905a9946a57d81fb69b4b71c4d8eb3ad" dependencies = [ "bytes", "pin-project-lite", @@ -2690,9 +2817,9 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e974563a4b1c2206bbc61191ca4da9c22e4308b4c455e8906751cc7828393f08" +checksum = "ddf517c03a109db8100448a4be38d498df8a210a99fe0e1b9eaf39e78c640efe" dependencies = [ "bytes", "rand", @@ -2707,9 +2834,9 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4f0def2590301f4f667db5a77f9694fb004f82796dc1a8b1508fafa3d0e8b72" +checksum = "9096629c45860fc7fb143e125eb826b5e721e10be3263160c7d60ca832cf8c46" dependencies = [ "libc", "once_cell", @@ -2839,14 +2966,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.4" +version = "1.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.6", - "regex-syntax 0.8.3", + "regex-automata 0.4.7", + "regex-syntax 0.8.4", ] [[package]] @@ -2860,13 +2987,13 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.3", + "regex-syntax 0.8.4", ] [[package]] @@ -2877,9 +3004,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" +checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" [[package]] name = "rend" @@ -3686,6 +3813,16 @@ dependencies = [ "time-core", ] +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "tinytemplate" version = "1.2.1" @@ -3816,9 +3953,9 @@ dependencies = [ [[package]] name = "tonic" -version = "0.10.2" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d560933a0de61cf715926b9cac824d4c883c2c43142f787595e48280c40a1d0e" +checksum = "76c4eb7a4e9ef9d4763600161f12f5070b92a578e1b634db88a6887844c91a13" dependencies = [ "async-stream", "async-trait", @@ -4003,9 +4140,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.0" +version = "2.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +checksum = "f7c25da092f0a868cdf09e8674cd3b7ef3a7d92a24253e663a2fb85e2496de56" dependencies = [ "form_urlencoded", "idna", @@ -4018,11 +4155,23 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "utf8parse" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" @@ -4361,6 +4510,18 @@ dependencies = [ "memchr", ] +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + [[package]] name = "wyz" version = "0.5.1" @@ -4397,6 +4558,30 @@ dependencies = [ "time", ] +[[package]] +name = "yoke" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", + "synstructure", +] + [[package]] name = "zerocopy" version = "0.7.34" @@ -4417,6 +4602,27 @@ dependencies = [ "syn 2.0.66", ] +[[package]] +name = "zerofrom" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", + "synstructure", +] + [[package]] name = "zeroize" version = "1.8.1" @@ -4436,3 +4642,25 @@ dependencies = [ "quote", "syn 2.0.66", ] + +[[package]] +name = "zerovec" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb2cc8827d6c0994478a15c53f374f46fbd41bea663d809b14744bc42e6b109c" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97cf56601ee5052b4417d90c8755c6683473c926039908196cf35d99f893ebe7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] diff --git a/cdn-broker/Cargo.toml b/cdn-broker/Cargo.toml index 197493d..adaf8f9 100644 --- a/cdn-broker/Cargo.toml +++ b/cdn-broker/Cargo.toml @@ -39,7 +39,7 @@ path = "src/binaries/bad-broker.rs" # This dependency is used for the Tokio console [target.'cfg(tokio_unstable)'.dependencies] -console-subscriber = "0.2" +console-subscriber = "0.3" [dependencies] From 07b67b660b4876970f19bcba4f695312a531890e Mon Sep 17 00:00:00 2001 From: Rob Date: Mon, 10 Jun 2024 14:37:39 -0400 Subject: [PATCH 31/31] only calculate latency if count difference != 0 --- cdn-proto/src/metrics.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cdn-proto/src/metrics.rs b/cdn-proto/src/metrics.rs index 667ff47..26f60a7 100644 --- a/cdn-proto/src/metrics.rs +++ b/cdn-proto/src/metrics.rs @@ -48,7 +48,9 @@ pub async fn running_latency_calculator() { latency_sum = metrics::LATENCY.get_sample_sum() - latency_sum; latency_count = metrics::LATENCY.get_sample_count() - latency_count; - // Set the running latency - metrics::RUNNING_LATENCY.set(latency_sum / latency_count as f64); + // Set the running latency if the new count is not 0 + if latency_count != 0 { + metrics::RUNNING_LATENCY.set(latency_sum / latency_count as f64); + } } }