diff --git a/Cargo.lock b/Cargo.lock index 9c07851e..37b1362c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -258,7 +258,7 @@ dependencies = [ "clap", "criterion-plot", "is-terminal", - "itertools 0.10.5", + "itertools", "num-traits", "once_cell", "oorandom", @@ -279,7 +279,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" dependencies = [ "cast", - "itertools 0.10.5", + "itertools", ] [[package]] @@ -737,15 +737,6 @@ dependencies = [ "either", ] -[[package]] -name = "itertools" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" -dependencies = [ - "either", -] - [[package]] name = "itoa" version = "1.0.9" @@ -1375,7 +1366,6 @@ dependencies = [ "http", "http-body", "hyper", - "itertools 0.11.0", "serde", "serde_json", "thiserror", diff --git a/socketioxide/Cargo.toml b/socketioxide/Cargo.toml index 02df44b5..11d8a174 100644 --- a/socketioxide/Cargo.toml +++ b/socketioxide/Cargo.toml @@ -29,7 +29,6 @@ tower = { version = "0.4.13", default-features = false } http = "0.2.9" http-body = "0.4.5" thiserror = "1.0.40" -itertools = "0.11.0" # Extensions dashmap = { version = "5.4.0", optional = true } diff --git a/socketioxide/benches/packet_decode.rs b/socketioxide/benches/packet_decode.rs index 9f0fd378..b15e4228 100644 --- a/socketioxide/benches/packet_decode.rs +++ b/socketioxide/benches/packet_decode.rs @@ -3,18 +3,15 @@ use engineioxide::sid::Sid; use socketioxide::{Packet, PacketData, ProtocolVersion}; fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Decode packet connect on /", |b| { - let packet: String = Packet::connect( - black_box("/").to_string(), - black_box(Sid::ZERO), - ProtocolVersion::V5, - ) - .try_into() - .unwrap(); + let packet: String = + Packet::connect(black_box("/"), black_box(Sid::ZERO), ProtocolVersion::V5) + .try_into() + .unwrap(); b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); c.bench_function("Decode packet connect on /custom_nsp", |b| { let packet: String = Packet::connect( - black_box("/custom_nsp").to_string(), + black_box("/custom_nsp"), black_box(Sid::ZERO), ProtocolVersion::V5, ) @@ -27,21 +24,18 @@ fn criterion_benchmark(c: &mut Criterion) { const BINARY: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; c.bench_function("Decode packet event on /", |b| { let data = serde_json::to_value(DATA).unwrap(); - let packet: String = Packet::event( - black_box("/").to_string(), - black_box("event").to_string(), - black_box(data.clone()), - ) - .try_into() - .unwrap(); + let packet: String = + Packet::event(black_box("/"), black_box("event"), black_box(data.clone())) + .try_into() + .unwrap(); b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); c.bench_function("Decode packet event on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet: String = Packet::event( - black_box("custom_nsp").to_string(), - black_box("event").to_string(), + black_box("custom_nsp"), + black_box("event"), black_box(data.clone()), ) .try_into() @@ -51,11 +45,8 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Decode packet event with ack on /", |b| { let data = serde_json::to_value(DATA).unwrap(); - let packet: Packet = Packet::event( - black_box("/").to_string(), - black_box("event").to_string(), - black_box(data.clone()), - ); + let packet: Packet = + Packet::event(black_box("/"), black_box("event"), black_box(data.clone())); match packet.inner { PacketData::Event(_, _, mut ack) => ack.insert(black_box(0)), _ => panic!("Wrong packet type"), @@ -67,8 +58,8 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Decode packet event with ack on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::event( - black_box("/custom_nsp").to_string(), - black_box("event").to_string(), + black_box("/custom_nsp"), + black_box("event"), black_box(data.clone()), ); match packet.inner { @@ -82,20 +73,16 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Decode packet ack on /", |b| { let data = serde_json::to_value(DATA).unwrap(); - let packet: String = Packet::ack( - black_box("/").to_string(), - black_box(data.clone()), - black_box(0), - ) - .try_into() - .unwrap(); + let packet: String = Packet::ack(black_box("/"), black_box(data.clone()), black_box(0)) + .try_into() + .unwrap(); b.iter(|| Packet::try_from(packet.clone()).unwrap()) }); c.bench_function("Decode packet ack on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet: String = Packet::ack( - black_box("/custom_nsp").to_string(), + black_box("/custom_nsp"), black_box(data.clone()), black_box(0), ) @@ -107,8 +94,8 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Decode packet binary event (b64) on /", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet: String = Packet::bin_event( - black_box("/").to_string(), - black_box("event").to_string(), + black_box("/"), + black_box("event"), black_box(data.clone()), black_box(vec![BINARY.to_vec().clone()]), ) @@ -120,8 +107,8 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Decode packet binary event (b64) on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet: String = Packet::bin_event( - black_box("/custom_nsp").to_string(), - black_box("event").to_string(), + black_box("/custom_nsp"), + black_box("event"), black_box(data.clone()), black_box(vec![BINARY.to_vec().clone()]), ) @@ -133,7 +120,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Decode packet binary ack (b64) on /", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet: String = Packet::bin_ack( - black_box("/").to_string(), + black_box("/"), black_box(data.clone()), black_box(vec![BINARY.to_vec().clone()]), black_box(0), @@ -146,7 +133,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Decode packet binary ack (b64) on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet: String = Packet::bin_ack( - black_box("/custom_nsp").to_string(), + black_box("/custom_nsp"), black_box(data.clone()), black_box(vec![BINARY.to_vec().clone()]), black_box(0), diff --git a/socketioxide/benches/packet_encode.rs b/socketioxide/benches/packet_encode.rs index 271633f4..f70a340e 100644 --- a/socketioxide/benches/packet_encode.rs +++ b/socketioxide/benches/packet_encode.rs @@ -3,18 +3,14 @@ use engineioxide::sid::Sid; use socketioxide::{Packet, PacketData, ProtocolVersion}; fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Encode packet connect on /", |b| { - let packet = Packet::connect( - black_box("/").to_string(), - black_box(Sid::ZERO), - ProtocolVersion::V5, - ); + let packet = Packet::connect(black_box("/"), black_box(Sid::ZERO), ProtocolVersion::V5); b.iter(|| { let _: String = packet.clone().try_into().unwrap(); }) }); c.bench_function("Encode packet connect on /custom_nsp", |b| { let packet = Packet::connect( - black_box("/custom_nsp").to_string(), + black_box("/custom_nsp"), black_box(Sid::ZERO), ProtocolVersion::V5, ); @@ -27,11 +23,7 @@ fn criterion_benchmark(c: &mut Criterion) { const BINARY: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; c.bench_function("Encode packet event on /", |b| { let data = serde_json::to_value(DATA).unwrap(); - let packet = Packet::event( - black_box("/").to_string(), - black_box("event").to_string(), - black_box(data.clone()), - ); + let packet = Packet::event(black_box("/"), black_box("event"), black_box(data.clone())); b.iter(|| { let _: String = packet.clone().try_into().unwrap(); }) @@ -40,8 +32,8 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Encode packet event on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::event( - black_box("custom_nsp").to_string(), - black_box("event").to_string(), + black_box("custom_nsp"), + black_box("event"), black_box(data.clone()), ); b.iter(|| { @@ -51,11 +43,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Encode packet event with ack on /", |b| { let data = serde_json::to_value(DATA).unwrap(); - let packet = Packet::event( - black_box("/").to_string(), - black_box("event").to_string(), - black_box(data.clone()), - ); + let packet = Packet::event(black_box("/"), black_box("event"), black_box(data.clone())); match packet.inner { PacketData::Event(_, _, mut ack) => ack.insert(black_box(0)), _ => panic!("Wrong packet type"), @@ -68,8 +56,8 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Encode packet event with ack on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::event( - black_box("/custom_nsp").to_string(), - black_box("event").to_string(), + black_box("/custom_nsp"), + black_box("event"), black_box(data.clone()), ); match packet.inner { @@ -83,11 +71,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Encode packet ack on /", |b| { let data = serde_json::to_value(DATA).unwrap(); - let packet = Packet::ack( - black_box("/").to_string(), - black_box(data.clone()), - black_box(0), - ); + let packet = Packet::ack(black_box("/"), black_box(data.clone()), black_box(0)); b.iter(|| { let _: String = packet.clone().try_into().unwrap(); }) @@ -96,7 +80,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Encode packet ack on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::ack( - black_box("/custom_nsp").to_string(), + black_box("/custom_nsp"), black_box(data.clone()), black_box(0), ); @@ -108,8 +92,8 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Encode packet binary event (b64) on /", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::bin_event( - black_box("/").to_string(), - black_box("event").to_string(), + black_box("/"), + black_box("event"), black_box(data.clone()), black_box(vec![BINARY.to_vec().clone()]), ); @@ -121,8 +105,8 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Encode packet binary event (b64) on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::bin_event( - black_box("/custom_nsp").to_string(), - black_box("event").to_string(), + black_box("/custom_nsp"), + black_box("event"), black_box(data.clone()), black_box(vec![BINARY.to_vec().clone()]), ); @@ -134,7 +118,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Encode packet binary ack (b64) on /", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::bin_ack( - black_box("/").to_string(), + black_box("/"), black_box(data.clone()), black_box(vec![BINARY.to_vec().clone()]), black_box(0), @@ -147,7 +131,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("Encode packet binary ack (b64) on /custom_nsp", |b| { let data = serde_json::to_value(DATA).unwrap(); let packet = Packet::bin_ack( - black_box("/custom_nsp").to_string(), + black_box("/custom_nsp"), black_box(data.clone()), black_box(vec![BINARY.to_vec().clone()]), black_box(0), diff --git a/socketioxide/src/adapter.rs b/socketioxide/src/adapter.rs index 4f1ac227..792d38d7 100644 --- a/socketioxide/src/adapter.rs +++ b/socketioxide/src/adapter.rs @@ -4,6 +4,7 @@ //! Other adapters can be made to share the state between multiple servers. use std::{ + borrow::Cow, collections::{HashMap, HashSet}, convert::Infallible, sync::{Arc, RwLock, Weak}, @@ -15,7 +16,6 @@ use futures::{ stream::{self, BoxStream}, StreamExt, }; -use itertools::Itertools; use serde::de::DeserializeOwned; use crate::{ @@ -28,7 +28,7 @@ use crate::{ }; /// A room identifier -pub type Room = String; +pub type Room = Cow<'static, str>; /// Flags that can be used to modify the behavior of the broadcast methods. #[derive(Clone, Debug, Hash, PartialEq, Eq)] @@ -47,9 +47,9 @@ pub struct BroadcastOptions { /// The flags to apply to the broadcast. pub flags: HashSet, /// The rooms to broadcast to. - pub rooms: Vec, + pub rooms: HashSet, /// The rooms to exclude from the broadcast. - pub except: Vec, + pub except: HashSet, /// The socket id of the sender. pub sid: Option, } @@ -57,8 +57,8 @@ impl BroadcastOptions { pub fn new(sid: Option) -> Self { Self { flags: HashSet::new(), - rooms: Vec::new(), - except: Vec::new(), + rooms: HashSet::new(), + except: HashSet::new(), sid, } } @@ -94,7 +94,7 @@ pub trait Adapter: std::fmt::Debug + Send + Sync + 'static { /// Broadcast the packet to the sockets that match the [`BroadcastOptions`] and return a stream of ack responses. fn broadcast_with_ack( &self, - packet: Packet, + packet: Packet<'static>, opts: BroadcastOptions, ) -> Result, AckError>>, BroadcastError>; @@ -208,7 +208,7 @@ impl Adapter for LocalAdapter { fn broadcast_with_ack( &self, - packet: Packet, + packet: Packet<'static>, opts: BroadcastOptions, ) -> Result, AckError>>, BroadcastError> { let duration = opts.flags.iter().find_map(|flag| match flag { @@ -241,7 +241,7 @@ impl Adapter for LocalAdapter { } //TODO: make this operation O(1) - fn socket_rooms(&self, sid: Sid) -> Result, Infallible> { + fn socket_rooms(&self, sid: Sid) -> Result>, Infallible> { let rooms_map = self.rooms.read().unwrap(); Ok(rooms_map .iter() @@ -300,7 +300,6 @@ impl LocalAdapter { .iter() .filter_map(|room| rooms_map.get(room)) .flatten() - .unique() .filter(|sid| { !except.contains(*sid) && (!opts.flags.contains(&BroadcastFlags::Broadcast) @@ -321,7 +320,7 @@ impl LocalAdapter { } } - fn get_except_sids(&self, except: &Vec) -> HashSet { + fn get_except_sids(&self, except: &HashSet) -> HashSet { let mut except_sids = HashSet::new(); let rooms_map = self.rooms.read().unwrap(); for room in except { @@ -337,6 +336,12 @@ impl LocalAdapter { mod test { use super::*; + macro_rules! hash_set { + {$($v: expr),* $(,)?} => { + std::collections::HashSet::from([$($v,)*]) + }; + } + #[tokio::test] async fn test_server_count() { let ns = Namespace::new_dummy([]); @@ -412,7 +417,7 @@ mod test { adapter.add_all(socket, ["room1"]).unwrap(); let mut opts = BroadcastOptions::new(Some(socket)); - opts.rooms = vec!["room1".to_string()]; + opts.rooms = hash_set!["room1".into()]; adapter.add_sockets(opts, "room2").unwrap(); let rooms_map = adapter.rooms.read().unwrap(); @@ -429,7 +434,7 @@ mod test { adapter.add_all(socket, ["room1"]).unwrap(); let mut opts = BroadcastOptions::new(Some(socket)); - opts.rooms = vec!["room1".to_string()]; + opts.rooms = hash_set!["room1".into()]; adapter.add_sockets(opts, "room2").unwrap(); { @@ -441,7 +446,7 @@ mod test { } let mut opts = BroadcastOptions::new(Some(socket)); - opts.rooms = vec!["room1".to_string()]; + opts.rooms = hash_set!["room1".into()]; adapter.del_sockets(opts, "room2").unwrap(); { @@ -498,7 +503,7 @@ mod test { .unwrap(); let mut opts = BroadcastOptions::new(Some(socket0)); - opts.rooms = vec!["room5".to_string()]; + opts.rooms = hash_set!["room5".into()]; match adapter.disconnect_socket(opts) { // todo it returns Ok, in previous commits it also returns Ok Err(BroadcastError::SendError(_)) | Ok(_) => {} @@ -531,15 +536,15 @@ mod test { // socket 2 is the sender let mut opts = BroadcastOptions::new(Some(socket2)); - opts.rooms = vec!["room1".to_string()]; - opts.except = vec!["room2".to_string()]; + opts.rooms = hash_set!["room1".into()]; + opts.except = hash_set!["room2".into()]; let sockets = adapter.fetch_sockets(opts).unwrap(); assert_eq!(sockets.len(), 1); assert_eq!(sockets[0].id, socket1); let mut opts = BroadcastOptions::new(Some(socket2)); opts.flags.insert(BroadcastFlags::Broadcast); - opts.except = vec!["room2".to_string()]; + opts.except = hash_set!["room2".into()]; let sockets = adapter.fetch_sockets(opts).unwrap(); assert_eq!(sockets.len(), 1); diff --git a/socketioxide/src/client.rs b/socketioxide/src/client.rs index b0060cf7..969c0610 100644 --- a/socketioxide/src/client.rs +++ b/socketioxide/src/client.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::collections::HashMap; use std::sync::{Arc, Mutex, RwLock}; @@ -21,7 +22,7 @@ use crate::{ProtocolVersion, Socket}; #[derive(Debug)] pub struct Client { pub(crate) config: Arc, - ns: RwLock>>>, + ns: RwLock, Arc>>>, } impl Client { @@ -58,7 +59,7 @@ impl Client { fn sock_connect( &self, auth: Option, - ns_path: String, + ns_path: &str, esocket: &Arc>, ) -> Result<(), Error> { #[cfg(feature = "tracing")] @@ -93,7 +94,7 @@ impl Client { } /// Cache-in the socket data until all the binary payloads are received - fn sock_recv_bin_packet(&self, socket: &EIoSocket, packet: Packet) { + fn sock_recv_bin_packet(&self, socket: &EIoSocket, packet: Packet<'static>) { socket .data .partial_bin_packet @@ -132,7 +133,7 @@ impl Client { } /// Add a new namespace handler - pub fn add_ns(&self, path: String, callback: C) + pub fn add_ns(&self, path: Cow<'static, str>, callback: C) where C: Fn(Arc>, V) -> F + Send + Sync + 'static, F: Future + Send + 'static, @@ -171,7 +172,7 @@ impl Client { pub struct SocketData { /// Partial binary packet that is being received /// Stored here until all the binary payloads are received - pub partial_bin_packet: Mutex>, + pub partial_bin_packet: Mutex>>, /// Channel used to notify the socket that it has been connected to a namespace #[cfg(feature = "v5")] @@ -245,7 +246,7 @@ impl EngineIoHandler for Client { let res: Result<(), Error> = match packet.inner { PacketData::Connect(auth) => self - .sock_connect(auth, packet.ns, &socket) + .sock_connect(auth, &packet.ns, &socket) .map_err(Into::into), PacketData::BinaryEvent(_, _, _) | PacketData::BinaryAck(_, _) => { self.sock_recv_bin_packet(&socket, packet); diff --git a/socketioxide/src/handler.rs b/socketioxide/src/handler.rs index b27cd011..aee4986c 100644 --- a/socketioxide/src/handler.rs +++ b/socketioxide/src/handler.rs @@ -142,7 +142,7 @@ impl AckSender { /// Send the ack response to the client. pub fn send(self, data: impl Serialize) -> Result<(), AckSenderError> { if let Some(ack_id) = self.ack_id { - let ns = self.socket.ns().clone(); + let ns = self.socket.ns(); let data = match serde_json::to_value(&data) { Err(err) => { return Err(AckSenderError::SendError { diff --git a/socketioxide/src/io.rs b/socketioxide/src/io.rs index 6f4a8017..d9858332 100644 --- a/socketioxide/src/io.rs +++ b/socketioxide/src/io.rs @@ -1,4 +1,4 @@ -use std::{sync::Arc, time::Duration}; +use std::{borrow::Cow, sync::Arc, time::Duration}; use engineioxide::{ config::{EngineIoConfig, EngineIoConfigBuilder, TransportType}, @@ -55,7 +55,6 @@ impl Default for SocketIoConfig { pub struct SocketIoBuilder { config: SocketIoConfig, engine_config_builder: EngineIoConfigBuilder, - req_path: String, } impl SocketIoBuilder { @@ -63,8 +62,7 @@ impl SocketIoBuilder { pub fn new() -> Self { Self { config: SocketIoConfig::default(), - engine_config_builder: EngineIoConfigBuilder::new(), - req_path: "/socket.io".to_string(), + engine_config_builder: EngineIoConfigBuilder::new().req_path("/socket.io".to_string()), } } @@ -73,7 +71,7 @@ impl SocketIoBuilder { /// Defaults to "/socket.io". #[inline] pub fn req_path(mut self, req_path: String) -> Self { - self.req_path = req_path; + self.engine_config_builder = self.engine_config_builder.req_path(req_path); self } @@ -163,7 +161,7 @@ impl SocketIoBuilder { /// /// The layer can be used as a tower layer pub fn build_layer_with_adapter(mut self) -> (SocketIoLayer, SocketIo) { - self.config.engine_config = self.engine_config_builder.req_path(self.req_path).build(); + self.config.engine_config = self.engine_config_builder.build(); let (layer, client) = SocketIoLayer::from_config(Arc::new(self.config)); (layer, SocketIo(client)) @@ -190,7 +188,7 @@ impl SocketIoBuilder { pub fn build_svc_with_adapter( mut self, ) -> (SocketIoService, SocketIo) { - self.config.engine_config = self.engine_config_builder.req_path(self.req_path).build(); + self.config.engine_config = self.engine_config_builder.build(); let (svc, client) = SocketIoService::with_config_inner(NotFoundService, Arc::new(self.config)); @@ -215,7 +213,7 @@ impl SocketIoBuilder { mut self, svc: S, ) -> (SocketIoService, SocketIo) { - self.config.engine_config = self.engine_config_builder.req_path(self.req_path).build(); + self.config.engine_config = self.engine_config_builder.build(); let (svc, client) = SocketIoService::with_config_inner(svc, Arc::new(self.config)); (svc, SocketIo(client)) @@ -320,7 +318,7 @@ impl SocketIo { /// /// ``` #[inline] - pub fn ns(&self, path: impl Into, callback: C) + pub fn ns(&self, path: impl Into>, callback: C) where C: Fn(Arc>, V) -> F + Send + Sync + 'static, F: Future + Send + 'static, @@ -569,7 +567,7 @@ impl SocketIo { #[inline] pub fn emit( &self, - event: impl Into, + event: impl Into>, data: impl serde::Serialize, ) -> Result<(), serde_json::Error> { self.get_default_op().emit(event, data) @@ -608,7 +606,7 @@ impl SocketIo { #[inline] pub fn emit_with_ack( &self, - event: impl Into, + event: impl Into>, data: impl serde::Serialize, ) -> Result, AckError>>, BroadcastError> { self.get_default_op().emit_with_ack(event, data) diff --git a/socketioxide/src/ns.rs b/socketioxide/src/ns.rs index 42dc63f7..f91bb9a2 100644 --- a/socketioxide/src/ns.rs +++ b/socketioxide/src/ns.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Cow, collections::HashMap, sync::{Arc, RwLock}, }; @@ -17,14 +18,14 @@ use futures::Future; use serde::de::DeserializeOwned; pub struct Namespace { - pub path: String, + pub path: Cow<'static, str>, pub(crate) adapter: A, handler: BoxedNamespaceHandler, sockets: RwLock>>>, } impl Namespace { - pub fn new(path: String, callback: C) -> Arc + pub fn new(path: Cow<'static, str>, callback: C) -> Arc where C: Fn(Arc>, V) -> F + Send + Sync + 'static, F: Future + Send + 'static, @@ -52,7 +53,7 @@ impl Namespace { self.sockets.write().unwrap().insert(sid, socket.clone()); let protocol = esocket.protocol.into(); - if let Err(_e) = socket.send(Packet::connect(self.path.clone(), socket.id, protocol)) { + if let Err(_e) = socket.send(Packet::connect(&self.path, socket.id, protocol)) { #[cfg(feature = "tracing")] tracing::debug!("error sending connect packet: {:?}, closing conn", _e); esocket.close(engineioxide::DisconnectReason::PacketParsingError); @@ -78,7 +79,7 @@ impl Namespace { pub fn recv(&self, sid: Sid, packet: PacketData) -> Result<(), Error> { match packet { PacketData::Connect(_) => unreachable!("connect packets should be handled before"), - PacketData::ConnectError(_) => Ok(()), + PacketData::ConnectError => Err(Error::InvalidPacketType), packet => self.get_socket(sid)?.recv(packet), } } @@ -113,7 +114,7 @@ impl Namespace { #[cfg(test)] impl Namespace { pub fn new_dummy(sockets: [Sid; S]) -> Arc { - let ns = Namespace::new("/".to_string(), |_, _: ()| async {}); + let ns = Namespace::new(Cow::Borrowed("/"), |_, _: ()| async {}); for sid in sockets { ns.sockets .write() diff --git a/socketioxide/src/operators.rs b/socketioxide/src/operators.rs index f9b7816b..2f9e3493 100644 --- a/socketioxide/src/operators.rs +++ b/socketioxide/src/operators.rs @@ -1,8 +1,8 @@ +use std::borrow::Cow; use std::{sync::Arc, time::Duration}; use engineioxide::sid::Sid; use futures::stream::BoxStream; -use itertools::Itertools; use serde::de::DeserializeOwned; use crate::errors::BroadcastError; @@ -25,28 +25,61 @@ pub trait RoomParam: 'static { impl RoomParam for Room { type IntoIter = std::iter::Once; + #[inline(always)] fn into_room_iter(self) -> Self::IntoIter { std::iter::once(self) } } +impl RoomParam for String { + type IntoIter = std::iter::Once; + #[inline(always)] + fn into_room_iter(self) -> Self::IntoIter { + std::iter::once(Cow::Owned(self)) + } +} +impl RoomParam for Vec { + type IntoIter = std::iter::Map, fn(String) -> Room>; + #[inline(always)] + fn into_room_iter(self) -> Self::IntoIter { + self.into_iter().map(Cow::Owned) + } +} +impl RoomParam for Vec<&'static str> { + type IntoIter = std::iter::Map, fn(&'static str) -> Room>; + #[inline(always)] + fn into_room_iter(self) -> Self::IntoIter { + self.into_iter().map(Cow::Borrowed) + } +} + impl RoomParam for Vec { type IntoIter = std::vec::IntoIter; + #[inline(always)] fn into_room_iter(self) -> Self::IntoIter { self.into_iter() } } impl RoomParam for &'static str { type IntoIter = std::iter::Once; + #[inline(always)] fn into_room_iter(self) -> Self::IntoIter { - std::iter::once(self.to_string()) + std::iter::once(Cow::Borrowed(self)) } } impl RoomParam for [&'static str; COUNT] { type IntoIter = std::iter::Map, fn(&'static str) -> Room>; + #[inline(always)] + fn into_room_iter(self) -> Self::IntoIter { + self.into_iter().map(Cow::Borrowed) + } +} +impl RoomParam for [String; COUNT] { + type IntoIter = std::iter::Map, fn(String) -> Room>; + #[inline(always)] fn into_room_iter(self) -> Self::IntoIter { - self.into_iter().map(|s| s.to_string()) + self.into_iter().map(Cow::Owned) } } @@ -88,7 +121,7 @@ impl Operators { /// }); /// }); pub fn to(mut self, rooms: impl RoomParam) -> Self { - self.opts.rooms.extend(rooms.into_room_iter().unique()); + self.opts.rooms.extend(rooms.into_room_iter()); self.opts.flags.insert(BroadcastFlags::Broadcast); self } @@ -114,7 +147,7 @@ impl Operators { /// }); /// }); pub fn within(mut self, rooms: impl RoomParam) -> Self { - self.opts.rooms.extend(rooms.into_room_iter().unique()); + self.opts.rooms.extend(rooms.into_room_iter()); self } @@ -138,7 +171,7 @@ impl Operators { /// }); /// }); pub fn except(mut self, rooms: impl RoomParam) -> Self { - self.opts.except.extend(rooms.into_room_iter().unique()); + self.opts.except.extend(rooms.into_room_iter()); self.opts.flags.insert(BroadcastFlags::Broadcast); self } @@ -240,7 +273,7 @@ impl Operators { /// }); pub fn emit( mut self, - event: impl Into, + event: impl Into>, data: impl serde::Serialize, ) -> Result<(), serde_json::Error> { let packet = self.get_packet(event, data)?; @@ -277,7 +310,7 @@ impl Operators { /// }); pub fn emit_with_ack( mut self, - event: impl Into, + event: impl Into>, data: impl serde::Serialize, ) -> Result, AckError>>, BroadcastError> { let packet = self.get_packet(event, data)?; @@ -356,16 +389,16 @@ impl Operators { /// Create a packet with the given event and data. fn get_packet( &mut self, - event: impl Into, + event: impl Into>, data: impl serde::Serialize, - ) -> Result { - let ns = self.ns.clone(); + ) -> Result, serde_json::Error> { + let ns = self.ns.path.clone(); let data = serde_json::to_value(data)?; let packet = if self.binary.is_empty() { - Packet::event(ns.path.clone(), event.into(), data) + Packet::event(ns, event.into(), data) } else { let binary = std::mem::take(&mut self.binary); - Packet::bin_event(ns.path.clone(), event.into(), data, binary) + Packet::bin_event(ns, event.into(), data, binary) }; Ok(packet) } diff --git a/socketioxide/src/packet.rs b/socketioxide/src/packet.rs index 02bce2d7..baeddbb4 100644 --- a/socketioxide/src/packet.rs +++ b/socketioxide/src/packet.rs @@ -1,5 +1,6 @@ +use std::borrow::Cow; + use crate::ProtocolVersion; -use itertools::{Itertools, PeekingNext}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::{json, Value}; @@ -9,15 +10,15 @@ use engineioxide::sid::Sid; /// The socket.io packet type. /// Each packet has a type and a namespace #[derive(Debug, Clone, PartialEq, Eq)] -pub struct Packet { - pub inner: PacketData, - pub ns: String, +pub struct Packet<'a> { + pub inner: PacketData<'a>, + pub ns: Cow<'a, str>, } -impl Packet { +impl<'a> Packet<'a> { /// Send a connect packet with a default payload for v5 and no payload for v4 pub fn connect( - ns: String, + ns: &'a str, #[allow(unused_variables)] sid: Sid, #[allow(unused_variables)] protocol: ProtocolVersion, ) -> Self { @@ -42,72 +43,123 @@ impl Packet { /// Sends a connect packet without payload. #[cfg(feature = "v4")] - fn connect_v4(ns: String) -> Self { + fn connect_v4(ns: &'a str) -> Self { Self { inner: PacketData::Connect(None), - ns, + ns: Cow::Borrowed(ns), } } /// Sends a connect packet with payload. #[cfg(feature = "v5")] - fn connect_v5(ns: String, sid: Sid) -> Self { + fn connect_v5(ns: &'a str, sid: Sid) -> Self { let val = serde_json::to_string(&ConnectPacket { sid }).unwrap(); Self { inner: PacketData::Connect(Some(val)), - ns, + ns: Cow::Borrowed(ns), } } - pub fn disconnect(ns: String) -> Self { + pub fn disconnect(ns: &'a str) -> Self { Self { inner: PacketData::Disconnect, - ns, + ns: Cow::Borrowed(ns), } } } -impl Packet { - pub fn invalid_namespace(ns: String) -> Self { +impl<'a> Packet<'a> { + pub fn invalid_namespace(ns: &'a str) -> Self { Self { - inner: PacketData::ConnectError(ConnectErrorPacket { - message: "Invalid namespace".to_string(), - }), - ns, + inner: PacketData::ConnectError, + ns: Cow::Borrowed(ns), } } - pub fn event(ns: String, e: String, data: Value) -> Self { + pub fn event(ns: impl Into>, e: impl Into>, data: Value) -> Self { Self { - inner: PacketData::Event(e, data, None), - ns, + inner: PacketData::Event(e.into(), data, None), + ns: ns.into(), } } - pub fn bin_event(ns: String, e: String, data: Value, bin: Vec>) -> Self { + pub fn bin_event( + ns: impl Into>, + e: impl Into>, + data: Value, + bin: Vec>, + ) -> Self { debug_assert!(!bin.is_empty()); let packet = BinaryPacket::outgoing(data, bin); Self { - inner: PacketData::BinaryEvent(e, packet, None), - ns, + inner: PacketData::BinaryEvent(e.into(), packet, None), + ns: ns.into(), } } - pub fn ack(ns: String, data: Value, ack: i64) -> Self { + pub fn ack(ns: &'a str, data: Value, ack: i64) -> Self { Self { inner: PacketData::EventAck(data, ack), - ns, + ns: Cow::Borrowed(ns), } } - pub fn bin_ack(ns: String, data: Value, bin: Vec>, ack: i64) -> Self { + pub fn bin_ack(ns: &'a str, data: Value, bin: Vec>, ack: i64) -> Self { debug_assert!(!bin.is_empty()); let packet = BinaryPacket::outgoing(data, bin); Self { inner: PacketData::BinaryAck(packet, ack), - ns, + ns: Cow::Borrowed(ns), } } + + /// Get the max size the packet could have when serialized + /// This is used to pre-allocate a buffer for the packet + /// + /// #### Disclaimer: The size does not include serialized `Value` size + fn get_size_hint(&self) -> usize { + use PacketData::*; + const PACKET_INDEX_SIZE: usize = 1; + const BINARY_PUNCTUATION_SIZE: usize = 2; + const ACK_PUNCTUATION_SIZE: usize = 1; + const NS_PUNCTUATION_SIZE: usize = 1; + + let data_size = match &self.inner { + Connect(Some(data)) => data.len(), + Connect(None) => 0, + Disconnect => 0, + Event(_, _, Some(ack)) => { + ack.checked_ilog10().unwrap_or(0) as usize + ACK_PUNCTUATION_SIZE + } + Event(_, _, None) => 0, + BinaryEvent(_, bin, None) => { + bin.payload_count.checked_ilog10().unwrap_or(0) as usize + BINARY_PUNCTUATION_SIZE + } + BinaryEvent(_, bin, Some(ack)) => { + ack.checked_ilog10().unwrap_or(0) as usize + + bin.payload_count.checked_ilog10().unwrap_or(0) as usize + + ACK_PUNCTUATION_SIZE + + BINARY_PUNCTUATION_SIZE + } + EventAck(_, ack) => ack.checked_ilog10().unwrap_or(0) as usize + ACK_PUNCTUATION_SIZE, + BinaryAck(bin, ack) => { + ack.checked_ilog10().unwrap_or(0) as usize + + bin.payload_count.checked_ilog10().unwrap_or(0) as usize + + ACK_PUNCTUATION_SIZE + + BINARY_PUNCTUATION_SIZE + } + ConnectError => 31, + }; + + let nsp_size = if self.ns == "/" { + 0 + } else if self.ns.starts_with('/') { + self.ns.len() + NS_PUNCTUATION_SIZE + } else { + self.ns.len() + NS_PUNCTUATION_SIZE + 1 // (1 for the leading slash) + }; + data_size + nsp_size + PACKET_INDEX_SIZE + } } /// | Type | ID | Usage | @@ -120,13 +172,13 @@ impl Packet { /// | BINARY_EVENT | 5 | Used to [send binary data](#sending-and-receiving-data) to the other side. | /// | BINARY_ACK | 6 | Used to [acknowledge](#acknowledgement) an event (the response includes binary data). | #[derive(Debug, Clone, PartialEq, Eq)] -pub enum PacketData { +pub enum PacketData<'a> { Connect(Option), Disconnect, - Event(String, Value, Option), + Event(Cow<'a, str>, Value, Option), EventAck(Value, i64), - ConnectError(ConnectErrorPacket), - BinaryEvent(String, BinaryPacket, Option), + ConnectError, + BinaryEvent(Cow<'a, str>, BinaryPacket, Option), BinaryAck(BinaryPacket, i64), } @@ -137,16 +189,16 @@ pub struct BinaryPacket { payload_count: usize, } -impl PacketData { - fn index(&self) -> u8 { +impl<'a> PacketData<'a> { + fn index(&self) -> char { match self { - PacketData::Connect(_) => 0, - PacketData::Disconnect => 1, - PacketData::Event(_, _, _) => 2, - PacketData::EventAck(_, _) => 3, - PacketData::ConnectError(_) => 4, - PacketData::BinaryEvent(_, _, _) => 5, - PacketData::BinaryAck(_, _) => 6, + PacketData::Connect(_) => '0', + PacketData::Disconnect => '1', + PacketData::Event(_, _, _) => '2', + PacketData::EventAck(_, _) => '3', + PacketData::ConnectError => '4', + PacketData::BinaryEvent(_, _, _) => '5', + PacketData::BinaryAck(_, _) => '6', } } @@ -227,81 +279,93 @@ impl BinaryPacket { } } -impl TryInto for Packet { +impl<'a> TryInto for Packet<'a> { type Error = serde_json::Error; - fn try_into(self) -> Result { - let mut res = self.inner.index().to_string(); - - // Add the ns if it is not the default one and the packet is not binary - // In case of bin packet, we should first add the payload count before ns - if !self.ns.is_empty() && self.ns != "/" && !self.inner.is_binary() { - res.push_str(&format!("{},", self.ns)); - } + fn try_into(mut self) -> Result { + use PacketData::*; - match self.inner { - PacketData::Connect(data) => res.push_str(&data.unwrap_or_default()), - PacketData::Disconnect => (), - PacketData::Event(event, data, ack) => { - if let Some(ack) = ack { - res.push_str(&ack.to_string()); - } + // Serialize the data if there is any + // pre-serializing allows to preallocate the buffer + let data = match &mut self.inner { + Event(e, data, _) | BinaryEvent(e, BinaryPacket { data, .. }, _) => { // Expand the packet if it is an array -> ["event", ...data] let packet = match data { - Value::Array(mut v) => { - v.insert(0, Value::String(event)); + Value::Array(ref mut v) => { + v.insert(0, Value::String(e.to_string())); serde_json::to_string(&v) } - _ => serde_json::to_string(&(event, data)), + _ => serde_json::to_string(&(e, data)), }?; - res.push_str(&packet) + Some(packet) } - PacketData::EventAck(data, ack) => { - res.push_str(&ack.to_string()); + EventAck(data, _) | BinaryAck(BinaryPacket { data, .. }, _) => { // Enforce that the packet is an array -> [data] let packet = match data { Value::Array(_) => serde_json::to_string(&data), - Value::Null => serde_json::to_string::<[(); 0]>(&[]), + Value::Null => Ok("[]".to_string()), _ => serde_json::to_string(&[data]), }?; - res.push_str(&packet) + Some(packet) + } + _ => None, + }; + + let capacity = self.get_size_hint() + data.as_ref().map(|d| d.len()).unwrap_or(0); + let mut res = String::with_capacity(capacity); + res.push(self.inner.index()); + + // Add the ns if it is not the default one and the packet is not binary + // In case of bin packet, we should first add the payload count before ns + let push_nsp = |res: &mut String| { + if !self.ns.is_empty() && self.ns != "/" { + if !self.ns.starts_with('/') { + res.push('/'); + } + res.push_str(&self.ns); + res.push(','); + } + }; + + if !self.inner.is_binary() { + push_nsp(&mut res); + } + + match self.inner { + PacketData::Connect(Some(data)) => res.push_str(&data), + PacketData::Disconnect | PacketData::Connect(None) => (), + PacketData::Event(_, _, ack) => { + if let Some(ack) = ack { + res.push_str(&ack.to_string()); + } + + res.push_str(&data.unwrap()) } - PacketData::ConnectError(data) => res.push_str(&serde_json::to_string(&data)?), - PacketData::BinaryEvent(event, bin, ack) => { + PacketData::EventAck(_, ack) => { + res.push_str(&ack.to_string()); + res.push_str(&data.unwrap()) + } + PacketData::ConnectError => res.push_str("{\"message\":\"Invalid namespace\"}"), + PacketData::BinaryEvent(_, bin, ack) => { res.push_str(&bin.payload_count.to_string()); res.push('-'); - if !self.ns.is_empty() && self.ns != "/" { - res.push_str(&format!("{},", self.ns)); - } + + push_nsp(&mut res); if let Some(ack) = ack { res.push_str(&ack.to_string()); } - // Expand the packet if it is an array -> ["event", ...data] - let packet = match bin.data { - Value::Array(mut v) => { - v.insert(0, Value::String(event)); - serde_json::to_string(&v) - } - _ => serde_json::to_string(&(event, bin.data)), - }?; - res.push_str(&packet) + res.push_str(&data.unwrap()) } PacketData::BinaryAck(packet, ack) => { res.push_str(&packet.payload_count.to_string()); res.push('-'); - if !self.ns.is_empty() && self.ns != "/" { - res.push_str(&format!("{},", self.ns)); - } + + push_nsp(&mut res); + res.push_str(&ack.to_string()); - // Enforce that the packet is an array -> [data] - let packet = match packet.data { - Value::Array(_) => serde_json::to_string(&packet.data), - Value::Null => serde_json::to_string::<[(); 0]>(&[]), - _ => serde_json::to_string(&[packet.data]), - }?; - res.push_str(&packet) + res.push_str(&data.unwrap()) } }; Ok(res) @@ -347,65 +411,76 @@ fn deserialize_packet(data: &str) -> Result, serd /// [<# of binary attachments>-][,][][JSON-stringified payload without binary] /// + binary attachments extracted /// ``` -impl TryFrom for Packet { +impl<'a> TryFrom for Packet<'a> { type Error = Error; fn try_from(value: String) -> Result { - let mut chars = value.chars(); - let index = chars.next().ok_or(Error::InvalidPacketType)?; - - let attachments: u8 = if index == '5' || index == '6' { - chars - .take_while_ref(|c| *c != '-') - .collect::() - .parse() - .unwrap_or(0) + // It is possible to parse the packet from a byte slice because separators are only ASCII + let chars = value.as_bytes(); + let mut i = 1; + let index = (b'0'..=b'6') + .contains(&chars[0]) + .then_some(chars[0]) + .ok_or(Error::InvalidPacketType)?; + + // Move the cursor to skip the payload count if it is a binary packet + if index == b'5' || index == b'6' { + while chars.get(i) != Some(&b'-') { + i += 1; + } + i += 1; + } + + let start_index = i; + // Custom nsps will start with a slash + let ns = if chars.get(i) == Some(&b'/') { + loop { + match chars.get(i) { + Some(b',') => { + i += 1; + break Cow::Owned(value[start_index..i - 1].to_string()); + } + // It maybe possible depending on clients that ns does not end with a comma + // if it is the end of the packet + // e.g `1/custom` + None => { + break Cow::Owned(value[start_index..i].to_string()); + } + Some(_) => i += 1, + } + } } else { - 0 + Cow::Borrowed("/") }; - // If there are attachments, skip the `-` separator - chars.peeking_next(|c| attachments > 0 && !c.is_ascii_digit()); - - let mut ns: String = chars - .take_while_ref(|c| *c != ',' && *c != '{' && *c != '[' && !c.is_ascii_digit()) - .collect(); - - // If there is a namespace, skip the `,` separator - if !ns.is_empty() { - chars.next(); - } - if !ns.starts_with('/') { - ns.insert(0, '/'); - } - - let ack: Option = chars - .take_while_ref(|c| c.is_ascii_digit()) - .collect::() - .parse() - .ok(); + let start_index = i; + let ack: Option = loop { + match chars.get(i) { + Some(c) if c.is_ascii_digit() => i += 1, + Some(b'[') | Some(b'{') if i > start_index => { + break value[start_index..i].parse().ok() + } + _ => break None, + } + }; - let data = chars.as_str(); + let data = &value[i..]; let inner = match index { - '0' => PacketData::Connect((!data.is_empty()).then(|| data.to_string())), - '1' => PacketData::Disconnect, - '2' => { + b'0' => PacketData::Connect((!data.is_empty()).then(|| data.to_string())), + b'1' => PacketData::Disconnect, + b'2' => { let (event, payload) = deserialize_event_packet(data)?; - PacketData::Event(event, payload, ack) + PacketData::Event(event.into(), payload, ack) } - '3' => { + b'3' => { let packet = deserialize_packet(data)?.ok_or(Error::InvalidPacketType)?; PacketData::EventAck(packet, ack.ok_or(Error::InvalidPacketType)?) } - '4' => { - let payload = deserialize_packet(data)?.ok_or(Error::InvalidPacketType)?; - PacketData::ConnectError(payload) - } - '5' => { + b'5' => { let (event, payload) = deserialize_event_packet(data)?; - PacketData::BinaryEvent(event, BinaryPacket::incoming(payload), ack) + PacketData::BinaryEvent(event.into(), BinaryPacket::incoming(payload), ack) } - '6' => { + b'6' => { let packet = deserialize_packet(data)?.ok_or(Error::InvalidPacketType)?; PacketData::BinaryAck( BinaryPacket::incoming(packet), @@ -420,16 +495,11 @@ impl TryFrom for Packet { } /// Connect packet sent by the client -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConnectPacket { sid: Sid, } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct ConnectErrorPacket { - message: String, -} - #[cfg(test)] mod test { use serde_json::json; @@ -442,18 +512,12 @@ mod test { let payload = format!("0{}", json!({"sid": sid})); let packet = Packet::try_from(payload).unwrap(); - assert_eq!( - Packet::connect("/".to_string(), sid, ProtocolVersion::V5), - packet - ); + assert_eq!(Packet::connect("/", sid, ProtocolVersion::V5), packet); let payload = format!("0/admin™,{}", json!({"sid": sid})); let packet = Packet::try_from(payload).unwrap(); - assert_eq!( - Packet::connect("/admin™".to_string(), sid, ProtocolVersion::V5), - packet - ); + assert_eq!(Packet::connect("/admin™", sid, ProtocolVersion::V5), packet); } #[test] @@ -462,13 +526,13 @@ mod test { let sid = Sid::new(); let payload = format!("0{}", json!({"sid": sid})); - let packet: String = Packet::connect("/".to_string(), sid, ProtocolVersion::V5) + let packet: String = Packet::connect("/", sid, ProtocolVersion::V5) .try_into() .unwrap(); assert_eq!(packet, payload); let payload = format!("0/admin™,{}", json!({"sid": sid})); - let packet: String = Packet::connect("/admin™".to_string(), sid, ProtocolVersion::V5) + let packet: String = Packet::connect("/admin™", sid, ProtocolVersion::V5) .try_into() .unwrap(); assert_eq!(packet, payload); @@ -480,23 +544,21 @@ mod test { fn packet_decode_disconnect() { let payload = "1".to_string(); let packet = Packet::try_from(payload).unwrap(); - assert_eq!(Packet::disconnect("/".to_string()), packet); + assert_eq!(Packet::disconnect("/"), packet); let payload = "1/admin™,".to_string(); let packet = Packet::try_from(payload).unwrap(); - assert_eq!(Packet::disconnect("/admin™".to_string()), packet); + assert_eq!(Packet::disconnect("/admin™"), packet); } #[test] fn packet_encode_disconnect() { let payload = "1".to_string(); - let packet: String = Packet::disconnect("/".to_string()).try_into().unwrap(); + let packet: String = Packet::disconnect("/").try_into().unwrap(); assert_eq!(packet, payload); let payload = "1/admin™,".to_string(); - let packet: String = Packet::disconnect("/admin™".to_string()) - .try_into() - .unwrap(); + let packet: String = Packet::disconnect("/admin™").try_into().unwrap(); assert_eq!(packet, payload); } @@ -507,11 +569,7 @@ mod test { let packet = Packet::try_from(payload).unwrap(); assert_eq!( - Packet::event( - "/".to_string(), - "event".to_string(), - json!([{"data": "value"}]) - ), + Packet::event("/", "event", json!([{"data": "value"}])), packet ); @@ -519,11 +577,7 @@ mod test { let payload = format!("21{}", json!(["event", { "data": "value" }])); let packet = Packet::try_from(payload).unwrap(); - let mut comparison_packet = Packet::event( - "/".to_string(), - "event".to_string(), - json!([{"data": "value"}]), - ); + let mut comparison_packet = Packet::event("/", "event", json!([{"data": "value"}])); comparison_packet.inner.set_ack_id(1); assert_eq!(packet, comparison_packet); @@ -532,11 +586,7 @@ mod test { let packet = Packet::try_from(payload).unwrap(); assert_eq!( - Packet::event( - "/admin™".to_string(), - "event".to_string(), - json!([{"data": "value™"}]) - ), + Packet::event("/admin™", "event", json!([{"data": "value™"}])), packet ); @@ -545,11 +595,7 @@ mod test { let mut packet = Packet::try_from(payload).unwrap(); packet.inner.set_ack_id(1); - let mut comparison_packet = Packet::event( - "/admin™".to_string(), - "event".to_string(), - json!([{"data": "value™"}]), - ); + let mut comparison_packet = Packet::event("/admin™", "event", json!([{"data": "value™"}])); comparison_packet.inner.set_ack_id(1); assert_eq!(packet, comparison_packet); @@ -558,23 +604,15 @@ mod test { #[test] fn packet_encode_event() { let payload = format!("2{}", json!(["event", { "data": "value™" }])); - let packet: String = Packet::event( - "/".to_string(), - "event".to_string(), - json!({ "data": "value™" }), - ) - .try_into() - .unwrap(); + let packet: String = Packet::event("/", "event", json!({ "data": "value™" })) + .try_into() + .unwrap(); assert_eq!(packet, payload); // Encode with ack ID let payload = format!("21{}", json!(["event", { "data": "value™" }])); - let mut packet = Packet::event( - "/".to_string(), - "event".to_string(), - json!({ "data": "value™" }), - ); + let mut packet = Packet::event("/", "event", json!({ "data": "value™" })); packet.inner.set_ack_id(1); let packet: String = packet.try_into().unwrap(); @@ -582,23 +620,15 @@ mod test { // Encode with NS let payload = format!("2/admin™,{}", json!(["event", { "data": "value™" }])); - let packet: String = Packet::event( - "/admin™".to_string(), - "event".to_string(), - json!([{"data": "value™"}]), - ) - .try_into() - .unwrap(); + let packet: String = Packet::event("/admin™", "event", json!([{"data": "value™"}])) + .try_into() + .unwrap(); assert_eq!(packet, payload); // Encode with NS and ack ID let payload = format!("2/admin™,1{}", json!(["event", { "data": "value™" }])); - let mut packet = Packet::event( - "/admin™".to_string(), - "event".to_string(), - json!([{"data": "value™"}]), - ); + let mut packet = Packet::event("/admin™", "event", json!([{"data": "value™"}])); packet.inner.set_ack_id(1); let packet: String = packet.try_into().unwrap(); assert_eq!(packet, payload); @@ -610,84 +640,55 @@ mod test { let payload = "354[\"data\"]".to_string(); let packet = Packet::try_from(payload).unwrap(); - assert_eq!(Packet::ack("/".to_string(), json!(["data"]), 54), packet); + assert_eq!(Packet::ack("/", json!(["data"]), 54), packet); let payload = "3/admin™,54[\"data\"]".to_string(); let packet = Packet::try_from(payload).unwrap(); - assert_eq!( - Packet::ack("/admin™".to_string(), json!(["data"]), 54), - packet - ); + assert_eq!(Packet::ack("/admin™", json!(["data"]), 54), packet); } #[test] fn packet_encode_event_ack() { let payload = "354[\"data\"]".to_string(); - let packet: String = Packet::ack("/".to_string(), json!("data"), 54) - .try_into() - .unwrap(); + let packet: String = Packet::ack("/", json!("data"), 54).try_into().unwrap(); assert_eq!(packet, payload); let payload = "3/admin™,54[\"data\"]".to_string(); - let packet: String = Packet::ack("/admin™".to_string(), json!("data"), 54) + let packet: String = Packet::ack("/admin™", json!("data"), 54) .try_into() .unwrap(); assert_eq!(packet, payload); } - // ConnectError(ConnectErrorPacket), - #[test] - fn packet_decode_connect_error() { - let payload = format!("4{}", json!({ "message": "Invalid namespace" })); - let packet = Packet::try_from(payload).unwrap(); - - assert_eq!(Packet::invalid_namespace("/".to_string()), packet); - - let payload = format!("4/admin™,{}", json!({ "message": "Invalid namespace" })); - let packet = Packet::try_from(payload).unwrap(); - - assert_eq!(Packet::invalid_namespace("/admin™".to_string()), packet); - } #[test] fn packet_encode_connect_error() { let payload = format!("4{}", json!({ "message": "Invalid namespace" })); - let packet: String = Packet::invalid_namespace("/".to_string()) - .try_into() - .unwrap(); + let packet: String = Packet::invalid_namespace("/").try_into().unwrap(); assert_eq!(packet, payload); let payload = format!("4/admin™,{}", json!({ "message": "Invalid namespace" })); - let packet: String = Packet::invalid_namespace("/admin™".to_string()) - .try_into() - .unwrap(); + let packet: String = Packet::invalid_namespace("/admin™").try_into().unwrap(); assert_eq!(packet, payload); } + // BinaryEvent(String, BinaryPacket, Option), #[test] fn packet_encode_binary_event() { let json = json!(["event", { "data": "value™" }, { "_placeholder": true, "num": 0}]); let payload = format!("51-{}", json); - let packet: String = Packet::bin_event( - "/".to_string(), - "event".to_string(), - json!({ "data": "value™" }), - vec![vec![1]], - ) - .try_into() - .unwrap(); + let packet: String = + Packet::bin_event("/", "event", json!({ "data": "value™" }), vec![vec![1]]) + .try_into() + .unwrap(); assert_eq!(packet, payload); // Encode with ack ID let payload = format!("51-254{}", json); - let mut packet = Packet::bin_event( - "/".to_string(), - "event".to_string(), - json!({ "data": "value™" }), - vec![vec![1]], - ); + let mut packet = + Packet::bin_event("/", "event", json!({ "data": "value™" }), vec![vec![1]]); packet.inner.set_ack_id(254); let packet: String = packet.try_into().unwrap(); @@ -696,8 +697,8 @@ mod test { // Encode with NS let payload = format!("51-/admin™,{}", json); let packet: String = Packet::bin_event( - "/admin™".to_string(), - "event".to_string(), + "/admin™", + "event", json!([{"data": "value™"}]), vec![vec![1]], ) @@ -709,8 +710,8 @@ mod test { // Encode with NS and ack ID let payload = format!("51-/admin™,254{}", json); let mut packet = Packet::bin_event( - "/admin™".to_string(), - "event".to_string(), + "/admin™", + "event", json!([{"data": "value™"}]), vec![vec![1]], ); @@ -724,7 +725,7 @@ mod test { let json = json!(["event", { "data": "value™" }, { "_placeholder": true, "num": 0}]); let comparison_packet = |ack, ns: &'static str| Packet { inner: PacketData::BinaryEvent( - "event".to_string(), + "event".into(), BinaryPacket { bin: vec![vec![1]], data: json!([{"data": "value™"}]), @@ -780,27 +781,18 @@ mod test { let json = json!([{ "data": "value™" }, { "_placeholder": true, "num": 0}]); let payload = format!("61-54{}", json); - let packet: String = Packet::bin_ack( - "/".to_string(), - json!({ "data": "value™" }), - vec![vec![1]], - 54, - ) - .try_into() - .unwrap(); + let packet: String = Packet::bin_ack("/", json!({ "data": "value™" }), vec![vec![1]], 54) + .try_into() + .unwrap(); assert_eq!(packet, payload); // Encode with NS let payload = format!("61-/admin™,54{}", json); - let packet: String = Packet::bin_ack( - "/admin™".to_string(), - json!({ "data": "value™" }), - vec![vec![1]], - 54, - ) - .try_into() - .unwrap(); + let packet: String = + Packet::bin_ack("/admin™", json!({ "data": "value™" }), vec![vec![1]], 54) + .try_into() + .unwrap(); assert_eq!(packet, payload); } @@ -839,4 +831,50 @@ mod test { assert_eq!(packet, comparison_packet(54, "/admin™")); } + + #[test] + fn packet_size_hint() { + let sid = Sid::new(); + let len = serde_json::to_string(&ConnectPacket { sid }).unwrap().len(); + let packet = Packet::connect("/", sid, ProtocolVersion::V5); + assert_eq!(packet.get_size_hint(), len + 1); + + let packet = Packet::connect("/admin", sid, ProtocolVersion::V5); + assert_eq!(packet.get_size_hint(), len + 8); + + let packet = Packet::connect("admin", sid, ProtocolVersion::V4); + assert_eq!(packet.get_size_hint(), 8); + + let packet = Packet::disconnect("/"); + assert_eq!(packet.get_size_hint(), 1); + + let packet = Packet::disconnect("/admin"); + assert_eq!(packet.get_size_hint(), 8); + + let packet = Packet::event("/", "event", json!({ "data": "value™" })); + assert_eq!(packet.get_size_hint(), 1); + + let packet = Packet::event("/admin", "event", json!({ "data": "value™" })); + assert_eq!(packet.get_size_hint(), 8); + + let packet = Packet::ack("/", json!("data"), 54); + assert_eq!(packet.get_size_hint(), 3); + + let packet = Packet::ack("/admin", json!("data"), 54); + assert_eq!(packet.get_size_hint(), 10); + + let packet = Packet::bin_event("/", "event", json!({ "data": "value™" }), vec![vec![1]]); + assert_eq!(packet.get_size_hint(), 3); + + let packet = Packet::bin_event( + "/admin", + "event", + json!({ "data": "value™" }), + vec![vec![1]], + ); + assert_eq!(packet.get_size_hint(), 10); + + let packet = Packet::bin_ack("/", json!("data"), vec![vec![1]], 54); + assert_eq!(packet.get_size_hint(), 5); + } } diff --git a/socketioxide/src/service.rs b/socketioxide/src/service.rs index 3b7e64ab..478b9850 100644 --- a/socketioxide/src/service.rs +++ b/socketioxide/src/service.rs @@ -46,20 +46,17 @@ impl SocketIoService { /// Create a new [`EngineIoService`] with a custom inner service and a custom config. pub fn with_config_inner(inner: S, config: Arc) -> (Self, Arc>) { - let client = Arc::new(Client::new(config.clone())); - let svc = - EngineIoService::with_config_inner(inner, client.clone(), config.engine_config.clone()); + let engine_config = config.engine_config.clone(); + let client = Arc::new(Client::new(config)); + let svc = EngineIoService::with_config_inner(inner, client.clone(), engine_config); (Self { engine_svc: svc }, client) } /// Create a new [`EngineIoService`] with a custom inner service and an existing client /// It is mainly used with a [`SocketIoLayer`](crate::layer::SocketIoLayer) that owns the client pub fn with_client(inner: S, client: Arc>) -> Self { - let svc = EngineIoService::with_config_inner( - inner, - client.clone(), - client.config.engine_config.clone(), - ); + let engine_config = client.config.engine_config.clone(); + let svc = EngineIoService::with_config_inner(inner, client, engine_config); Self { engine_svc: svc } } } diff --git a/socketioxide/src/socket.rs b/socketioxide/src/socket.rs index 3db94e42..f2035572 100644 --- a/socketioxide/src/socket.rs +++ b/socketioxide/src/socket.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Cow, collections::HashMap, fmt::Debug, sync::Mutex, @@ -242,7 +243,7 @@ impl Socket { event: impl Into, data: impl Serialize, ) -> Result<(), serde_json::Error> { - let ns = self.ns.path.clone(); + let ns = self.ns(); let data = serde_json::to_value(data)?; if let Err(_e) = self.send(Packet::event(ns, event.into(), data)) { #[cfg(feature = "tracing")] @@ -276,9 +277,9 @@ impl Socket { where V: DeserializeOwned + Send + Sync + 'static, { - let ns = self.ns.path.clone(); + let ns = self.ns(); let data = serde_json::to_value(data)?; - let packet = Packet::event(ns, event.into(), data); + let packet = Packet::event(Cow::Borrowed(ns), event.into(), data); self.send_with_ack(packet, None).await } @@ -459,7 +460,7 @@ impl Socket { /// /// It will also call the disconnect handler if it is set. pub fn disconnect(self: Arc) -> Result<(), SendError> { - self.send(Packet::disconnect(self.ns.path.clone()))?; + self.send(Packet::disconnect(&self.ns.path))?; self.close(DisconnectReason::ServerNSDisconnect)?; Ok(()) } @@ -476,7 +477,7 @@ impl Socket { } /// Get the current namespace path. - pub fn ns(&self) -> &String { + pub fn ns(&self) -> &str { &self.ns.path } @@ -499,9 +500,9 @@ impl Socket { Ok(()) } - pub(crate) async fn send_with_ack( + pub(crate) async fn send_with_ack<'a, V: DeserializeOwned>( &self, - mut packet: Packet, + mut packet: Packet<'a>, timeout: Option, ) -> Result, AckError> { let (tx, rx) = oneshot::channel(); @@ -529,9 +530,9 @@ impl Socket { // Receive data from client: pub(crate) fn recv(self: Arc, packet: PacketData) -> Result<(), Error> { match packet { - PacketData::Event(e, data, ack) => self.recv_event(e, data, ack), + PacketData::Event(e, data, ack) => self.recv_event(&e, data, ack), PacketData::EventAck(data, ack_id) => self.recv_ack(data, ack_id), - PacketData::BinaryEvent(e, packet, ack) => self.recv_bin_event(e, packet, ack), + PacketData::BinaryEvent(e, packet, ack) => self.recv_bin_event(&e, packet, ack), PacketData::BinaryAck(packet, ack) => self.recv_bin_ack(packet, ack), PacketData::Disconnect => self .close(DisconnectReason::ClientNSDisconnect) @@ -540,8 +541,8 @@ impl Socket { } } - fn recv_event(self: Arc, e: String, data: Value, ack: Option) -> Result<(), Error> { - if let Some(handler) = self.message_handlers.read().unwrap().get(&e) { + fn recv_event(self: Arc, e: &str, data: Value, ack: Option) -> Result<(), Error> { + if let Some(handler) = self.message_handlers.read().unwrap().get(e) { handler.call(self.clone(), data, vec![], ack)?; } Ok(()) @@ -549,11 +550,11 @@ impl Socket { fn recv_bin_event( self: Arc, - e: String, + e: &str, packet: BinaryPacket, ack: Option, ) -> Result<(), Error> { - if let Some(handler) = self.message_handlers.read().unwrap().get(&e) { + if let Some(handler) = self.message_handlers.read().unwrap().get(e) { handler.call(self.clone(), packet.data, packet.bin, ack)?; } Ok(())