diff --git a/Cargo.toml b/Cargo.toml index e73fbfa..f7d7935 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ socks = ["dep:socks"] [dependencies] bitcoin = "0.29.2" crossbeam-channel = "0.5.7" +intmap = "2.0.0" log = "0.4.17" mio = { version = "0.8.6", features = ["net", "os-poll"] } slab = "0.4.8" diff --git a/src/peer.rs b/src/peer.rs index 1ef7fe9..79f62cf 100644 --- a/src/peer.rs +++ b/src/peer.rs @@ -1,26 +1,15 @@ -/// Unique peer identifier. The user should assume these can be reused by different peers as -/// peers come and go, i.e. they are not assigned just once for the lifetime of the process. +/// Unique peer identifier. These are unique for the lifetime of the process and strictly +/// incrementing for each new connection. Even if the same peer (in terms of socket address) +/// connects multiple times, a new `PeerId` instance will be issued for each connection. #[derive(Debug, Clone, Hash, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub struct PeerId(pub usize); +pub struct PeerId(pub u64); impl PeerId { - pub fn value(&self) -> usize { + pub fn value(&self) -> u64 { self.0 } } -impl From for PeerId { - fn from(token: mio::Token) -> Self { - Self(token.0) - } -} - -impl From for mio::Token { - fn from(id: PeerId) -> Self { - Self(id.0) - } -} - impl std::fmt::Display for PeerId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!("{}", self.0)) diff --git a/src/reactor.rs b/src/reactor.rs index bf0ab06..5de9f94 100644 --- a/src/reactor.rs +++ b/src/reactor.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use bitcoin::network::message::RawNetworkMessage; +use intmap::IntMap; use mio::net::{TcpListener, TcpStream}; use mio::{Events, Interest, Poll, Registry, Token, Waker}; use slab::Slab; @@ -35,7 +36,9 @@ pub enum Command { pub enum Event { /// The reactor attempted to connect to a remote peer. ConnectedTo { + /// The socket address that was connected to. addr: SocketAddr, + /// The result of the connection attempt. result: io::Result, }, /// The reactor received a connection from a remote peer. @@ -220,6 +223,12 @@ impl Handle { } } +/// Contains a stream along with its peer id. +struct Entry { + stream: MessageStream, + peer_id: PeerId, +} + /// Runs the reactor in a loop until an error is produced or a shutdown command is received. fn run( Reactor { @@ -248,10 +257,13 @@ fn run( }) .collect::>>()?; - let mut streams: Slab> = Slab::with_capacity(16); + let mut streams: Slab = Slab::with_capacity(16); let mut events = Events::with_capacity(1024); let mut read_buf = [0; 1024 * 1024]; let mut last_maintenance = Instant::now(); + let mut token_map: IntMap = IntMap::new(); + let mut next_peer_id: u64 = 0; + let mut remove_stale: Vec = Vec::with_capacity(16); loop { poll.poll(&mut events, Some(Duration::from_secs(5)))?; @@ -275,6 +287,8 @@ fn run( let peer = add_stream( poll.registry(), &mut streams, + &mut token_map, + &mut next_peer_id, stream, config.stream_config.clone(), )?; @@ -295,12 +309,18 @@ fn run( let _ = sender.send(Event::ConnectedTo { addr, result }); } - Command::Disconnect(peer) => match streams.try_remove(peer.value()) { - Some(mut stream) => { - poll.registry().deregister(stream.inner_mut())?; + Command::Disconnect(peer) => { + if token_map.contains_key(peer.value()) { + let mut entry = remove_stream( + poll.registry(), + &mut streams, + &mut token_map, + peer, + )?; - let _ = write(&mut stream, now); - let _ = stream.inner_mut().shutdown(std::net::Shutdown::Both); + let _ = write(&mut entry.stream, now); + let _ = + entry.stream.inner_mut().shutdown(std::net::Shutdown::Both); log::info!("peer {peer}: disconnected"); @@ -308,40 +328,40 @@ fn run( peer, reason: DisconnectReason::Requested, }); - } - None => { + } else { let _ = sender.send(Event::NoPeer(peer)); log::warn!("disconnect: peer {} not found", peer.value()); } - }, + } - Command::Message(peer, message) => { - match streams.get_mut(peer.value()) { - Some(stream) => { - if stream.queue_message(&message) { - poll.registry().reregister( - stream.inner_mut(), - peer.into(), - Interest::READABLE | Interest::WRITABLE, - )?; - } else { - let _ = sender - .send(Event::SendBufferFull { peer, message }); - log::warn!("send buffer for peer {peer} is full"); - } + Command::Message(peer, message) => match token_map.get(peer.value()) { + Some(token) => { + let entry = streams.get_mut(token.0).expect("must exist here"); + + if entry.stream.queue_message(&message) { + poll.registry().reregister( + entry.stream.inner_mut(), + *token, + Interest::READABLE | Interest::WRITABLE, + )?; + } else { + let _ = + sender.send(Event::SendBufferFull { peer, message }); + log::warn!("send buffer for peer {peer} is full"); } + } - None => { - let _ = sender.send(Event::NoPeer(peer)); - log::warn!("message: peer {} not found", peer.value()); - } + None => { + let _ = sender.send(Event::NoPeer(peer)); + log::warn!("message: peer {} not found", peer.value()); } - } + }, Command::Shutdown => { - for (id, mut stream) in streams { - let _ = write(&mut stream, now); - let r = stream.inner_mut().shutdown(std::net::Shutdown::Both); + for (id, mut entry) in streams { + let _ = write(&mut entry.stream, now); + let r = + entry.stream.inner_mut().shutdown(std::net::Shutdown::Both); log::debug!("shut down stream {}: {:?}", id, r); } @@ -365,6 +385,8 @@ fn run( let peer = add_stream( poll.registry(), &mut streams, + &mut token_map, + &mut next_peer_id, stream, config.stream_config.clone(), )?; @@ -382,10 +404,10 @@ fn run( } } - (token, Some(stream)) => { - let peer = token.into(); + (token, Some(entry)) => { + let peer = entry.peer_id; - if !stream.is_ready() { + if !entry.stream.is_ready() { log::trace!("peer: {peer}: stream not ready"); continue; } @@ -394,10 +416,10 @@ fn run( log::trace!("peer {peer}: readable"); 'read: loop { - let read_result = stream.read(&mut read_buf); + let read_result = entry.stream.read(&mut read_buf); 'decode: loop { - match stream.receive_message() { + match entry.stream.receive_message() { Ok(message) => { log::debug!("peer {peer}: rx message: {}", message.cmd()); @@ -410,7 +432,12 @@ fn run( ) => { log::info!("peer {peer}: codec violation"); - remove_stream(poll.registry(), &mut streams, peer)?; + remove_stream( + poll.registry(), + &mut streams, + &mut token_map, + peer, + )?; let _ = sender.send(Event::Disconnected { peer, @@ -428,7 +455,12 @@ fn run( Ok(0) => { log::debug!("peer {peer}: peer left"); - remove_stream(poll.registry(), &mut streams, peer)?; + remove_stream( + poll.registry(), + &mut streams, + &mut token_map, + peer, + )?; let _ = sender.send(Event::Disconnected { peer, @@ -445,7 +477,12 @@ fn run( Err(err) => { log::warn!("peer {peer}: IO error: {err}"); - remove_stream(poll.registry(), &mut streams, peer)?; + remove_stream( + poll.registry(), + &mut streams, + &mut token_map, + peer, + )?; let _ = sender.send(Event::Disconnected { peer, @@ -461,11 +498,14 @@ fn run( if event.is_writable() { log::trace!("peer {peer}: writable"); - match write(stream, now) { + match write(&mut entry.stream, now) { Ok(()) => { - let interests = choose_interest(stream); - poll.registry() - .reregister(stream.inner_mut(), token, interests)?; + let interests = choose_interest(&entry.stream); + poll.registry().reregister( + entry.stream.inner_mut(), + token, + interests, + )?; } Err(err) if would_block(&err) => {} @@ -473,7 +513,7 @@ fn run( Err(err) => { log::warn!("peer {peer}: IO error: {err}"); - remove_stream(poll.registry(), &mut streams, peer)?; + remove_stream(poll.registry(), &mut streams, &mut token_map, peer)?; let _ = sender.send(Event::Disconnected { peer, @@ -495,28 +535,27 @@ fn run( } // stale stream removal - streams.retain(|token, stream| { - if stream.is_write_stale(now) { - let peer = PeerId(token); - log::info!("removing stale peer {peer}"); + remove_stale.extend( + streams + .iter() + .filter_map(|(_, entry)| entry.stream.is_write_stale(now).then_some(entry.peer_id)), + ); - poll.registry().deregister(stream.inner_mut()).unwrap(); + for peer in remove_stale.drain(..) { + log::info!("removing stale peer {peer}"); - let _ = sender.send(Event::Disconnected { - peer, - reason: DisconnectReason::WriteStale, - }); + remove_stream(poll.registry(), &mut streams, &mut token_map, peer)?; - false - } else { - true - } - }); + let _ = sender.send(Event::Disconnected { + peer, + reason: DisconnectReason::WriteStale, + }); + } // periodic buffer resize if (now - last_maintenance).as_secs() > 30 { - for (_, stream) in &mut streams { - stream.resize_buffers(); + for (_, entry) in &mut streams { + entry.stream.resize_buffers(); } last_maintenance = now; @@ -617,29 +656,44 @@ fn write(stream: &mut MessageStream, now: Instant) -> io::Result<()> /// Registers a peer with the poll and adds him to the stream list. fn add_stream( registry: &Registry, - streams: &mut Slab>, + streams: &mut Slab, + token_map: &mut IntMap, + next_peer_id: &mut u64, mut stream: TcpStream, stream_cfg: message_stream::StreamConfig, ) -> std::io::Result { let token = Token(streams.vacant_key()); + let peer_id = *next_peer_id; registry.register(&mut stream, token, Interest::READABLE)?; - streams.insert(MessageStream::new(stream, stream_cfg)); - Ok(token.into()) + let prev_mapping = token_map.insert(peer_id, token); + assert!(prev_mapping.is_none()); + + streams.insert(Entry { + stream: MessageStream::new(stream, stream_cfg), + peer_id: PeerId(peer_id), + }); + + *next_peer_id += 1; + + Ok(PeerId(peer_id)) } /// Deregisters a peer from the poll and removes him from the stream list. fn remove_stream( registry: &Registry, - streams: &mut Slab>, + streams: &mut Slab, + token_map: &mut IntMap, peer: PeerId, -) -> std::io::Result<()> { - let mut stream = streams.remove(peer.value()); +) -> std::io::Result { + let token = token_map.remove(peer.0).expect("must exist here"); - registry.deregister(stream.inner_mut())?; + let mut entry = streams.remove(token.0); - Ok(()) + registry.deregister(entry.stream.inner_mut())?; + + Ok(entry) } /// Checks if the token is associated with the server (connection listener). @@ -696,7 +750,4 @@ mod test { assert!(is_listener(3, Token(WAKE_TOKEN.0 - 3))); assert!(!is_listener(3, Token(WAKE_TOKEN.0 - 4))); } - - #[test] - fn connection_slot_open() {} } diff --git a/tests/interact.rs b/tests/interact.rs index 0a9ef48..9772c92 100644 --- a/tests/interact.rs +++ b/tests/interact.rs @@ -115,7 +115,7 @@ fn many_to_one_interleaved() { client_reactor.run(); let (client_peer, _) = connect(&client, &server, server_addr); - assert_eq!(i, client_peer.0); + assert_eq!(i as u64, client_peer.0); client }) @@ -125,7 +125,7 @@ fn many_to_one_interleaved() { for (i, c) in clients.iter().enumerate() { message(c, PeerId(0), NetworkMessage::Ping(nonce)); let ping_from_peer = expect_ping(server.receive().unwrap(), nonce); - assert_eq!(ping_from_peer, PeerId(i)); + assert_eq!(ping_from_peer, PeerId(i as u64)); message(&server, ping_from_peer, NetworkMessage::Pong(nonce)); expect_pong(c.receive().unwrap(), nonce); } @@ -163,7 +163,7 @@ fn many_to_one_bulk() { client_reactor.run(); let (client_peer, _) = connect(&client, &server, server_addr); - assert_eq!(i, client_peer.0); + assert_eq!(i as u64, client_peer.0); client }) @@ -231,6 +231,38 @@ fn very_large() { } } +/// Starts one server and sequentally connects and disconnects clients to it. Verifies that each +/// peer id from the perspective of the server is unique, i.e. that there is no peer id reuse. +#[test] +fn peer_id_increments() { + let _ = env_logger::builder().is_test(true).try_init(); + + let server_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8106).into(); + + let server_config = Config { + bind_addr: vec![server_addr], + ..Default::default() + }; + + let (server_reactor, server) = Reactor::new(server_config).unwrap(); + server_reactor.run(); + + std::thread::sleep(std::time::Duration::from_millis(10)); + + let n_clients = 5; + + for i in 0..n_clients { + let (client_reactor, client) = Reactor::new(Config::default()).unwrap(); + client_reactor.run(); + + let (client_peer, _) = connect(&client, &server, server_addr); + assert_eq!(i as u64, client_peer.0); + + let client_that_left = disconnect(&client, &server, PeerId(0)); + assert_eq!(client_that_left, client_peer); + } +} + #[allow(dead_code)] struct Scaffold { server: Handle, @@ -295,6 +327,33 @@ fn connect(client: &Handle, server: &Handle, server_addr: SocketAddr) -> (PeerId (client_peer, server_peer) } +fn disconnect(client: &Handle, server: &Handle, server_peer: PeerId) -> PeerId { + client.send(Command::Disconnect(server_peer)).unwrap(); + + let client_that_left = match server.receive().unwrap() { + peerlink::Event::Disconnected { + peer, + reason: DisconnectReason::Left, + } => { + println!("server: client has disconnected"); + peer + } + _ => panic!(), + }; + + match client.receive().unwrap() { + peerlink::Event::Disconnected { + peer, + reason: DisconnectReason::Requested, + } if peer == server_peer => { + println!("client: disconnected from server"); + } + _ => panic!(), + }; + + client_that_left +} + fn message(handle: &Handle, peer: PeerId, message: NetworkMessage) { handle .send(Command::Message(